Skip to content

Commit 0e690c1

Browse files
jburnimGoogle-ML-Automation
authored andcommitted
Use additional semaphores to avoid data races in TPU paged_attention_kernel.
Also prevents an out-of-bounds read of SMEM. And re-enables tests for the TPU paged_attention_kernel. @apaszke confirmed the presence of data races using the race detector in the new TPU interpret mode. With the additional semaphores, the race detector no longer detects any races in the this kernel and I no longer see any test failures in 20+ test runs on a TPU. Details on the data races: - In each iteration, the kernel: (a) Starts copying data for `k` and `v` for the next iteration. (b) Waits for the copy of `k` for the current iteration to finish. (c) Waits for the copy of `v` for the current iteration to finish. - It is possible for these copies to happen out of order -- that is: (a) The copies for the next iteration can finish before the copies for the current iteration. (b) And the copies for `v` for the current iteration can finish before the copies for `k` for the current iteration. - If the same DMA semaphore is used for everything, then out-of-order copies can lead to: (a) `k = async_copy_k.wait_and_get_loaded()` returns but the data isn't all available because the underlying semaphore was signaled by the completion of copies of `v` for the current iteration or copies of `k` or `v` for the next iteration. (a) `v = async_copy_v.wait_and_get_loaded()` returns but the data isn't all available because the underlying semaphore was signaled by the completion of copies of `k` or `v` for the next iteration. PiperOrigin-RevId: 762136079
1 parent fc68336 commit 0e690c1

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ def paged_flash_attention_kernel(
127127
k_scales_vmem_buffer,
128128
v_vmem_buffer,
129129
v_scales_vmem_buffer,
130-
sem,
130+
k_sems,
131+
v_sems,
131132
*,
132133
batch_size: int,
133134
pages_per_compute_block: int,
@@ -176,7 +177,9 @@ def advance_to_next_non_zero_length():
176177

177178
return (
178179
lax.cond(
179-
jnp.logical_and(next_b < batch_size, lengths_ref[next_b] == 0),
180+
jnp.logical_and(
181+
next_b < batch_size,
182+
lengths_ref[lax.clamp(0, next_b, batch_size - 1)] == 0),
180183
advance_to_next_non_zero_length,
181184
lambda: next_b,
182185
),
@@ -200,7 +203,7 @@ def create_kv_async_copy_descriptors(b, h, i, buffer_index):
200203
k_scales_vmem_buffer.at[buffer_index]
201204
if k_scales_vmem_buffer is not None
202205
else None,
203-
sem,
206+
k_sems.at[buffer_index],
204207
page_indices_ref,
205208
page_offset,
206209
pages_to_load,
@@ -213,7 +216,7 @@ def create_kv_async_copy_descriptors(b, h, i, buffer_index):
213216
v_scales_vmem_buffer.at[buffer_index]
214217
if v_scales_vmem_buffer is not None
215218
else None,
216-
sem,
219+
v_sems.at[buffer_index],
217220
page_indices_ref,
218221
page_offset,
219222
pages_to_load,
@@ -301,7 +304,8 @@ def paged_flash_attention_kernel_inline_seq_dim(
301304
k_scales_vmem_buffer,
302305
v_vmem_buffer,
303306
v_scales_vmem_buffer,
304-
sem,
307+
k_sems,
308+
v_sems,
305309
*,
306310
batch_size: int,
307311
pages_per_compute_block: int,
@@ -336,7 +340,8 @@ def body(i, _):
336340
k_scales_vmem_buffer,
337341
v_vmem_buffer,
338342
v_scales_vmem_buffer,
339-
sem,
343+
k_sems,
344+
v_sems,
340345
batch_size=batch_size,
341346
pages_per_compute_block=pages_per_compute_block,
342347
pages_per_sequence=pages_per_sequence,
@@ -584,7 +589,8 @@ def paged_attention(
584589
),
585590
v_scales_pages.dtype, # pytype: disable=attribute-error
586591
), # v_scales_pages buffer
587-
pltpu.SemaphoreType.DMA,
592+
pltpu.SemaphoreType.DMA((2,)),
593+
pltpu.SemaphoreType.DMA((2,)),
588594
)
589595
else:
590596
in_specs = [
@@ -615,7 +621,8 @@ def paged_attention(
615621
v_pages.dtype,
616622
), # v_pages buffer
617623
None,
618-
pltpu.SemaphoreType.DMA,
624+
pltpu.SemaphoreType.DMA((2,)),
625+
pltpu.SemaphoreType.DMA((2,)),
619626
)
620627

621628
out, _, _ = pl.pallas_call(

tests/pallas/tpu_paged_attention_kernel_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,6 @@ def test_paged_attention(
265265
attn_logits_soft_cap,
266266
are_kv_quantized,
267267
):
268-
# TODO(mvoz, skyewm): Re-enable this test once the data race is fixed.
269-
self.skipTest("This kernel has data races that need to be fixed.")
270268
if not jtu.is_device_tpu_at_least(4):
271269
self.skipTest("Only supports TPU generation 4 or above")
272270
if jtu.is_device_tpu(version=4) and are_kv_quantized:

0 commit comments

Comments
 (0)