Skip to content

Commit 82e7ee8

Browse files
bythew3iroot
and
root
authored
[ragged-paged-attn] Apply kv mask to filter out NaNs (#9219)
Co-authored-by: root <root@t1v-n-573bae2d-w-0.us-central2-b.c.tpu-prod-env-one-vm.internal>
1 parent e28174c commit 82e7ee8

File tree

1 file changed

+23
-19
lines changed

1 file changed

+23
-19
lines changed

torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -476,18 +476,16 @@ def __init__(
476476
vmem_buf, # [num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, head_dim]
477477
sem,
478478
page_indices_ref, # i32[max_num_seqs, pages_per_seq]
479-
offset, # [seq_idx, kv_pages_start]
479+
metadata, # [seq_idx, start_page_idx, end_page_idx]
480480
):
481481
self._vmem_buf = vmem_buf
482-
seq_id, kv_pages_start = offset
483-
pages_per_seq = page_indices_ref.shape[1]
482+
seq_id, start_page_idx, end_page_idx = metadata
484483
self._async_copies = []
485484
# TODO(jevinjiang): Only fetch dynamic shape in need! This will insert
486485
# a bunch of if-ops. Check the performance when we have benchmarking setup.
487486
for i in range(vmem_buf.shape[0]):
488-
page_idx = kv_pages_start + i
489-
page_idx = jax.lax.select(page_idx < pages_per_seq, page_idx,
490-
pages_per_seq - 1)
487+
page_idx = start_page_idx + i
488+
page_idx = jax.lax.select(page_idx < end_page_idx, page_idx, 0)
491489
self._async_copies.append(
492490
pltpu.make_async_copy(
493491
pages_hbm_ref.at[page_indices_ref[seq_id, page_idx]],
@@ -719,6 +717,7 @@ def ragged_paged_attention_kernel(
719717
if mask_value is None:
720718
mask_value = DEFAULT_MASK_VALUE
721719
num_q_per_blk, num_q_heads_per_blk, head_dim = q_ref.shape
720+
pages_per_seq = page_indices_ref.shape[-1]
722721
num_seqs = num_seqs_ref[0]
723722
_, num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, _ = (
724723
kv_bufs.shape)
@@ -737,7 +736,10 @@ def ragged_paged_attention_kernel(
737736

738737
def create_kv_async_copy_descriptors(heads_blk_idx, seq_idx, kv_blk_idx,
739738
buf_idx):
740-
offset = (seq_idx, kv_blk_idx * num_kv_pages_per_blk)
739+
start_kv_page_idx = kv_blk_idx * num_kv_pages_per_blk
740+
end_kv_page_idx = jnp.minimum(pages_per_seq,
741+
cdiv(kv_lens_ref[seq_idx], page_size))
742+
metadata = (seq_idx, start_kv_page_idx, end_kv_page_idx)
741743
heads_start = heads_blk_idx * num_combined_kv_heads_per_blk
742744
async_copy_kv = MultiPageAsyncCopyDescriptor(
743745
kv_pages_hbm_ref.at[:, :,
@@ -746,7 +748,7 @@ def create_kv_async_copy_descriptors(heads_blk_idx, seq_idx, kv_blk_idx,
746748
kv_bufs.at[buf_idx],
747749
sems.at[buf_idx],
748750
page_indices_ref,
749-
offset,
751+
metadata,
750752
)
751753
return async_copy_kv
752754

@@ -841,19 +843,15 @@ def flash_attention(
841843
num_q_per_blk * num_q_heads_per_kv_head,
842844
head_dim,
843845
)
844-
assert k.shape == (
846+
assert (k.shape == v.shape == (
845847
num_kv_per_blk,
846848
head_dim,
847-
), f"{k.shape=}, {(num_kv_per_blk, head_dim)=} {k.dtype=}"
848-
assert v.shape == (num_kv_per_blk, head_dim)
849-
assert head_m_ref.shape == (
849+
))
850+
assert k.dtype == v.dtype
851+
assert (head_m_ref.shape == head_l_ref.shape == (
850852
num_q_per_blk * num_q_heads_per_kv_head,
851853
128,
852-
)
853-
assert head_l_ref.shape == (
854-
num_q_per_blk * num_q_heads_per_kv_head,
855-
128,
856-
)
854+
))
857855
assert head_acc_ref.shape == (
858856
num_q_per_blk,
859857
num_q_heads_per_kv_head,
@@ -867,6 +865,12 @@ def masked_store(ref, val, start, end, group=1):
867865
pl.store(
868866
ref, idx=tuple(slice(None) for _ in ref.shape), val=val, mask=mask)
869867

868+
# kv lens will be contracting dim, we should mask out the NaNs.
869+
kv_mask = (
870+
lax.broadcasted_iota(jnp.int32, k.shape, 0) < kv_len - kv_len_start)
871+
k = jnp.where(kv_mask, k.astype(jnp.float32), 0).astype(k.dtype)
872+
v = jnp.where(kv_mask, v.astype(jnp.float32), 0).astype(v.dtype)
873+
870874
qk = (
871875
jnp.einsum("nd,md->nm", q, k, preferred_element_type=jnp.float32) *
872876
sm_scale)
@@ -1111,7 +1115,7 @@ def ragged_paged_attention(
11111115
11121116
Args:
11131117
q: concatenated all sequences' queries.
1114-
kv_pages: paged K cache. Normally in HBM.
1118+
kv_pages: paged KV cache. Normally in HBM.
11151119
kv_lens: padded kv lengths. Only the first num_seqs values are valid.
11161120
page_indices: the first index indicates which page to use in the kv cache
11171121
for each sequence. Only the first num_seqs values are valid.
@@ -1185,7 +1189,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_):
11851189
)
11861190
in_specs = [
11871191
q_block_spec,
1188-
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
1192+
pl.BlockSpec(memory_space=pltpu.ANY),
11891193
]
11901194
out_specs = q_block_spec
11911195
lm_scratch = pltpu.VMEM(

0 commit comments

Comments
 (0)