From f4536a854ac02ec5e034e43b7b1c2efbcc88c1d7 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Tue, 9 Sep 2025 16:12:05 -0700 Subject: [PATCH 01/12] ignore build --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index e7a026874b..7efb940e79 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ # PyCache files __pycache__/ cutlass_library.egg-info/ -/build* +build* +*.log From 61a5c287ca2b5ecff267236f13f9cccee5d5f8fc Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Tue, 9 Sep 2025 16:49:34 -0700 Subject: [PATCH 02/12] add my run.sh --- run.sh | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 run.sh diff --git a/run.sh b/run.sh new file mode 100644 index 0000000000..1a8ef14b48 --- /dev/null +++ b/run.sh @@ -0,0 +1,42 @@ +rm ~/cutlass/output.log +# rm -rf ~/cutlass/build +mkdir -p ~/cutlass/build +cd ~/cutlass/build + +export CUDA_VISIBLE_DEVICES=7 +# export REF_PRINT_DIFF=1 + +cmake .. -DCUTLASS_NVCC_ARCHS=100a + +# Record start time for e2e timing +start_time=$(date +%s) +echo "E2E Test Run Started at: $(date)" | tee -a ~/cutlass/output.log + +targets=( + test_examples_77_blackwell_fmha_bwd_fp16_test_basic + test_examples_77_blackwell_fmha_bwd_fp16_test_varlen +) + +for test in "${targets[@]}" +do + echo "Running $test" + make $test 2>&1 | tee -a ~/cutlass/output.log + echo "Running compute sanitizer $test" + compute-sanitizer make $test 2>&1 | tee -a ~/cutlass/output.log + echo "Running compute sanitizer memcheck $test" + compute-sanitizer --tool=memcheck make $test 2>&1 | tee -a ~/cutlass/output.log + echo "Running compute sanitizer racecheck $test" + compute-sanitizer --tool=racecheck make $test 2>&1 | tee -a ~/cutlass/output.log + echo "Running compute sanitizer synccheck $test" + compute-sanitizer --tool=synccheck make $test 2>&1 | tee -a ~/cutlass/output.log +done + +# Record end time and calculate e2e duration +end_time=$(date +%s) +duration=$((end_time - start_time)) + +echo "E2E Test Run Completed at: $(date)" | tee -a ~/cutlass/output.log +echo "Total E2E Duration: ${duration} seconds" | tee -a ~/cutlass/output.log + +unset CUDA_VISIBLE_DEVICES +unset REF_PRINT_DIFF From 9c567895ad2bf66e33b065ccfbe4e4d97dd12787 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Tue, 23 Sep 2025 10:42:11 -0700 Subject: [PATCH 03/12] port fmha_fusion.hpp --- .../collective/fmha_fusion.hpp | 368 +++++++++++++++++- 1 file changed, 364 insertions(+), 4 deletions(-) diff --git a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp index a33ce2d2ce..df68ae38c5 100644 --- a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp +++ b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp @@ -39,6 +39,20 @@ namespace cutlass::fmha::collective { using namespace cute; struct NoMask { + CUTLASS_DEVICE + NoMask(int left = -1, int right = -1) {} + + template + CUTLASS_DEVICE + cute::tuple get_n_block_min_max( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + const int n_block_max = ceil_div(get<1>(problem_size), get<1>(tile_shape)); + return cute::make_tuple(0, n_block_max); + } + template CUTLASS_DEVICE int get_trip_count( @@ -69,6 +83,26 @@ struct NoMask { return get_trip_count(blk_coord, tile_shape, problem_size); } + template + CUTLASS_DEVICE + int get_n_block_start_unmask( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return 0; + } + + template + CUTLASS_DEVICE + int get_n_block_stop_unmask( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return ceil_div(get<1>(problem_size), get<1>(tile_shape)); + } + template CUTLASS_DEVICE void apply_mask( @@ -84,6 +118,21 @@ struct ResidualMask : NoMask { using Base = NoMask; + CUTLASS_DEVICE + ResidualMask(int left = -1, int right = -1) : NoMask(left, right) {} + + template + CUTLASS_DEVICE + cute::tuple get_n_block_min_max( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + const int n_block_min = 0; + const int n_block_max = ceil_div(get<1>(problem_size), get<1>(tile_shape)); + return cute::make_tuple(n_block_min, n_block_max); + } + template CUTLASS_DEVICE int get_masked_trip_count( BlkCoord const& blk_coord, @@ -110,6 +159,26 @@ struct ResidualMask : NoMask { return get_trip_count(blk_coord, tile_shape, problem_size); } + template + CUTLASS_DEVICE + int get_n_block_start_unmask( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return 0; + } + + template + CUTLASS_DEVICE + int get_n_block_stop_unmask( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return get_unmasked_trip_count(blk_coord, tile_shape, problem_size); + } + template CUTLASS_DEVICE void apply_mask( @@ -136,6 +205,9 @@ struct ResidualMaskForBackward : NoMask { using Base = NoMask; + CUTLASS_DEVICE + ResidualMaskForBackward(int left = -1, int right = -1) {} + template CUTLASS_DEVICE int get_masked_trip_count( BlkCoord const& blk_coord, @@ -192,26 +264,45 @@ struct CausalMask : NoMask { using Base = NoMask; + CUTLASS_DEVICE + CausalMask(int left = -1, int right = -1) : NoMask(left, right) {} + static constexpr bool IsQBegin = kIsQBegin; template CUTLASS_DEVICE - int get_trip_count( + cute::tuple get_n_block_min_max( BlkCoord const& blk_coord, TileShape const& tile_shape, ProblemSize const& problem_size) { + constexpr int n_block_min = 0; + int n_block_max; + // See note below on different ways to think about causal attention // Again, we'd add the offset_q into the max_blocks_q calculation int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size); if constexpr (IsQBegin) { int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape)); - return std::min(max_blocks_k, max_blocks_q); + n_block_max = std::min(max_blocks_k, max_blocks_q); } else { const int offset_q = get<1>(problem_size) - get<0>(problem_size); int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape) + offset_q, get<1>(tile_shape)); - return std::min(max_blocks_k, max_blocks_q); + n_block_max = std::min(max_blocks_k, max_blocks_q); } + + return cute::make_tuple(n_block_min, n_block_max); + } + + template + CUTLASS_DEVICE + int get_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + auto min_max = get_n_block_min_max(blk_coord, tile_shape, problem_size); + return get<1>(min_max); } template @@ -220,7 +311,7 @@ struct CausalMask : NoMask { BlkCoord const& blk_coord, TileShape const& tile_shape, ProblemSize const& problem_size) { - + int trip_count = get_trip_count(blk_coord, tile_shape, problem_size); if constexpr (IsQBegin) { return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape)))); @@ -240,6 +331,26 @@ struct CausalMask : NoMask { return get_trip_count(blk_coord, tile_shape, problem_size) - get_masked_trip_count(blk_coord, tile_shape, problem_size); } + template + CUTLASS_DEVICE + int get_n_block_start_unmask( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return 0; + } + + template + CUTLASS_DEVICE + int get_n_block_stop_unmask( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return get_unmasked_trip_count(blk_coord, tile_shape, problem_size); + } + template CUTLASS_DEVICE void apply_mask( @@ -281,6 +392,9 @@ struct CausalForBackwardMask : CausalMask, ResidualMaskForBackward { using Base = CausalMask; + CUTLASS_DEVICE + CausalForBackwardMask(int left = -1, int right = -1) {} + template CUTLASS_DEVICE void apply_mask( @@ -313,6 +427,252 @@ struct CausalForBackwardMask : CausalMask, ResidualMaskForBackward { }; +template +struct LocalMask : NoMask { + + using Base = NoMask; + + static constexpr bool IsQBegin = kIsQBegin; + + int window_size_left; + int window_size_right; + + CUTLASS_DEVICE + LocalMask(int left = -1, int right = -1) + : window_size_left(left), window_size_right(right) {} + + template + CUTLASS_DEVICE + cute::tuple get_n_block_min_max( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + // tile shape + const int kBlockM = get<0>(tile_shape); + const int kBlockN = get<1>(tile_shape); + const int seq_len_k = get<1>(problem_size); + + // max + const int m_block = get<0>(blk_coord); + const int m_idx_max = (m_block + 1) * kBlockM; + const int offset_q = get<1>(problem_size) - get<0>(problem_size); + int n_idx_max; + if constexpr (IsQBegin) { + n_idx_max = m_idx_max + window_size_right; + } else { + n_idx_max = m_idx_max + offset_q + window_size_right; + } + n_idx_max = std::min(n_idx_max, seq_len_k); + const int n_block_max = ceil_div(n_idx_max, kBlockN); + + // min + const int m_idx_min = m_block * kBlockM; + int n_idx_min; + if constexpr (IsQBegin) { + n_idx_min = m_idx_min - window_size_left; + } else { + n_idx_min = m_idx_min + offset_q - window_size_left; + } + n_idx_min = std::max(n_idx_min, 0); + const int n_block_min = n_idx_min / kBlockN; + + return cute::make_tuple(n_block_min, n_block_max); + } + + template + CUTLASS_DEVICE + int get_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + auto min_max = get_n_block_min_max(blk_coord, tile_shape, problem_size); + int n_block_min = get<0>(min_max); + int n_block_max = get<1>(min_max); + + return n_block_max - n_block_min; + } + + template + CUTLASS_DEVICE + int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + // TODO: follow CausalMask to improve this + + int trip_count = get_trip_count(blk_coord, tile_shape, problem_size); + return trip_count; + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + const int n_block_start_unmask = get_n_block_start_unmask(blk_coord, tile_shape, problem_size); + const int n_block_stop_unmask = get_n_block_stop_unmask(blk_coord, tile_shape, problem_size); + + return n_block_stop_unmask - n_block_start_unmask; + } + + template + CUTLASS_DEVICE + int get_n_block_start_unmask( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + // this does not guarantee to be smaller than n_block_stop_unmask + + const int kBlockM = get<0>(tile_shape); + const int kBlockN = get<1>(tile_shape); + const int seq_len_k = get<1>(problem_size); + + const int m_block = get<0>(blk_coord); + const int offset_q = IsQBegin? 0 : get<1>(problem_size) - get<0>(problem_size); + + const int m_idx_max = (m_block + 1) * kBlockM; + + // -1 to make this inclusive + const int n_idx_max_left = std::max(m_idx_max + offset_q - window_size_left - 1, 0); + + return ceil_div(n_idx_max_left, kBlockN); + } + + template + CUTLASS_DEVICE + int get_n_block_stop_unmask( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + // this does not guarantee to be larger than n_block_start_unmask + + const int kBlockM = get<0>(tile_shape); + const int kBlockN = get<1>(tile_shape); + const int seq_len_k = get<1>(problem_size); + + const int m_block = get<0>(blk_coord); + const int offset_q = IsQBegin? 0 : get<1>(problem_size) - get<0>(problem_size); + + const int m_idx_min = m_block * kBlockM; + // +1 to make this exclusive + const int n_idx_min_right = std::min(m_idx_min + offset_q + window_size_right + 1, seq_len_k); + + return n_idx_min_right / kBlockN; + } + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + // There are two ways to do causal if N_Q != N_K + // (1) is to assume that the Q is at the beginning of the matrix + // - this is the default setting. + // (2) is that it is at the end of the matrix + // - this is usually what we want for inference settings + // where we only compute the next row and use cache for the rest + // - if you'd like this, you only need to set kIsQBegin=false + + const int K = get<1>(problem_size); + + if constexpr (IsQBegin) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + const int pos_i = get<0>(pos); + const int pos_j = get<1>(pos); + + const int window_left_bound = pos_i - window_size_left; + const int window_right_bound = pos_i + window_size_right; + + bool masked = (pos_j < window_left_bound) || (pos_j > window_right_bound) || (pos_j >= K); + + acc_qk(i) = masked ? -INFINITY : acc_qk(i); + } + } else { + const auto offset_q = get<1>(problem_size) - get<0>(problem_size); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + const int pos_i = get<0>(pos); + const int pos_j = get<1>(pos); + + const int offset_pos_i = pos_i + offset_q; + + const int window_left_bound = offset_pos_i - window_size_left; + const int window_right_bound = offset_pos_i + window_size_right; + + bool masked = (pos_j < window_left_bound) || (pos_j > window_right_bound) || (pos_j >= K); + + acc_qk(i) = masked ? -INFINITY : acc_qk(i); + } + } + } +}; + +template +struct LocalMaskForBackward : LocalMask, ResidualMaskForBackward { + + using Base = LocalMask; + + static constexpr bool IsQBegin = kIsQBegin; + + int window_size_left; + int window_size_right; + + CUTLASS_DEVICE + LocalMaskForBackward(int left = -1, int right = -1) + : window_size_left(left), window_size_right(right) {} + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + if constexpr (IsQBegin) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + const int pos_i = get<0>(pos); + const int pos_j = get<1>(pos); + + const int window_left_bound = pos_i - window_size_left; + const int window_right_bound = pos_i + window_size_right; + + bool masked = (pos_j < window_left_bound) || (pos_j > window_right_bound) || !elem_less(pos, problem_size); + + acc_qk(i) = masked ? -INFINITY : acc_qk(i); + } + } else { + const auto offset_q = get<1>(problem_size) - get<0>(problem_size); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + const int pos_i = get<0>(pos); + const int pos_j = get<1>(pos); + + const int offset_pos_i = pos_i + offset_q; + const int window_left_bound = offset_pos_i - window_size_left; + const int window_right_bound = offset_pos_i + window_size_right; + + bool masked = (pos_j < window_left_bound) || (pos_j > window_right_bound) || !elem_less(pos, problem_size); + + acc_qk(i) = masked ? -INFINITY : acc_qk(i); + } + } + } +}; + struct VariableLength { int max_length; int* cumulative_length = nullptr; From bd9cde35f0f2f59e2725eda8726bcd6bc2872230 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Tue, 23 Sep 2025 10:45:04 -0700 Subject: [PATCH 04/12] port fwd mainloop --- ..._fmha_fwd_mainloop_tma_warpspecialized.hpp | 119 +++++++++++++----- 1 file changed, 85 insertions(+), 34 deletions(-) diff --git a/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp index 1e094bf42d..6f114e64b3 100644 --- a/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp @@ -197,6 +197,9 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { // scaling factor to quantize O float inv_scale_o = 1.0f; + + int window_size_left = -1; + int window_size_right = -1; }; struct Params { @@ -206,6 +209,9 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { float scale_softmax_log2; float scale_output; + + int window_size_left; + int window_size_right; }; template @@ -229,7 +235,9 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { Load::to_underlying_arguments(problem_shape, args.load, workspace), args.scale_q * args.scale_k * scale_softmax, args.scale_q * args.scale_k * log2_e * scale_softmax, - args.scale_v * args.inv_scale_o + args.scale_v * args.inv_scale_o, + args.window_size_left, + args.window_size_right }; } @@ -269,7 +277,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { auto pipeline_q_release_state = pipeline_q_consumer_state; auto pipeline_kv_release_state = pipeline_kv_consumer_state; - int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + int mask_tile_count = Mask(params.window_size_left, params.window_size_right).get_trip_count(blk_coord, TileShape{}, problem_shape); typename CollectiveMmaQK::TiledMma mma_qk; ThrMMA thr_mma_qk = mma_qk.get_slice(0); @@ -569,7 +577,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS); if constexpr (need_apply_mask) { - Mask{}.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, problem_shape); + Mask(params.window_size_left, params.window_size_right).apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, problem_shape); } ElementQK old_row_max = row_max; @@ -720,7 +728,12 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, OrderBarrierSoftmax& order_s) { - int mask_tile_count = Mask{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_shape); + Mask mask(params.window_size_left, params.window_size_right); + auto min_max = mask.get_n_block_min_max(blk_coord, TileShape{}, problem_shape); + int n_block_min = get<0>(min_max); + const int n_block_max = get<1>(min_max); + const int n_block_start_unmask = mask.get_n_block_start_unmask(blk_coord, TileShape{}, problem_shape); + const int n_block_stop_unmask = mask.get_n_block_stop_unmask(blk_coord, TileShape{}, problem_shape); ElementQK row_max = -INFINITY; ElementQK row_sum = 0; @@ -728,41 +741,79 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { Tensor cS_base = make_identity_tensor(select<0,1>(TileShapeQK{})); auto logical_offset = make_coord( get<0>(blk_coord) * get<0>(TileShape{}) + (stage % get<0>(ThreadShape{})) * get<0>(TileShapeQK{}), - 0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{}) + 0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{}) + n_block_min * get<1>(TileShape{}) ); Tensor cS = domain_offset(logical_offset, cS_base); pipeline_c.producer_acquire(pipeline_c_producer_state); - CUTLASS_PRAGMA_NO_UNROLL - for (; mask_tile_count > 0; mask_tile_count -= 1) { - softmax_step( - row_max, row_sum, stage, - (mask_tile_count == 1) && - (Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape) == 0), - blk_coord, cS, params, problem_shape, - pipeline_s, pipeline_s_consumer_state, - pipeline_c, pipeline_c_producer_state, - order_s - ); - - cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); - } - - // Masked iterations - mask_tile_count = Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape); - - CUTLASS_PRAGMA_NO_UNROLL - for (; mask_tile_count > 0; mask_tile_count -= 1) { - softmax_step( - row_max, row_sum, stage, mask_tile_count == 1, - blk_coord, cS, params, problem_shape, - pipeline_s, pipeline_s_consumer_state, - pipeline_c, pipeline_c_producer_state, - order_s - ); + // from observation, dispatch is better for the mask -> unmask -> mask pattern and when the number of tiles is small + if constexpr (std::is_base_of_v, Mask> + || std::is_base_of_v, Mask>) { + auto dispatch_bool = [](bool b, auto fn) { + if (b) { + fn(cute::true_type{}); + } + else { + fn(cute::false_type{}); + } + }; + + CUTLASS_PRAGMA_NO_UNROLL + for (; n_block_min < n_block_max; n_block_min += 1) { + // Apply mask only for tiles outside the attention window + // for local mask, we don't guarantee n_block_start_unmask <= n_block_stop_unmask <= n_block_max + bool need_apply_mask = warp_uniform(n_block_min < n_block_start_unmask || n_block_min >= n_block_stop_unmask); + + dispatch_bool(need_apply_mask, [&](auto is_masked_tile) { + if constexpr (decltype(is_masked_tile)::value) { + softmax_step( + row_max, row_sum, stage, (n_block_min == n_block_max - 1), + blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, + pipeline_c, pipeline_c_producer_state, + order_s + ); + } else { + softmax_step( + row_max, row_sum, stage, (n_block_min == n_block_max - 1), + blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, + pipeline_c, pipeline_c_producer_state, + order_s + ); + } + }); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + } else { + CUTLASS_PRAGMA_NO_UNROLL + for (; n_block_min < n_block_stop_unmask; n_block_min += 1) { + softmax_step( + row_max, row_sum, stage, + (n_block_min == n_block_max - 1), + blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, + pipeline_c, pipeline_c_producer_state, + order_s + ); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } - cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + CUTLASS_PRAGMA_NO_UNROLL + for (; n_block_min < n_block_max; n_block_min += 1) { + softmax_step( + row_max, row_sum, stage, n_block_min == n_block_max - 1, + blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, + pipeline_c, pipeline_c_producer_state, + order_s + ); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } } pipeline_c.producer_commit(pipeline_c_producer_state); @@ -963,7 +1014,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, CollectiveEpilogue& epilogue) { - int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + int mask_tile_count = Mask(params.window_size_left, params.window_size_right).get_trip_count(blk_coord, TileShape{}, problem_shape); int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); From ff896d8b74a7a3a5fd99f4e22595dc4a51563765 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Tue, 23 Sep 2025 10:50:32 -0700 Subject: [PATCH 05/12] port fwd tma load --- .../sm100_fmha_load_tma_warpspecialized.hpp | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/examples/77_blackwell_fmha/collective/sm100_fmha_load_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/collective/sm100_fmha_load_tma_warpspecialized.hpp index 3606dcc7d9..ed028cf814 100644 --- a/examples/77_blackwell_fmha/collective/sm100_fmha_load_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/collective/sm100_fmha_load_tma_warpspecialized.hpp @@ -71,6 +71,9 @@ struct Sm100FmhaLoadTmaWarpspecialized { StrideK dK; const Element* ptr_V; StrideV dV; + + int window_size_left = -1; + int window_size_right = -1; }; using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; @@ -81,6 +84,8 @@ struct Sm100FmhaLoadTmaWarpspecialized { TMA_Q tma_load_q; TMA_K tma_load_k; TMA_V tma_load_v; + int window_size_left; + int window_size_right; }; template @@ -130,7 +135,9 @@ struct Sm100FmhaLoadTmaWarpspecialized { return Params{ params_qk.tma_load_a, params_qk.tma_load_b, - params_pv.tma_load_b + params_pv.tma_load_b, + args.window_size_left, + args.window_size_right }; } @@ -154,7 +161,9 @@ struct Sm100FmhaLoadTmaWarpspecialized { BlkCoord blk_coord_q = blk_coord_in; BlkCoord blk_coord_kv = blk_coord_in; - int mask_tile_count = Mask{}.get_trip_count(blk_coord_in, TileShape{}, problem_shape); + auto min_max = Mask(params.window_size_left, params.window_size_right).get_n_block_min_max(blk_coord_in, TileShape{}, problem_shape); + int n_block_min = get<0>(min_max); + int n_block_max = get<1>(min_max); using X = Underscore; @@ -247,7 +256,7 @@ struct Sm100FmhaLoadTmaWarpspecialized { ++pipeline_q_producer_state; // K1 - int k_index = 0; + int k_index = n_block_min; pipeline_kv.producer_acquire(pipeline_kv_producer_state); if (lane_predicate) { auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); @@ -273,8 +282,7 @@ struct Sm100FmhaLoadTmaWarpspecialized { k_index += 1; // loop: - mask_tile_count -= 1; - for (; mask_tile_count > 0; mask_tile_count -= 1) { + for (; k_index < n_block_max; k_index += 1) { // Ki pipeline_kv.producer_acquire(pipeline_kv_producer_state); @@ -291,7 +299,6 @@ struct Sm100FmhaLoadTmaWarpspecialized { copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index())); } ++pipeline_kv_producer_state; - k_index += 1; } } }; From 08e72511a254dde1172e36d62f5991176750774b Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Tue, 23 Sep 2025 14:36:09 -0700 Subject: [PATCH 06/12] port bwd kernel --- ...00_fmha_bwd_kernel_tma_warpspecialized.hpp | 51 +++++++++++++++++-- 1 file changed, 47 insertions(+), 4 deletions(-) diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp index 742b507d4d..9136bff2c3 100644 --- a/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp @@ -303,6 +303,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { TensorStride stride_dq_acc; ElementAcc softmax_scale = 1.0f / sqrtf(TileShapeDQK{}); + + int window_size_left = -1; + int window_size_right = -1; }; using TMA_K = typename CollectiveMmaKQ::Params::TMA_A; @@ -321,6 +324,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { TMA_Q tma_load_q; TMA_DO tma_load_do; TMA_DQ tma_red_dq; + + int window_size_left; + int window_size_right; }; struct EpilogueArguments { @@ -405,7 +411,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { params_vdo.tma_load_a, params_kq.tma_load_b, params_vdo.tma_load_b, - tma_red_dq + tma_red_dq, + args.mainloop.window_size_left, + args.mainloop.window_size_right }, args.epilogue, args.hw_info @@ -1008,7 +1016,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { make_coord(blk_coord_k * TileShapeK{}, _0{}), make_identity_tensor(take<0,2>(TileShapePDO{})) ); - + for (int i = threadIdx.x; i < size(gDK); i += blockDim.x) { if (elem_less(cDK(i), select<1,2>(problem_shape))) { gDK(i) = Element(0); @@ -1272,16 +1280,36 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { } bool trailing_residual_masking = false; if constexpr (std::is_base_of_v) { + // this matters for causal and local masking too trailing_residual_masking = warp_uniform((iter_index == iter_end - 1) || is_residual_k); } + bool local_masking = false; + if constexpr ( + std::is_base_of_v, Mask> + || std::is_base_of_v, Mask> + ) { + const int offset = std::is_base_of_v, Mask> ? (get<1>(problem_shape) - get<0>(problem_shape)) : 0; + const int kv_left = get<1>(blk_coord) * TileShapeK{}; + const int kv_right = kv_left + TileShapeK{} - 1; + // index for j + const int q_left = iter_index * TileShapeQ{} + offset; + const int q_right = q_left + TileShapeQ{} - 1; + + const int q_right_window_left = q_right - mainloop_args.window_size_left; + const int q_left_window_right = q_left + mainloop_args.window_size_right; + + const bool local_unmasked = (q_right_window_left < kv_left) && (q_left_window_right > kv_right); + + local_masking = warp_uniform(!local_unmasked); + } - dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) { + dispatch_bool(leading_causal_masking || trailing_residual_masking || local_masking, [&](auto is_masked_tile) { // compute P = softmax(S, LSE) cute::copy(tiled_t2r, tTR_tST, tTR_rST); if constexpr (decltype(is_masked_tile)::value) { - Mask{}.apply_mask(tTR_rST, [&](int i) { + Mask(mainloop_args.window_size_left, mainloop_args.window_size_right).apply_mask(tTR_rST, [&](int i) { auto c_transpose = tTR_cST(i); return make_coord(get<1>(c_transpose) + iter_index * TileShapeQ{}, get<0>(c_transpose) + get<1>(blk_coord) * TileShapeK{}); }, problem_shape); @@ -1720,6 +1748,21 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { } else if constexpr (std::is_base_of_v, Mask>) { int offset = get<1>(problem_shape) - get<0>(problem_shape); iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{}); + } else if constexpr ( + std::is_base_of_v, Mask> || + std::is_base_of_v, Mask> + ) { + int offset = std::is_base_of_v, Mask> + ? get<1>(problem_shape) - get<0>(problem_shape) + : 0; + + int k_max = (get<1>(blk_coord) + 1) * TileShapeK{}; + int q_max = min(get<0>(problem_shape), k_max - offset + params.mainloop_params.window_size_left); + iter_end = ceil_div(q_max, TileShapeQ{}); + + int k_min = get<1>(blk_coord) * TileShapeK{}; + int q_min = max(0, k_min - offset - params.mainloop_params.window_size_right); + iter_start = q_min / (int)TileShapeQ{}; } if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) { return; From d10a9f5f5fc3b458b125e922db76dc74680eff0a Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Tue, 23 Sep 2025 14:37:55 -0700 Subject: [PATCH 07/12] port device bwd --- examples/77_blackwell_fmha/device/fmha_device_bwd.hpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp b/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp index 9e4efb34ec..4d987a7e06 100644 --- a/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp +++ b/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp @@ -122,6 +122,9 @@ class Sm100FmhaBwd { ElementAccumulator softmax_scale; + int window_size_left = -1; + int window_size_right = -1; + cutlass::KernelHardwareInfo hw_info; }; @@ -219,7 +222,9 @@ class Sm100FmhaBwd { scaled_lse, to_bwd_stride(stride_scaled_lse), sum_OdO, to_bwd_stride(stride_sum_OdO), dQ_acc, to_bwd_stride(stride_dQ), - args.softmax_scale }, + args.softmax_scale, + args.window_size_left, + args.window_size_right }, { args.ptr_dK, to_bwd_stride(args.stride_dK), args.ptr_dV, to_bwd_stride(args.stride_dV) }, args.hw_info From 99c2c319afba75567f54bb2d3f5404ac1aa1ad24 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Tue, 23 Sep 2025 16:28:31 -0700 Subject: [PATCH 08/12] port 77_blackwell_fmha.cu --- .../77_blackwell_fmha/77_blackwell_fmha.cu | 89 ++++++++++++++----- 1 file changed, 69 insertions(+), 20 deletions(-) diff --git a/examples/77_blackwell_fmha/77_blackwell_fmha.cu b/examples/77_blackwell_fmha/77_blackwell_fmha.cu index cc09e68a42..7de1ebc865 100644 --- a/examples/77_blackwell_fmha/77_blackwell_fmha.cu +++ b/examples/77_blackwell_fmha/77_blackwell_fmha.cu @@ -125,6 +125,10 @@ struct Options { bool verify = false; bool verbose = false; + int window_size_left = -1; + int window_size_right = -1; + bool local = false; + bool causal = false; bool causal_q_begin = true; bool residual = false; @@ -261,6 +265,9 @@ struct Options { cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations); cmd.get_cmd_line_argument("tensor_ring_buffers", tensor_ring_buffers, defaults.tensor_ring_buffers); + cmd.get_cmd_line_argument("window_size_left", window_size_left, defaults.window_size_left); + cmd.get_cmd_line_argument("window_size_right", window_size_right, defaults.window_size_right); + verify = cmd.check_cmd_line_flag("verify"); verbose = cmd.check_cmd_line_flag("verbose"); persistent = cmd.check_cmd_line_flag("persistent"); @@ -270,12 +277,13 @@ struct Options { std::string causal_type; cmd.get_cmd_line_argument("causal-type", causal_type, ""); if (mask == "no" || mask == "") { - causal = residual = false; + local = causal = residual = false; if (varlen) { residual = true; } } else if (mask == "causal") { + local = false; residual = false; causal = true; if(causal_type == "qend") { @@ -285,9 +293,16 @@ struct Options { } } else if (mask == "residual") { + local = false; residual = true; causal = false; } + else if (mask == "local") { + local = true; + residual = false; + causal = false; + } + cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count); get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q); get_init_style_argument(cmd, "init-style", init_style_k, defaults.init_style_q); @@ -423,7 +438,7 @@ struct FwdRunner { using ProblemShapeRegular = cute::tuple, int>>; using ProblemShapeVarlen = cute::tuple, int>>; using ProblemShapeType = std::conditional_t; - + using StrideQ = cute::tuple, int>>; // Q D ((H_R, H_K), B) using StrideK = cute::tuple, int>>; // K D ((H_R, H_K), B) using StrideV = StrideK; @@ -433,7 +448,7 @@ struct FwdRunner { static constexpr bool kIsPersistent = find_option_t::value; using TileScheduler = std::conditional_t; - using Mainloop = + using Mainloop = cutlass::fmha::collective::Sm100FmhaFwdMainloopTmaWarpspecialized< Element, ElementAccumulatorQK, ElementAccumulatorPV, TileShape, StrideQ, StrideK, StrideV, @@ -494,7 +509,7 @@ struct FwdRunner { // // Methods // - bool verify(const ProblemShapeType& problem_shape, DeviceBuffer& buffer) { + bool verify(const ProblemShapeType& problem_shape, DeviceBuffer& buffer, Options const & options) { Tensor mQ = make_tensor(make_gmem_ptr(buffer.block_Q.get()), select<0,2,3>(problem_shape), stride_Q); @@ -514,12 +529,12 @@ struct FwdRunner { Tensor mLSE = make_tensor(make_gmem_ptr(buffer.block_ref_LSE.get()), select<0,3>(problem_shape), stride_LSE); - + auto [Q, K, D, HB] = problem_shape; auto problem_shape_ref = cute::make_tuple(Q, K, D, D, HB); - fmha_reference(problem_shape_ref, mQ, mK, mV, mO, mLSE, ActiveMask{}); + fmha_reference(problem_shape_ref, mQ, mK, mV, mO, mLSE, ActiveMask(options.window_size_left, options.window_size_right)); cudaError_t result = cudaDeviceSynchronize(); if (result != cudaSuccess) { @@ -538,7 +553,7 @@ struct FwdRunner { bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); if (! passed_O) { - std::cerr << "failed O: max diff " << max_diff + std::cerr << "failed O: max diff " << max_diff << " mean " << mean_diff << std::endl; } @@ -546,7 +561,7 @@ struct FwdRunner { bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); if ( ! passed_LSE) { - std::cerr << "failed LSE: max diff " << max_diff + std::cerr << "failed LSE: max diff " << max_diff << " mean " << mean_diff << std::endl; } @@ -562,7 +577,7 @@ struct FwdRunner { // generate Q as --b times // gaussian (--Q, --Q / 2) sampled positive - // track cumulative + // track cumulative std::mt19937 rng(0x202305151552ull); std::normal_distribution dist_q(get<0>(problem_size), get<0>(problem_size) / 2); std::normal_distribution dist_kv(get<1>(problem_size), get<1>(problem_size) / 2); @@ -585,7 +600,7 @@ struct FwdRunner { int max_seqlen_kv = 0; for (int i = 0; i < num_batches; i++) { - int seqlen_q = (! options.varlen_q.empty()) ? options.varlen_q.at(i) : + int seqlen_q = (! options.varlen_q.empty()) ? options.varlen_q.at(i) : kVarlenSame ? get<0>(problem_size) : generate_positive_int(dist_q, rng); int seqlen_kv = (! options.varlen_k.empty()) ? options.varlen_k.at(i) : @@ -626,7 +641,7 @@ struct FwdRunner { int h_r = options.h / options.h_k; assert(options.h % options.h_k == 0); auto problem_shape_in = cute::make_tuple(options.q, options.k, options.d, cute::make_tuple(cute::make_tuple(h_r, options.h_k), options.b)); - + ProblemShapeType problem_shape; decltype(problem_shape_in) problem_size; @@ -690,7 +705,7 @@ struct FwdRunner { buffer.device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size()); buffer.device_cumulative_seqlen_kv.copy_from_host( cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size()); - } + } }; buffers.push_back(std::make_unique()); @@ -710,7 +725,7 @@ struct FwdRunner { return problem_shape; } - auto get_arguments(const ProblemShapeType& problem_shape, const cutlass::KernelHardwareInfo& hw_info, int buffer_index) { + auto get_arguments(const ProblemShapeType& problem_shape, const cutlass::KernelHardwareInfo& hw_info, int buffer_index, const Options& options) { auto problem_shape_ = problem_shape; if constexpr (kIsVarlen) { get<0>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_q.get(); @@ -733,7 +748,7 @@ struct FwdRunner { ProblemShapeType problem_shape = initialize(options); int buffer_index = 0; - typename Operation::Arguments arguments = get_arguments(problem_shape, hw_info, buffer_index); + typename Operation::Arguments arguments = get_arguments(problem_shape, hw_info, buffer_index, options); Operation op; @@ -769,7 +784,7 @@ struct FwdRunner { return example_result; } buffer_index = (buffer_index + 1) % buffers.size(); - arguments = get_arguments(problem_shape, hw_info, buffer_index); + arguments = get_arguments(problem_shape, hw_info, buffer_index, options); status = op.update(arguments, workspace.get()); if (status != cutlass::Status::kSuccess) { std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: " @@ -814,7 +829,7 @@ struct FwdRunner { return example_result; } buffer_index = (buffer_index + 1) % buffers.size(); - arguments = get_arguments(problem_shape, hw_info, buffer_index); + arguments = get_arguments(problem_shape, hw_info, buffer_index, options); status = op.update(arguments, workspace.get()); if (status != cutlass::Status::kSuccess) { std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: " @@ -866,7 +881,30 @@ struct FwdRunner { flops *= static_cast(size<1>(problem_shape)); flops *= static_cast(size<3,1>(problem_shape)); } - flops *= 4.0 * (std::is_same_v> || std::is_same_v> ? 0.5 : 1.0); + + double flops_ratio = 1.0; + if (std::is_same_v> || std::is_same_v>) { + flops_ratio = 0.5; + } + if (std::is_same_v> || std::is_same_v>) { + // For regular sequences + int seqlen_q = size<0>(problem_shape); + int seqlen_k = size<1>(problem_shape); + + double total_valid_pairs = 0.0; + for (int row_idx = 0; row_idx < seqlen_q; row_idx++) { + int col_left = std::max(row_idx - options.window_size_left, 0); + int col_right = std::min(row_idx + options.window_size_right, seqlen_k - 1); + // Valid positions in this row + if (col_right >= col_left) { + total_valid_pairs += (col_right - col_left + 1); + } + } + double total_positions = static_cast(seqlen_q) * static_cast(seqlen_k); + flops_ratio = (total_positions > 0) ? (total_valid_pairs / total_positions) : 1.0; + flops_ratio = std::min(flops_ratio, 1.0); + } + flops *= 4.0 * flops_ratio; flops *= static_cast(size<2>(problem_shape)); flops *= static_cast(size<3,0>(problem_shape)); double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/); @@ -883,7 +921,7 @@ struct FwdRunner { // Verify that the result is correct bool passed = true; if (options.verify) { - passed = verify(problem_shape, *buffers[0]); + passed = verify(problem_shape, *buffers[0], options); if (passed) example_result.verified = true; } @@ -935,7 +973,7 @@ void run_fwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareIn auto result = runner.run(options, hw_info); print_result(name, result, options.verbose); } - else + else { FwdRunner runner; auto result = runner.run(options, hw_info); @@ -968,7 +1006,7 @@ void run_fwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInf auto result = runner.run(options, hw_info); print_result(name, result, options.verbose); } - else + else { FwdRunner runner; auto result = runner.run(options, hw_info); @@ -1083,6 +1121,7 @@ int main_single(int argc, char const **args) { std::cout << "###### B " << options.b << " H " << options.h << " H_K " << options.h_k << " Q " << options.q << " K " << options.k << " D " << options.d << " "; std::cout << "Forward" << " " << (options.causal ? "Causal" : (options.residual ? "Residual" : "None")) << " "; + std::cout << (options.local ? ("Local with window size " + std::to_string(options.window_size_left) + " " + std::to_string(options.window_size_right)) : "Not local") << " "; std::cout << "#SM " << hw_info.sm_count << std::endl; auto with_mask = [&](auto fn) { @@ -1096,6 +1135,16 @@ int main_single(int argc, char const **args) { else if (options.residual) { fn(ResidualMask{}); } + else if (options.local) { + if (options.window_size_left == -1 || options.window_size_right == -1) { + throw std::runtime_error("Error: --window_size_left and --window_size_right must be set for local attention."); + } + if(options.causal_q_begin) { + fn(LocalMask{}); + } else { + fn(LocalMask{}); + } + } else { fn(NoMask{}); } From c51c3808e7e86abbc55826c6c57c242fc3b95330 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Tue, 23 Sep 2025 18:00:31 -0700 Subject: [PATCH 09/12] forward is fine --- .../77_blackwell_fmha/77_blackwell_fmha.cu | 48 +++++++++++++++---- examples/77_blackwell_fmha/CMakeLists.txt | 34 ++++++++++++- run.sh | 38 ++++++++++----- 3 files changed, 98 insertions(+), 22 deletions(-) diff --git a/examples/77_blackwell_fmha/77_blackwell_fmha.cu b/examples/77_blackwell_fmha/77_blackwell_fmha.cu index 7de1ebc865..1d2ca139bc 100644 --- a/examples/77_blackwell_fmha/77_blackwell_fmha.cu +++ b/examples/77_blackwell_fmha/77_blackwell_fmha.cu @@ -301,6 +301,14 @@ struct Options { local = true; residual = false; causal = false; + if(causal_type == "qend") { + causal_q_begin = false; + } else { + causal_q_begin = true; + } + if (varlen) { + residual = true; + } } cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count); @@ -732,13 +740,33 @@ struct FwdRunner { get<1>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_kv.get(); } typename Operation::Arguments arguments{ - problem_shape_, - { buffers[buffer_index]->block_Q.get(), stride_Q, - buffers[buffer_index]->block_K.get(), stride_K, - buffers[buffer_index]->block_V.get(), stride_V }, - { buffers[buffer_index]->block_O.get(), stride_O, - buffers[buffer_index]->block_LSE.get(), stride_LSE }, - hw_info + problem_shape_, // 1st field: Problem dimensions + + // 2nd field: Mainloop arguments - input tensor data and scaling parameters + { + // Nested Load arguments for tensor pointers and strides + { buffers[buffer_index]->block_Q.get(), stride_Q, // Query tensor pointer and stride + buffers[buffer_index]->block_K.get(), stride_K, // Key tensor pointer and stride + buffers[buffer_index]->block_V.get(), stride_V, // Value tensor pointer and stride + options.window_size_left, // window_size_left: for local attention + options.window_size_right // window_size_right: for local attention + }, + + // Scaling parameters for attention computation + 0.0f, // scale_softmax: 0.0f means use default 1/sqrt(D) + 1.0f, // scale_q: scaling factor for Q tensor dequantization + 1.0f, // scale_k: scaling factor for K tensor dequantization + 1.0f, // scale_v: scaling factor for V tensor dequantization + 1.0f, // inv_scale_o: inverse scaling factor for O tensor quantization + options.window_size_left, // window_size_left: for local attention + options.window_size_right // window_size_right: for local attention + }, + + // 3rd field: Epilogue arguments - output tensors O, LSE with their memory pointers and strides + { buffers[buffer_index]->block_O.get(), stride_O, // Output tensor pointer and stride + buffers[buffer_index]->block_LSE.get(), stride_LSE },// Log-sum-exp tensor pointer and stride + + hw_info // 4th field: Hardware info (SM count, etc.) }; return arguments; } @@ -1132,9 +1160,6 @@ int main_single(int argc, char const **args) { fn(CausalMask{}); } } - else if (options.residual) { - fn(ResidualMask{}); - } else if (options.local) { if (options.window_size_left == -1 || options.window_size_right == -1) { throw std::runtime_error("Error: --window_size_left and --window_size_right must be set for local attention."); @@ -1145,6 +1170,9 @@ int main_single(int argc, char const **args) { fn(LocalMask{}); } } + else if (options.residual) { + fn(ResidualMask{}); + } else { fn(NoMask{}); } diff --git a/examples/77_blackwell_fmha/CMakeLists.txt b/examples/77_blackwell_fmha/CMakeLists.txt index 65034d3d8e..3127998a73 100644 --- a/examples/77_blackwell_fmha/CMakeLists.txt +++ b/examples/77_blackwell_fmha/CMakeLists.txt @@ -41,6 +41,7 @@ set_property( set(TEST_BASIC --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=no) set(TEST_CAUSAL_00 --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal) set(TEST_CAUSAL_01 --verify --iterations=0 --b=1 --h=1 --h_k=1 --q=1013 --k=1024 --d=128 --mask=causal --causal-type=qend) +set(TEST_LOCAL --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=local --window_size_right=512 --window_size_left=512) set(TEST_VARLEN --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=residual --varlen) set(TEST_HDIM64 --b=2 --h=4 --q=512 --k=512 --d=64 --verify) set(TEST_GQA --b=2 --h=4 --h_k=2 --q=512 --k=512 --d=64 --verify) @@ -69,7 +70,16 @@ set(TEST_VARLEN_20 --verify --varlen --mask=causal --causal-type=qend --d=128 -- set(TEST_VARLEN_21 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=1013 --varlen-k=1024) set(TEST_VARLEN_22 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=1024 --varlen-k=1035) - +set(TEST_LOCAL_00 --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=local --window_size_right=128 --window_size_left=0) +set(TEST_LOCAL_01 --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=local --window_size_right=128 --window_size_left=128) +set(TEST_LOCAL_02 --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=local --window_size_right=0 --window_size_left=0) +set(TEST_LOCAL_03 --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=local --window_size_right=32 --window_size_left=0) +set(TEST_LOCAL_04 --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=local --causal-type=qend --window_size_right=0 --window_size_left=32) +set(TEST_LOCAL_05 --b=1 --h=4 --d=128 --verify --varlen --mask=local --window_size_right=0 --window_size_left=128 --varlen-q=128 --varlen-k=128) +set(TEST_LOCAL_06 --b=1 --h=4 --d=128 --verify --varlen --mask=local --window_size_right=0 --window_size_left=128 --varlen-q=17 --varlen-k=257) +set(TEST_LOCAL_07 --h=4 --d=128 --verify --varlen --mask=local --window_size_right=0 --window_size_left=128 --varlen-q=100:300 --varlen-k=100:300) +set(TEST_LOCAL_08 --h=4 --d=128 --verify --varlen --mask=local --window_size_right=0 --window_size_left=128 --varlen-q=177:366:479 --varlen-k=257:0:766) +set(TEST_LOCAL_09 --h=4 --d=128 --verify --varlen --mask=local --causal-type=qend --window_size_right=0 --window_size_left=128 --varlen-q=177:366:479 --varlen-k=257:0:766) set(TEST_MLA_FWD_VARLEN_00 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=4 --varlen-q=128 --varlen-k=128) set(TEST_MLA_FWD_VARLEN_01 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128) @@ -122,6 +132,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC TEST_BASIC TEST_CAUSAL_00 TEST_CAUSAL_01 + TEST_LOCAL TEST_VARLEN TEST_HDIM64 TEST_GQA @@ -148,6 +159,16 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC TEST_VARLEN_20 TEST_VARLEN_21 TEST_VARLEN_22 + TEST_LOCAL_00 + TEST_LOCAL_01 + TEST_LOCAL_02 + TEST_LOCAL_03 + TEST_LOCAL_04 + TEST_LOCAL_05 + TEST_LOCAL_06 + TEST_LOCAL_07 + TEST_LOCAL_08 + TEST_LOCAL_09 ) target_include_directories(77_blackwell_fmha_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) target_compile_definitions(77_blackwell_fmha_${PREC} PRIVATE ${PREC_MACRO}) @@ -195,6 +216,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC 77_blackwell_fmha_bwd.cu TEST_COMMAND_OPTIONS TEST_BASIC + TEST_LOCAL TEST_VARLEN # NOTE: bwd doesn't support GQA yet, --h_k will just get ignored in these tests TEST_VARLEN_00 @@ -214,6 +236,16 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC TEST_VARLEN_14 TEST_BWD_MLA_BASIC TEST_BWD_MLA_VARLEN + TEST_LOCAL_00 + TEST_LOCAL_01 + TEST_LOCAL_02 + TEST_LOCAL_03 + TEST_LOCAL_04 + TEST_LOCAL_05 + TEST_LOCAL_06 + TEST_LOCAL_07 + TEST_LOCAL_08 + TEST_LOCAL_09 ) target_include_directories(77_blackwell_fmha_bwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) target_compile_definitions(77_blackwell_fmha_bwd_${PREC} PRIVATE ${PREC_MACRO}) diff --git a/run.sh b/run.sh index 1a8ef14b48..df906ab22e 100644 --- a/run.sh +++ b/run.sh @@ -4,7 +4,7 @@ mkdir -p ~/cutlass/build cd ~/cutlass/build export CUDA_VISIBLE_DEVICES=7 -# export REF_PRINT_DIFF=1 +export REF_PRINT_DIFF=1 cmake .. -DCUTLASS_NVCC_ARCHS=100a @@ -13,22 +13,38 @@ start_time=$(date +%s) echo "E2E Test Run Started at: $(date)" | tee -a ~/cutlass/output.log targets=( - test_examples_77_blackwell_fmha_bwd_fp16_test_basic - test_examples_77_blackwell_fmha_bwd_fp16_test_varlen + # test_examples_77_blackwell_fmha_fp16_test_causal + # test_examples_77_blackwell_fmha_fp16_test_varlen + # test_examples_77_blackwell_fmha_fp16_test_local + # test_examples_77_blackwell_fmha_fp16_test_local_00 + # test_examples_77_blackwell_fmha_fp16_test_local_01 + # test_examples_77_blackwell_fmha_fp16_test_local_02 + # test_examples_77_blackwell_fmha_fp16_test_local_03 + # test_examples_77_blackwell_fmha_fp16_test_local_04 + # test_examples_77_blackwell_fmha_fp16_test_local_05 + # test_examples_77_blackwell_fmha_fp16_test_local_06 + # test_examples_77_blackwell_fmha_fp16_test_local_07 + # test_examples_77_blackwell_fmha_fp16_test_local_08 + # test_examples_77_blackwell_fmha_fp16_test_local_09 + # test_examples_77_blackwell_fmha_bwd_fp16_test_causal + # test_examples_77_blackwell_fmha_bwd_fp16_test_varlen + # test_examples_77_blackwell_fmha_bwd_fp16_test_local + # test_examples_77_blackwell_fmha_bwd_fp16_test_local_00 + # test_examples_77_blackwell_fmha_bwd_fp16_test_local_01 + # test_examples_77_blackwell_fmha_bwd_fp16_test_local_02 + # test_examples_77_blackwell_fmha_bwd_fp16_test_local_03 + # test_examples_77_blackwell_fmha_bwd_fp16_test_local_04 + # test_examples_77_blackwell_fmha_bwd_fp16_test_local_05 + # test_examples_77_blackwell_fmha_bwd_fp16_test_local_06 + # test_examples_77_blackwell_fmha_bwd_fp16_test_local_07 + # test_examples_77_blackwell_fmha_bwd_fp16_test_local_08 + # test_examples_77_blackwell_fmha_bwd_fp16_test_local_09 ) for test in "${targets[@]}" do echo "Running $test" make $test 2>&1 | tee -a ~/cutlass/output.log - echo "Running compute sanitizer $test" - compute-sanitizer make $test 2>&1 | tee -a ~/cutlass/output.log - echo "Running compute sanitizer memcheck $test" - compute-sanitizer --tool=memcheck make $test 2>&1 | tee -a ~/cutlass/output.log - echo "Running compute sanitizer racecheck $test" - compute-sanitizer --tool=racecheck make $test 2>&1 | tee -a ~/cutlass/output.log - echo "Running compute sanitizer synccheck $test" - compute-sanitizer --tool=synccheck make $test 2>&1 | tee -a ~/cutlass/output.log done # Record end time and calculate e2e duration From ebb7eeec451625fd0ab57af5eda386cfd95ffb83 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Tue, 23 Sep 2025 19:05:26 -0700 Subject: [PATCH 10/12] test backward --- .../77_blackwell_fmha/77_blackwell_fmha.cu | 4 +- .../77_blackwell_fmha_bwd.cu | 47 ++++++++++++++----- examples/77_blackwell_fmha/CMakeLists.txt | 1 - ...mha_bwd_mla_kernel_tma_warpspecialized.hpp | 9 ++-- run.sh | 26 +++++----- 5 files changed, 55 insertions(+), 32 deletions(-) diff --git a/examples/77_blackwell_fmha/77_blackwell_fmha.cu b/examples/77_blackwell_fmha/77_blackwell_fmha.cu index 1d2ca139bc..3ce0e91ade 100644 --- a/examples/77_blackwell_fmha/77_blackwell_fmha.cu +++ b/examples/77_blackwell_fmha/77_blackwell_fmha.cu @@ -341,9 +341,11 @@ struct Options { << " --tensor_ring_buffers= Sets the number of tensor ring buffers\n" << " --warmup_iterations= Sets the warmup iterations\n" << " --iterations= Benchmarking iterations\n" + << " --window_size_left= Window size left for local attention\n" + << " --window_size_right= Window size right for local attention\n" << " --verify Verify results\n" << " --verbose Print smem and execution time per kernel\n" - << " --mask= Enables masking\n" + << " --mask= Enables masking\n" << " --causal-type= Causal mask type\n" << " --persistent Enables persistent scheduler\n" << " --varlen Enables variable sequence length\n" diff --git a/examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu b/examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu index 4521d87faf..e67733d326 100644 --- a/examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu +++ b/examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu @@ -122,6 +122,10 @@ struct Options { bool verify = false; bool verbose = false; + int window_size_left = -1; + int window_size_right = -1; + bool local = false; + bool causal = false; bool residual = false; bool varlen = false; @@ -259,6 +263,10 @@ struct Options { if (b == 0) b = 1; cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations); + + cmd.get_cmd_line_argument("window_size_left", window_size_left, defaults.window_size_left); + cmd.get_cmd_line_argument("window_size_right", window_size_right, defaults.window_size_right); + verify = cmd.check_cmd_line_flag("verify"); verbose = cmd.check_cmd_line_flag("verbose"); std::string mask; @@ -269,6 +277,9 @@ struct Options { else if (mask == "residual") { residual = true; } + else if (mask == "local") { + local = true; + } else { causal = defaults.causal; } @@ -309,9 +320,11 @@ struct Options { << " --d= Sets the D extent\n" << " --d_vo= Sets the D_VO extent\n" << " --iterations= Benchmarking iterations\n" + << " --window_size_left= Window size left for local attention\n" + << " --window_size_right= Window size right for local attention\n" << " --verify Verify results\n" << " --verbose Print smem and execution time per kernel\n" - << " --mask= Enables masking\n" + << " --mask= Enables masking\n" << " --varlen Enables variable sequence length\n" << " B*Q and B*K become the total sequence length\n" << " and are split B-ways, alternatingly +10% and -10%\n" @@ -415,7 +428,7 @@ struct BwdRunner { cute::tuple, int>>, cute::tuple, int>> >; - + using StrideQ = Stride, int>>; // Q D ((H_R, H_K), B) using StrideK = Stride, int>>; // K D ((H_R, H_K), B) using StrideV = StrideK; // K D_VO ((H_R, H_K), B) @@ -467,7 +480,7 @@ struct BwdRunner { // // Methods // - bool verify(const ProblemShape& problem_shape) { + bool verify(const ProblemShape& problem_shape, const Options& options) { auto [Q, K, D, D_VO, HB] = problem_shape; auto [H, B] = HB; @@ -481,7 +494,7 @@ struct BwdRunner { Tensor mDV = make_tensor(make_gmem_ptr(block_ref_dV.get()), make_shape(K, D_VO, HB), stride_dV); Tensor mDO = make_tensor(make_gmem_ptr(block_dO.get()), make_shape(Q, D_VO, HB), stride_dO); - fmha_bwd_reference(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, mDK, mDV, ActiveMask{}); + fmha_bwd_reference(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, mDK, mDV, ActiveMask(options.window_size_left, options.window_size_right)); cudaError_t result = cudaDeviceSynchronize(); if (result != cudaSuccess) { @@ -500,7 +513,7 @@ struct BwdRunner { bool passed_dQ = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); if (! passed_dQ) { - std::cerr << "failed dQ: max diff " << max_diff + std::cerr << "failed dQ: max diff " << max_diff << " mean " << mean_diff << std::endl; } @@ -508,7 +521,7 @@ struct BwdRunner { bool passed_dK = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); if (! passed_dK) { - std::cerr << "failed dK: max diff " << max_diff + std::cerr << "failed dK: max diff " << max_diff << " mean " << mean_diff << std::endl; } @@ -516,7 +529,7 @@ struct BwdRunner { bool passed_dV = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); if (! passed_dV) { - std::cerr << "failed dV: max diff " << max_diff + std::cerr << "failed dV: max diff " << max_diff << " mean " << mean_diff << std::endl; } @@ -532,7 +545,7 @@ struct BwdRunner { // generate Q as --b times // gaussian (--Q, --Q / 2) sampled positive - // track cumulative + // track cumulative std::mt19937 rng(0x202305151552ull); std::normal_distribution dist_q(options.q, options.q / 2); std::normal_distribution dist_kv(options.k, options.k / 2); @@ -552,7 +565,7 @@ struct BwdRunner { const bool kVarlenSame = false; for (int i = 0; i < num_batches; i++) { - int seqlen_q = (! options.varlen_q.empty()) ? options.varlen_q.at(i) : + int seqlen_q = (! options.varlen_q.empty()) ? options.varlen_q.at(i) : kVarlenSame ? options.q : generate_positive_int(dist_q, rng); int seqlen_kv = (! options.varlen_k.empty()) ? options.varlen_k.at(i) : @@ -672,7 +685,7 @@ struct BwdRunner { stride_LSE); if (not options.skip_reference) { - fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{}); + fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask(options.window_size_left, options.window_size_right)); } return problem_shape; @@ -699,6 +712,8 @@ struct BwdRunner { block_dK.get(), stride_dK, block_dV.get(), stride_dV, softmax_scale, + options.window_size_left, + options.window_size_right, hw_info }; @@ -819,10 +834,10 @@ struct BwdRunner { // Verify that the result is correct bool passed = true; if (options.verify) { - passed = verify(problem_shape); + passed = verify(problem_shape, options); if (passed) example_result.verified = true; } - + if (!passed) { std::cerr << "Reference check failed" << std::endl; return example_result; @@ -945,7 +960,7 @@ int main_single(int argc, char const **args) { << "(compute capability 100a) and CUDA 12.8 or greater.\n"; return 0; } - + // // Parse options // @@ -992,6 +1007,12 @@ int main_single(int argc, char const **args) { if (options.causal) { fn(CausalForBackwardMask{}); } + else if (options.local) { + if (options.window_size_left == -1 || options.window_size_right == -1) { + throw std::runtime_error("Error: --window_size_left and --window_size_right must be set for local attention."); + } + fn(LocalMaskForBackward{}); + } else if (options.residual) { fn(ResidualMaskForBackward{}); } diff --git a/examples/77_blackwell_fmha/CMakeLists.txt b/examples/77_blackwell_fmha/CMakeLists.txt index 3127998a73..ac7137c783 100644 --- a/examples/77_blackwell_fmha/CMakeLists.txt +++ b/examples/77_blackwell_fmha/CMakeLists.txt @@ -245,7 +245,6 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC TEST_LOCAL_06 TEST_LOCAL_07 TEST_LOCAL_08 - TEST_LOCAL_09 ) target_include_directories(77_blackwell_fmha_bwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) target_compile_definitions(77_blackwell_fmha_bwd_${PREC} PRIVATE ${PREC_MACRO}) diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp index bf72843a43..70c3f608ab 100644 --- a/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp @@ -307,6 +307,9 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized { TensorStride stride_dq_acc; ElementAcc softmax_scale = 1.0f / sqrtf(TileShapeDQK{}); + + int window_size_left = -1; + int window_size_right = -1; }; using TMA_K = typename CollectiveMmaQK::Params::TMA_B; @@ -1001,7 +1004,7 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized { make_coord(blk_coord_k * TileShapeK{}, _0{}), make_identity_tensor(take<0,2>(TileShapePDO{})) ); - + for (int i = threadIdx.x; i < size(gDK); i += blockDim.x) { if (elem_less(cDK(i), select<1,2>(problem_shape))) { gDK(i) = Element(0); @@ -1278,7 +1281,7 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized { } auto tRT_rST = quantize(tTR_rST); - + Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{}) (_, _, _, pipeline_compute_mma_p_producer_state.index()); @@ -1482,7 +1485,7 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized { CUTLASS_DEVICE void operator()(Params const& params, char* smem) { #if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED)) printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); -#else +#else int warp_idx = cutlass::canonical_warp_idx_sync(); auto role = warp_idx_to_role(warp_idx); uint32_t lane_predicate = cute::elect_one_sync(); diff --git a/run.sh b/run.sh index df906ab22e..c4a4754ea4 100644 --- a/run.sh +++ b/run.sh @@ -4,7 +4,7 @@ mkdir -p ~/cutlass/build cd ~/cutlass/build export CUDA_VISIBLE_DEVICES=7 -export REF_PRINT_DIFF=1 +# export REF_PRINT_DIFF=1 cmake .. -DCUTLASS_NVCC_ARCHS=100a @@ -26,19 +26,17 @@ targets=( # test_examples_77_blackwell_fmha_fp16_test_local_07 # test_examples_77_blackwell_fmha_fp16_test_local_08 # test_examples_77_blackwell_fmha_fp16_test_local_09 - # test_examples_77_blackwell_fmha_bwd_fp16_test_causal - # test_examples_77_blackwell_fmha_bwd_fp16_test_varlen - # test_examples_77_blackwell_fmha_bwd_fp16_test_local - # test_examples_77_blackwell_fmha_bwd_fp16_test_local_00 - # test_examples_77_blackwell_fmha_bwd_fp16_test_local_01 - # test_examples_77_blackwell_fmha_bwd_fp16_test_local_02 - # test_examples_77_blackwell_fmha_bwd_fp16_test_local_03 - # test_examples_77_blackwell_fmha_bwd_fp16_test_local_04 - # test_examples_77_blackwell_fmha_bwd_fp16_test_local_05 - # test_examples_77_blackwell_fmha_bwd_fp16_test_local_06 - # test_examples_77_blackwell_fmha_bwd_fp16_test_local_07 - # test_examples_77_blackwell_fmha_bwd_fp16_test_local_08 - # test_examples_77_blackwell_fmha_bwd_fp16_test_local_09 + test_examples_77_blackwell_fmha_bwd_fp16_test_varlen + test_examples_77_blackwell_fmha_bwd_fp16_test_local + test_examples_77_blackwell_fmha_bwd_fp16_test_local_00 + test_examples_77_blackwell_fmha_bwd_fp16_test_local_01 + test_examples_77_blackwell_fmha_bwd_fp16_test_local_02 + test_examples_77_blackwell_fmha_bwd_fp16_test_local_03 + test_examples_77_blackwell_fmha_bwd_fp16_test_local_04 + test_examples_77_blackwell_fmha_bwd_fp16_test_local_05 + test_examples_77_blackwell_fmha_bwd_fp16_test_local_06 + test_examples_77_blackwell_fmha_bwd_fp16_test_local_07 + test_examples_77_blackwell_fmha_bwd_fp16_test_local_08 ) for test in "${targets[@]}" From b86b5cb5a43ff567631e0bdc3bb2adab946a7015 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Tue, 23 Sep 2025 22:26:35 -0700 Subject: [PATCH 11/12] remove run script --- run.sh | 56 -------------------------------------------------------- 1 file changed, 56 deletions(-) delete mode 100644 run.sh diff --git a/run.sh b/run.sh deleted file mode 100644 index c4a4754ea4..0000000000 --- a/run.sh +++ /dev/null @@ -1,56 +0,0 @@ -rm ~/cutlass/output.log -# rm -rf ~/cutlass/build -mkdir -p ~/cutlass/build -cd ~/cutlass/build - -export CUDA_VISIBLE_DEVICES=7 -# export REF_PRINT_DIFF=1 - -cmake .. -DCUTLASS_NVCC_ARCHS=100a - -# Record start time for e2e timing -start_time=$(date +%s) -echo "E2E Test Run Started at: $(date)" | tee -a ~/cutlass/output.log - -targets=( - # test_examples_77_blackwell_fmha_fp16_test_causal - # test_examples_77_blackwell_fmha_fp16_test_varlen - # test_examples_77_blackwell_fmha_fp16_test_local - # test_examples_77_blackwell_fmha_fp16_test_local_00 - # test_examples_77_blackwell_fmha_fp16_test_local_01 - # test_examples_77_blackwell_fmha_fp16_test_local_02 - # test_examples_77_blackwell_fmha_fp16_test_local_03 - # test_examples_77_blackwell_fmha_fp16_test_local_04 - # test_examples_77_blackwell_fmha_fp16_test_local_05 - # test_examples_77_blackwell_fmha_fp16_test_local_06 - # test_examples_77_blackwell_fmha_fp16_test_local_07 - # test_examples_77_blackwell_fmha_fp16_test_local_08 - # test_examples_77_blackwell_fmha_fp16_test_local_09 - test_examples_77_blackwell_fmha_bwd_fp16_test_varlen - test_examples_77_blackwell_fmha_bwd_fp16_test_local - test_examples_77_blackwell_fmha_bwd_fp16_test_local_00 - test_examples_77_blackwell_fmha_bwd_fp16_test_local_01 - test_examples_77_blackwell_fmha_bwd_fp16_test_local_02 - test_examples_77_blackwell_fmha_bwd_fp16_test_local_03 - test_examples_77_blackwell_fmha_bwd_fp16_test_local_04 - test_examples_77_blackwell_fmha_bwd_fp16_test_local_05 - test_examples_77_blackwell_fmha_bwd_fp16_test_local_06 - test_examples_77_blackwell_fmha_bwd_fp16_test_local_07 - test_examples_77_blackwell_fmha_bwd_fp16_test_local_08 -) - -for test in "${targets[@]}" -do - echo "Running $test" - make $test 2>&1 | tee -a ~/cutlass/output.log -done - -# Record end time and calculate e2e duration -end_time=$(date +%s) -duration=$((end_time - start_time)) - -echo "E2E Test Run Completed at: $(date)" | tee -a ~/cutlass/output.log -echo "Total E2E Duration: ${duration} seconds" | tee -a ~/cutlass/output.log - -unset CUDA_VISIBLE_DEVICES -unset REF_PRINT_DIFF From 87db76d462fb2f9c37b8330d4ab33abbd0cee051 Mon Sep 17 00:00:00 2001 From: henrylhtsang Date: Tue, 23 Sep 2025 22:26:41 -0700 Subject: [PATCH 12/12] Revert "ignore build" This reverts commit 1908dd839fc20dc3bdd472f695475b79df66cc68. --- .gitignore | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 7efb940e79..e7a026874b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,4 @@ # PyCache files __pycache__/ cutlass_library.egg-info/ -build* -*.log +/build*