Skip to content

Commit 06285d6

Browse files
sarithad-metameta-codesync[bot]
authored andcommitted
Add Paged Attention support to FMHA FWD CUTLASS kernel for variable length (#5033)
Summary: Pull Request resolved: #5033 X-link: https://github.com/facebookresearch/FBGEMM/pull/2046 Added Paged attention for variable sequence length case for Blackwell Cutlass FWD kernel. Reviewed By: Aya-ZIbra Differential Revision: D84284273 fbshipit-source-id: d1c102d225a8ebd704811fba8a2a5791e471908d
1 parent ccae43d commit 06285d6

File tree

5 files changed

+194
-51
lines changed

5 files changed

+194
-51
lines changed

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,13 @@ std::tuple<at::Tensor, at::Tensor> dispatch_fmha_fwd(
104104
};
105105

106106
auto dispatch_mask = [&](auto varlen) {
107-
int seq_k = kIsPaged ? static_cast<int>(*seqlen_k) : varlen ? k.size(0) : k.size(1);
107+
int seq_k = kIsPaged
108+
? (varlen
109+
? static_cast<int>(*max_seq_len_k)
110+
: static_cast<int>(*seqlen_k))
111+
: (varlen
112+
? k.size(0)
113+
: k.size(1));
108114
if (causal) {
109115
if (bottom_right) {
110116
return dispatch_head_dim(varlen, CausalMask</*kIsQBegin=*/false>{});
@@ -113,7 +119,7 @@ std::tuple<at::Tensor, at::Tensor> dispatch_fmha_fwd(
113119
}
114120
} else if (local) {
115121
if (bottom_right) {
116-
return dispatch_head_dim(varlen, LocalMask</*kIsQBegin=*/false>{});
122+
return dispatch_head_dim(varlen, LocalMask</*kIsQBegin=*/false>{});
117123
} else {
118124
return dispatch_head_dim(varlen, LocalMask</*kIsQBegin=*/true>{});
119125
}

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd_template.cuh

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,19 @@ std::tuple<at::Tensor, at::Tensor> fmha_fwd(
9191
StrideLSE>,
9292
TileScheduler>>;
9393

94-
if (kIsPaged && !kIsVarlen) {
95-
TORCH_CHECK(
94+
if (kIsPaged) {
95+
if (kIsVarlen) { // Variable length
96+
TORCH_CHECK(
97+
q.dim() == 3,
98+
"Expect Q shape to be (total_Q_seqlen, num_Q_heads, head_dim) ",
99+
"Found shape ", q.sizes());
100+
}
101+
else { // Fixed Length
102+
TORCH_CHECK(
96103
q.dim() == 4,
97104
"Expect Q shape to be (batch_size, Q_seqlen, num_Q_heads, head_dim). ",
98105
"Found shape ", q.sizes());
106+
}
99107
TORCH_CHECK(
100108
k.dim() == 4,
101109
"Expect K shape to be (num_blocks, page_block_size, num_KV_heads, head_dim) ",
@@ -113,7 +121,10 @@ std::tuple<at::Tensor, at::Tensor> fmha_fwd(
113121
TORCH_CHECK((k.size(1) % tile_N) == 0, "Page Block Size should be divisible by N tile size");
114122
TORCH_CHECK((v.size(1) % tile_N) == 0, "Page Block Size should be divisible by N tile size");
115123

116-
TORCH_CHECK(seqlen_k.has_value(), "seqlen_k should be set");
124+
// For fixed length sequences, seqlen_k should be set.
125+
if (!kIsVarlen) {
126+
TORCH_CHECK(seqlen_k.has_value(), "seqlen_k should be set");
127+
}
117128
}
118129
else if (kIsVarlen) {
119130
TORCH_CHECK(
@@ -153,7 +164,8 @@ std::tuple<at::Tensor, at::Tensor> fmha_fwd(
153164

154165
// Extract dimensions from input tensors
155166
int H_Q = kIsVarlen ? q.size(1) : q.size(2); // Number of Q heads
156-
int H_K = kIsVarlen ? k.size(1) : k.size(2); // Number of K heads
167+
int H_K = (kIsPaged && kIsVarlen) ? k.size(2)
168+
: (kIsVarlen ? k.size(1) : k.size(2)); // Number of K heads
157169
int D = q.size(q.dim() - 1); // Head dimension (D)
158170

159171
TORCH_CHECK(H_Q % H_K == 0);
@@ -162,14 +174,20 @@ std::tuple<at::Tensor, at::Tensor> fmha_fwd(
162174

163175
// SQ represents SumB(Q) for varlen (jagged len)
164176
int SQ = kIsVarlen ? q.size(0) : q.size(1);
165-
int SK = kIsPaged ? static_cast<int>(*seqlen_k) : kIsVarlen ? k.size(0) : k.size(1);
177+
int SK = kIsPaged
178+
? (kIsVarlen
179+
? static_cast<int>(*max_seq_len_k)
180+
: static_cast<int>(*seqlen_k))
181+
: (kIsVarlen
182+
? k.size(0)
183+
: k.size(1));
166184
int B = kIsVarlen ? cu_seqlens_q->size(0) - 1 : q.size(0);
167185

168186
// Parameters for paged attention.
169187
int page_table_stride = kIsPaged ? page_table.value().size(1) : 0;
170188
int num_blocks = kIsPaged ? k.size(0) : 1; // num_blocks
171189
int page_block_size = kIsPaged ? k.size(1) : 1; // page_block_size
172-
// num KV tiles > 1 within a page in the case of page_block_size > TileShapeN.
190+
// num KV tiles > 1 within a page in the case of page_block_size > TileShapeN.
173191
int num_KV_tiles_per_page = kIsPaged ? k.size(1) / (get<1>(TileShape{}).value) : 1;
174192

175193
ProblemShapeType problem_shape;
@@ -250,8 +268,10 @@ std::tuple<at::Tensor, at::Tensor> fmha_fwd(
250268
typename Operation::Arguments arguments;
251269
if constexpr (kIsVarlen) {
252270
get<2, 1>(stride_Q) = 0;
253-
get<2, 1>(stride_K) = 0;
254-
get<2, 1>(stride_V) = 0;
271+
if (!kIsPaged) {
272+
get<2, 1>(stride_K) = 0;
273+
get<2, 1>(stride_V) = 0;
274+
}
255275
get<2, 1>(stride_O) = 0;
256276
get<1, 1>(stride_LSE) = 0;
257277
}

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_load_tma_warpspecialized.hpp

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ struct Sm100FmhaLoadTmaWarpspecialized {
118118
auto dV = args.dV;
119119
bool kIsPaged = args.ptr_page_table ? true : false;
120120

121-
122121
// Local changes (D79534034)
123122
int get_0 = int(get<0>(problem_shape));
124123
int get_1 = int(get<1>(problem_shape));
@@ -128,10 +127,12 @@ struct Sm100FmhaLoadTmaWarpspecialized {
128127
get_0 = get<0>(problem_shape).total_length;
129128
}
130129

131-
if constexpr (is_variable_length_v<tuple_element_t<1, ProblemShape>>) {
132-
get<2, 1>(dK) = 0;
133-
get<2, 1>(dV) = 0;
134-
get_1 = get<1>(problem_shape).total_length;
130+
if (!kIsPaged) {
131+
if constexpr (is_variable_length_v<tuple_element_t<1, ProblemShape>>) {
132+
get<2, 1>(dK) = 0;
133+
get<2, 1>(dV) = 0;
134+
get_1 = get<1>(problem_shape).total_length;
135+
}
135136
}
136137

137138
TMA_Q tma_load_q;
@@ -141,7 +142,8 @@ struct Sm100FmhaLoadTmaWarpspecialized {
141142
if (kIsPaged) { // Paged Case
142143
//Create TMA Atom/Descriptor for Q, K, V
143144
//Q
144-
Layout layout_Q = make_layout(select<0,2,3>(problem_shape), dQ);
145+
auto problem_shape_q = make_tuple(get_0, get_1, get<2>(problem_shape), get<3>(problem_shape));
146+
Layout layout_Q = make_layout(select<0,2,3>(problem_shape_q), dQ);
145147
Tensor mQ = make_tensor(make_gmem_ptr(ptr_Q), layout_Q);
146148

147149
auto cluster_layout_vmnk =
@@ -152,21 +154,16 @@ struct Sm100FmhaLoadTmaWarpspecialized {
152154
typename CollectiveMmaQK::TiledMma{}, cluster_layout_vmnk);
153155

154156
// K
155-
auto problem_shape_paged_k = make_tuple(get_0, get_1, get<2>(problem_shape), get<3>(problem_shape));
156-
get<1> (problem_shape_paged_k) = args.page_block_size;
157-
get<3, 1>(problem_shape_paged_k) = args.num_blocks;
158-
Layout layout_k = make_layout(select<1,2,3>(problem_shape_paged_k), dK);
157+
auto problem_shape_paged_kv = make_tuple(get_0, args.page_block_size, get<2>(problem_shape), make_tuple(get<0>(get<3>(problem_shape)), args.num_blocks));
158+
Layout layout_k = make_layout(select<1,2,3>(problem_shape_paged_kv), dK);
159159
Tensor mK = make_tensor(make_gmem_ptr(ptr_K), layout_k);
160160

161161
tma_load_k = make_tma_atom_B_sm100<Element>(
162162
cute::SM90_TMA_LOAD{}, mK, SmemLayoutK{}(_, _, _, _0{}), TileShapeQK{},
163163
typename CollectiveMmaQK::TiledMma{}, cluster_layout_vmnk);
164164

165165
// V
166-
auto problem_shape_paged_v = make_tuple(get_0, get<2>(problem_shape), get_1, get<3>(problem_shape));
167-
get<2> (problem_shape_paged_v) = args.page_block_size;
168-
get<3, 1>(problem_shape_paged_v) = args.num_blocks;
169-
Layout layout_v = make_layout(select<1,2,3>(problem_shape_paged_v), select<1,0,2>(dV));
166+
Layout layout_v = make_layout(select<2,1,3>(problem_shape_paged_kv), select<1,0,2>(dV));
170167
Tensor mV = make_tensor(make_gmem_ptr(ptr_V), layout_v);
171168

172169
tma_load_v = make_tma_atom_B_sm100<Element>(
@@ -368,7 +365,7 @@ struct Sm100FmhaLoadTmaWarpspecialized {
368365
}
369366
}
370367

371-
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
368+
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
372369
CUTLASS_DEVICE void
373370
load_paged(
374371
BlkCoord const& blk_coord_in, ProblemShape const& problem_shape,
@@ -418,11 +415,8 @@ template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
418415
Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q));
419416

420417
// compute gK, sK
421-
ProblemShapeK problem_shape_k = problem_shape;
422-
get<1> (problem_shape_k) = params.page_block_size;
423-
get<3, 1>(problem_shape_k) = params.num_blocks;
424-
425-
Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape_k));
418+
ProblemShapeK problem_shape_kv = make_tuple(get<0>(problem_shape), params.page_block_size, get<2>(problem_shape), make_tuple(get<0>(get<3>(problem_shape)), params.num_blocks));
419+
Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape_kv));
426420

427421
Tensor gK_kdl = local_tile(mK_kdl_p, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{});
428422
Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl);
@@ -437,10 +431,7 @@ template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
437431

438432
// compute gV, sV
439433
ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0);
440-
ProblemShapeK problem_shape_v = problem_shape;
441-
get<1> (problem_shape_v) = params.page_block_size;
442-
get<3, 1>(problem_shape_v) = params.num_blocks;
443-
Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape_v));
434+
Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape_kv));
444435

445436
Tensor gV_dkl = local_tile(mV_dkl_p, TileShapePV{}, make_coord(_, _, _), Step<X, _1, _1>{});
446437
Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl);

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@ struct Sm100FmhaCtxKernelWarpspecializedSchedule {
8282
static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0);
8383
static const int NumRegsOther = 32 + (kDebugUsingPrintf ? 16 : 0);
8484
static const int NumRegsEmpty = 24;
85-
85+
8686
static const int NumWarps = 16;
87-
87+
8888
};
8989

9090

@@ -148,7 +148,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
148148
static const int NumWarpsCorrection = KernelSchedule::NumWarpsCorrection;
149149
static const int NumWarpsEpilogue = KernelSchedule::NumWarpsEpilogue;
150150
static const int NumWarpsLoad = KernelSchedule::NumWarpsLoad;
151-
151+
152152
static_assert(NumWarpsEpilogue == CollectiveEpilogue::NumWarpsEpilogue);
153153
static_assert(NumWarpsLoad == CollectiveEpilogue::NumWarpsLoad);
154154

@@ -177,13 +177,13 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
177177
};
178178

179179
static constexpr bool IsPersistent = std::is_same_v<TileScheduler, PersistentTileScheduler> || std::is_same_v<TileScheduler, CausalPersistentTileScheduler>;
180-
using MainloopEpilogueStorage = std::conditional_t<IsPersistent,
181-
std::conditional_t<IsMla,
180+
using MainloopEpilogueStorage = std::conditional_t<IsPersistent,
181+
std::conditional_t<IsMla,
182182
std::conditional_t<CollectiveMainloop::IsOrderLoadEpilogue, UnionType, StructType>,
183183
StructType>,
184184
UnionType>;
185185

186-
MainloopEpilogueStorage mainloop_epilogue;
186+
MainloopEpilogueStorage mainloop_epilogue;
187187

188188
struct PipelineStorage {
189189
alignas(16) typename CollectiveMainloop::PipelineQ::SharedStorage load_q;
@@ -305,7 +305,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
305305
shared_storage.pipelines.load_q,
306306
pipeline_load_q_params,
307307
ClusterShape{}, cute::true_type{}, /*mask calc*/cute::false_type{});
308-
308+
309309
typename CollectiveMainloop::PipelineKV::Params pipeline_load_kv_params;
310310
if (role == WarpRole::Load) {
311311
pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Producer;
@@ -565,7 +565,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
565565
warpgroup_reg_set<NumRegsOther>();
566566

567567
if constexpr (IsMla && CollectiveMainloop::IsOrderLoadEpilogue) {
568-
cutlass::arch::NamedBarrier::arrive((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp,
568+
cutlass::arch::NamedBarrier::arrive((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp,
569569
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
570570
}
571571

0 commit comments

Comments
 (0)