Skip to content

Commit 33b23b9

Browse files
rdyroGoogle-ML-Automation
authored andcommitted
Simplify attention VJP definition
PiperOrigin-RevId: 762181545
1 parent 57d07e1 commit 33b23b9

File tree

2 files changed

+54
-73
lines changed

2 files changed

+54
-73
lines changed

jax/experimental/pallas/ops/gpu/attention.py

Lines changed: 30 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def body(start_k, carry):
152152
# Apply mask to qk.
153153
qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE)
154154

155-
m_curr = qk.max(axis=-1)
155+
m_curr = jnp.max(qk, axis=-1)
156156
m_next = jnp.maximum(m_prev, m_curr)
157157
correction = jnp.exp2(m_prev - m_next)
158158
l_prev_corr = correction * l_prev
@@ -201,7 +201,7 @@ def segment_mask(
201201

202202

203203
@functools.partial(
204-
jax.custom_vjp, nondiff_argnums=[4, 5, 6, 7, 8, 9, 10, 11, 12]
204+
jax.custom_vjp, nondiff_argnums=[4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
205205
)
206206
@functools.partial(
207207
jax.jit,
@@ -215,6 +215,7 @@ def segment_mask(
215215
"grid",
216216
"interpret",
217217
"debug",
218+
"return_residuals",
218219
],
219220
)
220221
def mha(
@@ -231,6 +232,7 @@ def mha(
231232
grid: tuple[int, ...] | None = None,
232233
interpret: bool = False,
233234
debug: bool = False,
235+
return_residuals: bool = False,
234236
):
235237
del backward_pass_impl
236238
batch_size, q_seq_len, num_heads, head_dim = q.shape
@@ -273,21 +275,27 @@ def mha(
273275
if segment_ids is None
274276
else pl.BlockSpec((None, kv_seq_len), lambda _, j, k: (j, 0))
275277
)
276-
out_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype)
277-
return pl.pallas_call(
278+
out_shape = [q]
279+
out_specs = [pl.BlockSpec((None, block_q, None, head_dim_padded),
280+
lambda i, j, k: (j, i, k, 0))]
281+
if return_residuals:
282+
out_shape.append(jax.ShapeDtypeStruct(
283+
shape=(batch_size, num_heads, q_seq_len), dtype=jnp.float32)) # lse
284+
out_specs.append(
285+
pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i))) # lse
286+
out = pl.pallas_call(
278287
kernel,
279288
grid=grid_,
280289
in_specs=in_specs,
281-
out_specs=pl.BlockSpec(
282-
(None, block_q, None, head_dim_padded), lambda i, j, k: (j, i, k, 0)
283-
),
290+
out_specs=out_specs,
284291
compiler_params=plgpu.TritonCompilerParams(
285292
num_warps=num_warps_, num_stages=num_stages),
286293
out_shape=out_shape,
287294
debug=debug,
288295
interpret=interpret,
289296
name="mha_forward",
290297
)(q, k, v, segment_ids)
298+
return out if return_residuals else out[0]
291299

292300

293301
def _mha_forward(
@@ -304,71 +312,17 @@ def _mha_forward(
304312
grid: Any,
305313
interpret: bool,
306314
debug: bool,
315+
return_residuals: bool,
307316
):
308-
del backward_pass_impl
309-
batch_size, q_seq_len, num_heads, head_dim = q.shape
310-
kv_seq_len = k.shape[1]
311-
block_q = min(block_sizes.block_q, q_seq_len)
312-
block_k = min(block_sizes.block_k, kv_seq_len)
313-
if (q.shape[-1] != k.shape[-1]) or (q.shape[-1] != v.shape[-1]):
314-
raise ValueError(
315-
f"This kernel expects q, k, and v to have the same head dimension, but"
316-
f" found {q.shape=}, {k.shape=}, {v.shape=}."
317-
)
318-
if q_seq_len % block_q != 0:
319-
raise ValueError(f"{q_seq_len=} must be a multiple of {block_q=}")
320-
if kv_seq_len % block_k != 0:
321-
raise ValueError(f"{kv_seq_len=} must be a multiple of {block_k=}")
322-
head_dim_padded = pl.next_power_of_2(head_dim)
323-
324-
# Heuristics.
325-
grid_ = grid
326-
if grid_ is None:
327-
grid_ = (pl.cdiv(q_seq_len, block_q), batch_size, num_heads)
328-
329-
num_warps_ = num_warps
330-
if num_warps_ is None:
331-
num_warps_ = 4 if head_dim <= 64 else 8
332-
kernel = functools.partial(mha_forward_kernel, sm_scale=sm_scale,
333-
causal=causal, block_q=block_q, block_k=block_k,
334-
head_dim=head_dim)
335-
out_shape = [
336-
jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out
337-
jax.ShapeDtypeStruct(
338-
shape=(batch_size, num_heads, q_seq_len), dtype=jnp.float32 # lse
339-
),
340-
]
341-
in_specs = [
342-
pl.BlockSpec((None, block_q, None, head_dim_padded),
343-
lambda i, j, k: (j, i, k, 0)),
344-
pl.BlockSpec((None, kv_seq_len, None, head_dim_padded),
345-
lambda _, j, k: (j, 0, k, 0)),
346-
pl.BlockSpec((None, kv_seq_len, None, head_dim_padded),
347-
lambda _, j, k: (j, 0, k, 0)),
348-
]
349-
in_specs.append(
350-
None # type: ignore[arg-type]
351-
if segment_ids is None
352-
else pl.BlockSpec((None, kv_seq_len), lambda _, j, k: (j, 0))
353-
)
354-
out, lse = pl.pallas_call(
355-
kernel,
356-
grid=grid_,
357-
in_specs=in_specs,
358-
out_specs=[
359-
pl.BlockSpec((None, block_q, None, head_dim_padded),
360-
lambda i, j, k: (j, i, k, 0)),
361-
pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i)),
362-
],
363-
compiler_params=plgpu.TritonCompilerParams(
364-
num_warps=num_warps_, num_stages=num_stages
365-
),
366-
out_shape=out_shape,
367-
debug=debug,
368-
interpret=interpret,
369-
name="mha_forward",
370-
)(q, k, v, segment_ids)
371-
return out, (q, k, v, segment_ids, out, lse)
317+
out, lse = mha(q, k, v, segment_ids=segment_ids, sm_scale=sm_scale,
318+
causal=causal, block_sizes=block_sizes,
319+
backward_pass_impl=backward_pass_impl,
320+
num_warps=num_warps, num_stages=num_stages,
321+
grid=grid, interpret=interpret, debug=debug,
322+
return_residuals=True)
323+
residuals = (q, k, v, segment_ids, out, lse)
324+
ret = (out, lse) if return_residuals else out
325+
return ret, residuals
372326

373327

374328
def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref, head_dim: int):
@@ -576,9 +530,12 @@ def inner_loop_dq(start_k, dq):
576530
def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes,
577531
backward_pass_impl: str, num_warps: int | None,
578532
num_stages: int, grid: Any, interpret: bool,
579-
debug: bool, res, do):
580-
del num_stages, grid
533+
debug: bool, return_residuals: bool, res, do):
534+
if return_residuals:
535+
raise ValueError(
536+
"Kernel differentiation is not supported if return_residuals is True.")
581537
q, k, v, segment_ids, out, lse = res
538+
del num_stages, grid, return_residuals
582539

583540
if backward_pass_impl == "xla":
584541
return jax.vjp(

tests/pallas/gpu_ops_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,30 @@ def f_ref(q, k, v):
313313
self.assertAllClose(dk, dk_ref, atol=5e-2)
314314
self.assertAllClose(dv, dv_ref, atol=5e-2)
315315

316+
def test_return_residuals_not_differentiable(self):
317+
batch_size, seq_len, num_heads, head_dim = 2, 128, 2, 128
318+
causal = False
319+
k1, k2, k3 = random.split(random.key(0), 3)
320+
q = random.normal(
321+
k1, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16
322+
)
323+
k = random.normal(
324+
k2, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16
325+
)
326+
v = random.normal(
327+
k3, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16
328+
)
329+
segment_ids = None
330+
331+
def f(q, k, v):
332+
return attention.mha(q, k, v, causal=causal, segment_ids=segment_ids,
333+
interpret=self.INTERPRET,
334+
return_residuals=True)[0].sum()
335+
336+
with self.assertRaisesRegex(ValueError, "Kernel differentiation is not"
337+
" supported if return_residuals is True."):
338+
_ = jax.grad(f, argnums=(0, 1, 2))(q, k, v)
339+
316340

317341
class FusedAttentionInterpretTest(FusedAttentionTest):
318342
INTERPRET = True

0 commit comments

Comments
 (0)