@@ -11,32 +11,40 @@ namespace {
1111torch::Tensor attention_ref (
1212 torch::Tensor query, // [batch_size, n_heads, q_len, head_dim]
1313 torch::Tensor key, // [batch_size, n_kv_heads, kv_len, head_dim]
14- torch::Tensor value // [batch_size, n_kv_heads, kv_len, head_dim]
15- ) {
14+ torch::Tensor value, // [batch_size, n_kv_heads, kv_len, head_dim]
15+ float logits_soft_cap ) {
1616 const auto n_heads = query.size (1 );
1717 const auto n_kv_heads = key.size (1 );
1818 const auto head_dim = query.size (3 );
1919 assert (n_heads == n_kv_heads);
2020
2121 const float sm_scale = 1.0 / sqrt (head_dim);
2222 // query * key => [n_heads, q_seq_len, seq_len]
23- auto scores = torch::einsum (" bhqd,bhkd->bhqk" , {query, key});
23+ auto scores = torch::einsum (" bhqd,bhkd->bhqk" ,
24+ {query.to (torch::kFloat ), key.to (torch::kFloat )});
2425 // apply scale
2526 scores *= sm_scale;
2627
28+ // apply softcap if needed
29+ if (logits_soft_cap != 0.0 ) {
30+ scores = torch::tanh (scores / logits_soft_cap) * logits_soft_cap;
31+ }
32+
2733 // safe softmax
2834 scores = torch::softmax (scores, /* dim=*/ -1 );
2935
3036 // score * value => [batch_size, n_heads, q_seq_len, head_dim]
31- return torch::einsum (" bhqk,bhkd->bhqd" , {scores, value});
37+ return torch::einsum (" bhqk,bhkd->bhqd" , {scores, value.to (torch::kFloat )})
38+ .type_as (query);
3239}
3340
3441torch::Tensor attention_sm80 (
3542 torch::Tensor query, // [batch_size, n_heads, q_len, head_dim]
3643 torch::Tensor key, // [batch_size, n_kv_heads, kv_len, head_dim]
37- torch::Tensor value // [batch_size, n_kv_heads, kv_len, head_dim]
38- ) {
44+ torch::Tensor value, // [batch_size, n_kv_heads, kv_len, head_dim]
45+ float logits_soft_cap ) {
3946 const auto batch_size = query.size (0 );
47+ const auto n_heads = query.size (1 );
4048 const auto q_len = query.size (2 );
4149 const auto kv_len = key.size (2 );
4250 const auto head_dim = query.size (3 );
@@ -50,13 +58,13 @@ torch::Tensor attention_sm80(
5058 constexpr int32_t kBlockM = 64 ;
5159 constexpr int32_t kBlockN = 64 ;
5260
53- const float sm_scale = 1.0 / sqrt (head_dim) * M_LOG2E ;
61+ const float sm_scale = 1.0 / sqrt (head_dim);
5462
5563 using AttentionTraits =
5664 AttentionTraitsSM80<cute::half_t , kHeadDim , kBlockM , kBlockN >;
5765
5866 dim3 block = AttentionTraits::kThreadNum ;
59- dim3 grid ((q_len + kBlockM - 1 ) / kBlockM , batch_size * head_dim );
67+ dim3 grid ((q_len + kBlockM - 1 ) / kBlockM , batch_size * n_heads );
6068
6169 const auto smem_size = AttentionTraits::kSmemSize ;
6270 auto attention_kernel = mha_kernel_sm80<AttentionTraits>;
@@ -72,7 +80,8 @@ torch::Tensor attention_sm80(
7280 kv_h_stride,
7381 q_len,
7482 kv_len,
75- sm_scale);
83+ sm_scale,
84+ logits_soft_cap);
7685 C10_CUDA_KERNEL_LAUNCH_CHECK ();
7786 return out;
7887}
@@ -85,11 +94,23 @@ class AttentionKernelTest
8594 int64_t /* kv_len*/ ,
8695 int64_t /* n_heads*/ ,
8796 int64_t /* n_kv_heads*/ ,
88- int64_t /* head_dim*/ >> {};
97+ int64_t /* head_dim*/ ,
98+ float /* logits_soft_cap*/ >> {
99+ public:
100+ void SetUp () override {
101+ // Set random seed for test stability
102+ torch::manual_seed (0 );
103+ }
104+ };
89105
90106TEST_P (AttentionKernelTest, MHA) {
91- const auto [batch_size, q_len, kv_len, n_heads, n_kv_heads, head_dim] =
92- GetParam ();
107+ const auto [batch_size,
108+ q_len,
109+ kv_len,
110+ n_heads,
111+ n_kv_heads,
112+ head_dim,
113+ logits_soft_cap] = GetParam ();
93114
94115 const auto options = torch::dtype (torch::kHalf ).device (torch::kCUDA );
95116
@@ -100,21 +121,22 @@ TEST_P(AttentionKernelTest, MHA) {
100121 const auto value =
101122 torch::randn ({batch_size, n_kv_heads, kv_len, head_dim}, options);
102123
103- auto ref_out = attention_ref (query, key, value);
104- auto out = attention_sm80 (query, key, value);
124+ auto ref_out = attention_ref (query, key, value, logits_soft_cap );
125+ auto out = attention_sm80 (query, key, value, logits_soft_cap );
105126
106127 EXPECT_TRUE (torch::allclose (out, ref_out, /* rtol=*/ 1e-3 , /* atol=*/ 1e-3 ));
107128}
108129
109- INSTANTIATE_TEST_SUITE_P (MHA,
110- AttentionKernelTest,
111- ::testing::Combine (::testing::Values(1 ), // batch_size
112- ::testing::Values(64 ), // q_len
113- ::testing::Values(64 ,
114- 256 ), // kv_len
115- ::testing::Values(2 ), // n_heads
116- ::testing::Values(2 ), // n_kv_heads
117- ::testing::Values(64 ) // head_dim
118- ));
130+ INSTANTIATE_TEST_SUITE_P (
131+ MHA,
132+ AttentionKernelTest,
133+ ::testing::Combine (::testing::Values(1 , 2 , 4 ), // batch_size
134+ ::testing::Values(128 , 256 , 1024 ), // q_len
135+ ::testing::Values(128 , 256 , 1024 ), // kv_len
136+ ::testing::Values(16 ), // n_heads
137+ ::testing::Values(16 ), // n_kv_heads
138+ ::testing::Values(64 ), // head_dim
139+ ::testing::Values(0.0 , 50.0 ) // logits_soft_cap
140+ ));
119141
120142} // namespace llm
0 commit comments