Skip to content

Commit c54d4c8

Browse files
committed
use einsum
1 parent 9391866 commit c54d4c8

File tree

3 files changed

+5
-7
lines changed

3 files changed

+5
-7
lines changed

src/kernels/attention/attention_cpu.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ inline void mha(torch::Tensor query,
9191
}
9292
// apply causal mask
9393
if (kv_idx_base + j > q_idx_base + q_idx) {
94-
s(j) = -INFINITY;
94+
s(j) = -5e4;
9595
}
9696
max = std::max(max, s(j));
9797
}

src/kernels/attention/attention_cpu_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ torch::Tensor masked_self_attention(
3535
torch::Tensor mask = torch::ones({1, q_seq_len, seq_len}, torch::kBool);
3636
// returns the lower triangular part of a matrix
3737
mask = torch::tril(mask, /*diagonal=*/seq_len - q_seq_len).to(query);
38-
scores = scores.masked_fill(mask == 0, -INFINITY);
38+
scores = scores.masked_fill(mask == 0, -5e4);
3939

4040
// safe softmax
4141
scores = torch::softmax(scores, /*dim=*/-1);

src/kernels/attention/attention_kernel_sm80_test.cu

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,15 @@ torch::Tensor attention_ref(
2121

2222
const float sm_scale = 1.0 / sqrt(head_dim);
2323
// query * key => [n_heads, q_seq_len, seq_len]
24-
// auto scores = torch::einsum("bhqd,bhkd->bhqk", {query, key});
25-
auto scores = torch::matmul(query, key.transpose(-2, -1));
24+
auto scores = torch::einsum("bhqd,bhkd->bhqk", {query, key});
2625
// apply scale
2726
scores *= sm_scale;
2827

2928
// safe softmax
3029
scores = torch::softmax(scores, /*dim=*/-1);
3130

32-
// score * value => [batch_size, q_seq_len, n_heads, head_dim]
33-
// return torch::einsum("bhqk,bkhd->bhqd", {scores, value});
34-
return torch::matmul(scores, value);
31+
// score * value => [batch_size, n_heads, q_seq_len, head_dim]
32+
return torch::einsum("bhqk,bhkd->bhqd", {scores, value});
3533
}
3634

3735
torch::Tensor attention_sm80(

0 commit comments

Comments
 (0)