Skip to content
78 changes: 72 additions & 6 deletions cpp/tensorrt_llm/common/opUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,20 +238,73 @@ class PerCudaCtxPerThreadSingletonCreator
std::unordered_map<CacheKey, std::weak_ptr<T>, hash<CacheKey>> mObservers;
};

// Structure to hold memory information
struct MemoryInfo
{
size_t free_mb;
size_t total_mb;
float free_percent;
};

// Helper function to get current memory information
MemoryInfo getMemoryInfo()
{
size_t free_mem = 0, total_mem = 0;
TLLM_CUDA_CHECK(cudaMemGetInfo(&free_mem, &total_mem));

size_t const free_mb = free_mem / (1024 * 1024);
size_t const total_mb = total_mem / (1024 * 1024);
float const free_percent = (total_mem > 0) ? (static_cast<float>(free_mem) / total_mem * 100.0f) : 0.0f;

return {free_mb, total_mb, free_percent};
}

// Helper function to log current memory usage
void logMemoryUsage(char const* operation, CUcontext ctx)
{
auto const mem = getMemoryInfo();
TLLM_LOG_DEBUG("%s: Context=%p, Free Memory=%zu MB (%.1f%%), Total=%zu MB", operation, ctx, mem.free_mb,
mem.free_percent, mem.total_mb);
}

// Helper function to throw
void throwCublasErrorWithMemInfo(char const* operation, CUcontext ctx, cublasStatus_t status)
{
auto const mem = getMemoryInfo();
TLLM_THROW(
"Failed to create %s. "
"Status: %d, Context: %p, Free Memory: %zu MB (%.1f%%), Total: %zu MB. "
"Consider reducing kv_cache_config.free_gpu_memory_fraction.",
operation, status, ctx, mem.free_mb, mem.free_percent, mem.total_mb);
}

} // namespace

std::shared_ptr<cublasHandle_t> getCublasHandle()
{
static PerCudaCtxPerThreadSingletonCreator<cublasHandle_t> creator(
[]() -> auto
{
auto handle = std::unique_ptr<cublasHandle_t>(new cublasHandle_t);
TLLM_CUDA_CHECK(cublasCreate(handle.get()));
CUcontext ctx = getCurrentCudaCtx();
logMemoryUsage("Creating cublas handle", ctx);

auto handle = std::make_unique<cublasHandle_t>();
cublasStatus_t status = cublasCreate(handle.get());

if (status != CUBLAS_STATUS_SUCCESS)
{
throwCublasErrorWithMemInfo("cublas handle", ctx, status);
}

return handle;
},
[](cublasHandle_t* handle)
{
TLLM_CUDA_CHECK(cublasDestroy(*handle));
cublasStatus_t status = cublasDestroy(*handle);
if (status != CUBLAS_STATUS_SUCCESS)
{
TLLM_LOG_WARNING("Failed to destroy cublas handle. Status: %d", status);
}
delete handle;
});
return creator();
Expand All @@ -262,13 +315,26 @@ std::shared_ptr<cublasLtHandle_t> getCublasLtHandle()
static PerCudaCtxPerThreadSingletonCreator<cublasLtHandle_t> creator(
[]() -> auto
{
auto handle = std::unique_ptr<cublasLtHandle_t>(new cublasLtHandle_t);
TLLM_CUDA_CHECK(cublasLtCreate(handle.get()));
CUcontext ctx = getCurrentCudaCtx();
logMemoryUsage("Creating cublasLt handle", ctx);

auto handle = std::make_unique<cublasLtHandle_t>();
cublasStatus_t status = cublasLtCreate(handle.get());

if (status != CUBLAS_STATUS_SUCCESS)
{
throwCublasErrorWithMemInfo("cublasLt handle", ctx, status);
}

return handle;
},
[](cublasLtHandle_t* handle)
{
TLLM_CUDA_CHECK(cublasLtDestroy(*handle));
cublasStatus_t status = cublasLtDestroy(*handle);
if (status != CUBLAS_STATUS_SUCCESS)
{
TLLM_LOG_WARNING("Failed to destroy cublasLt handle. Status: %d", status);
}
delete handle;
});
return creator();
Expand Down
12 changes: 10 additions & 2 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,11 @@ def test_fp8(self, fp8kv, attn_backend, torch_compile):
disable_overlap_scheduler=torch_compile,
)
if fp8kv:
pytorch_config["kv_cache_config"] = KvCacheConfig(dtype="fp8")
pytorch_config["kv_cache_config"] = KvCacheConfig(
dtype="fp8",
free_gpu_memory_fraction=
0.8, # Prevent cublas/cublasLt handle allocation memory insufficient errors
)
with LLM(
f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8",
**pytorch_config) as llm:
Expand Down Expand Up @@ -196,7 +200,11 @@ def test_fp8_4gpus(self, tp_size, pp_size, fp8kv, attn_backend,
disable_overlap_scheduler=torch_compile,
)
if fp8kv:
pytorch_config["kv_cache_config"] = KvCacheConfig(dtype="fp8")
pytorch_config["kv_cache_config"] = KvCacheConfig(
dtype="fp8",
free_gpu_memory_fraction=
0.8, # Prevent cublas/cublasLt handle allocation memory insufficient errors
)
with LLM(
f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8",
tensor_parallel_size=tp_size,
Expand Down