@@ -238,28 +238,35 @@ class PerCudaCtxPerThreadSingletonCreator
238238 std::unordered_map<CacheKey, std::weak_ptr<T>, hash<CacheKey>> mObservers ;
239239};
240240
241+ // Helper function to log memory usage - returns the memory values for potential error handling
242+ static std::pair<size_t , size_t > logMemoryUsage (char const * operation, CUcontext ctx)
243+ {
244+ size_t free_mem = 0 , total_mem = 0 ;
245+ TLLM_CUDA_CHECK (cudaMemGetInfo (&free_mem, &total_mem));
246+
247+ TLLM_LOG_DEBUG (" %s: Context=%p, Free Memory=%zu MB (%.1f%%), Total=%zu MB" , operation, ctx,
248+ free_mem / (1024 * 1024 ), (float ) free_mem / total_mem * 100.0 , total_mem / (1024 * 1024 ));
249+
250+ return {free_mem, total_mem};
251+ }
252+
241253} // namespace
242254
243255std::shared_ptr<cublasHandle_t> getCublasHandle ()
244256{
245257 static PerCudaCtxPerThreadSingletonCreator<cublasHandle_t> creator (
246258 []() -> auto
247259 {
248- size_t free_mem = 0 , total_mem = 0 ;
249- cudaMemGetInfo (&free_mem, &total_mem);
250-
251- CUcontext ctx;
252- cuCtxGetCurrent (&ctx);
260+ CUcontext ctx = getCurrentCudaCtx ();
261+ auto [free_mem, total_mem] = logMemoryUsage (" Creating cublas handle" , ctx);
253262
254- TLLM_LOG_DEBUG (" Creating cublas handle: Context=%p, Free Memory=%zu MB (%.1f%%), Total=%zu MB" , ctx,
255- free_mem / (1024 * 1024 ), (float ) free_mem / total_mem * 100.0 , total_mem / (1024 * 1024 ));
256-
257- auto handle = std::unique_ptr<cublasHandle_t>(new cublasHandle_t);
263+ auto handle = std::make_unique<cublasHandle_t>();
258264
259265 cublasStatus_t status = cublasCreate (handle.get ());
260266
261267 if (status != CUBLAS_STATUS_SUCCESS)
262268 {
269+ // Re-fetch memory info for error message (memory state might have changed)
263270 cudaMemGetInfo (&free_mem, &total_mem);
264271 TLLM_THROW (
265272 " Failed to create cublas handle. "
@@ -273,7 +280,11 @@ std::shared_ptr<cublasHandle_t> getCublasHandle()
273280 },
274281 [](cublasHandle_t* handle)
275282 {
276- TLLM_CUDA_CHECK (cublasDestroy (*handle));
283+ cublasStatus_t status = cublasDestroy (*handle);
284+ if (status != CUBLAS_STATUS_SUCCESS)
285+ {
286+ TLLM_LOG_WARNING (" Failed to destroy cublas handle. Status: %d" , status);
287+ }
277288 delete handle;
278289 });
279290 return creator ();
@@ -284,21 +295,16 @@ std::shared_ptr<cublasLtHandle_t> getCublasLtHandle()
284295 static PerCudaCtxPerThreadSingletonCreator<cublasLtHandle_t> creator (
285296 []() -> auto
286297 {
287- size_t free_mem = 0 , total_mem = 0 ;
288- cudaMemGetInfo (&free_mem, &total_mem);
289-
290- CUcontext ctx;
291- cuCtxGetCurrent (&ctx);
292-
293- TLLM_LOG_DEBUG (" Creating cublasLt handle: Context=%p, Free Memory=%zu MB (%.1f%%), Total=%zu MB" , ctx,
294- free_mem / (1024 * 1024 ), (float ) free_mem / total_mem * 100.0 , total_mem / (1024 * 1024 ));
298+ CUcontext ctx = getCurrentCudaCtx ();
299+ auto [free_mem, total_mem] = logMemoryUsage (" Creating cublasLt handle" , ctx);
295300
296- auto handle = std::unique_ptr <cublasLtHandle_t>(new cublasLtHandle_t );
301+ auto handle = std::make_unique <cublasLtHandle_t>();
297302
298303 cublasStatus_t status = cublasLtCreate (handle.get ());
299304
300305 if (status != CUBLAS_STATUS_SUCCESS)
301306 {
307+ // Re-fetch memory info for error message (memory state might have changed)
302308 cudaMemGetInfo (&free_mem, &total_mem);
303309 TLLM_THROW (
304310 " Failed to create cublasLt handle. "
@@ -312,7 +318,11 @@ std::shared_ptr<cublasLtHandle_t> getCublasLtHandle()
312318 },
313319 [](cublasLtHandle_t* handle)
314320 {
315- TLLM_CUDA_CHECK (cublasLtDestroy (*handle));
321+ cublasStatus_t status = cublasLtDestroy (*handle);
322+ if (status != CUBLAS_STATUS_SUCCESS)
323+ {
324+ TLLM_LOG_WARNING (" Failed to destroy cublasLt handle. Status: %d" , status);
325+ }
316326 delete handle;
317327 });
318328 return creator ();
0 commit comments