Skip to content

Simplify attention VJP definition #28960

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 30 additions & 73 deletions jax/experimental/pallas/ops/gpu/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def body(start_k, carry):
# Apply mask to qk.
qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE)

m_curr = qk.max(axis=-1)
m_curr = jnp.max(qk, axis=-1)
m_next = jnp.maximum(m_prev, m_curr)
correction = jnp.exp2(m_prev - m_next)
l_prev_corr = correction * l_prev
Expand Down Expand Up @@ -201,7 +201,7 @@ def segment_mask(


@functools.partial(
jax.custom_vjp, nondiff_argnums=[4, 5, 6, 7, 8, 9, 10, 11, 12]
jax.custom_vjp, nondiff_argnums=[4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
)
@functools.partial(
jax.jit,
Expand All @@ -215,6 +215,7 @@ def segment_mask(
"grid",
"interpret",
"debug",
"return_residuals",
],
)
def mha(
Expand All @@ -231,6 +232,7 @@ def mha(
grid: tuple[int, ...] | None = None,
interpret: bool = False,
debug: bool = False,
return_residuals: bool = False,
):
del backward_pass_impl
batch_size, q_seq_len, num_heads, head_dim = q.shape
Expand Down Expand Up @@ -273,21 +275,27 @@ def mha(
if segment_ids is None
else pl.BlockSpec((None, kv_seq_len), lambda _, j, k: (j, 0))
)
out_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype)
return pl.pallas_call(
out_shape = [q]
out_specs = [pl.BlockSpec((None, block_q, None, head_dim_padded),
lambda i, j, k: (j, i, k, 0))]
if return_residuals:
out_shape.append(jax.ShapeDtypeStruct(
shape=(batch_size, num_heads, q_seq_len), dtype=jnp.float32)) # lse
out_specs.append(
pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i))) # lse
out = pl.pallas_call(
kernel,
grid=grid_,
in_specs=in_specs,
out_specs=pl.BlockSpec(
(None, block_q, None, head_dim_padded), lambda i, j, k: (j, i, k, 0)
),
out_specs=out_specs,
compiler_params=plgpu.TritonCompilerParams(
num_warps=num_warps_, num_stages=num_stages),
out_shape=out_shape,
debug=debug,
interpret=interpret,
name="mha_forward",
)(q, k, v, segment_ids)
return out if return_residuals else out[0]


def _mha_forward(
Expand All @@ -304,71 +312,17 @@ def _mha_forward(
grid: Any,
interpret: bool,
debug: bool,
return_residuals: bool,
):
del backward_pass_impl
batch_size, q_seq_len, num_heads, head_dim = q.shape
kv_seq_len = k.shape[1]
block_q = min(block_sizes.block_q, q_seq_len)
block_k = min(block_sizes.block_k, kv_seq_len)
if (q.shape[-1] != k.shape[-1]) or (q.shape[-1] != v.shape[-1]):
raise ValueError(
f"This kernel expects q, k, and v to have the same head dimension, but"
f" found {q.shape=}, {k.shape=}, {v.shape=}."
)
if q_seq_len % block_q != 0:
raise ValueError(f"{q_seq_len=} must be a multiple of {block_q=}")
if kv_seq_len % block_k != 0:
raise ValueError(f"{kv_seq_len=} must be a multiple of {block_k=}")
head_dim_padded = pl.next_power_of_2(head_dim)

# Heuristics.
grid_ = grid
if grid_ is None:
grid_ = (pl.cdiv(q_seq_len, block_q), batch_size, num_heads)

num_warps_ = num_warps
if num_warps_ is None:
num_warps_ = 4 if head_dim <= 64 else 8
kernel = functools.partial(mha_forward_kernel, sm_scale=sm_scale,
causal=causal, block_q=block_q, block_k=block_k,
head_dim=head_dim)
out_shape = [
jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out
jax.ShapeDtypeStruct(
shape=(batch_size, num_heads, q_seq_len), dtype=jnp.float32 # lse
),
]
in_specs = [
pl.BlockSpec((None, block_q, None, head_dim_padded),
lambda i, j, k: (j, i, k, 0)),
pl.BlockSpec((None, kv_seq_len, None, head_dim_padded),
lambda _, j, k: (j, 0, k, 0)),
pl.BlockSpec((None, kv_seq_len, None, head_dim_padded),
lambda _, j, k: (j, 0, k, 0)),
]
in_specs.append(
None # type: ignore[arg-type]
if segment_ids is None
else pl.BlockSpec((None, kv_seq_len), lambda _, j, k: (j, 0))
)
out, lse = pl.pallas_call(
kernel,
grid=grid_,
in_specs=in_specs,
out_specs=[
pl.BlockSpec((None, block_q, None, head_dim_padded),
lambda i, j, k: (j, i, k, 0)),
pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i)),
],
compiler_params=plgpu.TritonCompilerParams(
num_warps=num_warps_, num_stages=num_stages
),
out_shape=out_shape,
debug=debug,
interpret=interpret,
name="mha_forward",
)(q, k, v, segment_ids)
return out, (q, k, v, segment_ids, out, lse)
out, lse = mha(q, k, v, segment_ids=segment_ids, sm_scale=sm_scale,
causal=causal, block_sizes=block_sizes,
backward_pass_impl=backward_pass_impl,
num_warps=num_warps, num_stages=num_stages,
grid=grid, interpret=interpret, debug=debug,
return_residuals=True)
residuals = (q, k, v, segment_ids, out, lse)
ret = (out, lse) if return_residuals else out
return ret, residuals


def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref, head_dim: int):
Expand Down Expand Up @@ -576,9 +530,12 @@ def inner_loop_dq(start_k, dq):
def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes,
backward_pass_impl: str, num_warps: int | None,
num_stages: int, grid: Any, interpret: bool,
debug: bool, res, do):
del num_stages, grid
debug: bool, return_residuals: bool, res, do):
if return_residuals:
raise ValueError(
"Kernel differentiation is not supported if return_residuals is True.")
q, k, v, segment_ids, out, lse = res
del num_stages, grid, return_residuals

if backward_pass_impl == "xla":
return jax.vjp(
Expand Down
24 changes: 24 additions & 0 deletions tests/pallas/gpu_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,30 @@ def f_ref(q, k, v):
self.assertAllClose(dk, dk_ref, atol=5e-2)
self.assertAllClose(dv, dv_ref, atol=5e-2)

def test_return_residuals_not_differentiable(self):
batch_size, seq_len, num_heads, head_dim = 2, 128, 2, 128
causal = False
k1, k2, k3 = random.split(random.key(0), 3)
q = random.normal(
k1, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16
)
k = random.normal(
k2, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16
)
v = random.normal(
k3, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16
)
segment_ids = None

def f(q, k, v):
return attention.mha(q, k, v, causal=causal, segment_ids=segment_ids,
interpret=self.INTERPRET,
return_residuals=True)[0].sum()

with self.assertRaisesRegex(ValueError, "Kernel differentiation is not"
" supported if return_residuals is True."):
_ = jax.grad(f, argnums=(0, 1, 2))(q, k, v)


class FusedAttentionInterpretTest(FusedAttentionTest):
INTERPRET = True
Expand Down
Loading