1+ #pragma once
2+
3+ #include < cuda.h>
4+ #include < cuda_runtime.h>
5+
6+ #include < cute/tensor.hpp>
7+
8+ #include " online_softmax.cuh"
9+
10+ namespace llm {
11+
12+ template <typename Traits>
13+ __global__ void mha_kernel_sm80 (void * o,
14+ const void * q,
15+ const void * k,
16+ const void * v,
17+ int h_stride,
18+ int q_len,
19+ int kv_len,
20+ float sm_scale) {
21+ using namespace cute ;
22+
23+ // type alias
24+ using T = typename Traits::T;
25+ using SmemLayoutQ = typename Traits::SmemLayoutQ;
26+ using SmemLayoutK = typename Traits::SmemLayoutKV;
27+ using SmemLayoutV = typename Traits::SmemLayoutKV;
28+ using SmemLayoutVt = typename Traits::SmemLayoutVt;
29+
30+ using SmemLayoutO = typename Traits::SmemLayoutO;
31+ using SmemCopyAtom = typename Traits::SmemCopyAtom;
32+ using SmemCopyAtomO = typename Traits::SmemCopyAtomO;
33+ using GmemTiledCopyQKV = typename Traits::GmemTiledCopyQKV;
34+ using GmemTiledCopyO = typename Traits::GmemTiledCopyO;
35+ using SmemCopyAtomTransposed = typename Traits::SmemCopyAtomTransposed;
36+ using TiledMMA = typename Traits::TiledMMA;
37+
38+ const int m_block = blockIdx .x ;
39+ const int base_id = blockIdx .y ;
40+ const int tidx = threadIdx .x ;
41+
42+ constexpr int kBlockM = Traits::kBlockM ;
43+ constexpr int kBlockN = Traits::kBlockN ;
44+ constexpr int kHeadDim = Traits::kHeadDim ;
45+
46+ // ProblemShape
47+ // TODO: support non-contiguous layout
48+ const int offset = base_id * h_stride;
49+ // (q_len, head_dim)
50+ auto Q = make_tensor (make_gmem_ptr (static_cast <T*>(q) + offset),
51+ make_shape (q_len, Int<kHeadDim >{}),
52+ make_stride (Int<kHeadDim >{}, _1{}));
53+ auto O = make_tensor (make_gmem_ptr (static_cast <T*>(o) + offset),
54+ make_shape (q_len, Int<kHeadDim >{}),
55+ make_stride (Int<kHeadDim >{}, _1{}));
56+ // (kv_len, head_dim)
57+ auto K = make_tensor (make_gmem_ptr (static_cast <T*>(k) + offset),
58+ make_shape (kv_len, Int<kHeadDim >{}),
59+ make_stride (Int<kHeadDim >{}, _1{}));
60+ auto V = make_tensor (make_gmem_ptr (static_cast <T*>(v) + offset),
61+ make_shape (kv_len, Int<kHeadDim >{}),
62+ make_stride (Int<kHeadDim >{}, _1{}));
63+
64+ // CTA/Block Shape
65+ // (BLK_M, head_dim)
66+ Tensor gQ = local_tile (
67+ Q, make_tile (Int<kBlockM >{}, Int<kHeadDim >{}), make_coord (m_block, _));
68+ Tensor gO = local_tile (
69+ O, make_tile (Int<kBlockM >{}, Int<kHeadDim >{}), make_coord (m_block, _));
70+
71+ // (BLK_N, head_dim)
72+ Tensor gK = local_tile (
73+ K, make_tile (Int<kBlockN >{}, Int<kHeadDim >{}), make_coord (0 , _));
74+ Tensor gV = local_tile (
75+ V, make_tile (Int<kBlockN >{}, Int<kHeadDim >{}), make_coord (0 , _));
76+
77+ // Smem
78+ extern __shared__ char smem[];
79+ T* q_smem = static_cast <T*>(smem);
80+ T* k_smem = q_smem + cosize (SmemLayoutQ{});
81+ T* v_smem = k_smem + cosize (SmemLayoutK{});
82+
83+ // (BLK_M, BLK_K), k-major
84+ Tensor sQ = make_tensor (make_smem_ptr (q_smem), SmemLayoutQ{});
85+ // (BLK_N, BLK_K), k-major
86+ Tensor sK = make_tensor (make_smem_ptr (k_smem), SmemLayoutK{});
87+ Tensor sV = make_tensor (make_smem_ptr (v_smem), SmemLayoutV{});
88+
89+ // Tensor for V^t; used in GEMM-II.
90+ // (BLK_K, BLK_N), k-major
91+ Tensor sVt = make_tensor (make_smem_ptr (v_smem), SmemLayoutVt{});
92+
93+ // rmem for mma
94+ TiledMMA tiled_mma;
95+ auto thr_mma = tiled_mma.get_slice (tidx);
96+ // gemm-I
97+ auto tSrQ = thr_mma.partition_fragment_A (sQ ); // (MMA,MMA_M,MMA_K)
98+ auto tSrK = thr_mma.partition_fragment_B (sK ); // (MMA,MMA_N,MMA_K)
99+
100+ // gemm-II
101+ auto tOrVt = thr_mma.partition_fragment_B (sVt ); // (MMA,MMA_K,MMA_N)
102+
103+ // g2s tiled copy for qkv
104+ GmemTiledCopyQKV gmem_tiled_copy_QKV;
105+ auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice (tidx);
106+ auto tQgQ = gmem_thr_copy_QKV.partition_S (gQ (_, _, 0 ));
107+ auto tQsQ = gmem_thr_copy_QKV.partition_D (sQ );
108+ auto tKgK = gmem_thr_copy_QKV.partition_S (gK (_, _, 0 ));
109+ auto tKsK = gmem_thr_copy_QKV.partition_D (sK );
110+ auto tVgV = gmem_thr_copy_QKV.partition_S (gV (_, _, 0 ));
111+ auto tVsV = gmem_thr_copy_QKV.partition_D (sV );
112+
113+ // s2r tiled copy for qkv
114+ auto smem_tiled_copy_Q = make_tiled_copy_A (SmemCopyAtom{}, tiled_mma);
115+ auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice (tidx);
116+ auto tSsQ = smem_thr_copy_Q.partition_S (sQ );
117+ auto tSrQ_copy_view = smem_thr_copy_Q.retile_D (tSrQ);
118+
119+ auto smem_tiled_copy_K = make_tiled_copy_B (SmemCopyAtom{}, tiled_mma);
120+ auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice (tidx);
121+ auto tSsK = smem_thr_copy_K.partition_S (sK );
122+ auto tSrK_copy_view = smem_thr_copy_K.retile_D (tSrK);
123+
124+ auto smem_tiled_copy_V =
125+ make_tiled_copy_B (SmemCopyAtomTransposed{}, tiled_mma);
126+ auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice (tidx);
127+ auto tOsVt = smem_thr_copy_V.partition_S (sVt );
128+ auto tOrVt_copy_view = smem_thr_copy_V.retile_D (tOrVt);
129+
130+ // ############### Prologue ###############
131+
132+ // produce q
133+ cute::copy (gmem_tiled_copy_QKV, tQgQ, tQsQ);
134+ cp_async_fence ();
135+
136+ // produce k
137+ cute::copy (gmem_tiled_copy_QKV, tKgK, tKsK);
138+ cp_async_fence ();
139+
140+ // multiply sm scale
141+ // wait q: [q, k] => [k]
142+ cp_async_wait<0 >();
143+ __syncthreads ();
144+
145+ // apply sm_scale
146+ // TODO: use thread parallelism
147+ for (int i = 0 ; i < size (tQsQ); ++i) {
148+ tQsQ (i) = T (tQsQ (i) * sm_scale);
149+ }
150+
151+ // Final output fragment
152+ auto tOrO =
153+ partition_fragment_C (tiled_mma, Shape<Int<kBlockM >, Int<kHeadDim >>{});
154+ clear (tOrO);
155+
156+ // reshape for iteration
157+ auto ol = logical_divide (tOrO.layout (), Shape<_2>{});
158+ auto rAccOut_new_layout = make_layout (make_layout (get<0 , 1 >(ol), get<1 >(ol)),
159+ make_layout (get<0 , 0 >(ol), get<2 >(ol)));
160+ auto tOrO_rc = make_tensor (tOrO.data (), rAccOut_new_layout);
161+
162+ // RowsPerThread = #rows_per_MMA * #MMA_M
163+ constexpr int RowsPerThread = 2 * size<1 >(tOrO);
164+ OnlineSoftmax<RowsPerThread> softmax;
165+
166+ // ############### Mainloop ###############
167+
168+ const int n_block_min = 0 ;
169+ const int n_block_max = cute::ceil_div (kv_len, kBlockN );
170+ CUTE_NO_UNROLL
171+ for (int ni = n_block_min; ni < n_block_max; ++ni) {
172+ // attention score
173+ // (MMA=4,MMA_M,MMA_N) (fp32)
174+ auto tSrS =
175+ partition_fragment_C (tiled_mma, Shape<Int<kBlockM >, Int<kBlockN >>{});
176+ // clear attention score
177+ clear (tSrS);
178+
179+ // wait k, queue: [q, k] => []
180+ cp_async_wait<0 >();
181+ __syncthreads ();
182+
183+ // produce v, [] => [v]
184+ {
185+ gV = local_tile (
186+ V, make_tile (Int<kBlockN >{}, Int<kHeadDim >{}), make_coord (ni, _));
187+ tVgV = gmem_thr_copy_QKV.partition_S (gV (_, _, 0 ));
188+ cute::copy (gmem_tiled_copy_QKV, tVgV, tVsV);
189+ }
190+ cp_async_fence ();
191+
192+ // 1> S = Q@K.T
193+ CUTE_UNROLL
194+ for (int ki = 0 ; ki < size<2 >(tSrQ); ++ki) {
195+ cute::copy (smem_tiled_copy_Q, tSsQ (_, _, ki), tSrQ_copy_view (_, _, ki));
196+ cute::copy (smem_tiled_copy_K, tSsK (_, _, ki), tSrK_copy_view (_, _, ki));
197+ cute::gemm (tiled_mma, tSrQ (_, _, ki), tSrK (_, _, ki), tSrS);
198+ }
199+
200+ // reshape for iteration
201+ auto sl = logical_divide (tSrS.layout (), Shape<_2>{});
202+ auto rAccScore_new_layout =
203+ make_layout (make_layout (get<0 , 1 >(sl), get<1 >(sl)),
204+ make_layout (get<0 , 0 >(sl), get<2 >(sl)));
205+ auto tSrS_rc = make_tensor (tSrS.data (), rAccScore_new_layout);
206+
207+ softmax.rescale (tSrS_rc, tOrO_rc);
208+
209+ // wait v, [v] => []
210+ cp_async_wait<0 >();
211+ __syncthreads ();
212+
213+ // produce next k: [] => [k]
214+ if (ni != n_block_max - 1 ) {
215+ gK = local_tile (
216+ K, make_tile (Int<kBlockN >{}, Int<kHeadDim >{}), make_coord (ni + 1 , _));
217+ tKgK = gmem_thr_copy_QKV.partition_S (gK (_, _, 0 ));
218+ cute::copy (gmem_tiled_copy_QKV, tKgK, tKsK);
219+ }
220+ cp_async_fence ();
221+
222+ // 2> O = softmax(S)*V
223+
224+ // cast scores from fp32 to fp16
225+ auto tSrS_T = make_tensor_like<T>(tSrS);
226+ CUTE_UNROLL
227+ for (int i = 0 ; i < size (tSrS); ++i) {
228+ tSrS_T (i) = static_cast <T>(tSrS (i));
229+ }
230+ // convert layout from gemm-I C to gemm-II A
231+ auto l = logical_divide (tSrS_T.layout (), Shape<X, X, _2>{});
232+ auto scores_new_layout = make_layout (
233+ make_layout (get<0 >(l), get<2 , 0 >(l)), get<1 >(l), get<2 , 1 >(l));
234+ auto tOrS = make_tensor (tSrS_T.data (), scores_new_layout);
235+
236+ CUTE_UNROLL
237+ for (int ki = 0 ; ki < size<2 >(tOrS); ++ki) {
238+ cute::copy (smem_tiled_copy_V, tOsVt (_, _, ki), tOrVt_copy_view (_, _, ki));
239+ cute::gemm (tiled_mma, tOrS (_, _, ki), tOrVt (_, _, ki), tOrO);
240+ }
241+ }
242+
243+ // ############### Epilogue ###############
244+
245+ // normalize output: o /= rowsum
246+ softmax.finalize (tOrO_rc);
247+
248+ // write output to gmem
249+ // 1> covernt output from fp32 to fp16
250+ auto tOrO_T = make_tensor_like<T>(tOrO);
251+ CUTE_UNROLL
252+ for (int si = 0 ; si < size (tOrO); ++si) {
253+ tOrO_T (si) = static_cast <T>(tOrO (si));
254+ }
255+
256+ // 2. copy output from reg to smem
257+ auto sO = make_tensor (sQ .data (), SmemLayoutO{});
258+ auto smem_tiled_copy_O = make_tiled_copy_C (SmemCopyAtomO{}, tiled_mma);
259+ auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice (tidx);
260+ // ((Atom,AtomNum),MMA_M,MMA_N)
261+ auto taccOrO = smem_thr_copy_O.retile_S (tOrO_T);
262+ // ((Atom,AtomNum),PIPE_M,PIPE_N)
263+ auto taccOsO = smem_thr_copy_O.partition_D (sO );
264+ cute::copy (smem_tiled_copy_O, taccOrO, taccOsO);
265+
266+ // 3. copy output from smem to gmem
267+ GmemTiledCopyO gmem_tiled_copy_O;
268+ auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice (tidx);
269+ // ((Atom,AtomNum),ATOM_M,ATOM_N)
270+ auto tOsO = gmem_thr_copy_O.partition_S (sO );
271+ auto tOgO = gmem_thr_copy_O.partition_D (gO (_, _, 0 ));
272+
273+ // wait for smem copy before copy to gmem
274+ __syncthreads ();
275+ cute::copy (gmem_tiled_copy_O, tOsO, tOgO);
276+ }
277+
278+ } // namespace llm
0 commit comments