Skip to content

Commit aacbdfd

Browse files
committed
kernel: added logits soft cap support for attention
1 parent b78389c commit aacbdfd

File tree

8 files changed

+172
-50
lines changed

8 files changed

+172
-50
lines changed

.devcontainer/Dockerfile

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@ ARG BASE_IMAGE=vectorchai/scalellm_devel:cuda12.4
22
FROM ${BASE_IMAGE}
33

44
ARG USER=vscode
5+
ARG UID=1000
6+
ARG GID=1000
7+
58
# Run as non-root user
6-
RUN useradd -m ${USER}
9+
RUN groupadd --gid ${GID} ${USER} \
10+
&& useradd --uid ${UID} --gid ${GID} -m ${USER} --shell /bin/bash
711
RUN echo ${USER} ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/${USER} \
812
&& chmod 0440 /etc/sudoers.d/${USER}
913

.devcontainer/devcontainer.json

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
"dockerfile": "Dockerfile",
66
"args": {
77
"BASE_IMAGE": "vectorchai/scalellm_devel:cuda12.4",
8-
"USER": "${localEnv:USER}"
8+
"USER": "${localEnv:USER:vscode}",
9+
"UID": "${localEnv:UID:1000}",
10+
"GID": "${localEnv:GID:1000}"
911
}
1012
},
1113
// Access GPUs from inside the container
@@ -17,8 +19,8 @@
1719
"HUGGING_FACE_HUB_TOKEN": "${localEnv:HUGGING_FACE_HUB_TOKEN}"
1820
},
1921
// Run as the current user
20-
"remoteUser": "${localEnv:USER}",
21-
"containerUser": "${localEnv:USER}",
22+
"remoteUser": "${localEnv:USER:vscode}",
23+
"containerUser": "${localEnv:USER:vscode}",
2224
"updateRemoteUserUID": true,
2325
// Ports should be forwarded from inside container to the local machine
2426
"forwardPorts": [],
@@ -35,7 +37,9 @@
3537
"ms-vscode.cpptools-extension-pack",
3638
"llvm-vs-code-extensions.vscode-clangd",
3739
"ms-python.python",
38-
"ms-azuretools.vscode-docker"
40+
"ms-azuretools.vscode-docker",
41+
"ziruiwang.nvidia-monitor",
42+
"mutantdino.resourcemonitor"
3943
],
4044
"settings": {
4145
"extensions.verifySignature": false,

src/kernels/attention/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ cc_library(
77
attention.kernel
88
HDRS
99
attention_cpu.h
10+
ptx.cuh
11+
fast_cast.cuh
12+
online_softmax.cuh
13+
attention_traits_sm80.h
1014
attention_kernel_sm80.cuh
1115
SRCS
1216
# attention.cu
@@ -37,6 +41,8 @@ cc_binary(
3741
nvbench::nvbench
3842
nvbench::main
3943
:attention.kernel
44+
COPTS
45+
-lineinfo
4046
)
4147

4248
add_subdirectory(flash_attn)

src/kernels/attention/attention_bench_sm80.cu

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ void attention_bench_sm80(nvbench::state& state) {
2020
const auto n_heads = state.get_int64("n_heads");
2121
const auto n_kv_heads = state.get_int64("n_kv_heads");
2222
const auto head_dim = state.get_int64("head_dim");
23+
const float logits_soft_cap = state.get_float64("logits_soft_cap");
2324

2425
const auto options = torch::dtype(torch::kHalf).device(torch::kCUDA);
2526
const auto query =
@@ -31,7 +32,7 @@ void attention_bench_sm80(nvbench::state& state) {
3132

3233
auto out = torch::empty_like(query);
3334

34-
const float sm_scale = 1.0 / sqrt(head_dim) * M_LOG2E;
35+
const float sm_scale = 1.0 / sqrt(head_dim);
3536
const auto h_stride = query.stride(1);
3637
const auto kv_h_stride = key.stride(1);
3738

@@ -43,7 +44,7 @@ void attention_bench_sm80(nvbench::state& state) {
4344
AttentionTraitsSM80<cute::half_t, kHeadDim, kBlockM, kBlockN>;
4445

4546
dim3 block = AttentionTraits::kThreadNum;
46-
dim3 grid((q_len + kBlockM - 1) / kBlockM, batch_size * head_dim);
47+
dim3 grid(q_len / kBlockM, batch_size * n_heads);
4748

4849
const auto smem_size = AttentionTraits::kSmemSize;
4950
auto attention_kernel = mha_kernel_sm80<AttentionTraits>;
@@ -61,14 +62,16 @@ void attention_bench_sm80(nvbench::state& state) {
6162
kv_h_stride,
6263
q_len,
6364
kv_len,
64-
sm_scale);
65+
sm_scale,
66+
logits_soft_cap);
6567
});
6668
}
6769

6870
NVBENCH_BENCH(attention_bench_sm80)
6971
.add_int64_axis("batch_size", {1})
70-
.add_int64_axis("q_len", {64})
71-
.add_int64_axis("kv_len", {64, 128})
72-
.add_int64_axis("n_heads", {2})
73-
.add_int64_axis("n_kv_heads", {2})
74-
.add_int64_axis("head_dim", {64});
72+
.add_int64_axis("q_len", {1024})
73+
.add_int64_axis("kv_len", {1024})
74+
.add_int64_axis("n_heads", {32})
75+
.add_int64_axis("n_kv_heads", {32})
76+
.add_int64_axis("head_dim", {64})
77+
.add_float64_axis("logits_soft_cap", {0.0});

src/kernels/attention/attention_kernel_sm80.cuh

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include "fast_cast.cuh"
99
#include "online_softmax.cuh"
10+
#include "ptx.cuh"
1011

1112
namespace llm {
1213

@@ -19,7 +20,8 @@ __global__ void mha_kernel_sm80(void* o,
1920
int64_t kv_h_stride,
2021
int64_t q_len,
2122
int64_t kv_len,
22-
float sm_scale) {
23+
float sm_scale,
24+
float logits_soft_cap) {
2325
using namespace cute;
2426

2527
// type alias
@@ -49,6 +51,30 @@ __global__ void mha_kernel_sm80(void* o,
4951
const auto base_id = blockIdx.y;
5052
const auto tidx = threadIdx.x;
5153

54+
// preprocess input parameters
55+
// TODO: Move following logic to the host side?
56+
if (logits_soft_cap != 0.0) {
57+
// Softmax(x * sm_scale) + apply_logits_soft_cap
58+
// => Softmax(Tanh(x * sm_scale / soft_cap) * soft_cap)
59+
// => Softmax(S' * sm_scale') where
60+
// S' = Tanh(x * sm_scale / soft_cap)
61+
// = Tanh(x * soft_cap')
62+
// soft_cap' = sm_scale / soft_cap
63+
// sm_scale' = soft_cap
64+
const auto sm_scale_hat = logits_soft_cap;
65+
logits_soft_cap = sm_scale * ptx::rcp(logits_soft_cap);
66+
sm_scale = sm_scale_hat;
67+
}
68+
auto apply_logits_soft_cap = [&](auto& tSrAccS) {
69+
CUTE_UNROLL
70+
for (int i = 0; i < size(tSrAccS); ++i) {
71+
tSrAccS(i) = ptx::tanh(tSrAccS(i) * logits_soft_cap);
72+
}
73+
};
74+
75+
// use exp2f instead of expf for better performance
76+
sm_scale *= M_LOG2E;
77+
5278
// ProblemShape
5379
// TODO: support non-contiguous layout
5480
// (q_len, head_dim)
@@ -136,10 +162,22 @@ __global__ void mha_kernel_sm80(void* o,
136162
// S = Q@K.T
137163
// tSrAccS: (MMA,MMA_M,MMA_N)
138164
auto compute_qk = [&](auto& tSrAccS) {
165+
// prefetch kv
166+
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, _0{}), tSrQ_copy_view(_, _, _0{}));
167+
cute::copy(smem_tiled_copy_K, tSsK(_, _, _0{}), tSrK_copy_view(_, _, _0{}));
168+
139169
CUTE_UNROLL
140170
for (int ki = 0; ki < size<2>(tSrQ); ++ki) {
141-
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, ki), tSrQ_copy_view(_, _, ki));
142-
cute::copy(smem_tiled_copy_K, tSsK(_, _, ki), tSrK_copy_view(_, _, ki));
171+
// prefetch next kv
172+
if (ki != size<2>(tSrQ) - 1) {
173+
const auto next_ki = ki + 1;
174+
cute::copy(smem_tiled_copy_Q,
175+
tSsQ(_, _, next_ki),
176+
tSrQ_copy_view(_, _, next_ki));
177+
cute::copy(smem_tiled_copy_K,
178+
tSsK(_, _, next_ki),
179+
tSrK_copy_view(_, _, next_ki));
180+
}
143181
cute::gemm(tiled_mma, tSrQ(_, _, ki), tSrK(_, _, ki), tSrAccS);
144182
}
145183
};
@@ -163,10 +201,18 @@ __global__ void mha_kernel_sm80(void* o,
163201
// convert layout from gemm-I C to gemm-II A
164202
auto tOrS = make_tensor(tSrS.data(), Layout::to_mma_a(tSrS.layout()));
165203

204+
// prefetch V^t
205+
cute::copy(
206+
smem_tiled_copy_Vt, tOsVt(_, _, _0{}), tOrVt_copy_view(_, _, _0{}));
166207
CUTE_UNROLL
167208
for (int ki = 0; ki < size<2>(tOrS); ++ki) {
168-
cute::copy(
169-
smem_tiled_copy_Vt, tOsVt(_, _, ki), tOrVt_copy_view(_, _, ki));
209+
// prefetch next V^t
210+
if (ki != size<2>(tOrS) - 1) {
211+
const auto next_ki = ki + 1;
212+
cute::copy(smem_tiled_copy_Vt,
213+
tOsVt(_, _, next_ki),
214+
tOrVt_copy_view(_, _, next_ki));
215+
}
170216
cute::gemm(tiled_mma, tOrS(_, _, ki), tOrVt(_, _, ki), tOrAccO);
171217
}
172218
};
@@ -246,6 +292,11 @@ __global__ void mha_kernel_sm80(void* o,
246292
// 1> S = Q@K.T
247293
compute_qk(tSrAccS);
248294

295+
// apply soft cap if needed
296+
if (logits_soft_cap != 0.0) {
297+
apply_logits_soft_cap(tSrAccS);
298+
}
299+
249300
// apply softmax and rescale
250301
softmax.rescale(tSrAccS_rc_view, tOrAccO_rc_view);
251302

src/kernels/attention/attention_kernel_sm80_test.cu

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,32 +11,40 @@ namespace {
1111
torch::Tensor attention_ref(
1212
torch::Tensor query, // [batch_size, n_heads, q_len, head_dim]
1313
torch::Tensor key, // [batch_size, n_kv_heads, kv_len, head_dim]
14-
torch::Tensor value // [batch_size, n_kv_heads, kv_len, head_dim]
15-
) {
14+
torch::Tensor value, // [batch_size, n_kv_heads, kv_len, head_dim]
15+
float logits_soft_cap) {
1616
const auto n_heads = query.size(1);
1717
const auto n_kv_heads = key.size(1);
1818
const auto head_dim = query.size(3);
1919
assert(n_heads == n_kv_heads);
2020

2121
const float sm_scale = 1.0 / sqrt(head_dim);
2222
// query * key => [n_heads, q_seq_len, seq_len]
23-
auto scores = torch::einsum("bhqd,bhkd->bhqk", {query, key});
23+
auto scores = torch::einsum("bhqd,bhkd->bhqk",
24+
{query.to(torch::kFloat), key.to(torch::kFloat)});
2425
// apply scale
2526
scores *= sm_scale;
2627

28+
// apply softcap if needed
29+
if (logits_soft_cap != 0.0) {
30+
scores = torch::tanh(scores / logits_soft_cap) * logits_soft_cap;
31+
}
32+
2733
// safe softmax
2834
scores = torch::softmax(scores, /*dim=*/-1);
2935

3036
// score * value => [batch_size, n_heads, q_seq_len, head_dim]
31-
return torch::einsum("bhqk,bhkd->bhqd", {scores, value});
37+
return torch::einsum("bhqk,bhkd->bhqd", {scores, value.to(torch::kFloat)})
38+
.type_as(query);
3239
}
3340

3441
torch::Tensor attention_sm80(
3542
torch::Tensor query, // [batch_size, n_heads, q_len, head_dim]
3643
torch::Tensor key, // [batch_size, n_kv_heads, kv_len, head_dim]
37-
torch::Tensor value // [batch_size, n_kv_heads, kv_len, head_dim]
38-
) {
44+
torch::Tensor value, // [batch_size, n_kv_heads, kv_len, head_dim]
45+
float logits_soft_cap) {
3946
const auto batch_size = query.size(0);
47+
const auto n_heads = query.size(1);
4048
const auto q_len = query.size(2);
4149
const auto kv_len = key.size(2);
4250
const auto head_dim = query.size(3);
@@ -50,13 +58,13 @@ torch::Tensor attention_sm80(
5058
constexpr int32_t kBlockM = 64;
5159
constexpr int32_t kBlockN = 64;
5260

53-
const float sm_scale = 1.0 / sqrt(head_dim) * M_LOG2E;
61+
const float sm_scale = 1.0 / sqrt(head_dim);
5462

5563
using AttentionTraits =
5664
AttentionTraitsSM80<cute::half_t, kHeadDim, kBlockM, kBlockN>;
5765

5866
dim3 block = AttentionTraits::kThreadNum;
59-
dim3 grid((q_len + kBlockM - 1) / kBlockM, batch_size * head_dim);
67+
dim3 grid((q_len + kBlockM - 1) / kBlockM, batch_size * n_heads);
6068

6169
const auto smem_size = AttentionTraits::kSmemSize;
6270
auto attention_kernel = mha_kernel_sm80<AttentionTraits>;
@@ -72,7 +80,8 @@ torch::Tensor attention_sm80(
7280
kv_h_stride,
7381
q_len,
7482
kv_len,
75-
sm_scale);
83+
sm_scale,
84+
logits_soft_cap);
7685
C10_CUDA_KERNEL_LAUNCH_CHECK();
7786
return out;
7887
}
@@ -85,11 +94,23 @@ class AttentionKernelTest
8594
int64_t /*kv_len*/,
8695
int64_t /*n_heads*/,
8796
int64_t /*n_kv_heads*/,
88-
int64_t /*head_dim*/>> {};
97+
int64_t /*head_dim*/,
98+
float /*logits_soft_cap*/>> {
99+
public:
100+
void SetUp() override {
101+
// Set random seed for test stability
102+
torch::manual_seed(0);
103+
}
104+
};
89105

90106
TEST_P(AttentionKernelTest, MHA) {
91-
const auto [batch_size, q_len, kv_len, n_heads, n_kv_heads, head_dim] =
92-
GetParam();
107+
const auto [batch_size,
108+
q_len,
109+
kv_len,
110+
n_heads,
111+
n_kv_heads,
112+
head_dim,
113+
logits_soft_cap] = GetParam();
93114

94115
const auto options = torch::dtype(torch::kHalf).device(torch::kCUDA);
95116

@@ -100,21 +121,22 @@ TEST_P(AttentionKernelTest, MHA) {
100121
const auto value =
101122
torch::randn({batch_size, n_kv_heads, kv_len, head_dim}, options);
102123

103-
auto ref_out = attention_ref(query, key, value);
104-
auto out = attention_sm80(query, key, value);
124+
auto ref_out = attention_ref(query, key, value, logits_soft_cap);
125+
auto out = attention_sm80(query, key, value, logits_soft_cap);
105126

106127
EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-3, /*atol=*/1e-3));
107128
}
108129

109-
INSTANTIATE_TEST_SUITE_P(MHA,
110-
AttentionKernelTest,
111-
::testing::Combine(::testing::Values(1), // batch_size
112-
::testing::Values(64), // q_len
113-
::testing::Values(64,
114-
256), // kv_len
115-
::testing::Values(2), // n_heads
116-
::testing::Values(2), // n_kv_heads
117-
::testing::Values(64) // head_dim
118-
));
130+
INSTANTIATE_TEST_SUITE_P(
131+
MHA,
132+
AttentionKernelTest,
133+
::testing::Combine(::testing::Values(1, 2, 4), // batch_size
134+
::testing::Values(128, 256, 1024), // q_len
135+
::testing::Values(128, 256, 1024), // kv_len
136+
::testing::Values(16), // n_heads
137+
::testing::Values(16), // n_kv_heads
138+
::testing::Values(64), // head_dim
139+
::testing::Values(0.0, 50.0) // logits_soft_cap
140+
));
119141

120142
} // namespace llm

0 commit comments

Comments
 (0)