diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp index a88b9a871..a03beac96 100644 --- a/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp @@ -468,7 +468,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { else if (role == WarpRole::Correction) { cutlass::arch::warpgroup_reg_dealloc(); - bool has_valid = false; + bool tmem_allocated = false; CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { @@ -481,8 +481,6 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { continue; } - has_valid = true; - if (get<1>(logical_problem_shape) == 0) { mainloop.correction_empty( blk_coord, @@ -495,6 +493,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { continue; } + tmem_allocated = true; + mainloop.correction( blk_coord, params.mainloop, logical_problem_shape, @@ -512,7 +512,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { if constexpr (NumWarpsEpilogue == 0) { static_assert(NumWarpsCorrection == 1); - if (has_valid) { + if (tmem_allocated) { uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); } @@ -522,7 +522,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { else if (role == WarpRole::MMA) { warpgroup_reg_set(); - bool allocated = false; + bool tmem_allocated = false; CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { @@ -531,18 +531,17 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { auto logical_problem_shape = apply_batch(params, params.problem_shape, get<2,1>(blk_coord)); - if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + if ( + (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) + || (get<1>(logical_problem_shape) == 0) + ) { continue; } - if (!allocated) { + if (!tmem_allocated) { tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); __syncwarp(); - allocated = true; - } - - if (get<1>(logical_problem_shape) == 0) { - continue; + tmem_allocated = true; } mainloop.mma( @@ -573,11 +572,10 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { auto logical_problem_shape = apply_batch(params, params.problem_shape, get<2,1>(blk_coord)); - if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { - continue; - } - - if (get<1>(logical_problem_shape) == 0) { + if ( + (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) + || (get<1>(logical_problem_shape) == 0) + ) { continue; } @@ -594,7 +592,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { else if (role == WarpRole::Epilogue) { warpgroup_reg_set(); - bool has_valid = false; + bool tmem_allocated = false; CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { @@ -607,7 +605,9 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { continue; } - has_valid = true; + if (get<1>(logical_problem_shape) != 0) { + tmem_allocated = true; + } epilogue.store( blk_coord, logical_problem_shape, @@ -620,7 +620,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { static_assert(NumWarpsEpilogue <= 1); if constexpr (NumWarpsEpilogue == 1) { - if(has_valid) { + if(tmem_allocated) { uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); }