Skip to content

Commit 247016b

Browse files
Fix cublas handle not sufficient memory bug in multi gpu case and reduce memory fraction in kv cache config.
Signed-off-by: Wangshanshan <[email protected]>
1 parent bbf1175 commit 247016b

File tree

2 files changed

+34
-24
lines changed

2 files changed

+34
-24
lines changed

cpp/tensorrt_llm/common/opUtils.cpp

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -238,28 +238,35 @@ class PerCudaCtxPerThreadSingletonCreator
238238
std::unordered_map<CacheKey, std::weak_ptr<T>, hash<CacheKey>> mObservers;
239239
};
240240

241+
// Helper function to log memory usage - returns the memory values for potential error handling
242+
static std::pair<size_t, size_t> logMemoryUsage(char const* operation, CUcontext ctx)
243+
{
244+
size_t free_mem = 0, total_mem = 0;
245+
TLLM_CUDA_CHECK(cudaMemGetInfo(&free_mem, &total_mem));
246+
247+
TLLM_LOG_DEBUG("%s: Context=%p, Free Memory=%zu MB (%.1f%%), Total=%zu MB", operation, ctx,
248+
free_mem / (1024 * 1024), (float) free_mem / total_mem * 100.0, total_mem / (1024 * 1024));
249+
250+
return {free_mem, total_mem};
251+
}
252+
241253
} // namespace
242254

243255
std::shared_ptr<cublasHandle_t> getCublasHandle()
244256
{
245257
static PerCudaCtxPerThreadSingletonCreator<cublasHandle_t> creator(
246258
[]() -> auto
247259
{
248-
size_t free_mem = 0, total_mem = 0;
249-
cudaMemGetInfo(&free_mem, &total_mem);
250-
251-
CUcontext ctx;
252-
cuCtxGetCurrent(&ctx);
260+
CUcontext ctx = getCurrentCudaCtx();
261+
auto [free_mem, total_mem] = logMemoryUsage("Creating cublas handle", ctx);
253262

254-
TLLM_LOG_DEBUG("Creating cublas handle: Context=%p, Free Memory=%zu MB (%.1f%%), Total=%zu MB", ctx,
255-
free_mem / (1024 * 1024), (float) free_mem / total_mem * 100.0, total_mem / (1024 * 1024));
256-
257-
auto handle = std::unique_ptr<cublasHandle_t>(new cublasHandle_t);
263+
auto handle = std::make_unique<cublasHandle_t>();
258264

259265
cublasStatus_t status = cublasCreate(handle.get());
260266

261267
if (status != CUBLAS_STATUS_SUCCESS)
262268
{
269+
// Re-fetch memory info for error message (memory state might have changed)
263270
cudaMemGetInfo(&free_mem, &total_mem);
264271
TLLM_THROW(
265272
"Failed to create cublas handle. "
@@ -273,7 +280,11 @@ std::shared_ptr<cublasHandle_t> getCublasHandle()
273280
},
274281
[](cublasHandle_t* handle)
275282
{
276-
TLLM_CUDA_CHECK(cublasDestroy(*handle));
283+
cublasStatus_t status = cublasDestroy(*handle);
284+
if (status != CUBLAS_STATUS_SUCCESS)
285+
{
286+
TLLM_LOG_WARNING("Failed to destroy cublas handle. Status: %d", status);
287+
}
277288
delete handle;
278289
});
279290
return creator();
@@ -284,21 +295,16 @@ std::shared_ptr<cublasLtHandle_t> getCublasLtHandle()
284295
static PerCudaCtxPerThreadSingletonCreator<cublasLtHandle_t> creator(
285296
[]() -> auto
286297
{
287-
size_t free_mem = 0, total_mem = 0;
288-
cudaMemGetInfo(&free_mem, &total_mem);
289-
290-
CUcontext ctx;
291-
cuCtxGetCurrent(&ctx);
292-
293-
TLLM_LOG_DEBUG("Creating cublasLt handle: Context=%p, Free Memory=%zu MB (%.1f%%), Total=%zu MB", ctx,
294-
free_mem / (1024 * 1024), (float) free_mem / total_mem * 100.0, total_mem / (1024 * 1024));
298+
CUcontext ctx = getCurrentCudaCtx();
299+
auto [free_mem, total_mem] = logMemoryUsage("Creating cublasLt handle", ctx);
295300

296-
auto handle = std::unique_ptr<cublasLtHandle_t>(new cublasLtHandle_t);
301+
auto handle = std::make_unique<cublasLtHandle_t>();
297302

298303
cublasStatus_t status = cublasLtCreate(handle.get());
299304

300305
if (status != CUBLAS_STATUS_SUCCESS)
301306
{
307+
// Re-fetch memory info for error message (memory state might have changed)
302308
cudaMemGetInfo(&free_mem, &total_mem);
303309
TLLM_THROW(
304310
"Failed to create cublasLt handle. "
@@ -312,7 +318,11 @@ std::shared_ptr<cublasLtHandle_t> getCublasLtHandle()
312318
},
313319
[](cublasLtHandle_t* handle)
314320
{
315-
TLLM_CUDA_CHECK(cublasLtDestroy(*handle));
321+
cublasStatus_t status = cublasLtDestroy(*handle);
322+
if (status != CUBLAS_STATUS_SUCCESS)
323+
{
324+
TLLM_LOG_WARNING("Failed to destroy cublasLt handle. Status: %d", status);
325+
}
316326
delete handle;
317327
});
318328
return creator();

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ def test_fp8(self, fp8kv, attn_backend, torch_compile):
164164
if fp8kv:
165165
pytorch_config["kv_cache_config"] = KvCacheConfig(
166166
dtype="fp8",
167-
max_tokens=
168-
100000, # Limit tokens to prevent no room for cublas/cublasLt handles
167+
free_gpu_memory_fraction=
168+
0.8, # Prevent cublas/cublasLt handle allocation memory insufficient errors
169169
)
170170
with LLM(
171171
f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8",
@@ -202,8 +202,8 @@ def test_fp8_4gpus(self, tp_size, pp_size, fp8kv, attn_backend,
202202
if fp8kv:
203203
pytorch_config["kv_cache_config"] = KvCacheConfig(
204204
dtype="fp8",
205-
max_tokens=
206-
100000, # Limit tokens to prevent no room for cublas/cublasLt handles
205+
free_gpu_memory_fraction=
206+
0.8, # Prevent cublas/cublasLt handle allocation memory insufficient errors
207207
)
208208
with LLM(
209209
f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8",

0 commit comments

Comments
 (0)