Skip to content

Commit f7dfc16

Browse files
authored
fix lightllm vit triton (#949)
Co-authored-by: baishihao <[email protected]>
1 parent b887d8d commit f7dfc16

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

lightllm/models/vit/triton_kernel/flashattention_nopad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def _fwd_kernel(
4848
q = tl.load(Q + off_q, mask=offs_m[:, None] < seq_len, other=0.0)
4949
# initialize pointer to m and l
5050
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
51-
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
51+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
5252
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
5353

5454
for start_n in range(0, seq_len, BLOCK_N):
@@ -68,7 +68,7 @@ def _fwd_kernel(
6868
qk += tl.where((start_n + offs_n[None, :]) < seq_len, 0, float("-inf"))
6969

7070
# -- compute m_ij, p, l_ij
71-
m_ij = tl.maximum(tl.max(qk, 1), l_i)
71+
m_ij = tl.maximum(tl.max(qk, 1), m_i)
7272
p = tl.exp(qk - m_ij[:, None])
7373
l_ij = tl.sum(p, 1)
7474

0 commit comments

Comments
 (0)