Skip to content

Commit 38b82c0

Browse files
committed
attention kernel: added attention kernel for sm80
1 parent 6fbe549 commit 38b82c0

File tree

7 files changed

+518
-1
lines changed

7 files changed

+518
-1
lines changed

.clang-format

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ BinPackArguments: false
66
ExperimentalAutoDetectBinPacking: false
77
AllowAllParametersOfDeclarationOnNextLine: false
88
DerivePointerAlignment: false
9+
AlwaysBreakTemplateDeclarations: Yes
910
PointerAlignment: Left
1011
ColumnLimit: 80
1112
...

src/kernels/attention/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ cc_test(
1818
attention_cpu.h
1919
SRCS
2020
cute_test.cpp
21-
attention_test.cpp
21+
attention_cpu_test.cpp
22+
attention_kernel_sm80_test.cu
2223
DEPS
2324
:attention.kernel
2425
glog::glog
File renamed without changes.
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
#pragma once
2+
3+
#include <cuda.h>
4+
#include <cuda_runtime.h>
5+
6+
#include <cute/tensor.hpp>
7+
8+
#include "online_softmax.cuh"
9+
10+
namespace llm {
11+
12+
template <typename Traits>
13+
__global__ void mha_kernel_sm80(void* o,
14+
const void* q,
15+
const void* k,
16+
const void* v,
17+
int h_stride,
18+
int q_len,
19+
int kv_len,
20+
float sm_scale) {
21+
using namespace cute;
22+
23+
// type alias
24+
using T = typename Traits::T;
25+
using SmemLayoutQ = typename Traits::SmemLayoutQ;
26+
using SmemLayoutK = typename Traits::SmemLayoutKV;
27+
using SmemLayoutV = typename Traits::SmemLayoutKV;
28+
using SmemLayoutVt = typename Traits::SmemLayoutVt;
29+
30+
using SmemLayoutO = typename Traits::SmemLayoutO;
31+
using SmemCopyAtom = typename Traits::SmemCopyAtom;
32+
using SmemCopyAtomO = typename Traits::SmemCopyAtomO;
33+
using GmemTiledCopyQKV = typename Traits::GmemTiledCopyQKV;
34+
using GmemTiledCopyO = typename Traits::GmemTiledCopyO;
35+
using SmemCopyAtomTransposed = typename Traits::SmemCopyAtomTransposed;
36+
using TiledMMA = typename Traits::TiledMMA;
37+
38+
const int m_block = blockIdx.x;
39+
const int base_id = blockIdx.y;
40+
const int tidx = threadIdx.x;
41+
42+
constexpr int kBlockM = Traits::kBlockM;
43+
constexpr int kBlockN = Traits::kBlockN;
44+
constexpr int kHeadDim = Traits::kHeadDim;
45+
46+
// ProblemShape
47+
// TODO: support non-contiguous layout
48+
const int offset = base_id * h_stride;
49+
// (q_len, head_dim)
50+
auto Q = make_tensor(make_gmem_ptr(static_cast<T*>(q) + offset),
51+
make_shape(q_len, Int<kHeadDim>{}),
52+
make_stride(Int<kHeadDim>{}, _1{}));
53+
auto O = make_tensor(make_gmem_ptr(static_cast<T*>(o) + offset),
54+
make_shape(q_len, Int<kHeadDim>{}),
55+
make_stride(Int<kHeadDim>{}, _1{}));
56+
// (kv_len, head_dim)
57+
auto K = make_tensor(make_gmem_ptr(static_cast<T*>(k) + offset),
58+
make_shape(kv_len, Int<kHeadDim>{}),
59+
make_stride(Int<kHeadDim>{}, _1{}));
60+
auto V = make_tensor(make_gmem_ptr(static_cast<T*>(v) + offset),
61+
make_shape(kv_len, Int<kHeadDim>{}),
62+
make_stride(Int<kHeadDim>{}, _1{}));
63+
64+
// CTA/Block Shape
65+
// (BLK_M, head_dim)
66+
Tensor gQ = local_tile(
67+
Q, make_tile(Int<kBlockM>{}, Int<kHeadDim>{}), make_coord(m_block, _));
68+
Tensor gO = local_tile(
69+
O, make_tile(Int<kBlockM>{}, Int<kHeadDim>{}), make_coord(m_block, _));
70+
71+
// (BLK_N, head_dim)
72+
Tensor gK = local_tile(
73+
K, make_tile(Int<kBlockN>{}, Int<kHeadDim>{}), make_coord(0, _));
74+
Tensor gV = local_tile(
75+
V, make_tile(Int<kBlockN>{}, Int<kHeadDim>{}), make_coord(0, _));
76+
77+
// Smem
78+
extern __shared__ char smem[];
79+
T* q_smem = static_cast<T*>(smem);
80+
T* k_smem = q_smem + cosize(SmemLayoutQ{});
81+
T* v_smem = k_smem + cosize(SmemLayoutK{});
82+
83+
// (BLK_M, BLK_K), k-major
84+
Tensor sQ = make_tensor(make_smem_ptr(q_smem), SmemLayoutQ{});
85+
// (BLK_N, BLK_K), k-major
86+
Tensor sK = make_tensor(make_smem_ptr(k_smem), SmemLayoutK{});
87+
Tensor sV = make_tensor(make_smem_ptr(v_smem), SmemLayoutV{});
88+
89+
// Tensor for V^t; used in GEMM-II.
90+
// (BLK_K, BLK_N), k-major
91+
Tensor sVt = make_tensor(make_smem_ptr(v_smem), SmemLayoutVt{});
92+
93+
// rmem for mma
94+
TiledMMA tiled_mma;
95+
auto thr_mma = tiled_mma.get_slice(tidx);
96+
// gemm-I
97+
auto tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
98+
auto tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
99+
100+
// gemm-II
101+
auto tOrVt = thr_mma.partition_fragment_B(sVt); // (MMA,MMA_K,MMA_N)
102+
103+
// g2s tiled copy for qkv
104+
GmemTiledCopyQKV gmem_tiled_copy_QKV;
105+
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
106+
auto tQgQ = gmem_thr_copy_QKV.partition_S(gQ(_, _, 0));
107+
auto tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
108+
auto tKgK = gmem_thr_copy_QKV.partition_S(gK(_, _, 0));
109+
auto tKsK = gmem_thr_copy_QKV.partition_D(sK);
110+
auto tVgV = gmem_thr_copy_QKV.partition_S(gV(_, _, 0));
111+
auto tVsV = gmem_thr_copy_QKV.partition_D(sV);
112+
113+
// s2r tiled copy for qkv
114+
auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtom{}, tiled_mma);
115+
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
116+
auto tSsQ = smem_thr_copy_Q.partition_S(sQ);
117+
auto tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
118+
119+
auto smem_tiled_copy_K = make_tiled_copy_B(SmemCopyAtom{}, tiled_mma);
120+
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
121+
auto tSsK = smem_thr_copy_K.partition_S(sK);
122+
auto tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK);
123+
124+
auto smem_tiled_copy_V =
125+
make_tiled_copy_B(SmemCopyAtomTransposed{}, tiled_mma);
126+
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
127+
auto tOsVt = smem_thr_copy_V.partition_S(sVt);
128+
auto tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt);
129+
130+
// ############### Prologue ###############
131+
132+
// produce q
133+
cute::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ);
134+
cp_async_fence();
135+
136+
// produce k
137+
cute::copy(gmem_tiled_copy_QKV, tKgK, tKsK);
138+
cp_async_fence();
139+
140+
// multiply sm scale
141+
// wait q: [q, k] => [k]
142+
cp_async_wait<0>();
143+
__syncthreads();
144+
145+
// apply sm_scale
146+
// TODO: use thread parallelism
147+
for (int i = 0; i < size(tQsQ); ++i) {
148+
tQsQ(i) = T(tQsQ(i) * sm_scale);
149+
}
150+
151+
// Final output fragment
152+
auto tOrO =
153+
partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{});
154+
clear(tOrO);
155+
156+
// reshape for iteration
157+
auto ol = logical_divide(tOrO.layout(), Shape<_2>{});
158+
auto rAccOut_new_layout = make_layout(make_layout(get<0, 1>(ol), get<1>(ol)),
159+
make_layout(get<0, 0>(ol), get<2>(ol)));
160+
auto tOrO_rc = make_tensor(tOrO.data(), rAccOut_new_layout);
161+
162+
// RowsPerThread = #rows_per_MMA * #MMA_M
163+
constexpr int RowsPerThread = 2 * size<1>(tOrO);
164+
OnlineSoftmax<RowsPerThread> softmax;
165+
166+
// ############### Mainloop ###############
167+
168+
const int n_block_min = 0;
169+
const int n_block_max = cute::ceil_div(kv_len, kBlockN);
170+
CUTE_NO_UNROLL
171+
for (int ni = n_block_min; ni < n_block_max; ++ni) {
172+
// attention score
173+
// (MMA=4,MMA_M,MMA_N) (fp32)
174+
auto tSrS =
175+
partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
176+
// clear attention score
177+
clear(tSrS);
178+
179+
// wait k, queue: [q, k] => []
180+
cp_async_wait<0>();
181+
__syncthreads();
182+
183+
// produce v, [] => [v]
184+
{
185+
gV = local_tile(
186+
V, make_tile(Int<kBlockN>{}, Int<kHeadDim>{}), make_coord(ni, _));
187+
tVgV = gmem_thr_copy_QKV.partition_S(gV(_, _, 0));
188+
cute::copy(gmem_tiled_copy_QKV, tVgV, tVsV);
189+
}
190+
cp_async_fence();
191+
192+
// 1> S = Q@K.T
193+
CUTE_UNROLL
194+
for (int ki = 0; ki < size<2>(tSrQ); ++ki) {
195+
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, ki), tSrQ_copy_view(_, _, ki));
196+
cute::copy(smem_tiled_copy_K, tSsK(_, _, ki), tSrK_copy_view(_, _, ki));
197+
cute::gemm(tiled_mma, tSrQ(_, _, ki), tSrK(_, _, ki), tSrS);
198+
}
199+
200+
// reshape for iteration
201+
auto sl = logical_divide(tSrS.layout(), Shape<_2>{});
202+
auto rAccScore_new_layout =
203+
make_layout(make_layout(get<0, 1>(sl), get<1>(sl)),
204+
make_layout(get<0, 0>(sl), get<2>(sl)));
205+
auto tSrS_rc = make_tensor(tSrS.data(), rAccScore_new_layout);
206+
207+
softmax.rescale(tSrS_rc, tOrO_rc);
208+
209+
// wait v, [v] => []
210+
cp_async_wait<0>();
211+
__syncthreads();
212+
213+
// produce next k: [] => [k]
214+
if (ni != n_block_max - 1) {
215+
gK = local_tile(
216+
K, make_tile(Int<kBlockN>{}, Int<kHeadDim>{}), make_coord(ni + 1, _));
217+
tKgK = gmem_thr_copy_QKV.partition_S(gK(_, _, 0));
218+
cute::copy(gmem_tiled_copy_QKV, tKgK, tKsK);
219+
}
220+
cp_async_fence();
221+
222+
// 2> O = softmax(S)*V
223+
224+
// cast scores from fp32 to fp16
225+
auto tSrS_T = make_tensor_like<T>(tSrS);
226+
CUTE_UNROLL
227+
for (int i = 0; i < size(tSrS); ++i) {
228+
tSrS_T(i) = static_cast<T>(tSrS(i));
229+
}
230+
// convert layout from gemm-I C to gemm-II A
231+
auto l = logical_divide(tSrS_T.layout(), Shape<X, X, _2>{});
232+
auto scores_new_layout = make_layout(
233+
make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
234+
auto tOrS = make_tensor(tSrS_T.data(), scores_new_layout);
235+
236+
CUTE_UNROLL
237+
for (int ki = 0; ki < size<2>(tOrS); ++ki) {
238+
cute::copy(smem_tiled_copy_V, tOsVt(_, _, ki), tOrVt_copy_view(_, _, ki));
239+
cute::gemm(tiled_mma, tOrS(_, _, ki), tOrVt(_, _, ki), tOrO);
240+
}
241+
}
242+
243+
// ############### Epilogue ###############
244+
245+
// normalize output: o /= rowsum
246+
softmax.finalize(tOrO_rc);
247+
248+
// write output to gmem
249+
// 1> covernt output from fp32 to fp16
250+
auto tOrO_T = make_tensor_like<T>(tOrO);
251+
CUTE_UNROLL
252+
for (int si = 0; si < size(tOrO); ++si) {
253+
tOrO_T(si) = static_cast<T>(tOrO(si));
254+
}
255+
256+
// 2. copy output from reg to smem
257+
auto sO = make_tensor(sQ.data(), SmemLayoutO{});
258+
auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
259+
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
260+
// ((Atom,AtomNum),MMA_M,MMA_N)
261+
auto taccOrO = smem_thr_copy_O.retile_S(tOrO_T);
262+
// ((Atom,AtomNum),PIPE_M,PIPE_N)
263+
auto taccOsO = smem_thr_copy_O.partition_D(sO);
264+
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
265+
266+
// 3. copy output from smem to gmem
267+
GmemTiledCopyO gmem_tiled_copy_O;
268+
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
269+
// ((Atom,AtomNum),ATOM_M,ATOM_N)
270+
auto tOsO = gmem_thr_copy_O.partition_S(sO);
271+
auto tOgO = gmem_thr_copy_O.partition_D(gO(_, _, 0));
272+
273+
// wait for smem copy before copy to gmem
274+
__syncthreads();
275+
cute::copy(gmem_tiled_copy_O, tOsO, tOgO);
276+
}
277+
278+
} // namespace llm
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#include <gtest/gtest.h>
2+
#include <torch/torch.h>
3+
4+
#include "attention_kernel_sm80.cuh"
5+
6+
namespace llm {
7+
namespace {
8+
// Multi-head attention implementation using pytorch
9+
torch::Tensor attention_ref(
10+
torch::Tensor query, // [q_seq_len, n_heads, head_dim]
11+
torch::Tensor key, // [seq_len, n_kv_heads, head_dim]
12+
torch::Tensor value // [seq_len, n_kv_heads, head_dim]
13+
) {
14+
const auto q_seq_len = query.size(0);
15+
const auto n_heads = query.size(1);
16+
const auto head_dim = query.size(2);
17+
const auto seq_len = key.size(0);
18+
const auto n_kv_heads = key.size(1);
19+
20+
assert(n_heads == n_kv_heads);
21+
22+
// query * key => [n_heads, q_seq_len, seq_len]
23+
auto scores = torch::einsum("qhd,khd->hqk", {query, key});
24+
// apply scale
25+
const float sm_scale = static_cast<float>(1.0 / std::sqrt(head_dim));
26+
scores *= sm_scale;
27+
28+
// safe softmax
29+
scores = torch::softmax(scores, /*dim=*/-1);
30+
31+
// score * value => [q_seq_len, n_heads, head_dim]
32+
return torch::einsum("hqk,khd->qhd", {scores, value});
33+
}
34+
35+
} // namespace
36+
37+
class AttentionKernelTest
38+
: public ::testing::TestWithParam<std::tuple<int64_t /*seq_len*/,
39+
int64_t /*q_seq_len*/,
40+
int64_t /*n_heads*/,
41+
int64_t /*n_kv_heads*/,
42+
int64_t /*head_dim*/>> {};
43+
44+
TEST_P(AttentionKernelTest, MHA) {
45+
const auto [seq_len, q_seq_len, n_heads, n_kv_heads, head_dim] = GetParam();
46+
47+
const auto options = torch::dtype(torch::kFloat).device(torch::kCPU);
48+
49+
const auto query = torch::randn({q_seq_len, n_heads, head_dim}, options);
50+
const auto key = torch::randn({seq_len, n_kv_heads, head_dim}, options);
51+
const auto value = torch::randn({seq_len, n_kv_heads, head_dim}, options);
52+
53+
auto ref_out = attention_ref(query, key, value);
54+
55+
// auto out = torch::empty_like(query);
56+
// mha(query, key, value, out);
57+
// EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-5, /*atol=*/1e-5));
58+
}
59+
60+
INSTANTIATE_TEST_SUITE_P(MHA,
61+
AttentionKernelTest,
62+
::testing::Combine(::testing::Values(64), // seq_len
63+
::testing::Values(64), // q_seq_len
64+
::testing::Values(8), // n_heads
65+
::testing::Values(8), // n_kv_heads
66+
::testing::Values(64) // head_dim
67+
));
68+
69+
} // namespace llm

0 commit comments

Comments
 (0)