Skip to content

Commit c4a90c1

Browse files
justinjfuGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add barrier transformation support to tcgen05_mma.
Also fix accumulator argument when it's dynamic. PiperOrigin-RevId: 762509416
1 parent f429162 commit c4a90c1

File tree

2 files changed

+106
-16
lines changed

2 files changed

+106
-16
lines changed

jax/_src/pallas/mosaic_gpu/primitives.py

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,7 @@ def _copy_gmem_to_smem_lowering(
554554
)
555555
return ()
556556

557+
557558
lowering.register_lowering_rule(
558559
copy_gmem_to_smem_p,
559560
mgpu.LoweringSemantics.Lane,
@@ -722,9 +723,14 @@ def _barrier_wait_pp_eqn(
722723

723724

724725
@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Lane)
725-
@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Lane,
726-
gpu_core.PrimitiveSemantics.Warp)
727-
@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Warpgroup)
726+
@lowering.register_lowering_rule(
727+
barrier_wait_p,
728+
mgpu.LoweringSemantics.Lane,
729+
gpu_core.PrimitiveSemantics.Warp,
730+
)
731+
@lowering.register_lowering_rule(
732+
barrier_wait_p, mgpu.LoweringSemantics.Warpgroup
733+
)
728734
def _barrier_wait_lowering(
729735
ctx: lowering.LoweringRuleContext,
730736
barrier,
@@ -1198,18 +1204,31 @@ def tcgen05_mma(acc: _Ref,
11981204
else:
11991205
b_transforms_leaves, b_transforms_tree = [], None
12001206

1207+
if isinstance(barrier, pallas_core.TransformedRef):
1208+
barrier_transforms_leaves, barrier_transforms_tree = jax.tree.flatten(
1209+
barrier.transforms
1210+
)
1211+
barrier = barrier.ref
1212+
else:
1213+
barrier_transforms_leaves, barrier_transforms_tree = [], None
1214+
12011215
tcgen05_mma_p.bind(acc, a, b, barrier, accumulate,
12021216
*a_transforms_leaves, *b_transforms_leaves,
1217+
*barrier_transforms_leaves,
12031218
a_transforms_tree=a_transforms_tree,
12041219
b_transforms_tree=b_transforms_tree,
1220+
barrier_transforms_tree=barrier_transforms_tree,
12051221
collective_axis=collective_axis)
12061222

1223+
12071224
@tcgen05_mma_p.def_abstract_eval
12081225
def _tcgen05_mma_abstract_eval(acc, a, b, barrier, accumulate,
12091226
*transforms_leaves,
12101227
a_transforms_tree, b_transforms_tree,
1228+
barrier_transforms_tree,
12111229
collective_axis):
1212-
del (accumulate, transforms_leaves, a_transforms_tree, b_transforms_tree)
1230+
del (accumulate, transforms_leaves, a_transforms_tree, b_transforms_tree,
1231+
barrier_transforms_tree)
12131232

12141233
if acc.memory_space != gpu_core.TMEM:
12151234
raise ValueError("Accumulator must be a TMEM Ref.")
@@ -1233,6 +1252,7 @@ def _tcgen05_mma_abstract_eval(acc, a, b, barrier, accumulate,
12331252

12341253
return []
12351254

1255+
12361256
@lowering.register_lowering_rule(tcgen05_mma_p, *gpu_core.LANExWG_SEMANTICS)
12371257
@lowering.register_lowering_rule(tcgen05_mma_p, *gpu_core.LANExWARP_SEMANTICS)
12381258
def _tcgen05_mma_lowering(
@@ -1245,16 +1265,26 @@ def _tcgen05_mma_lowering(
12451265
*transforms_leaves,
12461266
a_transforms_tree,
12471267
b_transforms_tree,
1268+
barrier_transforms_tree,
12481269
collective_axis,
12491270
):
12501271
_, a_aval, b_aval, *_ = ctx.avals_in
12511272
lhs_swizzle: int | None = None
12521273
lhs_transpose: bool = False
1253-
if a_transforms_tree is not None:
1254-
a_transforms_leaves, b_transforms_leaves = util.split_list(
1255-
transforms_leaves, [a_transforms_tree.num_leaves]
1256-
)
12571274

1275+
transforms_trees = (
1276+
a_transforms_tree,
1277+
b_transforms_tree,
1278+
barrier_transforms_tree,
1279+
)
1280+
(a_transforms_leaves, b_transforms_leaves, barrier_transforms_leaves, _) = (
1281+
util.split_list(
1282+
transforms_leaves,
1283+
[getattr(tree, "num_leaves", 0) for tree in transforms_trees],
1284+
)
1285+
)
1286+
1287+
if a_transforms_tree is not None:
12581288
a_transforms = a_transforms_tree.unflatten(a_transforms_leaves)
12591289
a_ref, a_transforms = lowering._handle_transforms(
12601290
ctx, a_ref, a_transforms, handle_transposes=False, handle_reshapes=True
@@ -1276,9 +1306,8 @@ def _tcgen05_mma_lowering(
12761306
if lhs_tiling != (8, swizzle_elems):
12771307
raise ValueError("MMA lhs tiling does not fit swizzle. "
12781308
f"{lhs_tiling=} expected={(8, swizzle_elems)}")
1279-
else:
1280-
b_transforms_leaves = transforms_leaves # type: ignore
12811309

1310+
assert b_transforms_tree is not None
12821311
b_transforms = b_transforms_tree.unflatten(b_transforms_leaves)
12831312
b_ref, b_transforms = lowering._handle_transforms(
12841313
ctx, b_ref, b_transforms, handle_transposes=False, handle_reshapes=True
@@ -1296,16 +1325,28 @@ def _tcgen05_mma_lowering(
12961325
raise NotImplementedError(
12971326
f"Unsupported transforms: {b_transforms}."
12981327
)
1299-
13001328
swizzle_elems = rhs_swizzle // b_aval.dtype.itemsize
1329+
if rhs_tiling != (8, swizzle_elems):
1330+
raise ValueError(
1331+
"MMA rhs tiling does not fit swizzle"
1332+
f" {rhs_tiling=} expected={(8, swizzle_elems)}"
1333+
)
1334+
1335+
if barrier_transforms_tree is not None:
1336+
barrier_transforms = barrier_transforms_tree.unflatten(
1337+
barrier_transforms_leaves
1338+
)
1339+
indexer = _extract_barrier_indexer(barrier_transforms)
1340+
if indexer is not None:
1341+
barrier_ref = barrier_ref.__getitem__(
1342+
*map(lowering._as_index, indexer.indices)
1343+
)
1344+
13011345
if lhs_swizzle is None:
13021346
lhs_swizzle = rhs_swizzle
13031347
elif rhs_swizzle != lhs_swizzle:
13041348
raise ValueError("MMA rhs swizzle must match lhs swizzle."
13051349
f" {lhs_swizzle=} {rhs_swizzle=}")
1306-
if rhs_tiling != (8, swizzle_elems):
1307-
raise ValueError("MMA rhs tiling does not fit swizzle"
1308-
f" {rhs_tiling=} expected={(8, swizzle_elems)}")
13091350
if lhs_transpose:
13101351
if isinstance(a_ref, tcgen05.TMEMRef):
13111352
raise ValueError("TMEM transpose not allowed.")
@@ -1314,6 +1355,9 @@ def _tcgen05_mma_lowering(
13141355
b_ref = mgpu.memref_transpose(b_ref, (1, 0, 3, 2))
13151356
if isinstance(accumulate, bool):
13161357
accumulate = mgpu.c(accumulate, ir.IntegerType.get_signless(1))
1358+
elif isinstance(accumulate, mgpu.FragmentedArray):
1359+
accumulate = accumulate.registers.item()
1360+
assert isinstance(accumulate, ir.Value)
13171361

13181362
predicate = ctx.module_ctx.single_lane_predicate
13191363
collective = False
@@ -1341,8 +1385,8 @@ def _tcgen05_mma_lowering(
13411385
acc,
13421386
a_ref,
13431387
b_ref,
1344-
a_swizzle=lhs_swizzle,
1345-
b_swizzle=rhs_swizzle,
1388+
a_swizzle=int(lhs_swizzle),
1389+
b_swizzle=int(rhs_swizzle),
13461390
accumulate=accumulate,
13471391
collective=collective,
13481392
)

tests/pallas/mosaic_gpu_test.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2496,6 +2496,52 @@ def _scoped(a_smem, b_smem,
24962496
expected = x @ y
24972497
np.testing.assert_allclose(result, expected, rtol=1e-3)
24982498

2499+
@parameterized.parameters((0,), (1,))
2500+
def test_mma_barrier_indexing(
2501+
self, barrier_index, shape=(128, 128), swizzle=128, dtype=jnp.float16
2502+
):
2503+
self.skip_if_wg_semantics()
2504+
swizzle_elems = swizzle // jnp.dtype(dtype).itemsize
2505+
transforms = (
2506+
plgpu.TilingTransform((8, swizzle_elems)),
2507+
plgpu.SwizzleTransform(swizzle),
2508+
)
2509+
2510+
def kernel(a_smem, b_smem, out_ref, acc_tmem, scratch_smem, barrier_ref):
2511+
plgpu.tcgen05_mma(
2512+
acc_tmem,
2513+
a_smem,
2514+
b_smem,
2515+
barrier_ref.at[barrier_index],
2516+
accumulate=False,
2517+
)
2518+
plgpu.barrier_wait(barrier_ref.at[barrier_index])
2519+
scratch_smem[...] = acc_tmem[...].astype(dtype)
2520+
plgpu.commit_smem()
2521+
plgpu.copy_smem_to_gmem(scratch_smem, out_ref)
2522+
plgpu.wait_smem_to_gmem(0)
2523+
2524+
scratch_shapes = [
2525+
plgpu.TMEM(shape, jnp.float32, packed=False),
2526+
plgpu.SMEM(shape, dtype, transforms=transforms),
2527+
plgpu.Barrier(num_arrivals=1, num_barriers=2, for_tensor_core=True),
2528+
]
2529+
f = self.pallas_call(
2530+
kernel,
2531+
in_specs=(
2532+
plgpu.BlockSpec(transforms=transforms, memory_space=plgpu.SMEM),
2533+
plgpu.BlockSpec(transforms=transforms, memory_space=plgpu.SMEM),
2534+
),
2535+
out_specs=plgpu.BlockSpec(memory_space=plgpu.GMEM),
2536+
out_shape=jax.ShapeDtypeStruct(shape, dtype),
2537+
scratch_shapes=scratch_shapes,
2538+
)
2539+
x = jax.random.uniform(jax.random.key(0), shape=shape, dtype=dtype)
2540+
y = jax.random.uniform(jax.random.key(1), shape=shape, dtype=dtype)
2541+
result = f(x, y)
2542+
expected = x @ y
2543+
np.testing.assert_allclose(result, expected, rtol=1e-3)
2544+
24992545

25002546
class PallasCallSm100AWGTest(
25012547
PallasCallSm100ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup

0 commit comments

Comments
 (0)