You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Use additional semaphores to avoid data races in TPU paged_attention_kernel.
Also prevents an out-of-bounds read of SMEM. And re-enables tests for the TPU paged_attention_kernel.
@apaszke confirmed the presence of data races using the race detector in the new TPU interpret mode. With the additional semaphores, the race detector no longer detects any races in the this kernel and I no longer see any test failures in 20+ test runs on a TPU.
Details on the data races:
- In each iteration, the kernel:
(a) Starts copying data for `k` and `v` for the next iteration.
(b) Waits for the copy of `k` for the current iteration to finish.
(c) Waits for the copy of `v` for the current iteration to finish.
- It is possible for these copies to happen out of order -- that is:
(a) The copies for the next iteration can finish before the copies
for the current iteration.
(b) And the copies for `v` for the current iteration can finish
before the copies for `k` for the current iteration.
- If the same DMA semaphore is used for everything, then out-of-order
copies can lead to:
(a) `k = async_copy_k.wait_and_get_loaded()` returns but the data
isn't all available because the underlying semaphore was
signaled by the completion of copies of `v` for the current
iteration or copies of `k` or `v` for the next iteration.
(a) `v = async_copy_v.wait_and_get_loaded()` returns but the data
isn't all available because the underlying semaphore was
signaled by the completion of copies of `k` or `v` for the
next iteration.
PiperOrigin-RevId: 762136079
0 commit comments