Skip to content

Commit 257bbf4

Browse files
committed
refactor and fix build errors
1 parent 38b82c0 commit 257bbf4

File tree

5 files changed

+198
-163
lines changed

5 files changed

+198
-163
lines changed

src/kernels/attention/attention_cpu_test.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
#include <ATen/ops/equal.h>
2-
#include <gtest/gtest.h>
3-
#include <torch/csrc/autograd/generated/variable_factories.h>
4-
51
#include "attention_cpu.h"
62

3+
#include <gtest/gtest.h>
4+
#include <torch/torch.h>
5+
76
namespace llm {
87
namespace {
98
// Multi-head attention implementation using pytorch

src/kernels/attention/attention_kernel_sm80.cuh

Lines changed: 86 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -21,64 +21,66 @@ __global__ void mha_kernel_sm80(void* o,
2121
using namespace cute;
2222

2323
// type alias
24-
using T = typename Traits::T;
24+
using Element = typename Traits::Element;
25+
using BLK_M = typename Traits::BLK_M;
26+
using BLK_N = typename Traits::BLK_N;
27+
using BLK_K = typename Traits::BLK_K;
28+
using HEAD_DIM = typename Traits::HEAD_DIM;
29+
30+
using TiledMMA = typename Traits::TiledMMA;
31+
using Convertor = typename Traits::FragmentConvertor;
32+
2533
using SmemLayoutQ = typename Traits::SmemLayoutQ;
2634
using SmemLayoutK = typename Traits::SmemLayoutKV;
2735
using SmemLayoutV = typename Traits::SmemLayoutKV;
2836
using SmemLayoutVt = typename Traits::SmemLayoutVt;
29-
3037
using SmemLayoutO = typename Traits::SmemLayoutO;
31-
using SmemCopyAtom = typename Traits::SmemCopyAtom;
32-
using SmemCopyAtomO = typename Traits::SmemCopyAtomO;
3338
using GmemTiledCopyQKV = typename Traits::GmemTiledCopyQKV;
3439
using GmemTiledCopyO = typename Traits::GmemTiledCopyO;
35-
using SmemCopyAtomTransposed = typename Traits::SmemCopyAtomTransposed;
36-
using TiledMMA = typename Traits::TiledMMA;
40+
41+
using SmemTiledCopyQ = typename Traits::SmemTiledCopyQ;
42+
using SmemTiledCopyK = typename Traits::SmemTiledCopyK;
43+
using SmemTiledCopyVT = typename Traits::SmemTiledCopyVT;
44+
using SmemTiledCopyO = typename Traits::SmemTiledCopyO;
3745

3846
const int m_block = blockIdx.x;
3947
const int base_id = blockIdx.y;
4048
const int tidx = threadIdx.x;
4149

42-
constexpr int kBlockM = Traits::kBlockM;
43-
constexpr int kBlockN = Traits::kBlockN;
44-
constexpr int kHeadDim = Traits::kHeadDim;
45-
4650
// ProblemShape
4751
// TODO: support non-contiguous layout
4852
const int offset = base_id * h_stride;
4953
// (q_len, head_dim)
50-
auto Q = make_tensor(make_gmem_ptr(static_cast<T*>(q) + offset),
51-
make_shape(q_len, Int<kHeadDim>{}),
52-
make_stride(Int<kHeadDim>{}, _1{}));
53-
auto O = make_tensor(make_gmem_ptr(static_cast<T*>(o) + offset),
54-
make_shape(q_len, Int<kHeadDim>{}),
55-
make_stride(Int<kHeadDim>{}, _1{}));
54+
auto Q = make_tensor(make_gmem_ptr((Element*)q + offset),
55+
make_shape(q_len, HEAD_DIM{}),
56+
make_stride(HEAD_DIM{}, _1{}));
57+
auto O = make_tensor(make_gmem_ptr((Element*)o + offset),
58+
make_shape(q_len, HEAD_DIM{}),
59+
make_stride(HEAD_DIM{}, _1{}));
5660
// (kv_len, head_dim)
57-
auto K = make_tensor(make_gmem_ptr(static_cast<T*>(k) + offset),
58-
make_shape(kv_len, Int<kHeadDim>{}),
59-
make_stride(Int<kHeadDim>{}, _1{}));
60-
auto V = make_tensor(make_gmem_ptr(static_cast<T*>(v) + offset),
61-
make_shape(kv_len, Int<kHeadDim>{}),
62-
make_stride(Int<kHeadDim>{}, _1{}));
61+
auto K = make_tensor(make_gmem_ptr((Element*)k + offset),
62+
make_shape(kv_len, HEAD_DIM{}),
63+
make_stride(HEAD_DIM{}, _1{}));
64+
auto V = make_tensor(make_gmem_ptr((Element*)v + offset),
65+
make_shape(kv_len, HEAD_DIM{}),
66+
make_stride(HEAD_DIM{}, _1{}));
6367

6468
// CTA/Block Shape
6569
// (BLK_M, head_dim)
66-
Tensor gQ = local_tile(
67-
Q, make_tile(Int<kBlockM>{}, Int<kHeadDim>{}), make_coord(m_block, _));
68-
Tensor gO = local_tile(
69-
O, make_tile(Int<kBlockM>{}, Int<kHeadDim>{}), make_coord(m_block, _));
70+
Tensor gQ =
71+
local_tile(Q, make_tile(BLK_M{}, HEAD_DIM{}), make_coord(m_block, _));
72+
Tensor gO =
73+
local_tile(O, make_tile(BLK_M{}, HEAD_DIM{}), make_coord(m_block, _));
7074

7175
// (BLK_N, head_dim)
72-
Tensor gK = local_tile(
73-
K, make_tile(Int<kBlockN>{}, Int<kHeadDim>{}), make_coord(0, _));
74-
Tensor gV = local_tile(
75-
V, make_tile(Int<kBlockN>{}, Int<kHeadDim>{}), make_coord(0, _));
76+
Tensor gK = local_tile(K, make_tile(BLK_N{}, HEAD_DIM{}), make_coord(0, _));
77+
Tensor gV = local_tile(V, make_tile(BLK_N{}, HEAD_DIM{}), make_coord(0, _));
7678

7779
// Smem
7880
extern __shared__ char smem[];
79-
T* q_smem = static_cast<T*>(smem);
80-
T* k_smem = q_smem + cosize(SmemLayoutQ{});
81-
T* v_smem = k_smem + cosize(SmemLayoutK{});
81+
Element* q_smem = (Element*)smem;
82+
Element* k_smem = q_smem + cosize(SmemLayoutQ{});
83+
Element* v_smem = k_smem + cosize(SmemLayoutK{});
8284

8385
// (BLK_M, BLK_K), k-major
8486
Tensor sQ = make_tensor(make_smem_ptr(q_smem), SmemLayoutQ{});
@@ -90,16 +92,25 @@ __global__ void mha_kernel_sm80(void* o,
9092
// (BLK_K, BLK_N), k-major
9193
Tensor sVt = make_tensor(make_smem_ptr(v_smem), SmemLayoutVt{});
9294

93-
// rmem for mma
95+
// Fragments for GEMM
9496
TiledMMA tiled_mma;
9597
auto thr_mma = tiled_mma.get_slice(tidx);
96-
// gemm-I
98+
// GEMM-I: S = Q@K.T
9799
auto tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
98100
auto tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
101+
auto tSrAccS = partition_fragment_C(
102+
tiled_mma, Shape<BLK_M, BLK_N>{}); // (MMA,MMA_M,MMA_N)
99103

100-
// gemm-II
104+
// GEMM-II: O = softmax(S)@V
101105
auto tOrVt = thr_mma.partition_fragment_B(sVt); // (MMA,MMA_K,MMA_N)
106+
auto tOrAccO = partition_fragment_C(
107+
tiled_mma, Shape<BLK_M, HEAD_DIM>{}); // (MMA,MMA_M,MMA_K)
102108

109+
// reshape for iterating over rows and columns
110+
auto tOrAccO_rc_view = Convertor::to_rowcol(tOrAccO);
111+
auto tSrAccS_rc_view = Convertor::to_rowcol(tSrAccS);
112+
113+
// Tiled Copy
103114
// g2s tiled copy for qkv
104115
GmemTiledCopyQKV gmem_tiled_copy_QKV;
105116
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
@@ -111,79 +122,64 @@ __global__ void mha_kernel_sm80(void* o,
111122
auto tVsV = gmem_thr_copy_QKV.partition_D(sV);
112123

113124
// s2r tiled copy for qkv
114-
auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtom{}, tiled_mma);
125+
SmemTiledCopyQ smem_tiled_copy_Q;
115126
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
116127
auto tSsQ = smem_thr_copy_Q.partition_S(sQ);
117128
auto tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
118129

119-
auto smem_tiled_copy_K = make_tiled_copy_B(SmemCopyAtom{}, tiled_mma);
130+
SmemTiledCopyK smem_tiled_copy_K;
120131
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
121132
auto tSsK = smem_thr_copy_K.partition_S(sK);
122133
auto tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK);
123134

124-
auto smem_tiled_copy_V =
125-
make_tiled_copy_B(SmemCopyAtomTransposed{}, tiled_mma);
126-
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
127-
auto tOsVt = smem_thr_copy_V.partition_S(sVt);
128-
auto tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt);
135+
SmemTiledCopyVT smem_tiled_copy_Vt;
136+
auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_thread_slice(tidx);
137+
auto tOsVt = smem_thr_copy_Vt.partition_S(sVt);
138+
auto tOrVt_copy_view = smem_thr_copy_Vt.retile_D(tOrVt);
129139

130140
// ############### Prologue ###############
131141

132-
// produce q
142+
// produce q: [] => [q]
133143
cute::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ);
134144
cp_async_fence();
135145

136-
// produce k
146+
// produce k: [q] => [q, k]
137147
cute::copy(gmem_tiled_copy_QKV, tKgK, tKsK);
138148
cp_async_fence();
139149

140-
// multiply sm scale
141150
// wait q: [q, k] => [k]
142-
cp_async_wait<0>();
151+
cp_async_wait<1>();
143152
__syncthreads();
144153

145154
// apply sm_scale
146155
// TODO: use thread parallelism
147156
for (int i = 0; i < size(tQsQ); ++i) {
148-
tQsQ(i) = T(tQsQ(i) * sm_scale);
157+
tQsQ(i) = Element(tQsQ(i) * sm_scale);
149158
}
150159

151-
// Final output fragment
152-
auto tOrO =
153-
partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{});
154-
clear(tOrO);
155-
156-
// reshape for iteration
157-
auto ol = logical_divide(tOrO.layout(), Shape<_2>{});
158-
auto rAccOut_new_layout = make_layout(make_layout(get<0, 1>(ol), get<1>(ol)),
159-
make_layout(get<0, 0>(ol), get<2>(ol)));
160-
auto tOrO_rc = make_tensor(tOrO.data(), rAccOut_new_layout);
161-
162160
// RowsPerThread = #rows_per_MMA * #MMA_M
163-
constexpr int RowsPerThread = 2 * size<1>(tOrO);
161+
constexpr int RowsPerThread = 2 * size<1>(tOrAccO);
164162
OnlineSoftmax<RowsPerThread> softmax;
165163

166164
// ############### Mainloop ###############
167165

168166
const int n_block_min = 0;
169-
const int n_block_max = cute::ceil_div(kv_len, kBlockN);
167+
const int n_block_max = cute::ceil_div(kv_len, BLK_N{});
168+
169+
// clear output
170+
clear(tOrAccO);
170171
CUTE_NO_UNROLL
171172
for (int ni = n_block_min; ni < n_block_max; ++ni) {
172-
// attention score
173-
// (MMA=4,MMA_M,MMA_N) (fp32)
174-
auto tSrS =
175-
partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
176-
// clear attention score
177-
clear(tSrS);
173+
// clear attention score for each block
174+
clear(tSrAccS);
178175

179176
// wait k, queue: [q, k] => []
180177
cp_async_wait<0>();
181178
__syncthreads();
182179

183180
// produce v, [] => [v]
184181
{
185-
gV = local_tile(
186-
V, make_tile(Int<kBlockN>{}, Int<kHeadDim>{}), make_coord(ni, _));
182+
gV = local_tile(V, make_tile(BLK_N{}, HEAD_DIM{}), make_coord(ni, _));
187183
tVgV = gmem_thr_copy_QKV.partition_S(gV(_, _, 0));
188184
cute::copy(gmem_tiled_copy_QKV, tVgV, tVsV);
189185
}
@@ -194,71 +190,64 @@ __global__ void mha_kernel_sm80(void* o,
194190
for (int ki = 0; ki < size<2>(tSrQ); ++ki) {
195191
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, ki), tSrQ_copy_view(_, _, ki));
196192
cute::copy(smem_tiled_copy_K, tSsK(_, _, ki), tSrK_copy_view(_, _, ki));
197-
cute::gemm(tiled_mma, tSrQ(_, _, ki), tSrK(_, _, ki), tSrS);
193+
cute::gemm(tiled_mma, tSrQ(_, _, ki), tSrK(_, _, ki), tSrAccS);
198194
}
199195

200-
// reshape for iteration
201-
auto sl = logical_divide(tSrS.layout(), Shape<_2>{});
202-
auto rAccScore_new_layout =
203-
make_layout(make_layout(get<0, 1>(sl), get<1>(sl)),
204-
make_layout(get<0, 0>(sl), get<2>(sl)));
205-
auto tSrS_rc = make_tensor(tSrS.data(), rAccScore_new_layout);
206-
207-
softmax.rescale(tSrS_rc, tOrO_rc);
196+
// apply softmax and rescale
197+
softmax.rescale(tSrAccS_rc_view, tOrAccO_rc_view);
208198

209199
// wait v, [v] => []
210200
cp_async_wait<0>();
211201
__syncthreads();
212202

213203
// produce next k: [] => [k]
214204
if (ni != n_block_max - 1) {
215-
gK = local_tile(
216-
K, make_tile(Int<kBlockN>{}, Int<kHeadDim>{}), make_coord(ni + 1, _));
205+
gK = local_tile(K, make_tile(BLK_N{}, HEAD_DIM{}), make_coord(ni + 1, _));
217206
tKgK = gmem_thr_copy_QKV.partition_S(gK(_, _, 0));
218207
cute::copy(gmem_tiled_copy_QKV, tKgK, tKsK);
219208
}
220209
cp_async_fence();
221210

222211
// 2> O = softmax(S)*V
223212

224-
// cast scores from fp32 to fp16
225-
auto tSrS_T = make_tensor_like<T>(tSrS);
213+
// cast scores from Accumulator to Element
214+
auto tSrS = make_tensor_like<Element>(tSrAccS);
226215
CUTE_UNROLL
227-
for (int i = 0; i < size(tSrS); ++i) {
228-
tSrS_T(i) = static_cast<T>(tSrS(i));
216+
for (int i = 0; i < size(tSrAccS); ++i) {
217+
tSrS(i) = static_cast<Element>(tSrAccS(i));
229218
}
219+
230220
// convert layout from gemm-I C to gemm-II A
231-
auto l = logical_divide(tSrS_T.layout(), Shape<X, X, _2>{});
232-
auto scores_new_layout = make_layout(
233-
make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
234-
auto tOrS = make_tensor(tSrS_T.data(), scores_new_layout);
221+
auto tOrS = Convertor::to_mma_a(tSrS);
235222

236223
CUTE_UNROLL
237224
for (int ki = 0; ki < size<2>(tOrS); ++ki) {
238-
cute::copy(smem_tiled_copy_V, tOsVt(_, _, ki), tOrVt_copy_view(_, _, ki));
239-
cute::gemm(tiled_mma, tOrS(_, _, ki), tOrVt(_, _, ki), tOrO);
225+
cute::copy(
226+
smem_tiled_copy_Vt, tOsVt(_, _, ki), tOrVt_copy_view(_, _, ki));
227+
cute::gemm(tiled_mma, tOrS(_, _, ki), tOrVt(_, _, ki), tOrAccO);
240228
}
241229
}
242230

243231
// ############### Epilogue ###############
244232

245233
// normalize output: o /= rowsum
246-
softmax.finalize(tOrO_rc);
234+
softmax.finalize(tOrAccO_rc_view);
247235

248236
// write output to gmem
249-
// 1> covernt output from fp32 to fp16
250-
auto tOrO_T = make_tensor_like<T>(tOrO);
237+
// 1> covernt output from ElementAccumulator to Element
238+
auto tOrO = make_tensor_like<Element>(tOrAccO);
251239
CUTE_UNROLL
252-
for (int si = 0; si < size(tOrO); ++si) {
253-
tOrO_T(si) = static_cast<T>(tOrO(si));
240+
for (int si = 0; si < size(tOrAccO); ++si) {
241+
tOrO(si) = static_cast<Element>(tOrAccO(si));
254242
}
255243

256244
// 2. copy output from reg to smem
257245
auto sO = make_tensor(sQ.data(), SmemLayoutO{});
258-
auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
246+
247+
SmemTiledCopyO smem_tiled_copy_O;
259248
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
260249
// ((Atom,AtomNum),MMA_M,MMA_N)
261-
auto taccOrO = smem_thr_copy_O.retile_S(tOrO_T);
250+
auto taccOrO = smem_thr_copy_O.retile_S(tOrO);
262251
// ((Atom,AtomNum),PIPE_M,PIPE_N)
263252
auto taccOsO = smem_thr_copy_O.partition_D(sO);
264253
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);

src/kernels/attention/attention_kernel_sm80_test.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <torch/torch.h>
33

44
#include "attention_kernel_sm80.cuh"
5+
#include "attention_traits_sm80.h"
56

67
namespace llm {
78
namespace {
@@ -50,6 +51,9 @@ TEST_P(AttentionKernelTest, MHA) {
5051
const auto key = torch::randn({seq_len, n_kv_heads, head_dim}, options);
5152
const auto value = torch::randn({seq_len, n_kv_heads, head_dim}, options);
5253

54+
using AttentionTraits = AttentionTraitsSM80<cute::half_t, 64, 64, 64>;
55+
auto attention_kernel = mha_kernel_sm80<AttentionTraits>;
56+
5357
auto ref_out = attention_ref(query, key, value);
5458

5559
// auto out = torch::empty_like(query);

0 commit comments

Comments
 (0)