diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/kernel_launcher.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/kernel_launcher.cuh index b4f155f735..1a97f09de7 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/kernel_launcher.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/kernel_launcher.cuh @@ -228,6 +228,40 @@ struct KernelLauncher { "]"); } + inline void kernelLaunchCheck() const { + // This is a replacement for C10_CUDA_KERNEL_LAUNCH_CHECK() that adds more + // context information to the error message. See: + // https://github.com/pytorch/pytorch/blob/main/c10/cuda/CUDAException.cpp + + const auto cuda_error = cudaGetLastError(); + + const auto cuda_kernel_failure = + c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().has_failed(); + + if (C10_LIKELY(cuda_error == cudaSuccess && !cuda_kernel_failure)) { + return; + } + + // Inject the context information into the error message on CUDA failures + TORCH_CHECK( + false, + context.description(), + " CUDA Error: ", + cudaGetErrorString(cuda_error), +#ifdef __HIPCC__ + // c10::cuda::get_cuda_check_suffix has only been recently added to + // Torch HIPify mappings, so wrap with __HIPCC__ until the mapping land + // in PyTorch OSS. + // + // TODO: Remove when HIPify mappings are updated in PyTorch OSS + c10::hip::get_hip_check_suffix(), +#else + c10::cuda::get_cuda_check_suffix(), +#endif + "\n", + c10::cuda::c10_retrieve_device_side_assertion_info()); + } + template inline void launch_kernel( const KernelFunc& kernel, @@ -304,8 +338,10 @@ struct KernelLauncher { cudaDeviceSynchronize(); } - // Check for CUDA errors - C10_CUDA_KERNEL_LAUNCH_CHECK(); + // Check for CUDA errors. This is a replacement for + // C10_CUDA_KERNEL_LAUNCH_CHECK() that adds more context information to the + // error message. + kernelLaunchCheck(); // If NaN checks are enabled, run post-kernel verifications on all kernel // arguments that are tensors