1+ #pragma once
2+
3+ #include < cuda.h>
4+ #include < cuda_runtime.h>
5+
6+ #include < cute/tensor.hpp>
7+
8+ #include " online_softmax.cuh"
9+
10+ namespace llm {
11+
12+ template <typename Traits>
13+ __global__ void mha_kernel_sm80 (void * o,
14+ const void * q,
15+ const void * k,
16+ const void * v,
17+ int h_stride,
18+ int q_len,
19+ int kv_len,
20+ float sm_scale) {
21+ using namespace cute ;
22+
23+ // type alias
24+ using Element = typename Traits::Element;
25+ using BLK_M = typename Traits::BLK_M;
26+ using BLK_N = typename Traits::BLK_N;
27+ using BLK_K = typename Traits::BLK_K;
28+ using HEAD_DIM = typename Traits::HEAD_DIM;
29+
30+ using TiledMMA = typename Traits::TiledMMA;
31+ using Convertor = typename Traits::FragmentConvertor;
32+
33+ using SmemLayoutQ = typename Traits::SmemLayoutQ;
34+ using SmemLayoutK = typename Traits::SmemLayoutKV;
35+ using SmemLayoutV = typename Traits::SmemLayoutKV;
36+ using SmemLayoutVt = typename Traits::SmemLayoutVt;
37+ using SmemLayoutO = typename Traits::SmemLayoutO;
38+ using GmemTiledCopyQKV = typename Traits::GmemTiledCopyQKV;
39+ using GmemTiledCopyO = typename Traits::GmemTiledCopyO;
40+
41+ using SmemTiledCopyQ = typename Traits::SmemTiledCopyQ;
42+ using SmemTiledCopyK = typename Traits::SmemTiledCopyK;
43+ using SmemTiledCopyVT = typename Traits::SmemTiledCopyVT;
44+ using SmemTiledCopyO = typename Traits::SmemTiledCopyO;
45+
46+ const int m_block = blockIdx .x ;
47+ const int base_id = blockIdx .y ;
48+ const int tidx = threadIdx .x ;
49+
50+ // ProblemShape
51+ // TODO: support non-contiguous layout
52+ const int offset = base_id * h_stride;
53+ // (q_len, head_dim)
54+ auto Q = make_tensor (make_gmem_ptr ((Element*)q + offset),
55+ make_shape (q_len, HEAD_DIM{}),
56+ make_stride (HEAD_DIM{}, _1{}));
57+ auto O = make_tensor (make_gmem_ptr ((Element*)o + offset),
58+ make_shape (q_len, HEAD_DIM{}),
59+ make_stride (HEAD_DIM{}, _1{}));
60+ // (kv_len, head_dim)
61+ auto K = make_tensor (make_gmem_ptr ((Element*)k + offset),
62+ make_shape (kv_len, HEAD_DIM{}),
63+ make_stride (HEAD_DIM{}, _1{}));
64+ auto V = make_tensor (make_gmem_ptr ((Element*)v + offset),
65+ make_shape (kv_len, HEAD_DIM{}),
66+ make_stride (HEAD_DIM{}, _1{}));
67+
68+ // CTA/Block Shape
69+ // (BLK_M, head_dim)
70+ Tensor gQ =
71+ local_tile (Q, make_tile (BLK_M{}, HEAD_DIM{}), make_coord (m_block, _));
72+ Tensor gO =
73+ local_tile (O, make_tile (BLK_M{}, HEAD_DIM{}), make_coord (m_block, _));
74+
75+ // (BLK_N, head_dim)
76+ Tensor gK = local_tile (K, make_tile (BLK_N{}, HEAD_DIM{}), make_coord (0 , _));
77+ Tensor gV = local_tile (V, make_tile (BLK_N{}, HEAD_DIM{}), make_coord (0 , _));
78+
79+ // Smem
80+ extern __shared__ char smem[];
81+ Element* q_smem = (Element*)smem;
82+ Element* k_smem = q_smem + cosize (SmemLayoutQ{});
83+ Element* v_smem = k_smem + cosize (SmemLayoutK{});
84+
85+ // (BLK_M, BLK_K), k-major
86+ Tensor sQ = make_tensor (make_smem_ptr (q_smem), SmemLayoutQ{});
87+ // (BLK_N, BLK_K), k-major
88+ Tensor sK = make_tensor (make_smem_ptr (k_smem), SmemLayoutK{});
89+ Tensor sV = make_tensor (make_smem_ptr (v_smem), SmemLayoutV{});
90+
91+ // Tensor for V^t; used in GEMM-II.
92+ // (BLK_K, BLK_N), k-major
93+ Tensor sVt = make_tensor (make_smem_ptr (v_smem), SmemLayoutVt{});
94+
95+ // Fragments for GEMM
96+ TiledMMA tiled_mma;
97+ auto thr_mma = tiled_mma.get_slice (tidx);
98+ // GEMM-I: S = Q@K.T
99+ auto tSrQ = thr_mma.partition_fragment_A (sQ ); // (MMA,MMA_M,MMA_K)
100+ auto tSrK = thr_mma.partition_fragment_B (sK ); // (MMA,MMA_N,MMA_K)
101+ auto tSrAccS = partition_fragment_C (
102+ tiled_mma, Shape<BLK_M, BLK_N>{}); // (MMA,MMA_M,MMA_N)
103+
104+ // GEMM-II: O = softmax(S)@V
105+ auto tOrVt = thr_mma.partition_fragment_B (sVt ); // (MMA,MMA_K,MMA_N)
106+ auto tOrAccO = partition_fragment_C (
107+ tiled_mma, Shape<BLK_M, HEAD_DIM>{}); // (MMA,MMA_M,MMA_K)
108+
109+ // reshape for iterating over rows and columns
110+ auto tOrAccO_rc_view = Convertor::to_rowcol (tOrAccO);
111+ auto tSrAccS_rc_view = Convertor::to_rowcol (tSrAccS);
112+
113+ // Tiled Copy
114+ // g2s tiled copy for qkv
115+ GmemTiledCopyQKV gmem_tiled_copy_QKV;
116+ auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice (tidx);
117+ auto tQgQ = gmem_thr_copy_QKV.partition_S (gQ (_, _, 0 ));
118+ auto tQsQ = gmem_thr_copy_QKV.partition_D (sQ );
119+ auto tKgK = gmem_thr_copy_QKV.partition_S (gK (_, _, 0 ));
120+ auto tKsK = gmem_thr_copy_QKV.partition_D (sK );
121+ auto tVgV = gmem_thr_copy_QKV.partition_S (gV (_, _, 0 ));
122+ auto tVsV = gmem_thr_copy_QKV.partition_D (sV );
123+
124+ // s2r tiled copy for qkv
125+ SmemTiledCopyQ smem_tiled_copy_Q;
126+ auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice (tidx);
127+ auto tSsQ = smem_thr_copy_Q.partition_S (sQ );
128+ auto tSrQ_copy_view = smem_thr_copy_Q.retile_D (tSrQ);
129+
130+ SmemTiledCopyK smem_tiled_copy_K;
131+ auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice (tidx);
132+ auto tSsK = smem_thr_copy_K.partition_S (sK );
133+ auto tSrK_copy_view = smem_thr_copy_K.retile_D (tSrK);
134+
135+ SmemTiledCopyVT smem_tiled_copy_Vt;
136+ auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_thread_slice (tidx);
137+ auto tOsVt = smem_thr_copy_Vt.partition_S (sVt );
138+ auto tOrVt_copy_view = smem_thr_copy_Vt.retile_D (tOrVt);
139+
140+ // ############### Prologue ###############
141+
142+ // produce q: [] => [q]
143+ cute::copy (gmem_tiled_copy_QKV, tQgQ, tQsQ);
144+ cp_async_fence ();
145+
146+ // produce k: [q] => [q, k]
147+ cute::copy (gmem_tiled_copy_QKV, tKgK, tKsK);
148+ cp_async_fence ();
149+
150+ // wait q: [q, k] => [k]
151+ cp_async_wait<1 >();
152+ __syncthreads ();
153+
154+ // apply sm_scale
155+ // TODO: use thread parallelism
156+ for (int i = 0 ; i < size (tQsQ); ++i) {
157+ tQsQ (i) = Element (tQsQ (i) * sm_scale);
158+ }
159+
160+ // RowsPerThread = #rows_per_MMA * #MMA_M
161+ constexpr int RowsPerThread = 2 * size<1 >(tOrAccO);
162+ OnlineSoftmax<RowsPerThread> softmax;
163+
164+ // ############### Mainloop ###############
165+
166+ const int n_block_min = 0 ;
167+ const int n_block_max = cute::ceil_div (kv_len, BLK_N{});
168+
169+ // clear output
170+ clear (tOrAccO);
171+ CUTE_NO_UNROLL
172+ for (int ni = n_block_min; ni < n_block_max; ++ni) {
173+ // clear attention score for each block
174+ clear (tSrAccS);
175+
176+ // wait k, queue: [q, k] => []
177+ cp_async_wait<0 >();
178+ __syncthreads ();
179+
180+ // produce v, [] => [v]
181+ {
182+ gV = local_tile (V, make_tile (BLK_N{}, HEAD_DIM{}), make_coord (ni, _));
183+ tVgV = gmem_thr_copy_QKV.partition_S (gV (_, _, 0 ));
184+ cute::copy (gmem_tiled_copy_QKV, tVgV, tVsV);
185+ }
186+ cp_async_fence ();
187+
188+ // 1> S = Q@K.T
189+ CUTE_UNROLL
190+ for (int ki = 0 ; ki < size<2 >(tSrQ); ++ki) {
191+ cute::copy (smem_tiled_copy_Q, tSsQ (_, _, ki), tSrQ_copy_view (_, _, ki));
192+ cute::copy (smem_tiled_copy_K, tSsK (_, _, ki), tSrK_copy_view (_, _, ki));
193+ cute::gemm (tiled_mma, tSrQ (_, _, ki), tSrK (_, _, ki), tSrAccS);
194+ }
195+
196+ // apply softmax and rescale
197+ softmax.rescale (tSrAccS_rc_view, tOrAccO_rc_view);
198+
199+ // wait v, [v] => []
200+ cp_async_wait<0 >();
201+ __syncthreads ();
202+
203+ // produce next k: [] => [k]
204+ if (ni != n_block_max - 1 ) {
205+ gK = local_tile (K, make_tile (BLK_N{}, HEAD_DIM{}), make_coord (ni + 1 , _));
206+ tKgK = gmem_thr_copy_QKV.partition_S (gK (_, _, 0 ));
207+ cute::copy (gmem_tiled_copy_QKV, tKgK, tKsK);
208+ }
209+ cp_async_fence ();
210+
211+ // 2> O = softmax(S)*V
212+
213+ // cast scores from Accumulator to Element
214+ auto tSrS = make_tensor_like<Element>(tSrAccS);
215+ CUTE_UNROLL
216+ for (int i = 0 ; i < size (tSrAccS); ++i) {
217+ tSrS (i) = static_cast <Element>(tSrAccS (i));
218+ }
219+
220+ // convert layout from gemm-I C to gemm-II A
221+ auto tOrS = Convertor::to_mma_a (tSrS);
222+
223+ CUTE_UNROLL
224+ for (int ki = 0 ; ki < size<2 >(tOrS); ++ki) {
225+ cute::copy (
226+ smem_tiled_copy_Vt, tOsVt (_, _, ki), tOrVt_copy_view (_, _, ki));
227+ cute::gemm (tiled_mma, tOrS (_, _, ki), tOrVt (_, _, ki), tOrAccO);
228+ }
229+ }
230+
231+ // ############### Epilogue ###############
232+
233+ // normalize output: o /= rowsum
234+ softmax.finalize (tOrAccO_rc_view);
235+
236+ // write output to gmem
237+ // 1> covernt output from ElementAccumulator to Element
238+ auto tOrO = make_tensor_like<Element>(tOrAccO);
239+ CUTE_UNROLL
240+ for (int si = 0 ; si < size (tOrAccO); ++si) {
241+ tOrO (si) = static_cast <Element>(tOrAccO (si));
242+ }
243+
244+ // 2. copy output from reg to smem
245+ auto sO = make_tensor (sQ .data (), SmemLayoutO{});
246+
247+ SmemTiledCopyO smem_tiled_copy_O;
248+ auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice (tidx);
249+ // ((Atom,AtomNum),MMA_M,MMA_N)
250+ auto taccOrO = smem_thr_copy_O.retile_S (tOrO);
251+ // ((Atom,AtomNum),PIPE_M,PIPE_N)
252+ auto taccOsO = smem_thr_copy_O.partition_D (sO );
253+ cute::copy (smem_tiled_copy_O, taccOrO, taccOsO);
254+
255+ // 3. copy output from smem to gmem
256+ GmemTiledCopyO gmem_tiled_copy_O;
257+ auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice (tidx);
258+ // ((Atom,AtomNum),ATOM_M,ATOM_N)
259+ auto tOsO = gmem_thr_copy_O.partition_S (sO );
260+ auto tOgO = gmem_thr_copy_O.partition_D (gO (_, _, 0 ));
261+
262+ // wait for smem copy before copy to gmem
263+ __syncthreads ();
264+ cute::copy (gmem_tiled_copy_O, tOsO, tOgO);
265+ }
266+
267+ } // namespace llm
0 commit comments