@@ -62,14 +62,35 @@ decltype(auto) transform_kernel_arg(const SourceContext& context, T&& arg) {
62
62
}
63
63
}
64
64
65
+ // //////////////////////////////////////////////////////////////////////////////
66
+ // Verify Kernel Argument
67
+ //
68
+ // Verify certain arguments before and after kernel invocation
69
+ // //////////////////////////////////////////////////////////////////////////////
70
+
71
+ template <typename T>
72
+ decltype (auto ) check_kernel_arg(const SourceContext& context, T&& arg) {
73
+ if constexpr (is_tensor_accessor_builder_v<std::decay_t <T>>) {
74
+ // If the arg is a TensorAccessorBuilder, run verifications on the tensor it
75
+ // is ref-wrapping, e.g. NaN value checks.
76
+ return arg.checkValues (context.description ());
77
+ } else {
78
+ // Otherwise, perfect-forward the argument as is
79
+ return std::forward<T>(arg);
80
+ }
81
+ }
82
+
65
83
// //////////////////////////////////////////////////////////////////////////////
66
84
// GPU Kernel Launcher
67
85
//
68
86
// This class encapsulates the common ceremonial pre- and post-execution
69
87
// routines when launching GPU kernels.
70
88
// //////////////////////////////////////////////////////////////////////////////
71
89
72
- template <bool EnableDSA = false , bool EnableBarrierIsolation = false >
90
+ template <
91
+ bool EnableDSA = false ,
92
+ bool EnableBarrierIsolation = false ,
93
+ bool EnableNaNChecks = false >
73
94
struct KernelLauncher {
74
95
const SourceContext context;
75
96
@@ -234,6 +255,21 @@ struct KernelLauncher {
234
255
// device associated with the compute stream
235
256
checkSharedMemoryPerBlockNotExceeded (properties, shared_mem_per_block);
236
257
258
+ // If NaN checks are enabled, run verifications on all kernel arguments that
259
+ // are tensors
260
+ if constexpr (EnableNaNChecks) {
261
+ const auto summary = std::string (context.summary ) + " (pre-execution)" ;
262
+ (check_kernel_arg (context.withSummary (summary), std::forward<Args>(args)),
263
+ ...);
264
+ }
265
+
266
+ // If barrier isolation is enabled, synchronize the stream first before
267
+ // launching the kernel. This has roughly the same effect as setting
268
+ // `CUDA_LAUNCH_BLOCKING=1` as an environment variable.
269
+ if constexpr (EnableBarrierIsolation) {
270
+ cudaDeviceSynchronize ();
271
+ }
272
+
237
273
if constexpr (EnableDSA) {
238
274
// This launch code here is essentially the same as the contents of
239
275
// TORCH_USE_CUDA_DSA macro, but with the addition of kernel argument
@@ -251,13 +287,6 @@ struct KernelLauncher {
251
287
c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref ();
252
288
#endif
253
289
254
- // If barrier isolation is enabled, synchronize the stream first before
255
- // launching the kernel. This has roughly the same effect as setting
256
- // `CUDA_LAUNCH_BLOCKING=1` as an environment variable.
257
- if constexpr (EnableBarrierIsolation) {
258
- cudaDeviceSynchronize ();
259
- }
260
-
261
290
// Launch the kernel
262
291
kernel<<<grid, block, shared_mem_per_block, stream>>> (
263
292
// Transform arguments to the kernel before forwarding them.
@@ -285,6 +314,14 @@ struct KernelLauncher {
285
314
286
315
// Check for CUDA errors
287
316
C10_CUDA_KERNEL_LAUNCH_CHECK ();
317
+
318
+ // If NaN checks are enabled, run post-kernel verifications on all kernel
319
+ // arguments that are tensors
320
+ if constexpr (EnableNaNChecks) {
321
+ const auto summary = std::string (context.summary ) + " (post-execution)" ;
322
+ (check_kernel_arg (context.withSummary (summary), std::forward<Args>(args)),
323
+ ...);
324
+ }
288
325
}
289
326
};
290
327
@@ -320,30 +357,38 @@ struct KernelLauncher {
320
357
#define _FKL_TFILE_ " "
321
358
#endif
322
359
323
- #ifdef FBGEMM_GPU_KERNEL_DEBUG
324
- #define _FKL_KDEBUG_ true
360
+ #ifdef FBGEMM_GPU_ISOLATE_KERNEL_LAUNCH
361
+ #define _FKL_BLOCKING_ true
362
+ #else
363
+ #define _FKL_BLOCKING_ false
364
+ #endif
365
+
366
+ #ifdef FBGEMM_GPU_TENSORCHECK
367
+ #define _FKL_TENSORCHECK_ true
325
368
#else
326
- #define _FKL_KDEBUG_ false
369
+ #define _FKL_TENSORCHECK_ false
327
370
#endif
328
371
329
- #define FBGEMM_LAUNCH_KERNEL (KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
330
- ([&] { \
331
- using source_location = fbgemm_gpu::utils::source_location; \
332
- constexpr auto location = source_location::current (); \
333
- decltype (KERNEL)& kernel = KERNEL; \
334
- \
335
- return fbgemm_gpu::utils::KernelLauncher<false , _FKL_KDEBUG_>( \
336
- location, #KERNEL, _FKL_TFILE_) \
337
- .launch_kernel (kernel, GRID, BLOCK, SMEM, STREAM, __VA_ARGS__); \
372
+ #define FBGEMM_LAUNCH_KERNEL (KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
373
+ ([&] { \
374
+ using source_location = fbgemm_gpu::utils::source_location; \
375
+ constexpr auto location = source_location::current (); \
376
+ decltype (KERNEL)& kernel = KERNEL; \
377
+ \
378
+ return fbgemm_gpu::utils:: \
379
+ KernelLauncher<false , _FKL_BLOCKING_, _FKL_TENSORCHECK_>( \
380
+ location, #KERNEL, _FKL_TFILE_) \
381
+ .launch_kernel (kernel, GRID, BLOCK, SMEM, STREAM, __VA_ARGS__); \
338
382
}())
339
383
340
- #define FBGEMM_LAUNCH_DSA_KERNEL (KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
341
- ([&] { \
342
- using source_location = fbgemm_gpu::utils::source_location; \
343
- constexpr auto location = source_location::current (); \
344
- decltype (KERNEL)& kernel = KERNEL; \
345
- \
346
- return fbgemm_gpu::utils::KernelLauncher<true , _FKL_KDEBUG_>( \
347
- location, #KERNEL, _FKL_TFILE_) \
348
- .launch_kernel (kernel, GRID, BLOCK, SMEM, STREAM, __VA_ARGS__); \
384
+ #define FBGEMM_LAUNCH_DSA_KERNEL (KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
385
+ ([&] { \
386
+ using source_location = fbgemm_gpu::utils::source_location; \
387
+ constexpr auto location = source_location::current (); \
388
+ decltype (KERNEL)& kernel = KERNEL; \
389
+ \
390
+ return fbgemm_gpu::utils:: \
391
+ KernelLauncher<true , _FKL_BLOCKING_, _FKL_TENSORCHECK_>( \
392
+ location, #KERNEL, _FKL_TFILE_) \
393
+ .launch_kernel (kernel, GRID, BLOCK, SMEM, STREAM, __VA_ARGS__); \
349
394
}())
0 commit comments