@@ -476,18 +476,16 @@ def __init__(
476
476
vmem_buf , # [num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, head_dim]
477
477
sem ,
478
478
page_indices_ref , # i32[max_num_seqs, pages_per_seq]
479
- offset , # [seq_idx, kv_pages_start ]
479
+ metadata , # [seq_idx, start_page_idx, end_page_idx ]
480
480
):
481
481
self ._vmem_buf = vmem_buf
482
- seq_id , kv_pages_start = offset
483
- pages_per_seq = page_indices_ref .shape [1 ]
482
+ seq_id , start_page_idx , end_page_idx = metadata
484
483
self ._async_copies = []
485
484
# TODO(jevinjiang): Only fetch dynamic shape in need! This will insert
486
485
# a bunch of if-ops. Check the performance when we have benchmarking setup.
487
486
for i in range (vmem_buf .shape [0 ]):
488
- page_idx = kv_pages_start + i
489
- page_idx = jax .lax .select (page_idx < pages_per_seq , page_idx ,
490
- pages_per_seq - 1 )
487
+ page_idx = start_page_idx + i
488
+ page_idx = jax .lax .select (page_idx < end_page_idx , page_idx , 0 )
491
489
self ._async_copies .append (
492
490
pltpu .make_async_copy (
493
491
pages_hbm_ref .at [page_indices_ref [seq_id , page_idx ]],
@@ -719,6 +717,7 @@ def ragged_paged_attention_kernel(
719
717
if mask_value is None :
720
718
mask_value = DEFAULT_MASK_VALUE
721
719
num_q_per_blk , num_q_heads_per_blk , head_dim = q_ref .shape
720
+ pages_per_seq = page_indices_ref .shape [- 1 ]
722
721
num_seqs = num_seqs_ref [0 ]
723
722
_ , num_kv_pages_per_blk , page_size , num_combined_kv_heads_per_blk , _ = (
724
723
kv_bufs .shape )
@@ -737,7 +736,10 @@ def ragged_paged_attention_kernel(
737
736
738
737
def create_kv_async_copy_descriptors (heads_blk_idx , seq_idx , kv_blk_idx ,
739
738
buf_idx ):
740
- offset = (seq_idx , kv_blk_idx * num_kv_pages_per_blk )
739
+ start_kv_page_idx = kv_blk_idx * num_kv_pages_per_blk
740
+ end_kv_page_idx = jnp .minimum (pages_per_seq ,
741
+ cdiv (kv_lens_ref [seq_idx ], page_size ))
742
+ metadata = (seq_idx , start_kv_page_idx , end_kv_page_idx )
741
743
heads_start = heads_blk_idx * num_combined_kv_heads_per_blk
742
744
async_copy_kv = MultiPageAsyncCopyDescriptor (
743
745
kv_pages_hbm_ref .at [:, :,
@@ -746,7 +748,7 @@ def create_kv_async_copy_descriptors(heads_blk_idx, seq_idx, kv_blk_idx,
746
748
kv_bufs .at [buf_idx ],
747
749
sems .at [buf_idx ],
748
750
page_indices_ref ,
749
- offset ,
751
+ metadata ,
750
752
)
751
753
return async_copy_kv
752
754
@@ -841,19 +843,15 @@ def flash_attention(
841
843
num_q_per_blk * num_q_heads_per_kv_head ,
842
844
head_dim ,
843
845
)
844
- assert k .shape == (
846
+ assert ( k . shape == v .shape == (
845
847
num_kv_per_blk ,
846
848
head_dim ,
847
- ), f" { k . shape = } , { ( num_kv_per_blk , head_dim ) = } { k . dtype = } "
848
- assert v . shape == ( num_kv_per_blk , head_dim )
849
- assert head_m_ref .shape == (
849
+ ))
850
+ assert k . dtype == v . dtype
851
+ assert ( head_m_ref . shape == head_l_ref .shape == (
850
852
num_q_per_blk * num_q_heads_per_kv_head ,
851
853
128 ,
852
- )
853
- assert head_l_ref .shape == (
854
- num_q_per_blk * num_q_heads_per_kv_head ,
855
- 128 ,
856
- )
854
+ ))
857
855
assert head_acc_ref .shape == (
858
856
num_q_per_blk ,
859
857
num_q_heads_per_kv_head ,
@@ -867,6 +865,12 @@ def masked_store(ref, val, start, end, group=1):
867
865
pl .store (
868
866
ref , idx = tuple (slice (None ) for _ in ref .shape ), val = val , mask = mask )
869
867
868
+ # kv lens will be contracting dim, we should mask out the NaNs.
869
+ kv_mask = (
870
+ lax .broadcasted_iota (jnp .int32 , k .shape , 0 ) < kv_len - kv_len_start )
871
+ k = jnp .where (kv_mask , k .astype (jnp .float32 ), 0 ).astype (k .dtype )
872
+ v = jnp .where (kv_mask , v .astype (jnp .float32 ), 0 ).astype (v .dtype )
873
+
870
874
qk = (
871
875
jnp .einsum ("nd,md->nm" , q , k , preferred_element_type = jnp .float32 ) *
872
876
sm_scale )
@@ -1111,7 +1115,7 @@ def ragged_paged_attention(
1111
1115
1112
1116
Args:
1113
1117
q: concatenated all sequences' queries.
1114
- kv_pages: paged K cache. Normally in HBM.
1118
+ kv_pages: paged KV cache. Normally in HBM.
1115
1119
kv_lens: padded kv lengths. Only the first num_seqs values are valid.
1116
1120
page_indices: the first index indicates which page to use in the kv cache
1117
1121
for each sequence. Only the first num_seqs values are valid.
@@ -1185,7 +1189,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_):
1185
1189
)
1186
1190
in_specs = [
1187
1191
q_block_spec ,
1188
- pl .BlockSpec (memory_space = pltpu .TPUMemorySpace . ANY ),
1192
+ pl .BlockSpec (memory_space = pltpu .ANY ),
1189
1193
]
1190
1194
out_specs = q_block_spec
1191
1195
lm_scratch = pltpu .VMEM (
0 commit comments