@@ -118,7 +118,6 @@ struct Sm100FmhaLoadTmaWarpspecialized {
118118 auto dV = args.dV ;
119119 bool kIsPaged = args.ptr_page_table ? true : false ;
120120
121-
122121 // Local changes (D79534034)
123122 int get_0 = int (get<0 >(problem_shape));
124123 int get_1 = int (get<1 >(problem_shape));
@@ -128,10 +127,12 @@ struct Sm100FmhaLoadTmaWarpspecialized {
128127 get_0 = get<0 >(problem_shape).total_length ;
129128 }
130129
131- if constexpr (is_variable_length_v<tuple_element_t <1 , ProblemShape>>) {
132- get<2 , 1 >(dK) = 0 ;
133- get<2 , 1 >(dV) = 0 ;
134- get_1 = get<1 >(problem_shape).total_length ;
130+ if (!kIsPaged ) {
131+ if constexpr (is_variable_length_v<tuple_element_t <1 , ProblemShape>>) {
132+ get<2 , 1 >(dK) = 0 ;
133+ get<2 , 1 >(dV) = 0 ;
134+ get_1 = get<1 >(problem_shape).total_length ;
135+ }
135136 }
136137
137138 TMA_Q tma_load_q;
@@ -141,7 +142,8 @@ struct Sm100FmhaLoadTmaWarpspecialized {
141142 if (kIsPaged ) { // Paged Case
142143 // Create TMA Atom/Descriptor for Q, K, V
143144 // Q
144- Layout layout_Q = make_layout (select<0 ,2 ,3 >(problem_shape), dQ);
145+ auto problem_shape_q = make_tuple (get_0, get_1, get<2 >(problem_shape), get<3 >(problem_shape));
146+ Layout layout_Q = make_layout (select<0 ,2 ,3 >(problem_shape_q), dQ);
145147 Tensor mQ = make_tensor (make_gmem_ptr (ptr_Q), layout_Q);
146148
147149 auto cluster_layout_vmnk =
@@ -152,21 +154,16 @@ struct Sm100FmhaLoadTmaWarpspecialized {
152154 typename CollectiveMmaQK::TiledMma{}, cluster_layout_vmnk);
153155
154156 // K
155- auto problem_shape_paged_k = make_tuple (get_0, get_1, get<2 >(problem_shape), get<3 >(problem_shape));
156- get<1 > (problem_shape_paged_k) = args.page_block_size ;
157- get<3 , 1 >(problem_shape_paged_k) = args.num_blocks ;
158- Layout layout_k = make_layout (select<1 ,2 ,3 >(problem_shape_paged_k), dK);
157+ auto problem_shape_paged_kv = make_tuple (get_0, args.page_block_size , get<2 >(problem_shape), make_tuple (get<0 >(get<3 >(problem_shape)), args.num_blocks ));
158+ Layout layout_k = make_layout (select<1 ,2 ,3 >(problem_shape_paged_kv), dK);
159159 Tensor mK = make_tensor (make_gmem_ptr (ptr_K), layout_k);
160160
161161 tma_load_k = make_tma_atom_B_sm100<Element>(
162162 cute::SM90_TMA_LOAD{}, mK , SmemLayoutK{}(_, _, _, _0{}), TileShapeQK{},
163163 typename CollectiveMmaQK::TiledMma{}, cluster_layout_vmnk);
164164
165165 // V
166- auto problem_shape_paged_v = make_tuple (get_0, get<2 >(problem_shape), get_1, get<3 >(problem_shape));
167- get<2 > (problem_shape_paged_v) = args.page_block_size ;
168- get<3 , 1 >(problem_shape_paged_v) = args.num_blocks ;
169- Layout layout_v = make_layout (select<1 ,2 ,3 >(problem_shape_paged_v), select<1 ,0 ,2 >(dV));
166+ Layout layout_v = make_layout (select<2 ,1 ,3 >(problem_shape_paged_kv), select<1 ,0 ,2 >(dV));
170167 Tensor mV = make_tensor (make_gmem_ptr (ptr_V), layout_v);
171168
172169 tma_load_v = make_tma_atom_B_sm100<Element>(
@@ -368,7 +365,7 @@ struct Sm100FmhaLoadTmaWarpspecialized {
368365 }
369366 }
370367
371- template <class BlkCoord , class ProblemShape , class ParamsProblemShape >
368+ template <class BlkCoord , class ProblemShape , class ParamsProblemShape >
372369 CUTLASS_DEVICE void
373370 load_paged (
374371 BlkCoord const & blk_coord_in, ProblemShape const & problem_shape,
@@ -418,11 +415,8 @@ template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
418415 Tensor tQgQ = tQgQ_qdl (_, _, _0{}, get<2 >(blk_coord_q));
419416
420417 // compute gK, sK
421- ProblemShapeK problem_shape_k = problem_shape;
422- get<1 > (problem_shape_k) = params.page_block_size ;
423- get<3 , 1 >(problem_shape_k) = params.num_blocks ;
424-
425- Tensor mK_kdl_p = params.tma_load_k .get_tma_tensor (select<1 ,2 ,3 >(problem_shape_k));
418+ ProblemShapeK problem_shape_kv = make_tuple (get<0 >(problem_shape), params.page_block_size , get<2 >(problem_shape), make_tuple (get<0 >(get<3 >(problem_shape)), params.num_blocks ));
419+ Tensor mK_kdl_p = params.tma_load_k .get_tma_tensor (select<1 ,2 ,3 >(problem_shape_kv));
426420
427421 Tensor gK_kdl = local_tile (mK_kdl_p , TileShapeQK{}, make_coord (_, _, _), Step<X, _1, _1>{});
428422 Tensor tSgK_kdl = mma_qk.partition_B (gK_kdl );
@@ -437,10 +431,7 @@ template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
437431
438432 // compute gV, sV
439433 ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice (0 );
440- ProblemShapeK problem_shape_v = problem_shape;
441- get<1 > (problem_shape_v) = params.page_block_size ;
442- get<3 , 1 >(problem_shape_v) = params.num_blocks ;
443- Tensor mV_dkl_p = params.tma_load_v .get_tma_tensor (select<2 ,1 ,3 >(problem_shape_v));
434+ Tensor mV_dkl_p = params.tma_load_v .get_tma_tensor (select<2 ,1 ,3 >(problem_shape_kv));
444435
445436 Tensor gV_dkl = local_tile (mV_dkl_p , TileShapePV{}, make_coord (_, _, _), Step<X, _1, _1>{});
446437 Tensor tOgV_dkl = mma_pv.partition_B (gV_dkl );
0 commit comments