Skip to content

Commit 5da2fe7

Browse files
authored
kernel: added attention kernel for sm80 (Happy new year!) (#355)
1 parent 6fbe549 commit 5da2fe7

File tree

8 files changed

+608
-7
lines changed

8 files changed

+608
-7
lines changed

.clang-format

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ BinPackArguments: false
66
ExperimentalAutoDetectBinPacking: false
77
AllowAllParametersOfDeclarationOnNextLine: false
88
DerivePointerAlignment: false
9+
AlwaysBreakTemplateDeclarations: Yes
910
PointerAlignment: Left
1011
ColumnLimit: 80
1112
...

src/kernels/attention/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ cc_test(
1818
attention_cpu.h
1919
SRCS
2020
cute_test.cpp
21-
attention_test.cpp
21+
attention_cpu_test.cpp
22+
attention_kernel_sm80_test.cu
2223
DEPS
2324
:attention.kernel
2425
glog::glog

src/kernels/attention/attention_cpu.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ inline void mha(torch::Tensor query,
9191
}
9292
// apply causal mask
9393
if (kv_idx_base + j > q_idx_base + q_idx) {
94-
s(j) = -INFINITY;
94+
s(j) = -5e4;
9595
}
9696
max = std::max(max, s(j));
9797
}

src/kernels/attention/attention_test.cpp renamed to src/kernels/attention/attention_cpu_test.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
#include <ATen/ops/equal.h>
2-
#include <gtest/gtest.h>
3-
#include <torch/csrc/autograd/generated/variable_factories.h>
4-
51
#include "attention_cpu.h"
62

3+
#include <gtest/gtest.h>
4+
#include <torch/torch.h>
5+
76
namespace llm {
87
namespace {
98
// Multi-head attention implementation using pytorch
@@ -36,7 +35,7 @@ torch::Tensor masked_self_attention(
3635
torch::Tensor mask = torch::ones({1, q_seq_len, seq_len}, torch::kBool);
3736
// returns the lower triangular part of a matrix
3837
mask = torch::tril(mask, /*diagonal=*/seq_len - q_seq_len).to(query);
39-
scores = scores.masked_fill(mask == 0, -INFINITY);
38+
scores = scores.masked_fill(mask == 0, -5e4);
4039

4140
// safe softmax
4241
scores = torch::softmax(scores, /*dim=*/-1);
Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
#pragma once
2+
3+
#include <cuda.h>
4+
#include <cuda_runtime.h>
5+
6+
#include <cute/tensor.hpp>
7+
8+
#include "online_softmax.cuh"
9+
10+
namespace llm {
11+
12+
template <typename Traits>
13+
__global__ void mha_kernel_sm80(void* o,
14+
const void* q,
15+
const void* k,
16+
const void* v,
17+
int h_stride,
18+
int q_len,
19+
int kv_len,
20+
float sm_scale) {
21+
using namespace cute;
22+
23+
// type alias
24+
using Element = typename Traits::Element;
25+
using BLK_M = typename Traits::BLK_M;
26+
using BLK_N = typename Traits::BLK_N;
27+
using BLK_K = typename Traits::BLK_K;
28+
using HEAD_DIM = typename Traits::HEAD_DIM;
29+
30+
using TiledMMA = typename Traits::TiledMMA;
31+
using Convertor = typename Traits::FragmentConvertor;
32+
33+
using SmemLayoutQ = typename Traits::SmemLayoutQ;
34+
using SmemLayoutK = typename Traits::SmemLayoutKV;
35+
using SmemLayoutV = typename Traits::SmemLayoutKV;
36+
using SmemLayoutVt = typename Traits::SmemLayoutVt;
37+
using SmemLayoutO = typename Traits::SmemLayoutO;
38+
using GmemTiledCopyQKV = typename Traits::GmemTiledCopyQKV;
39+
using GmemTiledCopyO = typename Traits::GmemTiledCopyO;
40+
41+
using SmemTiledCopyQ = typename Traits::SmemTiledCopyQ;
42+
using SmemTiledCopyK = typename Traits::SmemTiledCopyK;
43+
using SmemTiledCopyVT = typename Traits::SmemTiledCopyVT;
44+
using SmemTiledCopyO = typename Traits::SmemTiledCopyO;
45+
46+
const int m_block = blockIdx.x;
47+
const int base_id = blockIdx.y;
48+
const int tidx = threadIdx.x;
49+
50+
// ProblemShape
51+
// TODO: support non-contiguous layout
52+
const int offset = base_id * h_stride;
53+
// (q_len, head_dim)
54+
auto Q = make_tensor(make_gmem_ptr((Element*)q + offset),
55+
make_shape(q_len, HEAD_DIM{}),
56+
make_stride(HEAD_DIM{}, _1{}));
57+
auto O = make_tensor(make_gmem_ptr((Element*)o + offset),
58+
make_shape(q_len, HEAD_DIM{}),
59+
make_stride(HEAD_DIM{}, _1{}));
60+
// (kv_len, head_dim)
61+
auto K = make_tensor(make_gmem_ptr((Element*)k + offset),
62+
make_shape(kv_len, HEAD_DIM{}),
63+
make_stride(HEAD_DIM{}, _1{}));
64+
auto V = make_tensor(make_gmem_ptr((Element*)v + offset),
65+
make_shape(kv_len, HEAD_DIM{}),
66+
make_stride(HEAD_DIM{}, _1{}));
67+
68+
// CTA/Block Shape
69+
// (BLK_M, head_dim)
70+
Tensor gQ =
71+
local_tile(Q, make_tile(BLK_M{}, HEAD_DIM{}), make_coord(m_block, _));
72+
Tensor gO =
73+
local_tile(O, make_tile(BLK_M{}, HEAD_DIM{}), make_coord(m_block, _));
74+
75+
// (BLK_N, head_dim)
76+
Tensor gK = local_tile(K, make_tile(BLK_N{}, HEAD_DIM{}), make_coord(0, _));
77+
Tensor gV = local_tile(V, make_tile(BLK_N{}, HEAD_DIM{}), make_coord(0, _));
78+
79+
// Smem
80+
extern __shared__ char smem[];
81+
Element* q_smem = (Element*)smem;
82+
Element* k_smem = q_smem + cosize(SmemLayoutQ{});
83+
Element* v_smem = k_smem + cosize(SmemLayoutK{});
84+
85+
// (BLK_M, BLK_K), k-major
86+
Tensor sQ = make_tensor(make_smem_ptr(q_smem), SmemLayoutQ{});
87+
// (BLK_N, BLK_K), k-major
88+
Tensor sK = make_tensor(make_smem_ptr(k_smem), SmemLayoutK{});
89+
Tensor sV = make_tensor(make_smem_ptr(v_smem), SmemLayoutV{});
90+
91+
// Tensor for V^t; used in GEMM-II.
92+
// (BLK_K, BLK_N), k-major
93+
Tensor sVt = make_tensor(make_smem_ptr(v_smem), SmemLayoutVt{});
94+
95+
// Fragments for GEMM
96+
TiledMMA tiled_mma;
97+
auto thr_mma = tiled_mma.get_slice(tidx);
98+
// GEMM-I: S = Q@K.T
99+
auto tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
100+
auto tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
101+
auto tSrAccS = partition_fragment_C(
102+
tiled_mma, Shape<BLK_M, BLK_N>{}); // (MMA,MMA_M,MMA_N)
103+
104+
// GEMM-II: O = softmax(S)@V
105+
auto tOrVt = thr_mma.partition_fragment_B(sVt); // (MMA,MMA_K,MMA_N)
106+
auto tOrAccO = partition_fragment_C(
107+
tiled_mma, Shape<BLK_M, HEAD_DIM>{}); // (MMA,MMA_M,MMA_K)
108+
109+
// reshape for iterating over rows and columns
110+
auto tOrAccO_rc_view = Convertor::to_rowcol(tOrAccO);
111+
auto tSrAccS_rc_view = Convertor::to_rowcol(tSrAccS);
112+
113+
// Tiled Copy
114+
// g2s tiled copy for qkv
115+
GmemTiledCopyQKV gmem_tiled_copy_QKV;
116+
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
117+
auto tQgQ = gmem_thr_copy_QKV.partition_S(gQ(_, _, 0));
118+
auto tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
119+
auto tKgK = gmem_thr_copy_QKV.partition_S(gK(_, _, 0));
120+
auto tKsK = gmem_thr_copy_QKV.partition_D(sK);
121+
auto tVgV = gmem_thr_copy_QKV.partition_S(gV(_, _, 0));
122+
auto tVsV = gmem_thr_copy_QKV.partition_D(sV);
123+
124+
// s2r tiled copy for qkv
125+
SmemTiledCopyQ smem_tiled_copy_Q;
126+
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
127+
auto tSsQ = smem_thr_copy_Q.partition_S(sQ);
128+
auto tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
129+
130+
SmemTiledCopyK smem_tiled_copy_K;
131+
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
132+
auto tSsK = smem_thr_copy_K.partition_S(sK);
133+
auto tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK);
134+
135+
SmemTiledCopyVT smem_tiled_copy_Vt;
136+
auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_thread_slice(tidx);
137+
auto tOsVt = smem_thr_copy_Vt.partition_S(sVt);
138+
auto tOrVt_copy_view = smem_thr_copy_Vt.retile_D(tOrVt);
139+
140+
// ############### Prologue ###############
141+
142+
// produce q: [] => [q]
143+
cute::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ);
144+
cp_async_fence();
145+
146+
// produce k: [q] => [q, k]
147+
cute::copy(gmem_tiled_copy_QKV, tKgK, tKsK);
148+
cp_async_fence();
149+
150+
// wait q: [q, k] => [k]
151+
cp_async_wait<1>();
152+
__syncthreads();
153+
154+
// apply sm_scale
155+
// TODO: use thread parallelism
156+
for (int i = 0; i < size(tQsQ); ++i) {
157+
tQsQ(i) = Element(tQsQ(i) * sm_scale);
158+
}
159+
160+
// RowsPerThread = #rows_per_MMA * #MMA_M
161+
constexpr int RowsPerThread = 2 * size<1>(tOrAccO);
162+
OnlineSoftmax<RowsPerThread> softmax;
163+
164+
// ############### Mainloop ###############
165+
166+
const int n_block_min = 0;
167+
const int n_block_max = cute::ceil_div(kv_len, BLK_N{});
168+
169+
// clear output
170+
clear(tOrAccO);
171+
CUTE_NO_UNROLL
172+
for (int ni = n_block_min; ni < n_block_max; ++ni) {
173+
// clear attention score for each block
174+
clear(tSrAccS);
175+
176+
// wait k, queue: [q, k] => []
177+
cp_async_wait<0>();
178+
__syncthreads();
179+
180+
// produce v, [] => [v]
181+
{
182+
gV = local_tile(V, make_tile(BLK_N{}, HEAD_DIM{}), make_coord(ni, _));
183+
tVgV = gmem_thr_copy_QKV.partition_S(gV(_, _, 0));
184+
cute::copy(gmem_tiled_copy_QKV, tVgV, tVsV);
185+
}
186+
cp_async_fence();
187+
188+
// 1> S = Q@K.T
189+
CUTE_UNROLL
190+
for (int ki = 0; ki < size<2>(tSrQ); ++ki) {
191+
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, ki), tSrQ_copy_view(_, _, ki));
192+
cute::copy(smem_tiled_copy_K, tSsK(_, _, ki), tSrK_copy_view(_, _, ki));
193+
cute::gemm(tiled_mma, tSrQ(_, _, ki), tSrK(_, _, ki), tSrAccS);
194+
}
195+
196+
// apply softmax and rescale
197+
softmax.rescale(tSrAccS_rc_view, tOrAccO_rc_view);
198+
199+
// wait v, [v] => []
200+
cp_async_wait<0>();
201+
__syncthreads();
202+
203+
// produce next k: [] => [k]
204+
if (ni != n_block_max - 1) {
205+
gK = local_tile(K, make_tile(BLK_N{}, HEAD_DIM{}), make_coord(ni + 1, _));
206+
tKgK = gmem_thr_copy_QKV.partition_S(gK(_, _, 0));
207+
cute::copy(gmem_tiled_copy_QKV, tKgK, tKsK);
208+
}
209+
cp_async_fence();
210+
211+
// 2> O = softmax(S)*V
212+
213+
// cast scores from Accumulator to Element
214+
auto tSrS = make_tensor_like<Element>(tSrAccS);
215+
CUTE_UNROLL
216+
for (int i = 0; i < size(tSrAccS); ++i) {
217+
tSrS(i) = static_cast<Element>(tSrAccS(i));
218+
}
219+
220+
// convert layout from gemm-I C to gemm-II A
221+
auto tOrS = Convertor::to_mma_a(tSrS);
222+
223+
CUTE_UNROLL
224+
for (int ki = 0; ki < size<2>(tOrS); ++ki) {
225+
cute::copy(
226+
smem_tiled_copy_Vt, tOsVt(_, _, ki), tOrVt_copy_view(_, _, ki));
227+
cute::gemm(tiled_mma, tOrS(_, _, ki), tOrVt(_, _, ki), tOrAccO);
228+
}
229+
}
230+
231+
// ############### Epilogue ###############
232+
233+
// normalize output: o /= rowsum
234+
softmax.finalize(tOrAccO_rc_view);
235+
236+
// write output to gmem
237+
// 1> covernt output from ElementAccumulator to Element
238+
auto tOrO = make_tensor_like<Element>(tOrAccO);
239+
CUTE_UNROLL
240+
for (int si = 0; si < size(tOrAccO); ++si) {
241+
tOrO(si) = static_cast<Element>(tOrAccO(si));
242+
}
243+
244+
// 2. copy output from reg to smem
245+
auto sO = make_tensor(sQ.data(), SmemLayoutO{});
246+
247+
SmemTiledCopyO smem_tiled_copy_O;
248+
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
249+
// ((Atom,AtomNum),MMA_M,MMA_N)
250+
auto taccOrO = smem_thr_copy_O.retile_S(tOrO);
251+
// ((Atom,AtomNum),PIPE_M,PIPE_N)
252+
auto taccOsO = smem_thr_copy_O.partition_D(sO);
253+
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
254+
255+
// 3. copy output from smem to gmem
256+
GmemTiledCopyO gmem_tiled_copy_O;
257+
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
258+
// ((Atom,AtomNum),ATOM_M,ATOM_N)
259+
auto tOsO = gmem_thr_copy_O.partition_S(sO);
260+
auto tOgO = gmem_thr_copy_O.partition_D(gO(_, _, 0));
261+
262+
// wait for smem copy before copy to gmem
263+
__syncthreads();
264+
cute::copy(gmem_tiled_copy_O, tOsO, tOgO);
265+
}
266+
267+
} // namespace llm

0 commit comments

Comments
 (0)