You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Bring 4.2.1 changes to FBGEMM blackwell cutlass fmha (#5052)
Summary:
Pull Request resolved: #5052
X-link: https://github.com/facebookresearch/FBGEMM/pull/2061
As titled.
Mostly grabbing changes from 4.2.1.
This also remove changes from D84954166
There is one thing that is TBD, that is how to incorporate changes similar to that of D79534034 that they also made upstream. I prefer to stay close to upstream, but I tested it and it would fail. So settling for Aya-ZIbra D84970563 for now.
Reviewed By: Aya-ZIbra
Differential Revision: D84961921
fbshipit-source-id: ff3ab1746951d3b91126d7a1b75e4d7cf2c92b69
Copy file name to clipboardExpand all lines: fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/README.md
+10-2Lines changed: 10 additions & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -8,7 +8,7 @@ For generation usage, use an M-blocking (Num-Groups) of 128 (although the limit
8
8
9
9
Context loads are done via TMA, whereas generation usage utilized `cp.async` and is thus more amenable to complex load patterns.
10
10
11
-
For variable sequence lenght, the code requires a batch of valid (but never used) padding memory ahead of the first input batch. This is achieved with least overhead by leaving one batch free and then arranging QKV consecutively.
11
+
For variable sequence length, the code requires a batch of valid (but never used) padding memory ahead of the first output batch. No padding is needed for the input tensor, but it requires that the input tensor contain no NaN or Inf values. Note that users should set `total_length` to the `problem_shape`.
12
12
13
13
The approach of this implementation is to reuse the selection logic of the collective gemm builder and recombine the result into an FMHA kernel.
14
14
The kernel and collective layer are then formulated to be fmha-specific.
@@ -37,13 +37,19 @@ There are three kernels to compute backwards:
37
37
38
38
`Sm100FmhaBwdKernelTmaWarpSpecialized` is the main point of this sample, as it demonstrates how to use tensor cores to achieve a high performance fused kernel.
39
39
40
+
## MLA Blackwell Backward
41
+
42
+
The sample also provides the feature of MLA backward(d=192, d_vo=128). To enable MLA backward, please specify `--d=192 --d_vo=128` when running the bwd sample.
43
+
44
+
`Sm100FmhaBwdMlaKernelTmaWarpSpecialized`is the main point for MLA backward. The MLA approach is slightly different from the original one to enable high performance with the MLA shape.
45
+
40
46
# MLA Inference for Blackwell
41
47
42
48
This sample provides code for fused multi-head latent attention inference in
43
49
the weight-absorbed regime, i.e. for latent head dim 512, and rope head dim 64.
44
50
It supports fp16, bf16, and fp8 input and output types.
45
51
46
-
To accomodate the large output accumulator due to the large latent head dimension,
52
+
To accommodate the large output accumulator due to the large latent head dimension,
47
53
the sample demonstrates how to leverage 2Sm Blackwell tensor cores.
48
54
49
55
Loading can be done via TMA (either without paging or with page size 128), or using `cp.async`
@@ -61,6 +67,8 @@ For detailed information on how to invoke them, check out either the tests in `C
61
67
to simplify the sample, clarified that `fmha_gen` sample only supports head
62
68
dim 128.
63
69
70
+
* 4.3.0: For variable sequence length, the code requires a batch of valid (but never used) padding memory ahead of the first output batch. No padding is needed for the input tensor, but it requires that the input tensor contain no NaN or Inf values. Note that users should set `total_length` to the `problem_shape`.
71
+
64
72
# Copyright
65
73
66
74
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Copy file name to clipboardExpand all lines: fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp
Copy file name to clipboardExpand all lines: fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp
0 commit comments