Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions rtp_llm/async_decoder_engine/base_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,43 @@ def pause(self) -> None:
def restart(self) -> None:
"""Restarts the engine's execution."""
raise NotImplementedError()

@abstractmethod
def detach_physical_memory(self) -> bool:
"""
Release physical GPU memory while retaining the virtual address space.
This method is intended for engines that support virtual memory. It
immediately unmaps and frees all **physical** backing memory without
releasing the reserved **virtual** addresses. If any requests are still
in flight, the engine **must** wait for them to complete before
performing the detach operation.
Returns
-------
bool
``True`` – physical memory was successfully released.
``False`` – the engine does not support virtual memory **or** the
detach operation failed.
Notes
-----
After a successful detach, the virtual addresses remain valid but
accessing them will raise a device page-fault until
:meth:`attach_physical_memory` is called.
"""
raise NotImplementedError()

@abstractmethod
def attach_physical_memory(self) -> bool:
"""
Re-attach / map physical memory to previously reserved virtual addresses.
For every virtual address range that was **reserved but not mapped**
(e.g., after :meth:`detach_physical_memory`), this method allocates
physical GPU memory and binds it to those ranges. Virtual addresses that
already have physical backing are **not** re-allocated.
Returns
-------
bool
``True`` – physical memory was successfully (re-)mapped.
``False`` – the engine lacks virtual-memory support **or** the
mapping operation failed.
"""
raise NotImplementedError()
8 changes: 8 additions & 0 deletions rtp_llm/async_decoder_engine/rpc_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,11 @@ def pause(self) -> None:
@override
def restart(self) -> None:
return self.rtp_llm_op_.restart()

@override
def detach_physical_memory(self) -> bool:
return self.rtp_llm_op_.detach_physical_memory()

@override
def attach_physical_memory(self) -> bool:
return self.rtp_llm_op_.attach_physical_memory()
24 changes: 24 additions & 0 deletions rtp_llm/cpp/core/TrackerAllocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,28 @@ std::vector<MemoryChunk*> TrackerAllocator::getChunks() const {
return {};
}

void TrackerAllocator::map() {
if (auto allocator_ = dynamic_cast<IVirtualMemAllocator*>(real_allocator_)) {
return allocator_->map();
} else {
return map(); // this will throw not impl.
}
}

void TrackerAllocator::unmap() {
if (auto allocator_ = dynamic_cast<IVirtualMemAllocator*>(real_allocator_)) {
return allocator_->unmap();
} else {
return unmap(); // this will throw not impl.
}
}

void* TrackerAllocator::mallocPhysical(size_t size) {
if (auto allocator_ = dynamic_cast<IVirtualMemAllocator*>(real_allocator_)) {
return allocator_->mallocPhysical(size);
} else {
return real_allocator_->malloc(size);
}
}

} // namespace rtp_llm
6 changes: 5 additions & 1 deletion rtp_llm/cpp/core/TrackerAllocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct TrackerAllocatorParams {
size_t align_size = 1024;
};

class TrackerAllocator: public IAllocator {
class TrackerAllocator: public IVirtualMemAllocator {
public:
TrackerAllocator(const TrackerAllocatorParams& params);
~TrackerAllocator();
Expand All @@ -29,6 +29,10 @@ class TrackerAllocator: public IAllocator {

TrackerStatus getTrackerStatus() const;

void* mallocPhysical(size_t size) override;
void map() override;
void unmap() override;

private:
std::vector<MemoryChunk*> getChunks() const;
friend class DeviceBase;
Expand Down
22 changes: 22 additions & 0 deletions rtp_llm/cpp/core/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,28 @@ class IAllocator {
virtual void* mallocPrivate(size_t size);
};

/**
@brief Interface for allocators that allocate virtual address ranges and
optionally bind / unbind physical memory to them.
This interface is intended for virtual memory managers that need to:
reserve large contiguous virtual address regions up-front,
attach / detach physical backing (BAR, VRAM, etc.) on demand,
support over-commitment or memory oversubscription scenarios.
*/
class IVirtualMemAllocator: virtual public IAllocator {
public:
/// @brief Maps physical memory to the virtual address ranges owned by this
/// allocator (idempotent).
virtual void map() = 0;
/// @brief Unmaps physical memory from the virtual address ranges owned by
/// this allocator, without releasing the virtual addresses.
virtual void unmap() = 0;

/// @brief Allocates a block of actual physical memory that does not undergo virtual memory mapping.
/// The map and unmap methods will not be able to release memory allocated via mallocPhysical.
virtual void* mallocPhysical(size_t size) = 0;
};

template<AllocatorType AllocType_>
class TypedAllocator: virtual public IAllocator {
public:
Expand Down
157 changes: 149 additions & 8 deletions rtp_llm/cpp/cuda/allocator_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,32 +89,173 @@ void PurePointerCudaAllocator::free(void** ptr) {
return;
}

Allocator<AllocatorType::CUDA>::Allocator(int device_id): PurePointerCudaAllocator(device_id) {}
Allocator<AllocatorType::CUDA>::Allocator(int device_id):
PurePointerCudaAllocator(device_id),
IVirtualMemAllocator() {
pointer_mapping_ = std::make_unique<std::unordered_map<CUdeviceptr, VmemBlock>>();
}

Allocator<AllocatorType::CUDA>::~Allocator() {
destroy();
}

void* Allocator<AllocatorType::CUDA>::mallocPhysical(size_t size) {
RTP_LLM_LOG_DEBUG("malloc physical memory with size %lu\n", size);
auto address = doMallocSync(size);
if (!address) {
return nullptr;
}
CUdeviceptr dptr = reinterpret_cast<CUdeviceptr>(address);
std::lock_guard<std::mutex> lock(lock_);
auto it = pointer_mapping_->find(dptr);
if (it == pointer_mapping_->end()) {
RTP_LLM_LOG_ERROR("Unexpected allocation, pointer mapping missing.");
return address;
}
auto& block = it->second;
block.pin = true;
block.mapped = true;
return address;
}

void* Allocator<AllocatorType::CUDA>::doMalloc(size_t size) {
void* ptr = nullptr;
check_cuda_value(cudaMalloc(&ptr, (size_t)(ceil(size / 128.)) * 128));
return ptr;
RTP_LLM_LOG_DEBUG("Malloc virtual memory with size %lu\n", size);
size_t granularity = 0;
std::lock_guard<std::mutex> lock(lock_);

CUmemAllocationProp prop{};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = device_id_;

check_cuda_value(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));

if (granularity > 0) {
const size_t padded_size = (size + granularity - 1) / granularity * granularity;

RTP_LLM_LOG_DEBUG("Malloc virtual memory with padded size %lu\n", padded_size);

// 1. 先保留虚拟地址
CUdeviceptr dptr = 0;
check_cuda_value(cuMemAddressReserve(&dptr, padded_size, 0, 0, 0));

// 2. 创建物理显存
CUmemGenericAllocationHandle handle{};
check_cuda_value(cuMemCreate(&handle, padded_size, &prop, 0));

// 3. 映射
check_cuda_value(cuMemMap(dptr, padded_size, 0, handle, 0));

// 4. 设置访问权限
CUmemAccessDesc access{};
access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
access.location.id = device_id_;
access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
check_cuda_value(cuMemSetAccess(dptr, padded_size, &access, 1));

(*pointer_mapping_)[dptr] = {false, true, padded_size, handle};
return reinterpret_cast<void*>(dptr);

} else {
RTP_LLM_LOG_ERROR("Get system granularity failed\n");
return nullptr;
}
}

void* Allocator<AllocatorType::CUDA>::doMallocSync(size_t size) {
void* ptr = nullptr;
check_cuda_value(cudaMalloc(&ptr, (size_t)(ceil(size / 128.)) * 128));
return ptr;
return doMalloc(size);
}

void Allocator<AllocatorType::CUDA>::doFree(void* address) {
// tmp sync to avoid memory free before kernel run. cudaFree will not perform any implicit synchronization when the
// pointer was allocated with cudaMallocAsync or cudaMallocFromPoolAsync
cudaStreamSynchronize(stream_);
check_cuda_value(cudaFree(address));
if (!address) {
RTP_LLM_LOG_WARNING("Try to free an empty pointer\n");
return;
}

CUdeviceptr dptr = reinterpret_cast<CUdeviceptr>(address);

std::lock_guard<std::mutex> lock(lock_);
auto it = pointer_mapping_->find(dptr);
if (it == pointer_mapping_->end()) {
RTP_LLM_LOG_ERROR("Free Pointer Failed, Pointer is not managed by this alloctor %p\n", address);
return;
}

RTP_LLM_LOG_DEBUG("Vmem allocator free pointer %p\n", address);
// tmp sync to avoid memory free before kernel run. cudaFree will not perform any implicit synchronization when the
// pointer was allocated with cudaMallocAsync or cudaMallocFromPoolAsync
cudaStreamSynchronize(stream_);
const auto& block = it->second;
check_cuda_value(cuMemUnmap(dptr, block.size));
check_cuda_value(cuMemRelease(block.handle));
check_cuda_value(cuMemAddressFree(dptr, block.size));

RTP_LLM_LOG_DEBUG("Vmem allocator free pointer %p successfully\n", address);
return;
}

void Allocator<AllocatorType::CUDA>::unmap() {
std::lock_guard<std::mutex> lock(lock_);
RTP_LLM_LOG_INFO("Vmem allocator unmap all allocated buffer\n");

for (auto& [dptr, block] : *pointer_mapping_) {

if (block.pin || !block.mapped) {
continue;
}
RTP_LLM_LOG_INFO("Vmem allocator unmap %p[%lu]\n", dptr, block.size);

// 1. 解除映射
check_cuda_value(cuMemUnmap(dptr, block.size));
// 2. 释放对应的物理显存
check_cuda_value(cuMemRelease(block.handle));
block.mapped = false;
}
}

void Allocator<AllocatorType::CUDA>::map() {
std::lock_guard<std::mutex> lock(lock_);

RTP_LLM_LOG_INFO("Vmem allocator map all allocated buffer\n");

for (auto& [dptr, block] : *pointer_mapping_) {

if (block.pin || block.mapped) {
continue;
}

RTP_LLM_LOG_INFO("Vmem allocator map %p[%lu]\n", dptr, block.size);

size_t padded_size = block.size; // 沿用之前对齐后的大小
CUmemAllocationProp prop{};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = device_id_;

// 1. 重新创建一块新的物理显存句柄
CUmemGenericAllocationHandle new_handle{};
check_cuda_value(cuMemCreate(&new_handle, padded_size, &prop, 0));

// 2. 重新映射到新物理显存
check_cuda_value(cuMemMap(dptr, padded_size, 0, new_handle, 0));

// 3. 重新设置访问权限
CUmemAccessDesc access{};
access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
access.location.id = device_id_;
access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
check_cuda_value(cuMemSetAccess(dptr, padded_size, &access, 1));

// 4. 更新 mapping 中的 handle
block.handle = new_handle;

block.mapped = true;
}
}

Allocator<AllocatorType::CUDA_HOST>::Allocator(int device_id): PurePointerCudaAllocator(device_id) {}

Allocator<AllocatorType::CUDA_HOST>::~Allocator() {
Expand Down
61 changes: 60 additions & 1 deletion rtp_llm/cpp/cuda/allocator_cuda.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once
#include "rtp_llm/cpp/core/allocator.h"
#include "rtp_llm/cpp/cuda/cuda_host_utils.h"
#include <cuda.h>
#include <mutex>
#include <unordered_set>

Expand Down Expand Up @@ -60,14 +61,72 @@ class PurePointerCudaAllocator: public ICudaAllocator {
};

template<>
class Allocator<AllocatorType::CUDA>: public PurePointerCudaAllocator, public TypedAllocator<AllocatorType::CUDA> {
class Allocator<AllocatorType::CUDA>:
public PurePointerCudaAllocator, public IVirtualMemAllocator,
public TypedAllocator<AllocatorType::CUDA> {
public:
Allocator(int device_id);
~Allocator();

void* doMalloc(size_t size) override;
void* doMallocSync(size_t size) override;
void doFree(void* ptr) override;

/**
* @brief Remaps all non-pinned virtual address ranges to freshly allocated
* physical memory.
*
* This function is intended to be called after `unmap()` has been used to
* release underlying physical allocations (for instance, when temporarily
* returning memory to the OS or across suspend/resume cycles). For every
* virtual block that is **not** marked as pinned, it:
* 1. Creates a new physical allocation of identical size.
* 2. Maps that allocation to the previously reserved virtual address.
* 3. Re-establishes read/write access for the current device.
* 4. Updates internal bookkeeping with the new handle.
*
* Pinned blocks are skipped entirely, preserving their original physical
* backing.
*
* Thread-safe: protected by an internal mutex.
*
* @note Must be preceded by a matching `unmap()` call; otherwise the
* virtual addresses remain in an unmapped state.
*/
void map() override;
/**
* @brief Releases the physical backing of all non-pinned device allocations
* while preserving their virtual address reservations.
*
* For each block managed by this allocator:
* - If the block is **pinned**, it is left untouched.
* - Otherwise:
* 1. The virtual-to-physical mapping is removed (`cuMemUnmap`).
* 2. The physical allocation handle is destroyed (`cuMemRelease`).
*
* Virtual address ranges remain reserved and can later be re-populated with
* new physical memory by calling `map()`.
*
* Thread-safe: protected by an internal mutex.
*
* @warning After this call, accessing device pointers associated with
* non-pinned blocks results in undefined behavior until `map()`
* is invoked.
*/
void unmap() override;
void* mallocPhysical(size_t size) override;

private:
struct VmemBlock {
bool pin; // Whether the memory block is pinned (resident)
bool mapped; // Whether the memory has been mapped to virtual memory
size_t size; // Size of the block
CUmemGenericAllocationHandle handle; // Physical memory handle (CUDA memory handle)
};

// Mapping from virtual device pointer (CUdeviceptr) to the memory control block
std::unique_ptr<std::unordered_map<CUdeviceptr, VmemBlock>> pointer_mapping_;
std::mutex lock_;
};

template<>
Expand Down
Loading