diff --git a/cpp/tensorrt_llm/common/opUtils.cpp b/cpp/tensorrt_llm/common/opUtils.cpp index 053f9d9ece7..a8b1d146aaa 100644 --- a/cpp/tensorrt_llm/common/opUtils.cpp +++ b/cpp/tensorrt_llm/common/opUtils.cpp @@ -238,6 +238,46 @@ class PerCudaCtxPerThreadSingletonCreator std::unordered_map, hash> mObservers; }; +// Structure to hold memory information +struct MemoryInfo +{ + size_t free_mb; + size_t total_mb; + float free_percent; +}; + +// Helper function to get current memory information +MemoryInfo getMemoryInfo() +{ + size_t free_mem = 0, total_mem = 0; + TLLM_CUDA_CHECK(cudaMemGetInfo(&free_mem, &total_mem)); + + size_t const free_mb = free_mem / (1024 * 1024); + size_t const total_mb = total_mem / (1024 * 1024); + float const free_percent = (total_mem > 0) ? (static_cast(free_mem) / total_mem * 100.0f) : 0.0f; + + return {free_mb, total_mb, free_percent}; +} + +// Helper function to log current memory usage +void logMemoryUsage(char const* operation, CUcontext ctx) +{ + auto const mem = getMemoryInfo(); + TLLM_LOG_DEBUG("%s: Context=%p, Free Memory=%zu MB (%.1f%%), Total=%zu MB", operation, ctx, mem.free_mb, + mem.free_percent, mem.total_mb); +} + +// Helper function to throw +void throwCublasErrorWithMemInfo(char const* operation, CUcontext ctx, cublasStatus_t status) +{ + auto const mem = getMemoryInfo(); + TLLM_THROW( + "Failed to create %s. " + "Status: %d, Context: %p, Free Memory: %zu MB (%.1f%%), Total: %zu MB. " + "Consider reducing kv_cache_config.free_gpu_memory_fraction.", + operation, status, ctx, mem.free_mb, mem.free_percent, mem.total_mb); +} + } // namespace std::shared_ptr getCublasHandle() @@ -245,13 +285,26 @@ std::shared_ptr getCublasHandle() static PerCudaCtxPerThreadSingletonCreator creator( []() -> auto { - auto handle = std::unique_ptr(new cublasHandle_t); - TLLM_CUDA_CHECK(cublasCreate(handle.get())); + CUcontext ctx = getCurrentCudaCtx(); + logMemoryUsage("Creating cublas handle", ctx); + + auto handle = std::make_unique(); + cublasStatus_t status = cublasCreate(handle.get()); + + if (status != CUBLAS_STATUS_SUCCESS) + { + throwCublasErrorWithMemInfo("cublas handle", ctx, status); + } + return handle; }, [](cublasHandle_t* handle) { - TLLM_CUDA_CHECK(cublasDestroy(*handle)); + cublasStatus_t status = cublasDestroy(*handle); + if (status != CUBLAS_STATUS_SUCCESS) + { + TLLM_LOG_WARNING("Failed to destroy cublas handle. Status: %d", status); + } delete handle; }); return creator(); @@ -262,13 +315,26 @@ std::shared_ptr getCublasLtHandle() static PerCudaCtxPerThreadSingletonCreator creator( []() -> auto { - auto handle = std::unique_ptr(new cublasLtHandle_t); - TLLM_CUDA_CHECK(cublasLtCreate(handle.get())); + CUcontext ctx = getCurrentCudaCtx(); + logMemoryUsage("Creating cublasLt handle", ctx); + + auto handle = std::make_unique(); + cublasStatus_t status = cublasLtCreate(handle.get()); + + if (status != CUBLAS_STATUS_SUCCESS) + { + throwCublasErrorWithMemInfo("cublasLt handle", ctx, status); + } + return handle; }, [](cublasLtHandle_t* handle) { - TLLM_CUDA_CHECK(cublasLtDestroy(*handle)); + cublasStatus_t status = cublasLtDestroy(*handle); + if (status != CUBLAS_STATUS_SUCCESS) + { + TLLM_LOG_WARNING("Failed to destroy cublasLt handle. Status: %d", status); + } delete handle; }); return creator(); diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index e4089aa8c47..122a08d73f0 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -163,7 +163,11 @@ def test_fp8(self, fp8kv, attn_backend, torch_compile): disable_overlap_scheduler=torch_compile, ) if fp8kv: - pytorch_config["kv_cache_config"] = KvCacheConfig(dtype="fp8") + pytorch_config["kv_cache_config"] = KvCacheConfig( + dtype="fp8", + free_gpu_memory_fraction= + 0.8, # Prevent cublas/cublasLt handle allocation memory insufficient errors + ) with LLM( f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8", **pytorch_config) as llm: @@ -198,7 +202,11 @@ def test_fp8_4gpus(self, tp_size, pp_size, fp8kv, attn_backend, disable_overlap_scheduler=torch_compile, ) if fp8kv: - pytorch_config["kv_cache_config"] = KvCacheConfig(dtype="fp8") + pytorch_config["kv_cache_config"] = KvCacheConfig( + dtype="fp8", + free_gpu_memory_fraction= + 0.8, # Prevent cublas/cublasLt handle allocation memory insufficient errors + ) with LLM( f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8", tensor_parallel_size=tp_size,