Skip to content

Commit 4fddf73

Browse files
justinjfuGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add non-collective blackwell matmul example
PiperOrigin-RevId: 761718971
1 parent e71d5d5 commit 4fddf73

File tree

4 files changed

+377
-38
lines changed

4 files changed

+377
-38
lines changed

jax/_src/pallas/mosaic_gpu/primitives.py

Lines changed: 87 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import jax.numpy as jnp
5050

5151

52+
WARP_SIZE = 32
5253
WARPGROUP_SIZE = 128
5354

5455

@@ -464,7 +465,7 @@ def _copy_gmem_to_smem_lowering(
464465
dst_transforms_treedef,
465466
barrier_transforms_treedef,
466467
collective_axes,
467-
warpgroup_sync: bool = True,
468+
for_warpgroup: bool = True,
468469
):
469470
flat_src_transforms, flat_dst_transforms, flat_barrier_transforms = (
470471
util.split_list(
@@ -505,15 +506,23 @@ def _copy_gmem_to_smem_lowering(
505506
if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane:
506507
if bytes % WARPGROUP_SIZE:
507508
raise NotImplementedError("Only aligned copies are supported")
508-
# We arrive uniformly from each thread in the WG, so we need to divide the
509-
# number of bytes by the number of threads in the WG.
510-
# TODO: apaszke - Relax this. We can just select the WG leader and have it
511-
# arrive with the whole transfer size, while everyone else arrives with 0.
512-
# But we should continue using this scheme as it's likely to be faster.
513-
bytes //= WARPGROUP_SIZE
514-
if warpgroup_sync:
509+
if for_warpgroup:
510+
# We arrive uniformly from each thread in the WG, so we need to divide the
511+
# number of bytes by the number of threads in the WG.
512+
# TODO: apaszke - Relax this. We can just select the WG leader and have it
513+
# arrive with the whole transfer size, while everyone else arrives with 0.
514+
# But we should continue using this scheme as it's likely to be faster.
515+
bytes //= WARPGROUP_SIZE
515516
mgpu.warpgroup_barrier() # Make sure all reads have completed.
516-
barrier.arrive_expect_tx(bytes)
517+
barrier.arrive_expect_tx(bytes)
518+
else:
519+
# In Warp-level lowering, we arrive on each CUDA thread in a warp, but
520+
# the barrier still expects a full 128 arrivals so we arrive 4 times
521+
# on each CUDA thread.
522+
bytes //= WARP_SIZE
523+
barrier.arrive_nocomplete(3)
524+
barrier.arrive_expect_tx(bytes)
525+
517526
ctx.launch_ctx.async_copy(
518527
src_ref=src,
519528
dst_ref=dst,
@@ -549,7 +558,7 @@ def _copy_gmem_to_smem_lowering(
549558
copy_gmem_to_smem_p,
550559
mgpu.LoweringSemantics.Lane,
551560
primitive_semantics=gpu_core.PrimitiveSemantics.Warp,
552-
)(functools.partial(_copy_gmem_to_smem_lowering, warpgroup_sync=False))
561+
)(functools.partial(_copy_gmem_to_smem_lowering, for_warpgroup=False))
553562

554563

555564
def copy_gmem_to_smem(
@@ -713,6 +722,8 @@ def _barrier_wait_pp_eqn(
713722

714723

715724
@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Lane)
725+
@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Lane,
726+
gpu_core.PrimitiveSemantics.Warp)
716727
@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Warpgroup)
717728
def _barrier_wait_lowering(
718729
ctx: lowering.LoweringRuleContext,
@@ -1187,18 +1198,29 @@ def tcgen05_mma(acc: _Ref,
11871198
else:
11881199
b_transforms_leaves, b_transforms_tree = [], None
11891200

1201+
if isinstance(barrier, pallas_core.TransformedRef):
1202+
barrier_transforms_leaves, barrier_transforms_tree = jax.tree.flatten(
1203+
barrier.transforms)
1204+
barrier = barrier.ref
1205+
else:
1206+
barrier_transforms_leaves, barrier_transforms_tree = [], None
1207+
11901208
tcgen05_mma_p.bind(acc, a, b, barrier, accumulate,
11911209
*a_transforms_leaves, *b_transforms_leaves,
1210+
*barrier_transforms_leaves,
11921211
a_transforms_tree=a_transforms_tree,
11931212
b_transforms_tree=b_transforms_tree,
1213+
barrier_transforms_tree=barrier_transforms_tree,
11941214
collective_axis=collective_axis)
11951215

11961216
@tcgen05_mma_p.def_abstract_eval
11971217
def _tcgen05_mma_abstract_eval(acc, a, b, barrier, accumulate,
11981218
*transforms_leaves,
11991219
a_transforms_tree, b_transforms_tree,
1220+
barrier_transforms_tree,
12001221
collective_axis):
1201-
del (accumulate, transforms_leaves, a_transforms_tree, b_transforms_tree)
1222+
del (accumulate, transforms_leaves, a_transforms_tree, b_transforms_tree,
1223+
barrier_transforms_tree)
12021224

12031225
if acc.memory_space != gpu_core.TMEM:
12041226
raise ValueError("Accumulator must be a TMEM Ref.")
@@ -1222,6 +1244,20 @@ def _tcgen05_mma_abstract_eval(acc, a, b, barrier, accumulate,
12221244

12231245
return []
12241246

1247+
1248+
def _split_transforms(all_transforms_leaves, transforms_trees) -> list[Any]:
1249+
transform_leaves = []
1250+
for transforms_tree in transforms_trees:
1251+
if transforms_tree is None:
1252+
transform_leaves.append([])
1253+
continue
1254+
current_leaves, all_transforms_leaves = util.split_list(
1255+
all_transforms_leaves, [transforms_tree.num_leaves]
1256+
)
1257+
transform_leaves.append(current_leaves)
1258+
return transform_leaves
1259+
1260+
12251261
@lowering.register_lowering_rule(tcgen05_mma_p, *gpu_core.LANExWG_SEMANTICS)
12261262
@lowering.register_lowering_rule(tcgen05_mma_p, *gpu_core.LANExWARP_SEMANTICS)
12271263
def _tcgen05_mma_lowering(
@@ -1234,16 +1270,20 @@ def _tcgen05_mma_lowering(
12341270
*transforms_leaves,
12351271
a_transforms_tree,
12361272
b_transforms_tree,
1273+
barrier_transforms_tree,
12371274
collective_axis,
12381275
):
12391276
_, a_aval, b_aval, *_ = ctx.avals_in
12401277
lhs_swizzle: int | None = None
1278+
rhs_swizzle: int | None = None
12411279
lhs_transpose: bool = False
1242-
if a_transforms_tree is not None:
1243-
a_transforms_leaves, b_transforms_leaves = util.split_list(
1244-
transforms_leaves, [a_transforms_tree.num_leaves]
1245-
)
1280+
rhs_transpose: bool = False
12461281

1282+
a_transforms_leaves, b_transforms_leaves, barrier_transforms_leaves = (
1283+
_split_transforms(transforms_leaves,
1284+
[a_transforms_tree, b_transforms_tree, barrier_transforms_tree])
1285+
)
1286+
if a_transforms_tree is not None:
12471287
a_transforms = a_transforms_tree.unflatten(a_transforms_leaves)
12481288
a_ref, a_transforms = lowering._handle_transforms(
12491289
ctx, a_ref, a_transforms, handle_transposes=False, handle_reshapes=True
@@ -1265,36 +1305,42 @@ def _tcgen05_mma_lowering(
12651305
if lhs_tiling != (8, swizzle_elems):
12661306
raise ValueError("MMA lhs tiling does not fit swizzle. "
12671307
f"{lhs_tiling=} expected={(8, swizzle_elems)}")
1268-
else:
1269-
b_transforms_leaves = transforms_leaves # type: ignore
12701308

1271-
b_transforms = b_transforms_tree.unflatten(b_transforms_leaves)
1272-
b_ref, b_transforms = lowering._handle_transforms(
1273-
ctx, b_ref, b_transforms, handle_transposes=False, handle_reshapes=True
1274-
)
1275-
match b_transforms:
1276-
case (gpu_core.UnswizzleRef(rhs_swizzle), gpu_core.UntileRef(rhs_tiling)):
1277-
rhs_transpose = False
1278-
case (
1279-
gpu_core.UnswizzleRef(rhs_swizzle),
1280-
gpu_core.UntileRef(rhs_tiling),
1281-
gpu_core.TransposeRef((1, 0)),
1282-
):
1283-
rhs_transpose = True
1284-
case _:
1285-
raise NotImplementedError(
1286-
f"Unsupported transforms: {b_transforms}."
1287-
)
1309+
if b_transforms_tree is not None:
1310+
b_transforms = b_transforms_tree.unflatten(b_transforms_leaves)
1311+
b_ref, b_transforms = lowering._handle_transforms(
1312+
ctx, b_ref, b_transforms, handle_transposes=False, handle_reshapes=True
1313+
)
1314+
match b_transforms:
1315+
case (gpu_core.UnswizzleRef(rhs_swizzle), gpu_core.UntileRef(rhs_tiling)):
1316+
rhs_transpose = False
1317+
case (
1318+
gpu_core.UnswizzleRef(rhs_swizzle),
1319+
gpu_core.UntileRef(rhs_tiling),
1320+
gpu_core.TransposeRef((1, 0)),
1321+
):
1322+
rhs_transpose = True
1323+
case _:
1324+
raise NotImplementedError(
1325+
f"Unsupported transforms: {b_transforms}."
1326+
)
1327+
swizzle_elems = rhs_swizzle // b_aval.dtype.itemsize
1328+
if rhs_tiling != (8, swizzle_elems):
1329+
raise ValueError("MMA rhs tiling does not fit swizzle"
1330+
f" {rhs_tiling=} expected={(8, swizzle_elems)}")
1331+
1332+
if barrier_transforms_tree is not None:
1333+
barrier_transforms = barrier_transforms_tree.unflatten(
1334+
barrier_transforms_leaves)
1335+
indexer = _extract_barrier_indexer(barrier_transforms)
1336+
if indexer is not None:
1337+
barrier_ref = barrier_ref.__getitem__(*map(lowering._as_index, indexer.indices))
12881338

1289-
swizzle_elems = rhs_swizzle // b_aval.dtype.itemsize
12901339
if lhs_swizzle is None:
12911340
lhs_swizzle = rhs_swizzle
12921341
elif rhs_swizzle != lhs_swizzle:
12931342
raise ValueError("MMA rhs swizzle must match lhs swizzle."
12941343
f" {lhs_swizzle=} {rhs_swizzle=}")
1295-
if rhs_tiling != (8, swizzle_elems):
1296-
raise ValueError("MMA rhs tiling does not fit swizzle"
1297-
f" {rhs_tiling=} expected={(8, swizzle_elems)}")
12981344
if lhs_transpose:
12991345
if isinstance(a_ref, tcgen05.TMEMRef):
13001346
raise ValueError("TMEM transpose not allowed.")
@@ -1303,6 +1349,9 @@ def _tcgen05_mma_lowering(
13031349
b_ref = mgpu.memref_transpose(b_ref, (1, 0, 3, 2))
13041350
if isinstance(accumulate, bool):
13051351
accumulate = mgpu.c(accumulate, ir.IntegerType.get_signless(1))
1352+
elif isinstance(accumulate, mgpu.FragmentedArray):
1353+
accumulate = accumulate.registers.item()
1354+
assert isinstance(accumulate, ir.Value)
13061355

13071356
predicate = ctx.module_ctx.single_lane_predicate
13081357
collective = False

jax/experimental/mosaic/gpu/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,12 @@ def arrive_expect_tx(
830830
bytes = arith.index_cast(i32, bytes)
831831
nvvm.mbarrier_arrive_expect_tx_shared(self.get_ptr(), bytes, predicate=predicate)
832832

833+
def arrive_nocomplete(self, count: int) -> None:
834+
i64 = ir.IntegerType.get_signless(64)
835+
if isinstance(count, int):
836+
count = c(count, ir.IntegerType.get_signless(32))
837+
nvvm.mbarrier_arrive_nocomplete_shared(i64, self.get_ptr(), count)
838+
833839
def get_ptr(self):
834840
ptr = ir.Type.parse(f"!llvm.ptr<{WORKGROUP_NVPTX_ADDRESS_SPACE}>")
835841
i64 = ir.IntegerType.get_signless(64)

0 commit comments

Comments
 (0)