Skip to content

Commit 1aaec81

Browse files
rdyroGoogle-ML-Automation
authored andcommitted
Add support for non-power-of-2 head size in flash attention
Introduce checks on sequences being divisible by block sizes to address #27224 PiperOrigin-RevId: 762051831
1 parent a827a27 commit 1aaec81

File tree

2 files changed

+104
-84
lines changed

2 files changed

+104
-84
lines changed

jax/experimental/pallas/ops/gpu/attention.py

Lines changed: 102 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -86,28 +86,29 @@ def mha_forward_kernel(
8686
segment_ids_ref: jax.Array | None, # segment_id arrays
8787
o_ref: Any, # Output
8888
*residual_refs: Any, # Residual outputs
89-
num_heads: int,
9089
sm_scale: float,
9190
causal: bool,
9291
block_q: int,
93-
block_d: int,
9492
block_k: int,
93+
head_dim: int,
9594
):
9695
seq_len = k_ref.shape[0]
9796
start_q = pl.program_id(0)
97+
head_dim_padded = q_ref.shape[-1]
9898

9999
# o is the buffer where we accumulate the output on sram.
100100
# m_i and l_i (see FlashAttention paper) are updated during the k,v loop.
101101
m_i = jnp.zeros(block_q, dtype=jnp.float32) - float('inf')
102102
l_i = jnp.zeros(block_q, dtype=jnp.float32)
103103
# acc is the buffer where we accumulate the output on sram.
104-
o = jnp.zeros((block_q, block_d), dtype=jnp.float32)
104+
o = jnp.zeros((block_q, head_dim_padded), dtype=jnp.float32)
105105

106106
# Load q: it will stay in L1 throughout. Indices form a matrix because we
107107
# read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index.
108-
# q tile has shape [block_q, block_d], block_d == head_dim.
108+
# q tile has shape [block_q, head_dim_padded], head_dim_padded >= head_dim.
109109
curr_q_slice = pl.dslice(start_q * block_q, block_q)
110-
q = q_ref[...]
110+
head_mask = (jnp.arange(head_dim_padded) < head_dim)[None, :]
111+
q = pl.load(q_ref, (slice(None), slice(None)), mask=head_mask, other=0.0)
111112
q_segment_ids = (
112113
None
113114
if segment_ids_ref is None
@@ -121,7 +122,7 @@ def body(start_k, carry):
121122
o_prev, m_prev, l_prev = carry
122123
curr_k_slice = pl.dslice(start_k * block_k, block_k)
123124

124-
k = pl.load(k_ref, (curr_k_slice, slice(None)))
125+
k = pl.load(k_ref, (curr_k_slice, slice(None)), mask=head_mask, other=0.0)
125126
qk = pl.dot(q, k.T) # [block_q, block_k]
126127

127128
# Scale logits to convert from base-2 to the natural log domain.
@@ -161,7 +162,7 @@ def body(start_k, carry):
161162
l_curr = s_curr.sum(axis=-1)
162163
l_next = l_prev_corr + l_curr
163164
o_prev_corr = correction[:, None] * o_prev
164-
v = pl.load(v_ref, (curr_k_slice, pl.dslice(block_d)))
165+
v = pl.load(v_ref, (curr_k_slice, slice(None)), mask=head_mask)
165166
o_curr = pl.dot(s_curr.astype(v.dtype), v)
166167

167168
o_next = o_prev_corr + o_curr
@@ -182,7 +183,8 @@ def body(start_k, carry):
182183
lse_ref = residual_refs[0]
183184
lse_ref[...] = m_i + jnp.log2(l_i)
184185
# Write output to dram.
185-
o_ref[...] = o.astype(o_ref.dtype)
186+
pl.store(o_ref, (slice(None), slice(o.shape[-1])), o.astype(o_ref.dtype),
187+
mask=head_mask)
186188

187189
def segment_mask(
188190
q_segment_ids: jax.Array,
@@ -235,6 +237,17 @@ def mha(
235237
kv_seq_len = k.shape[1]
236238
block_q = min(block_sizes.block_q, q_seq_len)
237239
block_k = min(block_sizes.block_k, kv_seq_len)
240+
head_dim_padded = pl.next_power_of_2(head_dim)
241+
if (q.shape[-1] != k.shape[-1]) or (q.shape[-1] != v.shape[-1]):
242+
raise ValueError(
243+
f"This kernel expects q, k, and v to have the same head dimension, but"
244+
f" found {q.shape=}, {k.shape=}, {v.shape=}."
245+
)
246+
if q_seq_len % block_q != 0:
247+
raise ValueError(f"{q_seq_len=} must be a multiple of {block_q=}")
248+
if kv_seq_len % block_k != 0:
249+
raise ValueError(f"{kv_seq_len=} must be a multiple of {block_k=}")
250+
238251
# Heuristics.
239252
grid_ = grid
240253
if grid_ is None:
@@ -243,21 +256,17 @@ def mha(
243256
num_warps_ = num_warps
244257
if num_warps_ is None:
245258
num_warps_ = 4 if head_dim <= 64 else 8
246-
kernel = functools.partial(mha_forward_kernel, num_heads=num_heads,
247-
sm_scale=sm_scale, block_q=block_q,
248-
block_k=block_k, block_d=head_dim,
249-
causal=causal)
259+
kernel = functools.partial(mha_forward_kernel, sm_scale=sm_scale,
260+
block_q=block_q, block_k=block_k,
261+
head_dim=head_dim, causal=causal)
250262

251263
in_specs = [
252-
pl.BlockSpec(
253-
(None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0)
254-
),
255-
pl.BlockSpec(
256-
(None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)
257-
),
258-
pl.BlockSpec(
259-
(None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)
260-
),
264+
pl.BlockSpec((None, block_q, None, head_dim_padded),
265+
lambda i, j, k: (j, i, k, 0)),
266+
pl.BlockSpec((None, kv_seq_len, None, head_dim_padded),
267+
lambda _, j, k: (j, 0, k, 0)),
268+
pl.BlockSpec((None, kv_seq_len, None, head_dim_padded),
269+
lambda _, j, k: (j, 0, k, 0)),
261270
]
262271
in_specs.append(
263272
None # type: ignore[arg-type]
@@ -270,7 +279,7 @@ def mha(
270279
grid=grid_,
271280
in_specs=in_specs,
272281
out_specs=pl.BlockSpec(
273-
(None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0)
282+
(None, block_q, None, head_dim_padded), lambda i, j, k: (j, i, k, 0)
274283
),
275284
compiler_params=plgpu.TritonCompilerParams(
276285
num_warps=num_warps_, num_stages=num_stages),
@@ -301,6 +310,17 @@ def _mha_forward(
301310
kv_seq_len = k.shape[1]
302311
block_q = min(block_sizes.block_q, q_seq_len)
303312
block_k = min(block_sizes.block_k, kv_seq_len)
313+
if (q.shape[-1] != k.shape[-1]) or (q.shape[-1] != v.shape[-1]):
314+
raise ValueError(
315+
f"This kernel expects q, k, and v to have the same head dimension, but"
316+
f" found {q.shape=}, {k.shape=}, {v.shape=}."
317+
)
318+
if q_seq_len % block_q != 0:
319+
raise ValueError(f"{q_seq_len=} must be a multiple of {block_q=}")
320+
if kv_seq_len % block_k != 0:
321+
raise ValueError(f"{kv_seq_len=} must be a multiple of {block_k=}")
322+
head_dim_padded = pl.next_power_of_2(head_dim)
323+
304324
# Heuristics.
305325
grid_ = grid
306326
if grid_ is None:
@@ -309,25 +329,22 @@ def _mha_forward(
309329
num_warps_ = num_warps
310330
if num_warps_ is None:
311331
num_warps_ = 4 if head_dim <= 64 else 8
312-
kernel = functools.partial(mha_forward_kernel, num_heads=num_heads,
313-
sm_scale=sm_scale, causal=causal, block_q=block_q,
314-
block_k=block_k, block_d=head_dim)
332+
kernel = functools.partial(mha_forward_kernel, sm_scale=sm_scale,
333+
causal=causal, block_q=block_q, block_k=block_k,
334+
head_dim=head_dim)
315335
out_shape = [
316336
jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out
317337
jax.ShapeDtypeStruct(
318338
shape=(batch_size, num_heads, q_seq_len), dtype=jnp.float32 # lse
319339
),
320340
]
321341
in_specs = [
322-
pl.BlockSpec(
323-
(None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0)
324-
),
325-
pl.BlockSpec(
326-
(None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)
327-
),
328-
pl.BlockSpec(
329-
(None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)
330-
),
342+
pl.BlockSpec((None, block_q, None, head_dim_padded),
343+
lambda i, j, k: (j, i, k, 0)),
344+
pl.BlockSpec((None, kv_seq_len, None, head_dim_padded),
345+
lambda _, j, k: (j, 0, k, 0)),
346+
pl.BlockSpec((None, kv_seq_len, None, head_dim_padded),
347+
lambda _, j, k: (j, 0, k, 0)),
331348
]
332349
in_specs.append(
333350
None # type: ignore[arg-type]
@@ -339,9 +356,8 @@ def _mha_forward(
339356
grid=grid_,
340357
in_specs=in_specs,
341358
out_specs=[
342-
pl.BlockSpec(
343-
(None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0)
344-
),
359+
pl.BlockSpec((None, block_q, None, head_dim_padded),
360+
lambda i, j, k: (j, i, k, 0)),
345361
pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i)),
346362
],
347363
compiler_params=plgpu.TritonCompilerParams(
@@ -355,10 +371,11 @@ def _mha_forward(
355371
return out, (q, k, v, segment_ids, out, lse)
356372

357373

358-
def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref):
374+
def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref, head_dim: int):
359375
# load
360-
o = out_ref[...].astype(jnp.float32)
361-
do = dout_ref[...].astype(jnp.float32)
376+
head_mask = (jnp.arange(out_ref.shape[-1]) < head_dim)[None, :]
377+
o = pl.load(out_ref, (slice(None), slice(None)), mask=head_mask, other=0.0)
378+
do = pl.load(dout_ref, (slice(None), slice(None)), mask=head_mask, other=0.0)
362379
# compute
363380
delta = jnp.sum(o * do, axis=1)
364381
# write-back
@@ -368,17 +385,16 @@ def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref):
368385
def _preprocess_backward(out, do, lse, block_q: int,
369386
debug: bool, interpret: bool):
370387
batch_size, seq_len, num_heads, head_dim = out.shape
388+
head_dim_padded = pl.next_power_of_2(head_dim)
371389
out_shape = jax.ShapeDtypeStruct(lse.shape, lse.dtype)
372390
delta = pl.pallas_call(
373-
_preprocess_backward_kernel,
391+
functools.partial(_preprocess_backward_kernel, head_dim=head_dim),
374392
grid=(pl.cdiv(seq_len, block_q), batch_size, num_heads),
375393
in_specs=[
376-
pl.BlockSpec(
377-
(None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0)
378-
),
379-
pl.BlockSpec(
380-
(None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0)
381-
),
394+
pl.BlockSpec((None, block_q, None, head_dim_padded),
395+
lambda i, j, k: (j, i, k, 0)),
396+
pl.BlockSpec((None, block_q, None, head_dim_padded),
397+
lambda i, j, k: (j, i, k, 0)),
382398
],
383399
out_specs=pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i)),
384400
compiler_params=plgpu.TritonCompilerParams(num_warps=4, num_stages=3),
@@ -414,7 +430,7 @@ def mha_backward_kernel(
414430
block_kv_dkv: int,
415431
block_q_dq: int,
416432
block_kv_dq: int,
417-
block_d: int,
433+
head_dim: int,
418434
):
419435
del out_ref # Not needed
420436
q_seq_len = q_ref.shape[0]
@@ -427,11 +443,13 @@ def mha_backward_kernel(
427443
start_k = pl.program_id(2)
428444
curr_k_slice = pl.dslice(start_k * block_kv_dkv, block_kv_dkv)
429445

430-
dv = jnp.zeros([block_kv_dkv, block_d], dtype=jnp.float32)
431-
dk = jnp.zeros([block_kv_dkv, block_d], dtype=jnp.float32)
446+
head_dim_padded = q_ref.shape[-1]
447+
dv = jnp.zeros([block_kv_dkv, head_dim_padded], dtype=jnp.float32)
448+
dk = jnp.zeros([block_kv_dkv, head_dim_padded], dtype=jnp.float32)
432449

433-
v = pl.load(v_ref, (curr_k_slice, slice(None)))
434-
k = pl.load(k_ref, (curr_k_slice, slice(None)))
450+
head_mask = (jnp.arange(head_dim_padded) < head_dim)[None, :]
451+
v = pl.load(v_ref, (curr_k_slice, slice(None)), mask=head_mask, other=0.0)
452+
k = pl.load(k_ref, (curr_k_slice, slice(None)), mask=head_mask, other=0.0)
435453
span_k = start_k * block_kv_dkv + jnp.arange(block_kv_dkv)
436454
kv_segment_ids = (
437455
None
@@ -443,7 +461,7 @@ def inner_loop_dkdv(start_q, carry):
443461
dv, dk = carry
444462
curr_q_slice = pl.dslice(start_q * block_q_dkv, block_q_dkv)
445463

446-
q = pl.load(q_ref, (curr_q_slice, slice(None)))
464+
q = pl.load(q_ref, (curr_q_slice, slice(None)), mask=head_mask, other=0.0)
447465
qk = pl.dot(q, k.T)
448466
qk_scale = math.log2(math.e)
449467
if sm_scale != 1.:
@@ -466,7 +484,8 @@ def inner_loop_dkdv(start_q, carry):
466484

467485
lse = pl.load(lse_ref, (curr_q_slice,))
468486
di = pl.load(delta_ref, (curr_q_slice,))
469-
do = pl.load(do_scaled_ref, (curr_q_slice, slice(None)))
487+
do = pl.load(do_scaled_ref, (curr_q_slice, slice(None)), mask=head_mask,
488+
other=0.0)
470489

471490
p = jnp.exp2(qk - lse[:, None])
472491
dv = dv + pl.dot(p.astype(do.dtype).T, do)
@@ -483,8 +502,10 @@ def inner_loop_dkdv(start_q, carry):
483502
dv, dk = lax.fori_loop(
484503
lower_bound, pl.cdiv(q_seq_len, block_q_dkv), inner_loop_dkdv, (dv, dk)
485504
)
486-
dv_ref[...] = dv.astype(dv_ref.dtype)
487-
dk_ref[...] = dk.astype(dk_ref.dtype)
505+
pl.store(dv_ref, (slice(None), slice(dv.shape[-1])), dv.astype(dv_ref.dtype),
506+
mask=head_mask)
507+
pl.store(dk_ref, (slice(None), slice(dk.shape[-1])), dk.astype(dk_ref.dtype),
508+
mask=head_mask)
488509

489510
# Scan #2: dQ
490511
# 1. Load a block of Q of size (block_q_dq, head_dim) in SMEM.
@@ -493,22 +514,23 @@ def inner_loop_dkdv(start_q, carry):
493514
start_q = pl.program_id(2)
494515
curr_q_slice = pl.ds(start_q * block_q_dq, block_q_dq)
495516
span_q = start_q * block_q_dq + jnp.arange(block_q_dq)
496-
dq = jnp.zeros([block_q_dq, block_d], dtype=jnp.float32)
517+
dq = jnp.zeros([block_q_dq, head_dim_padded], dtype=jnp.float32)
497518

498-
q = pl.load(q_ref, (curr_q_slice, slice(None)))
519+
q = pl.load(q_ref, (curr_q_slice, slice(None)), mask=head_mask, other=0.0)
499520
q_segment_ids = (
500521
None
501522
if segment_ids_ref is None
502523
else pl.load(segment_ids_ref, (curr_q_slice,))
503524
)
504525
lse = pl.load(lse_ref, (curr_q_slice,))
505-
do = pl.load(do_scaled_ref, (curr_q_slice, slice(None)))
526+
do = pl.load(do_scaled_ref, (curr_q_slice, slice(None)), mask=head_mask,
527+
other=0.0)
506528
di = pl.load(delta_ref, (curr_q_slice,))
507529

508530
def inner_loop_dq(start_k, dq):
509531
curr_k_slice = pl.dslice(start_k * block_kv_dq, block_kv_dq)
510-
k = pl.load(k_ref, (curr_k_slice, slice(None)))
511-
v = pl.load(v_ref, (curr_k_slice, slice(None)))
532+
k = pl.load(k_ref, (curr_k_slice, slice(None)), mask=head_mask, other=0.0)
533+
v = pl.load(v_ref, (curr_k_slice, slice(None)), mask=head_mask, other=0.0)
512534

513535
qk = pl.dot(q, k.T)
514536
qk_scale = math.log2(math.e)
@@ -547,7 +569,8 @@ def inner_loop_dq(start_k, dq):
547569
upper_bound = pl.cdiv(kv_seq_len, block_kv_dq)
548570

549571
dq = lax.fori_loop(0, upper_bound, inner_loop_dq, (dq))
550-
dq_ref[...] = dq.astype(dq_ref.dtype)
572+
pl.store(dq_ref, (slice(None), slice(dq.shape[-1])), dq.astype(dq_ref.dtype),
573+
mask=head_mask)
551574

552575

553576
def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes,
@@ -576,6 +599,7 @@ def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes,
576599
block_kv_dkv = min(block_sizes.block_kv_dkv, kv_seq_len)
577600
block_q_dq = min(block_sizes.block_q_dq, q_seq_len)
578601
block_kv_dq = min(block_sizes.block_kv_dq, kv_seq_len)
602+
head_dim_padded = pl.next_power_of_2(head_dim)
579603

580604
if q_seq_len // block_q_dq != kv_seq_len // block_kv_dkv:
581605
raise ValueError(
@@ -591,28 +615,24 @@ def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes,
591615
]
592616

593617
in_specs = [
594-
pl.BlockSpec(
595-
(None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)
596-
),
597-
pl.BlockSpec(
598-
(None, kv_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)
599-
),
600-
pl.BlockSpec(
601-
(None, kv_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)
602-
),
603-
pl.BlockSpec(
604-
(None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)
605-
),
606-
pl.BlockSpec(
607-
(None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)
608-
),
618+
pl.BlockSpec((None, q_seq_len, None, head_dim_padded),
619+
lambda i, j, _: (i, 0, j, 0)),
620+
pl.BlockSpec((None, kv_seq_len, None, head_dim_padded),
621+
lambda i, j, _: (i, 0, j, 0)),
622+
pl.BlockSpec((None, kv_seq_len, None, head_dim_padded),
623+
lambda i, j, _: (i, 0, j, 0)),
624+
pl.BlockSpec((None, q_seq_len, None, head_dim_padded),
625+
lambda i, j, _: (i, 0, j, 0)),
626+
pl.BlockSpec((None, q_seq_len, None, head_dim_padded),
627+
lambda i, j, _: (i, 0, j, 0)),
609628
pl.BlockSpec((None, None, q_seq_len), lambda i, j, _: (i, j, 0)),
610629
pl.BlockSpec((None, None, q_seq_len), lambda i, j, _: (i, j, 0)),
611630
]
612631
if segment_ids is None:
613632
in_specs.insert(3, None) # type: ignore[arg-type]
614633
else:
615-
in_specs.insert(3, pl.BlockSpec((None, kv_seq_len), lambda i, j, _: (i, 0)))
634+
in_specs.insert(3, pl.BlockSpec((None, kv_seq_len),
635+
lambda i, j, _: (i, 0)))
616636

617637
grid = (batch_size, num_heads, pl.cdiv(kv_seq_len, block_kv_dkv))
618638
num_warps_ = num_warps
@@ -635,22 +655,22 @@ def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes,
635655
block_kv_dkv=block_kv_dkv,
636656
block_q_dq=block_q_dq,
637657
block_kv_dq=block_kv_dq,
638-
block_d=head_dim,
658+
head_dim=head_dim,
639659
),
640660
out_shape=out_shapes,
641661
in_specs=in_specs,
642662
grid=grid,
643663
out_specs=[
644664
pl.BlockSpec(
645-
(None, block_q_dq, None, head_dim),
665+
(None, block_q_dq, None, head_dim_padded),
646666
lambda i, j, k: (i, k, j, 0), # dq
647667
),
648668
pl.BlockSpec(
649-
(None, block_kv_dkv, None, head_dim),
669+
(None, block_kv_dkv, None, head_dim_padded),
650670
lambda i, j, k: (i, k, j, 0), # dk
651671
),
652672
pl.BlockSpec(
653-
(None, block_kv_dkv, None, head_dim),
673+
(None, block_kv_dkv, None, head_dim_padded),
654674
lambda i, j, k: (i, k, j, 0), # dv
655675
),
656676
],

0 commit comments

Comments
 (0)