Skip to content

Commit 51210b8

Browse files
henrylhtsangmeta-codesync[bot]
authored andcommitted
Bring 4.2.1 changes to FBGEMM blackwell cutlass fmha (#5052)
Summary: Pull Request resolved: #5052 X-link: https://github.com/facebookresearch/FBGEMM/pull/2061 As titled. Mostly grabbing changes from 4.2.1. This also remove changes from D84954166 There is one thing that is TBD, that is how to incorporate changes similar to that of D79534034 that they also made upstream. I prefer to stay close to upstream, but I tested it and it would fail. So settling for Aya-ZIbra D84970563 for now. Reviewed By: Aya-ZIbra Differential Revision: D84961921 fbshipit-source-id: ff3ab1746951d3b91126d7a1b75e4d7cf2c92b69
1 parent 9b5af57 commit 51210b8

16 files changed

+233
-99
lines changed

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/77_blackwell_fmha.cu

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -426,16 +426,16 @@ struct FwdRunner {
426426
using ElementOut = cutlass::half_t;
427427
#endif
428428

429-
// Q K D (B H)
429+
// Q K D ((H_R, H_K) B)
430430
using ProblemShapeRegular = cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>;
431431
using ProblemShapeVarlen = cute::tuple<VariableLength, VariableLength, int, cute::tuple<cute::tuple<int, int>, int>>;
432432
using ProblemShapeType = std::conditional_t<kIsVarlen, ProblemShapeVarlen, ProblemShapeRegular>;
433433

434-
using StrideQ = cute::tuple<int, _1, cute::tuple<cute::tuple<int, int>, int>>; // Q D (H_G H_R B)
435-
using StrideK = cute::tuple<int, _1, cute::tuple<cute::tuple<_0, int>, int>>; // K D (H_G H_R B)
434+
using StrideQ = cute::tuple<int, _1, cute::tuple<cute::tuple<int, int>, int>>; // Q D ((H_R, H_K), B)
435+
using StrideK = cute::tuple<int, _1, cute::tuple<cute::tuple<_0, int>, int>>; // K D ((H_R, H_K), B)
436436
using StrideV = StrideK;
437437
using StrideO = StrideQ;
438-
using StrideLSE = cute::tuple<_1, cute::tuple<cute::tuple<int, int>, int>>; // Q (H_G H_R B)
438+
using StrideLSE = cute::tuple<_1, cute::tuple<cute::tuple<int, int>, int>>; // Q ((H_R, H_K), B)
439439

440440
static constexpr bool kIsPersistent = find_option_t<Tag::kIsPersistent, true_type, KernelOptions...>::value;
441441
using TileScheduler = std::conditional_t<kIsPersistent, cutlass::fmha::kernel::PersistentTileScheduler, cutlass::fmha::kernel::IndividualTileScheduler>;
@@ -618,8 +618,8 @@ struct FwdRunner {
618618

619619
ProblemShapeType problem_size_for_launch;
620620

621-
get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q};
622-
get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv};
621+
get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q, nullptr, total_seqlen_q};
622+
get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv, nullptr, total_seqlen_kv};
623623
get<2>(problem_size_for_launch) = get<2>(problem_size);
624624
get<3>(problem_size_for_launch) = get<3>(problem_size);
625625

@@ -676,9 +676,9 @@ struct FwdRunner {
676676
}
677677

678678
auto buffer_init_fn = [&](auto& buffer) {
679-
buffer.block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
680-
buffer.block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
681-
buffer.block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
679+
buffer.block_Q.reset(size(shape_QO));
680+
buffer.block_K.reset(size(shape_KV));
681+
buffer.block_V.reset(size(shape_KV));
682682
buffer.block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
683683
buffer.block_LSE.reset(size(shape_LSE));
684684
buffer.block_ref_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
@@ -725,6 +725,7 @@ struct FwdRunner {
725725
}
726726
typename Operation::Arguments arguments{
727727
problem_shape_,
728+
// local changes
728729
nullptr,
729730
{{ buffers[buffer_index]->block_Q.get(), stride_Q,
730731
buffers[buffer_index]->block_K.get(), stride_K,

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/77_blackwell_mla_fwd.cu

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -591,8 +591,8 @@ struct MlaFwdRunner {
591591

592592
ProblemShapeType problem_size_for_launch;
593593

594-
get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q};
595-
get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv};
594+
get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q, nullptr, total_seqlen_q};
595+
get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv, nullptr, total_seqlen_kv};
596596
get<2>(problem_size_for_launch) = get<2>(problem_size);
597597
get<3>(problem_size_for_launch) = get<3>(problem_size);
598598

@@ -652,9 +652,9 @@ struct MlaFwdRunner {
652652
}
653653

654654
auto buffer_init_fn = [&](auto& buffer) {
655-
buffer.block_Q.reset(size(shape_Q), kIsVarlen ? D_latent_rope*SQ*H : 0);
656-
buffer.block_K.reset(size(shape_K), kIsVarlen ? D_latent_rope*SK*H_K : 0);
657-
buffer.block_V.reset(size(shape_V), kIsVarlen ? D*SK*H_K : 0);
655+
buffer.block_Q.reset(size(shape_Q));
656+
buffer.block_K.reset(size(shape_K));
657+
buffer.block_V.reset(size(shape_V));
658658
buffer.block_O.reset(size(shape_O), kIsVarlen ? D*SQ*H : 0);
659659
buffer.block_LSE.reset(size(shape_LSE));
660660
buffer.block_ref_O.reset(size(shape_O), kIsVarlen ? D*SQ*H : 0);
@@ -850,7 +850,8 @@ struct MlaFwdRunner {
850850
flops *= static_cast<double>(size<3,1>(problem_shape));
851851
}
852852

853-
flops *= 2.0 * (std::is_same_v<ActiveMask, CausalMask<false>> ? 0.5 : 1.0);
853+
flops *= 2.0 * (std::is_same_v<ActiveMask, CausalMask<false>> ||
854+
std::is_same_v<ActiveMask, CausalMask<true>> ? 0.5 : 1.0);
854855
flops *= static_cast<double>(size<3,0>(problem_shape));
855856

856857
double flops0 = flops * static_cast<double>(size<2, 0>(problem_shape) + size<2, 1>(problem_shape));

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/CMakeLists.txt

Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,14 @@ set_property(
3333
77_blackwell_fmha_gen.cu
3434
77_blackwell_mla.cu
3535
77_blackwell_fmha_bwd.cu
36+
77_blackwell_mla_fwd.cu
3637
PROPERTY
3738
COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0"
3839
)
3940

4041
set(TEST_BASIC --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=no)
41-
set(TEST_CAUSAL --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal)
42+
set(TEST_CAUSAL_00 --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal)
43+
set(TEST_CAUSAL_01 --verify --iterations=0 --b=1 --h=1 --h_k=1 --q=1013 --k=1024 --d=128 --mask=causal --causal-type=qend)
4244
set(TEST_VARLEN --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=residual --varlen)
4345
set(TEST_HDIM64 --b=2 --h=4 --q=512 --k=512 --d=64 --verify)
4446
set(TEST_GQA --b=2 --h=4 --h_k=2 --q=512 --k=512 --d=64 --verify)
@@ -58,6 +60,41 @@ set(TEST_VARLEN_11 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=2
5860
set(TEST_VARLEN_12 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=177:845 --varlen-k=257:766)
5961
set(TEST_VARLEN_13 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=2 --varlen-q=177:366:479 --varlen-k=257:0:766)
6062
set(TEST_VARLEN_14 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=1 --varlen-k=1)
63+
set(TEST_VARLEN_15 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
64+
set(TEST_VARLEN_16 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
65+
set(TEST_VARLEN_17 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
66+
set(TEST_VARLEN_18 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
67+
set(TEST_VARLEN_19 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
68+
set(TEST_VARLEN_20 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
69+
set(TEST_VARLEN_21 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=1013 --varlen-k=1024)
70+
set(TEST_VARLEN_22 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=1024 --varlen-k=1035)
71+
72+
73+
74+
set(TEST_MLA_FWD_VARLEN_00 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=4 --varlen-q=128 --varlen-k=128)
75+
set(TEST_MLA_FWD_VARLEN_01 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
76+
set(TEST_MLA_FWD_VARLEN_02 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=2 --varlen-q=128 --varlen-k=128)
77+
set(TEST_MLA_FWD_VARLEN_03 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=8 --varlen-q=256:256 --varlen-k=512:512)
78+
set(TEST_MLA_FWD_VARLEN_04 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=4 --varlen-q=256:256 --varlen-k=512:512)
79+
set(TEST_MLA_FWD_VARLEN_05 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=1 --varlen-q=256:256 --varlen-k=512:512)
80+
set(TEST_MLA_FWD_VARLEN_06 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:768:512:512)
81+
set(TEST_MLA_FWD_VARLEN_07 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:0:1280:512)
82+
set(TEST_MLA_FWD_VARLEN_08 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=2 --varlen-q=256:0:512:256 --varlen-k=256:256:1024:512)
83+
set(TEST_MLA_FWD_VARLEN_09 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=16 --h_k=16 --varlen-q=100:300 --varlen-k=100:300)
84+
set(TEST_MLA_FWD_VARLEN_10 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=2:3 --varlen-k=2:5)
85+
set(TEST_MLA_FWD_VARLEN_11 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=2 --varlen-q=11:10 --varlen-k=13:10)
86+
set(TEST_MLA_FWD_VARLEN_12 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=177:766 --varlen-k=257:845)
87+
set(TEST_MLA_FWD_VARLEN_13 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=2 --varlen-q=177:0:479 --varlen-k=257:0:766)
88+
set(TEST_MLA_FWD_VARLEN_14 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=1 --varlen-k=1)
89+
set(TEST_MLA_FWD_VARLEN_15 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
90+
set(TEST_MLA_FWD_VARLEN_16 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
91+
set(TEST_MLA_FWD_VARLEN_17 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
92+
set(TEST_MLA_FWD_VARLEN_18 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
93+
set(TEST_MLA_FWD_VARLEN_19 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
94+
set(TEST_MLA_FWD_VARLEN_20 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
95+
set(TEST_MLA_FWD_VARLEN_21 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=1013 --varlen-k=1024)
96+
set(TEST_MLA_FWD_VARLEN_22 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=1024 --varlen-k=1035)
97+
6198

6299
set(TEST_GEN_BASIC --b=1 --h=4 --k=512 --d=128 --verify)
63100
set(TEST_GEN_VARLEN --b=1 --h=4 --k=512 --d=128 --verify --varlen)
@@ -67,6 +104,11 @@ set(TEST_GEN_REMAP --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --remap)
67104
set(TEST_GEN_CACHEONLY --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --cache-only)
68105

69106
set(TEST_MLA_BASIC --b=1 --k=512 --page=128 --verify)
107+
set(TEST_BWD_MLA_BASIC --b=1 --h=4 --q=512 --k=512 --d=192 --d_vo=128 --verify --mask=no)
108+
set(TEST_BWD_MLA_VARLEN --b=1 --h=4 --q=512 --k=512 --d=192 --d_vo=128 --verify --mask=residual --varlen)
109+
110+
set(TEST_MLA_SEP_REDUCTION --b=1 --k=4096 --split_kv=8 --page=128 --verify)
111+
set(TEST_MLA_FUSE_REDUCTION --b=1 --k=4096 --split_kv=8 --page=128 --fuse_reduction --verify)
70112

71113
if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC_ARCHS MATCHES 100a))
72114

@@ -78,7 +120,8 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
78120
77_blackwell_fmha.cu
79121
TEST_COMMAND_OPTIONS
80122
TEST_BASIC
81-
TEST_CAUSAL
123+
TEST_CAUSAL_00
124+
TEST_CAUSAL_01
82125
TEST_VARLEN
83126
TEST_HDIM64
84127
TEST_GQA
@@ -97,6 +140,14 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
97140
TEST_VARLEN_12
98141
TEST_VARLEN_13
99142
TEST_VARLEN_14
143+
TEST_VARLEN_15
144+
TEST_VARLEN_16
145+
TEST_VARLEN_17
146+
TEST_VARLEN_18
147+
TEST_VARLEN_19
148+
TEST_VARLEN_20
149+
TEST_VARLEN_21
150+
TEST_VARLEN_22
100151
)
101152
target_include_directories(77_blackwell_fmha_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
102153
target_compile_definitions(77_blackwell_fmha_${PREC} PRIVATE ${PREC_MACRO})
@@ -120,6 +171,8 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
120171
77_blackwell_mla.cu
121172
TEST_COMMAND_OPTIONS
122173
TEST_MLA_BASIC
174+
TEST_MLA_SEP_REDUCTION
175+
TEST_MLA_FUSE_REDUCTION
123176
)
124177
target_include_directories(77_blackwell_mla_2sm_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
125178
target_compile_definitions(77_blackwell_mla_2sm_${PREC} PRIVATE ${PREC_MACRO})
@@ -130,6 +183,8 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
130183
77_blackwell_mla.cu
131184
TEST_COMMAND_OPTIONS
132185
TEST_MLA_BASIC
186+
TEST_MLA_SEP_REDUCTION
187+
TEST_MLA_FUSE_REDUCTION
133188
)
134189
target_include_directories(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
135190
target_compile_definitions(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${PREC_MACRO} CPASYNC)
@@ -157,10 +212,49 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
157212
TEST_VARLEN_12
158213
TEST_VARLEN_13
159214
TEST_VARLEN_14
215+
TEST_BWD_MLA_BASIC
216+
TEST_BWD_MLA_VARLEN
160217
)
161218
target_include_directories(77_blackwell_fmha_bwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
162219
target_compile_definitions(77_blackwell_fmha_bwd_${PREC} PRIVATE ${PREC_MACRO})
163220
target_compile_options(77_blackwell_fmha_bwd_${PREC} PRIVATE -Xptxas -v)
221+
222+
cutlass_example_add_executable(
223+
77_blackwell_mla_fwd_${PREC}
224+
77_blackwell_mla_fwd.cu
225+
TEST_COMMAND_OPTIONS
226+
TEST_BASIC
227+
TEST_CAUSAL_00
228+
TEST_VARLEN
229+
TEST_HDIM64
230+
TEST_GQA
231+
TEST_MLA_FWD_VARLEN_00
232+
TEST_MLA_FWD_VARLEN_01
233+
TEST_MLA_FWD_VARLEN_02
234+
TEST_MLA_FWD_VARLEN_03
235+
TEST_MLA_FWD_VARLEN_04
236+
TEST_MLA_FWD_VARLEN_05
237+
TEST_MLA_FWD_VARLEN_06
238+
TEST_MLA_FWD_VARLEN_07
239+
TEST_MLA_FWD_VARLEN_08
240+
TEST_MLA_FWD_VARLEN_09
241+
TEST_MLA_FWD_VARLEN_10
242+
TEST_MLA_FWD_VARLEN_11
243+
TEST_MLA_FWD_VARLEN_12
244+
TEST_MLA_FWD_VARLEN_13
245+
TEST_MLA_FWD_VARLEN_14
246+
TEST_MLA_FWD_VARLEN_15
247+
TEST_MLA_FWD_VARLEN_16
248+
TEST_MLA_FWD_VARLEN_17
249+
TEST_MLA_FWD_VARLEN_18
250+
TEST_MLA_FWD_VARLEN_19
251+
TEST_MLA_FWD_VARLEN_20
252+
TEST_MLA_FWD_VARLEN_21
253+
TEST_MLA_FWD_VARLEN_22
254+
)
255+
target_include_directories(77_blackwell_mla_fwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
256+
target_compile_definitions(77_blackwell_mla_fwd_${PREC} PRIVATE ${PREC_MACRO})
257+
target_compile_options(77_blackwell_mla_fwd_${PREC} PRIVATE -Xptxas -v)
164258
endforeach()
165259

166260
# Add a target that builds all examples
@@ -176,5 +270,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
176270
77_blackwell_mla_2sm_cpasync_fp16
177271
77_blackwell_fmha_bwd_fp8
178272
77_blackwell_fmha_bwd_fp16
273+
77_blackwell_mla_fwd_fp8
274+
77_blackwell_mla_fwd_fp16
179275
)
180276
endif()

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/README.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ For generation usage, use an M-blocking (Num-Groups) of 128 (although the limit
88

99
Context loads are done via TMA, whereas generation usage utilized `cp.async` and is thus more amenable to complex load patterns.
1010

11-
For variable sequence lenght, the code requires a batch of valid (but never used) padding memory ahead of the first input batch. This is achieved with least overhead by leaving one batch free and then arranging QKV consecutively.
11+
For variable sequence length, the code requires a batch of valid (but never used) padding memory ahead of the first output batch. No padding is needed for the input tensor, but it requires that the input tensor contain no NaN or Inf values. Note that users should set `total_length` to the `problem_shape`.
1212

1313
The approach of this implementation is to reuse the selection logic of the collective gemm builder and recombine the result into an FMHA kernel.
1414
The kernel and collective layer are then formulated to be fmha-specific.
@@ -37,13 +37,19 @@ There are three kernels to compute backwards:
3737

3838
`Sm100FmhaBwdKernelTmaWarpSpecialized` is the main point of this sample, as it demonstrates how to use tensor cores to achieve a high performance fused kernel.
3939

40+
## MLA Blackwell Backward
41+
42+
The sample also provides the feature of MLA backward(d=192, d_vo=128). To enable MLA backward, please specify `--d=192 --d_vo=128` when running the bwd sample.
43+
44+
`Sm100FmhaBwdMlaKernelTmaWarpSpecialized`is the main point for MLA backward. The MLA approach is slightly different from the original one to enable high performance with the MLA shape.
45+
4046
# MLA Inference for Blackwell
4147

4248
This sample provides code for fused multi-head latent attention inference in
4349
the weight-absorbed regime, i.e. for latent head dim 512, and rope head dim 64.
4450
It supports fp16, bf16, and fp8 input and output types.
4551

46-
To accomodate the large output accumulator due to the large latent head dimension,
52+
To accommodate the large output accumulator due to the large latent head dimension,
4753
the sample demonstrates how to leverage 2Sm Blackwell tensor cores.
4854

4955
Loading can be done via TMA (either without paging or with page size 128), or using `cp.async`
@@ -61,6 +67,8 @@ For detailed information on how to invoke them, check out either the tests in `C
6167
to simplify the sample, clarified that `fmha_gen` sample only supports head
6268
dim 128.
6369

70+
* 4.3.0: For variable sequence length, the code requires a batch of valid (but never used) padding memory ahead of the first output batch. No padding is needed for the input tensor, but it requires that the input tensor contain no NaN or Inf values. Note that users should set `total_length` to the `problem_shape`.
71+
6472
# Copyright
6573

6674
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_fusion.hpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -317,9 +317,8 @@ struct CausalMask : NoMask {
317317
if constexpr (IsQBegin) {
318318
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
319319
} else {
320-
// Local changes (to be upstreamed https://github.com/NVIDIA/cutlass/pull/2480)
321-
const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ;
322-
return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count);
320+
const int offset_tile_q = (get<1>(problem_size) - get<0>(problem_size)) % get<1>(tile_shape);
321+
return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape))));
323322
}
324323
}
325324

@@ -676,7 +675,7 @@ struct LocalMaskForBackward : LocalMask<kIsQBegin>, ResidualMaskForBackward {
676675
};
677676

678677
struct VariableLength {
679-
int max_length = 0;
678+
int max_length;
680679
int* cumulative_length = nullptr;
681680
int total_length = -1;
682681

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
199199
// scaling factor to quantize O
200200
float inv_scale_o = 1.0f;
201201

202+
// local changes
202203
int window_size_left = -1;
203204
int window_size_right = -1;
204205
};
@@ -211,6 +212,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
211212

212213
float scale_output;
213214

215+
// local changes
214216
int window_size_left;
215217
int window_size_right;
216218
};

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
8686
using Mask = Mask_;
8787

8888
static constexpr int StageCountQ = get<1>(TileShape{}) == 256 ? 1 : 2;
89+
// local changes
8990
static constexpr int StageCountKV = StageCountQ * (sizeof(Element) == 1 ? 11 : 5) ;
9091

9192
using StagesQ = cutlass::gemm::collective::StageCount<StageCountQ>;
@@ -540,6 +541,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
540541
tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1));
541542
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
542543

544+
// local changes
543545
// Each thread owns a single row
544546
using TMEM_LOAD = conditional_t<
545547
size<1>(TileShapeQK{}) < _128{},
@@ -802,7 +804,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
802804
// good values would be either 32 or 64
803805
const int kCorrectionTileSize = 32 / sizeof(ElementOut);
804806

805-
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
807+
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
806808

807809
typename CollectiveMmaPV::TiledMma mma;
808810
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));

0 commit comments

Comments
 (0)