Skip to content

[Mosaic GPU] Add barrier transformation support to tcgen05_mma. #28917

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 60 additions & 16 deletions jax/_src/pallas/mosaic_gpu/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,7 @@ def _copy_gmem_to_smem_lowering(
)
return ()


lowering.register_lowering_rule(
copy_gmem_to_smem_p,
mgpu.LoweringSemantics.Lane,
Expand Down Expand Up @@ -722,9 +723,14 @@ def _barrier_wait_pp_eqn(


@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Lane)
@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Lane,
gpu_core.PrimitiveSemantics.Warp)
@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Warpgroup)
@lowering.register_lowering_rule(
barrier_wait_p,
mgpu.LoweringSemantics.Lane,
gpu_core.PrimitiveSemantics.Warp,
)
@lowering.register_lowering_rule(
barrier_wait_p, mgpu.LoweringSemantics.Warpgroup
)
def _barrier_wait_lowering(
ctx: lowering.LoweringRuleContext,
barrier,
Expand Down Expand Up @@ -1198,18 +1204,31 @@ def tcgen05_mma(acc: _Ref,
else:
b_transforms_leaves, b_transforms_tree = [], None

if isinstance(barrier, pallas_core.TransformedRef):
barrier_transforms_leaves, barrier_transforms_tree = jax.tree.flatten(
barrier.transforms
)
barrier = barrier.ref
else:
barrier_transforms_leaves, barrier_transforms_tree = [], None

tcgen05_mma_p.bind(acc, a, b, barrier, accumulate,
*a_transforms_leaves, *b_transforms_leaves,
*barrier_transforms_leaves,
a_transforms_tree=a_transforms_tree,
b_transforms_tree=b_transforms_tree,
barrier_transforms_tree=barrier_transforms_tree,
collective_axis=collective_axis)


@tcgen05_mma_p.def_abstract_eval
def _tcgen05_mma_abstract_eval(acc, a, b, barrier, accumulate,
*transforms_leaves,
a_transforms_tree, b_transforms_tree,
barrier_transforms_tree,
collective_axis):
del (accumulate, transforms_leaves, a_transforms_tree, b_transforms_tree)
del (accumulate, transforms_leaves, a_transforms_tree, b_transforms_tree,
barrier_transforms_tree)

if acc.memory_space != gpu_core.TMEM:
raise ValueError("Accumulator must be a TMEM Ref.")
Expand All @@ -1233,6 +1252,7 @@ def _tcgen05_mma_abstract_eval(acc, a, b, barrier, accumulate,

return []


@lowering.register_lowering_rule(tcgen05_mma_p, *gpu_core.LANExWG_SEMANTICS)
@lowering.register_lowering_rule(tcgen05_mma_p, *gpu_core.LANExWARP_SEMANTICS)
def _tcgen05_mma_lowering(
Expand All @@ -1245,16 +1265,26 @@ def _tcgen05_mma_lowering(
*transforms_leaves,
a_transforms_tree,
b_transforms_tree,
barrier_transforms_tree,
collective_axis,
):
_, a_aval, b_aval, *_ = ctx.avals_in
lhs_swizzle: int | None = None
lhs_transpose: bool = False
if a_transforms_tree is not None:
a_transforms_leaves, b_transforms_leaves = util.split_list(
transforms_leaves, [a_transforms_tree.num_leaves]
)

transforms_trees = (
a_transforms_tree,
b_transforms_tree,
barrier_transforms_tree,
)
(a_transforms_leaves, b_transforms_leaves, barrier_transforms_leaves, _) = (
util.split_list(
transforms_leaves,
[getattr(tree, "num_leaves", 0) for tree in transforms_trees],
)
)

if a_transforms_tree is not None:
a_transforms = a_transforms_tree.unflatten(a_transforms_leaves)
a_ref, a_transforms = lowering._handle_transforms(
ctx, a_ref, a_transforms, handle_transposes=False, handle_reshapes=True
Expand All @@ -1276,9 +1306,8 @@ def _tcgen05_mma_lowering(
if lhs_tiling != (8, swizzle_elems):
raise ValueError("MMA lhs tiling does not fit swizzle. "
f"{lhs_tiling=} expected={(8, swizzle_elems)}")
else:
b_transforms_leaves = transforms_leaves # type: ignore

assert b_transforms_tree is not None
b_transforms = b_transforms_tree.unflatten(b_transforms_leaves)
b_ref, b_transforms = lowering._handle_transforms(
ctx, b_ref, b_transforms, handle_transposes=False, handle_reshapes=True
Expand All @@ -1296,16 +1325,28 @@ def _tcgen05_mma_lowering(
raise NotImplementedError(
f"Unsupported transforms: {b_transforms}."
)

swizzle_elems = rhs_swizzle // b_aval.dtype.itemsize
if rhs_tiling != (8, swizzle_elems):
raise ValueError(
"MMA rhs tiling does not fit swizzle"
f" {rhs_tiling=} expected={(8, swizzle_elems)}"
)

if barrier_transforms_tree is not None:
barrier_transforms = barrier_transforms_tree.unflatten(
barrier_transforms_leaves
)
indexer = _extract_barrier_indexer(barrier_transforms)
if indexer is not None:
barrier_ref = barrier_ref.__getitem__(
*map(lowering._as_index, indexer.indices)
)

if lhs_swizzle is None:
lhs_swizzle = rhs_swizzle
elif rhs_swizzle != lhs_swizzle:
raise ValueError("MMA rhs swizzle must match lhs swizzle."
f" {lhs_swizzle=} {rhs_swizzle=}")
if rhs_tiling != (8, swizzle_elems):
raise ValueError("MMA rhs tiling does not fit swizzle"
f" {rhs_tiling=} expected={(8, swizzle_elems)}")
if lhs_transpose:
if isinstance(a_ref, tcgen05.TMEMRef):
raise ValueError("TMEM transpose not allowed.")
Expand All @@ -1314,6 +1355,9 @@ def _tcgen05_mma_lowering(
b_ref = mgpu.memref_transpose(b_ref, (1, 0, 3, 2))
if isinstance(accumulate, bool):
accumulate = mgpu.c(accumulate, ir.IntegerType.get_signless(1))
elif isinstance(accumulate, mgpu.FragmentedArray):
accumulate = accumulate.registers.item()
assert isinstance(accumulate, ir.Value)

predicate = ctx.module_ctx.single_lane_predicate
collective = False
Expand Down Expand Up @@ -1341,8 +1385,8 @@ def _tcgen05_mma_lowering(
acc,
a_ref,
b_ref,
a_swizzle=lhs_swizzle,
b_swizzle=rhs_swizzle,
a_swizzle=int(lhs_swizzle),
b_swizzle=int(rhs_swizzle),
accumulate=accumulate,
collective=collective,
)
Expand Down
46 changes: 46 additions & 0 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2496,6 +2496,52 @@ def _scoped(a_smem, b_smem,
expected = x @ y
np.testing.assert_allclose(result, expected, rtol=1e-3)

@parameterized.parameters((0,), (1,))
def test_mma_barrier_indexing(
self, barrier_index, shape=(128, 128), swizzle=128, dtype=jnp.float16
):
self.skip_if_wg_semantics()
swizzle_elems = swizzle // jnp.dtype(dtype).itemsize
transforms = (
plgpu.TilingTransform((8, swizzle_elems)),
plgpu.SwizzleTransform(swizzle),
)

def kernel(a_smem, b_smem, out_ref, acc_tmem, scratch_smem, barrier_ref):
plgpu.tcgen05_mma(
acc_tmem,
a_smem,
b_smem,
barrier_ref.at[barrier_index],
accumulate=False,
)
plgpu.barrier_wait(barrier_ref.at[barrier_index])
scratch_smem[...] = acc_tmem[...].astype(dtype)
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(scratch_smem, out_ref)
plgpu.wait_smem_to_gmem(0)

scratch_shapes = [
plgpu.TMEM(shape, jnp.float32, packed=False),
plgpu.SMEM(shape, dtype, transforms=transforms),
plgpu.Barrier(num_arrivals=1, num_barriers=2, for_tensor_core=True),
]
f = self.pallas_call(
kernel,
in_specs=(
plgpu.BlockSpec(transforms=transforms, memory_space=plgpu.SMEM),
plgpu.BlockSpec(transforms=transforms, memory_space=plgpu.SMEM),
),
out_specs=plgpu.BlockSpec(memory_space=plgpu.GMEM),
out_shape=jax.ShapeDtypeStruct(shape, dtype),
scratch_shapes=scratch_shapes,
)
x = jax.random.uniform(jax.random.key(0), shape=shape, dtype=dtype)
y = jax.random.uniform(jax.random.key(1), shape=shape, dtype=dtype)
result = f(x, y)
expected = x @ y
np.testing.assert_allclose(result, expected, rtol=1e-3)


class PallasCallSm100AWGTest(
PallasCallSm100ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup
Expand Down
Loading