@@ -21,64 +21,66 @@ __global__ void mha_kernel_sm80(void* o,
2121 using namespace cute ;
2222
2323 // type alias
24- using T = typename Traits::T;
24+ using Element = typename Traits::Element;
25+ using BLK_M = typename Traits::BLK_M;
26+ using BLK_N = typename Traits::BLK_N;
27+ using BLK_K = typename Traits::BLK_K;
28+ using HEAD_DIM = typename Traits::HEAD_DIM;
29+
30+ using TiledMMA = typename Traits::TiledMMA;
31+ using Convertor = typename Traits::FragmentConvertor;
32+
2533 using SmemLayoutQ = typename Traits::SmemLayoutQ;
2634 using SmemLayoutK = typename Traits::SmemLayoutKV;
2735 using SmemLayoutV = typename Traits::SmemLayoutKV;
2836 using SmemLayoutVt = typename Traits::SmemLayoutVt;
29-
3037 using SmemLayoutO = typename Traits::SmemLayoutO;
31- using SmemCopyAtom = typename Traits::SmemCopyAtom;
32- using SmemCopyAtomO = typename Traits::SmemCopyAtomO;
3338 using GmemTiledCopyQKV = typename Traits::GmemTiledCopyQKV;
3439 using GmemTiledCopyO = typename Traits::GmemTiledCopyO;
35- using SmemCopyAtomTransposed = typename Traits::SmemCopyAtomTransposed;
36- using TiledMMA = typename Traits::TiledMMA;
40+
41+ using SmemTiledCopyQ = typename Traits::SmemTiledCopyQ;
42+ using SmemTiledCopyK = typename Traits::SmemTiledCopyK;
43+ using SmemTiledCopyVT = typename Traits::SmemTiledCopyVT;
44+ using SmemTiledCopyO = typename Traits::SmemTiledCopyO;
3745
3846 const int m_block = blockIdx .x ;
3947 const int base_id = blockIdx .y ;
4048 const int tidx = threadIdx .x ;
4149
42- constexpr int kBlockM = Traits::kBlockM ;
43- constexpr int kBlockN = Traits::kBlockN ;
44- constexpr int kHeadDim = Traits::kHeadDim ;
45-
4650 // ProblemShape
4751 // TODO: support non-contiguous layout
4852 const int offset = base_id * h_stride;
4953 // (q_len, head_dim)
50- auto Q = make_tensor (make_gmem_ptr (static_cast <T*>(q) + offset),
51- make_shape (q_len, Int< kHeadDim > {}),
52- make_stride (Int< kHeadDim > {}, _1{}));
53- auto O = make_tensor (make_gmem_ptr (static_cast <T*>(o) + offset),
54- make_shape (q_len, Int< kHeadDim > {}),
55- make_stride (Int< kHeadDim > {}, _1{}));
54+ auto Q = make_tensor (make_gmem_ptr ((Element*)q + offset),
55+ make_shape (q_len, HEAD_DIM {}),
56+ make_stride (HEAD_DIM {}, _1{}));
57+ auto O = make_tensor (make_gmem_ptr ((Element*)o + offset),
58+ make_shape (q_len, HEAD_DIM {}),
59+ make_stride (HEAD_DIM {}, _1{}));
5660 // (kv_len, head_dim)
57- auto K = make_tensor (make_gmem_ptr (static_cast <T*>(k) + offset),
58- make_shape (kv_len, Int< kHeadDim > {}),
59- make_stride (Int< kHeadDim > {}, _1{}));
60- auto V = make_tensor (make_gmem_ptr (static_cast <T*>(v) + offset),
61- make_shape (kv_len, Int< kHeadDim > {}),
62- make_stride (Int< kHeadDim > {}, _1{}));
61+ auto K = make_tensor (make_gmem_ptr ((Element*)k + offset),
62+ make_shape (kv_len, HEAD_DIM {}),
63+ make_stride (HEAD_DIM {}, _1{}));
64+ auto V = make_tensor (make_gmem_ptr ((Element*)v + offset),
65+ make_shape (kv_len, HEAD_DIM {}),
66+ make_stride (HEAD_DIM {}, _1{}));
6367
6468 // CTA/Block Shape
6569 // (BLK_M, head_dim)
66- Tensor gQ = local_tile (
67- Q, make_tile (Int< kBlockM > {}, Int< kHeadDim > {}), make_coord (m_block, _));
68- Tensor gO = local_tile (
69- O, make_tile (Int< kBlockM > {}, Int< kHeadDim > {}), make_coord (m_block, _));
70+ Tensor gQ =
71+ local_tile ( Q, make_tile (BLK_M {}, HEAD_DIM {}), make_coord (m_block, _));
72+ Tensor gO =
73+ local_tile ( O, make_tile (BLK_M {}, HEAD_DIM {}), make_coord (m_block, _));
7074
7175 // (BLK_N, head_dim)
72- Tensor gK = local_tile (
73- K, make_tile (Int<kBlockN >{}, Int<kHeadDim >{}), make_coord (0 , _));
74- Tensor gV = local_tile (
75- V, make_tile (Int<kBlockN >{}, Int<kHeadDim >{}), make_coord (0 , _));
76+ Tensor gK = local_tile (K, make_tile (BLK_N{}, HEAD_DIM{}), make_coord (0 , _));
77+ Tensor gV = local_tile (V, make_tile (BLK_N{}, HEAD_DIM{}), make_coord (0 , _));
7678
7779 // Smem
7880 extern __shared__ char smem[];
79- T * q_smem = static_cast <T*>(smem) ;
80- T * k_smem = q_smem + cosize (SmemLayoutQ{});
81- T * v_smem = k_smem + cosize (SmemLayoutK{});
81+ Element * q_smem = (Element*)smem ;
82+ Element * k_smem = q_smem + cosize (SmemLayoutQ{});
83+ Element * v_smem = k_smem + cosize (SmemLayoutK{});
8284
8385 // (BLK_M, BLK_K), k-major
8486 Tensor sQ = make_tensor (make_smem_ptr (q_smem), SmemLayoutQ{});
@@ -90,16 +92,25 @@ __global__ void mha_kernel_sm80(void* o,
9092 // (BLK_K, BLK_N), k-major
9193 Tensor sVt = make_tensor (make_smem_ptr (v_smem), SmemLayoutVt{});
9294
93- // rmem for mma
95+ // Fragments for GEMM
9496 TiledMMA tiled_mma;
9597 auto thr_mma = tiled_mma.get_slice (tidx);
96- // gemm-I
98+ // GEMM-I: S = Q@K.T
9799 auto tSrQ = thr_mma.partition_fragment_A (sQ ); // (MMA,MMA_M,MMA_K)
98100 auto tSrK = thr_mma.partition_fragment_B (sK ); // (MMA,MMA_N,MMA_K)
101+ auto tSrAccS = partition_fragment_C (
102+ tiled_mma, Shape<BLK_M, BLK_N>{}); // (MMA,MMA_M,MMA_N)
99103
100- // gemm -II
104+ // GEMM -II: O = softmax(S)@V
101105 auto tOrVt = thr_mma.partition_fragment_B (sVt ); // (MMA,MMA_K,MMA_N)
106+ auto tOrAccO = partition_fragment_C (
107+ tiled_mma, Shape<BLK_M, HEAD_DIM>{}); // (MMA,MMA_M,MMA_K)
102108
109+ // reshape for iterating over rows and columns
110+ auto tOrAccO_rc_view = Convertor::to_rowcol (tOrAccO);
111+ auto tSrAccS_rc_view = Convertor::to_rowcol (tSrAccS);
112+
113+ // Tiled Copy
103114 // g2s tiled copy for qkv
104115 GmemTiledCopyQKV gmem_tiled_copy_QKV;
105116 auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice (tidx);
@@ -111,79 +122,64 @@ __global__ void mha_kernel_sm80(void* o,
111122 auto tVsV = gmem_thr_copy_QKV.partition_D (sV );
112123
113124 // s2r tiled copy for qkv
114- auto smem_tiled_copy_Q = make_tiled_copy_A (SmemCopyAtom{}, tiled_mma) ;
125+ SmemTiledCopyQ smem_tiled_copy_Q;
115126 auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice (tidx);
116127 auto tSsQ = smem_thr_copy_Q.partition_S (sQ );
117128 auto tSrQ_copy_view = smem_thr_copy_Q.retile_D (tSrQ);
118129
119- auto smem_tiled_copy_K = make_tiled_copy_B (SmemCopyAtom{}, tiled_mma) ;
130+ SmemTiledCopyK smem_tiled_copy_K;
120131 auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice (tidx);
121132 auto tSsK = smem_thr_copy_K.partition_S (sK );
122133 auto tSrK_copy_view = smem_thr_copy_K.retile_D (tSrK);
123134
124- auto smem_tiled_copy_V =
125- make_tiled_copy_B (SmemCopyAtomTransposed{}, tiled_mma);
126- auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice (tidx);
127- auto tOsVt = smem_thr_copy_V.partition_S (sVt );
128- auto tOrVt_copy_view = smem_thr_copy_V.retile_D (tOrVt);
135+ SmemTiledCopyVT smem_tiled_copy_Vt;
136+ auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_thread_slice (tidx);
137+ auto tOsVt = smem_thr_copy_Vt.partition_S (sVt );
138+ auto tOrVt_copy_view = smem_thr_copy_Vt.retile_D (tOrVt);
129139
130140 // ############### Prologue ###############
131141
132- // produce q
142+ // produce q: [] => [q]
133143 cute::copy (gmem_tiled_copy_QKV, tQgQ, tQsQ);
134144 cp_async_fence ();
135145
136- // produce k
146+ // produce k: [q] => [q, k]
137147 cute::copy (gmem_tiled_copy_QKV, tKgK, tKsK);
138148 cp_async_fence ();
139149
140- // multiply sm scale
141150 // wait q: [q, k] => [k]
142- cp_async_wait<0 >();
151+ cp_async_wait<1 >();
143152 __syncthreads ();
144153
145154 // apply sm_scale
146155 // TODO: use thread parallelism
147156 for (int i = 0 ; i < size (tQsQ); ++i) {
148- tQsQ (i) = T (tQsQ (i) * sm_scale);
157+ tQsQ (i) = Element (tQsQ (i) * sm_scale);
149158 }
150159
151- // Final output fragment
152- auto tOrO =
153- partition_fragment_C (tiled_mma, Shape<Int<kBlockM >, Int<kHeadDim >>{});
154- clear (tOrO);
155-
156- // reshape for iteration
157- auto ol = logical_divide (tOrO.layout (), Shape<_2>{});
158- auto rAccOut_new_layout = make_layout (make_layout (get<0 , 1 >(ol), get<1 >(ol)),
159- make_layout (get<0 , 0 >(ol), get<2 >(ol)));
160- auto tOrO_rc = make_tensor (tOrO.data (), rAccOut_new_layout);
161-
162160 // RowsPerThread = #rows_per_MMA * #MMA_M
163- constexpr int RowsPerThread = 2 * size<1 >(tOrO );
161+ constexpr int RowsPerThread = 2 * size<1 >(tOrAccO );
164162 OnlineSoftmax<RowsPerThread> softmax;
165163
166164 // ############### Mainloop ###############
167165
168166 const int n_block_min = 0 ;
169- const int n_block_max = cute::ceil_div (kv_len, kBlockN );
167+ const int n_block_max = cute::ceil_div (kv_len, BLK_N{});
168+
169+ // clear output
170+ clear (tOrAccO);
170171 CUTE_NO_UNROLL
171172 for (int ni = n_block_min; ni < n_block_max; ++ni) {
172- // attention score
173- // (MMA=4,MMA_M,MMA_N) (fp32)
174- auto tSrS =
175- partition_fragment_C (tiled_mma, Shape<Int<kBlockM >, Int<kBlockN >>{});
176- // clear attention score
177- clear (tSrS);
173+ // clear attention score for each block
174+ clear (tSrAccS);
178175
179176 // wait k, queue: [q, k] => []
180177 cp_async_wait<0 >();
181178 __syncthreads ();
182179
183180 // produce v, [] => [v]
184181 {
185- gV = local_tile (
186- V, make_tile (Int<kBlockN >{}, Int<kHeadDim >{}), make_coord (ni, _));
182+ gV = local_tile (V, make_tile (BLK_N{}, HEAD_DIM{}), make_coord (ni, _));
187183 tVgV = gmem_thr_copy_QKV.partition_S (gV (_, _, 0 ));
188184 cute::copy (gmem_tiled_copy_QKV, tVgV, tVsV);
189185 }
@@ -194,71 +190,64 @@ __global__ void mha_kernel_sm80(void* o,
194190 for (int ki = 0 ; ki < size<2 >(tSrQ); ++ki) {
195191 cute::copy (smem_tiled_copy_Q, tSsQ (_, _, ki), tSrQ_copy_view (_, _, ki));
196192 cute::copy (smem_tiled_copy_K, tSsK (_, _, ki), tSrK_copy_view (_, _, ki));
197- cute::gemm (tiled_mma, tSrQ (_, _, ki), tSrK (_, _, ki), tSrS );
193+ cute::gemm (tiled_mma, tSrQ (_, _, ki), tSrK (_, _, ki), tSrAccS );
198194 }
199195
200- // reshape for iteration
201- auto sl = logical_divide (tSrS.layout (), Shape<_2>{});
202- auto rAccScore_new_layout =
203- make_layout (make_layout (get<0 , 1 >(sl), get<1 >(sl)),
204- make_layout (get<0 , 0 >(sl), get<2 >(sl)));
205- auto tSrS_rc = make_tensor (tSrS.data (), rAccScore_new_layout);
206-
207- softmax.rescale (tSrS_rc, tOrO_rc);
196+ // apply softmax and rescale
197+ softmax.rescale (tSrAccS_rc_view, tOrAccO_rc_view);
208198
209199 // wait v, [v] => []
210200 cp_async_wait<0 >();
211201 __syncthreads ();
212202
213203 // produce next k: [] => [k]
214204 if (ni != n_block_max - 1 ) {
215- gK = local_tile (
216- K, make_tile (Int<kBlockN >{}, Int<kHeadDim >{}), make_coord (ni + 1 , _));
205+ gK = local_tile (K, make_tile (BLK_N{}, HEAD_DIM{}), make_coord (ni + 1 , _));
217206 tKgK = gmem_thr_copy_QKV.partition_S (gK (_, _, 0 ));
218207 cute::copy (gmem_tiled_copy_QKV, tKgK, tKsK);
219208 }
220209 cp_async_fence ();
221210
222211 // 2> O = softmax(S)*V
223212
224- // cast scores from fp32 to fp16
225- auto tSrS_T = make_tensor_like<T>(tSrS );
213+ // cast scores from Accumulator to Element
214+ auto tSrS = make_tensor_like<Element>(tSrAccS );
226215 CUTE_UNROLL
227- for (int i = 0 ; i < size (tSrS ); ++i) {
228- tSrS_T (i) = static_cast <T>( tSrS (i));
216+ for (int i = 0 ; i < size (tSrAccS ); ++i) {
217+ tSrS (i) = static_cast <Element>( tSrAccS (i));
229218 }
219+
230220 // convert layout from gemm-I C to gemm-II A
231- auto l = logical_divide (tSrS_T.layout (), Shape<X, X, _2>{});
232- auto scores_new_layout = make_layout (
233- make_layout (get<0 >(l), get<2 , 0 >(l)), get<1 >(l), get<2 , 1 >(l));
234- auto tOrS = make_tensor (tSrS_T.data (), scores_new_layout);
221+ auto tOrS = Convertor::to_mma_a (tSrS);
235222
236223 CUTE_UNROLL
237224 for (int ki = 0 ; ki < size<2 >(tOrS); ++ki) {
238- cute::copy (smem_tiled_copy_V, tOsVt (_, _, ki), tOrVt_copy_view (_, _, ki));
239- cute::gemm (tiled_mma, tOrS (_, _, ki), tOrVt (_, _, ki), tOrO);
225+ cute::copy (
226+ smem_tiled_copy_Vt, tOsVt (_, _, ki), tOrVt_copy_view (_, _, ki));
227+ cute::gemm (tiled_mma, tOrS (_, _, ki), tOrVt (_, _, ki), tOrAccO);
240228 }
241229 }
242230
243231 // ############### Epilogue ###############
244232
245233 // normalize output: o /= rowsum
246- softmax.finalize (tOrO_rc );
234+ softmax.finalize (tOrAccO_rc_view );
247235
248236 // write output to gmem
249- // 1> covernt output from fp32 to fp16
250- auto tOrO_T = make_tensor_like<T>(tOrO );
237+ // 1> covernt output from ElementAccumulator to Element
238+ auto tOrO = make_tensor_like<Element>(tOrAccO );
251239 CUTE_UNROLL
252- for (int si = 0 ; si < size (tOrO ); ++si) {
253- tOrO_T (si) = static_cast <T>( tOrO (si));
240+ for (int si = 0 ; si < size (tOrAccO ); ++si) {
241+ tOrO (si) = static_cast <Element>( tOrAccO (si));
254242 }
255243
256244 // 2. copy output from reg to smem
257245 auto sO = make_tensor (sQ .data (), SmemLayoutO{});
258- auto smem_tiled_copy_O = make_tiled_copy_C (SmemCopyAtomO{}, tiled_mma);
246+
247+ SmemTiledCopyO smem_tiled_copy_O;
259248 auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice (tidx);
260249 // ((Atom,AtomNum),MMA_M,MMA_N)
261- auto taccOrO = smem_thr_copy_O.retile_S (tOrO_T );
250+ auto taccOrO = smem_thr_copy_O.retile_S (tOrO );
262251 // ((Atom,AtomNum),PIPE_M,PIPE_N)
263252 auto taccOsO = smem_thr_copy_O.partition_D (sO );
264253 cute::copy (smem_tiled_copy_O, taccOrO, taccOsO);
0 commit comments