Skip to content

Commit 1b79be4

Browse files
justinjfuGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add barrier transformation support to tcgen05_mma.
Also fix accumulator argument when it's dynamic. PiperOrigin-RevId: 761713920
1 parent 8da86ea commit 1b79be4

File tree

4 files changed

+304
-98
lines changed

4 files changed

+304
-98
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,6 +1694,19 @@ def convert(ty, x):
16941694
lax.not_p: lambda ctx, x: ~x,
16951695
})
16961696

1697+
def _unary_warp_lowering_rule(impl):
1698+
def _lowering_rule(ctx: LoweringRuleContext, x):
1699+
if not all(aval_in.shape == () for aval_in in ctx.avals_in):
1700+
raise NotImplementedError(
1701+
"Non-scalar arithmetic is not supported in warp-level lowering.")
1702+
return impl(x)
1703+
return _lowering_rule
1704+
1705+
mosaic_lowering_rules[gpu_core.LANExWARP_SEMANTICS].update({
1706+
lax.neg_p: _unary_warp_lowering_rule(lambda x: -x),
1707+
lax.not_p: _unary_warp_lowering_rule(lambda x: ~x)
1708+
})
1709+
16971710
mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS].update({
16981711
lax.neg_p: _lower_fun(lambda x: jnp.subtract(0, x), multiple_results=False),
16991712
lax.not_p: _lower_fun(
@@ -2159,6 +2172,8 @@ def _axis_index_warp_rule(ctx: LoweringRuleContext, *, axis_name: Hashable):
21592172

21602173

21612174
@register_lowering_rule(primitives.debug_print_p, mgpu.LoweringSemantics.Lane)
2175+
@register_lowering_rule(primitives.debug_print_p, mgpu.LoweringSemantics.Lane,
2176+
gpu_core.PrimitiveSemantics.Warp)
21622177
def _debug_print_lowering_rule(
21632178
ctx: LoweringRuleContext,
21642179
*args,
@@ -2167,13 +2182,17 @@ def _debug_print_lowering_rule(
21672182
):
21682183
del has_placeholders # Unused.
21692184
primitives.check_debug_print_format(fmt, *args)
2185+
scope = mgpu.ThreadSubset.WARPGROUP
2186+
if ctx.module_ctx.primitive_semantics == gpu_core.PrimitiveSemantics.Warp:
2187+
scope = mgpu.ThreadSubset.WARP
21702188
if not any(aval.shape for aval in ctx.avals_in):
21712189
mgpu.debug_print(
21722190
fmt,
21732191
*(
21742192
_ensure_ir_value(arg, aval.dtype)
21752193
for arg, aval in zip(args, ctx.avals_in)
21762194
),
2195+
scope=scope
21772196
)
21782197
elif len(ctx.avals_in) == 1:
21792198
[arg] = args
@@ -2451,6 +2470,8 @@ def loop(loop_index, body_args):
24512470

24522471
@register_lowering_rule(lax.scan_p, mgpu.LoweringSemantics.Lane)
24532472
@register_lowering_rule(lax.scan_p, mgpu.LoweringSemantics.Warpgroup)
2473+
@register_lowering_rule(lax.scan_p, mgpu.LoweringSemantics.Lane,
2474+
gpu_core.PrimitiveSemantics.Warp)
24542475
def _scan_lowering_rule(
24552476
ctx: LoweringRuleContext,
24562477
*args,

jax/_src/pallas/mosaic_gpu/primitives.py

Lines changed: 87 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import jax.numpy as jnp
5050

5151

52+
WARP_SIZE = 32
5253
WARPGROUP_SIZE = 128
5354

5455

@@ -464,7 +465,7 @@ def _copy_gmem_to_smem_lowering(
464465
dst_transforms_treedef,
465466
barrier_transforms_treedef,
466467
collective_axes,
467-
warpgroup_sync: bool = True,
468+
for_warpgroup: bool = True,
468469
):
469470
flat_src_transforms, flat_dst_transforms, flat_barrier_transforms = (
470471
util.split_list(
@@ -505,15 +506,23 @@ def _copy_gmem_to_smem_lowering(
505506
if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane:
506507
if bytes % WARPGROUP_SIZE:
507508
raise NotImplementedError("Only aligned copies are supported")
508-
# We arrive uniformly from each thread in the WG, so we need to divide the
509-
# number of bytes by the number of threads in the WG.
510-
# TODO: apaszke - Relax this. We can just select the WG leader and have it
511-
# arrive with the whole transfer size, while everyone else arrives with 0.
512-
# But we should continue using this scheme as it's likely to be faster.
513-
bytes //= WARPGROUP_SIZE
514-
if warpgroup_sync:
509+
if for_warpgroup:
510+
# We arrive uniformly from each thread in the WG, so we need to divide the
511+
# number of bytes by the number of threads in the WG.
512+
# TODO: apaszke - Relax this. We can just select the WG leader and have it
513+
# arrive with the whole transfer size, while everyone else arrives with 0.
514+
# But we should continue using this scheme as it's likely to be faster.
515+
bytes //= WARPGROUP_SIZE
515516
mgpu.warpgroup_barrier() # Make sure all reads have completed.
516-
barrier.arrive_expect_tx(bytes)
517+
barrier.arrive_expect_tx(bytes)
518+
else:
519+
# In Warp-level lowering, we arrive on each CUDA thread in a warp, but
520+
# the barrier still expects a full 128 arrivals so we arrive 4 times
521+
# on each CUDA thread.
522+
bytes //= WARP_SIZE
523+
barrier.arrive_nocomplete(3)
524+
barrier.arrive_expect_tx(bytes)
525+
517526
ctx.launch_ctx.async_copy(
518527
src_ref=src,
519528
dst_ref=dst,
@@ -549,7 +558,7 @@ def _copy_gmem_to_smem_lowering(
549558
copy_gmem_to_smem_p,
550559
mgpu.LoweringSemantics.Lane,
551560
primitive_semantics=gpu_core.PrimitiveSemantics.Warp,
552-
)(functools.partial(_copy_gmem_to_smem_lowering, warpgroup_sync=False))
561+
)(functools.partial(_copy_gmem_to_smem_lowering, for_warpgroup=False))
553562

554563

555564
def copy_gmem_to_smem(
@@ -713,6 +722,8 @@ def _barrier_wait_pp_eqn(
713722

714723

715724
@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Lane)
725+
@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Lane,
726+
gpu_core.PrimitiveSemantics.Warp)
716727
@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Warpgroup)
717728
def _barrier_wait_lowering(
718729
ctx: lowering.LoweringRuleContext,
@@ -1187,18 +1198,29 @@ def tcgen05_mma(acc: _Ref,
11871198
else:
11881199
b_transforms_leaves, b_transforms_tree = [], None
11891200

1201+
if isinstance(barrier, pallas_core.TransformedRef):
1202+
barrier_transforms_leaves, barrier_transforms_tree = jax.tree.flatten(
1203+
barrier.transforms)
1204+
barrier = barrier.ref
1205+
else:
1206+
barrier_transforms_leaves, barrier_transforms_tree = [], None
1207+
11901208
tcgen05_mma_p.bind(acc, a, b, barrier, accumulate,
11911209
*a_transforms_leaves, *b_transforms_leaves,
1210+
*barrier_transforms_leaves,
11921211
a_transforms_tree=a_transforms_tree,
11931212
b_transforms_tree=b_transforms_tree,
1213+
barrier_transforms_tree=barrier_transforms_tree,
11941214
collective_axis=collective_axis)
11951215

11961216
@tcgen05_mma_p.def_abstract_eval
11971217
def _tcgen05_mma_abstract_eval(acc, a, b, barrier, accumulate,
11981218
*transforms_leaves,
11991219
a_transforms_tree, b_transforms_tree,
1220+
barrier_transforms_tree,
12001221
collective_axis):
1201-
del (accumulate, transforms_leaves, a_transforms_tree, b_transforms_tree)
1222+
del (accumulate, transforms_leaves, a_transforms_tree, b_transforms_tree,
1223+
barrier_transforms_tree)
12021224

12031225
if acc.memory_space != gpu_core.TMEM:
12041226
raise ValueError("Accumulator must be a TMEM Ref.")
@@ -1222,6 +1244,20 @@ def _tcgen05_mma_abstract_eval(acc, a, b, barrier, accumulate,
12221244

12231245
return []
12241246

1247+
1248+
def _split_transforms(all_transforms_leaves, transforms_trees) -> list[Any]:
1249+
transform_leaves = []
1250+
for transforms_tree in transforms_trees:
1251+
if transforms_tree is None:
1252+
transform_leaves.append([])
1253+
continue
1254+
current_leaves, all_transforms_leaves = util.split_list(
1255+
all_transforms_leaves, [transforms_tree.num_leaves]
1256+
)
1257+
transform_leaves.append(current_leaves)
1258+
return transform_leaves
1259+
1260+
12251261
@lowering.register_lowering_rule(tcgen05_mma_p, *gpu_core.LANExWG_SEMANTICS)
12261262
@lowering.register_lowering_rule(tcgen05_mma_p, *gpu_core.LANExWARP_SEMANTICS)
12271263
def _tcgen05_mma_lowering(
@@ -1234,16 +1270,20 @@ def _tcgen05_mma_lowering(
12341270
*transforms_leaves,
12351271
a_transforms_tree,
12361272
b_transforms_tree,
1273+
barrier_transforms_tree,
12371274
collective_axis,
12381275
):
12391276
_, a_aval, b_aval, *_ = ctx.avals_in
12401277
lhs_swizzle: int | None = None
1278+
rhs_swizzle: int | None = None
12411279
lhs_transpose: bool = False
1242-
if a_transforms_tree is not None:
1243-
a_transforms_leaves, b_transforms_leaves = util.split_list(
1244-
transforms_leaves, [a_transforms_tree.num_leaves]
1245-
)
1280+
rhs_transpose: bool = False
12461281

1282+
a_transforms_leaves, b_transforms_leaves, barrier_transforms_leaves = (
1283+
_split_transforms(transforms_leaves,
1284+
[a_transforms_tree, b_transforms_tree, barrier_transforms_tree])
1285+
)
1286+
if a_transforms_tree is not None:
12471287
a_transforms = a_transforms_tree.unflatten(a_transforms_leaves)
12481288
a_ref, a_transforms = lowering._handle_transforms(
12491289
ctx, a_ref, a_transforms, handle_transposes=False, handle_reshapes=True
@@ -1265,36 +1305,42 @@ def _tcgen05_mma_lowering(
12651305
if lhs_tiling != (8, swizzle_elems):
12661306
raise ValueError("MMA lhs tiling does not fit swizzle. "
12671307
f"{lhs_tiling=} expected={(8, swizzle_elems)}")
1268-
else:
1269-
b_transforms_leaves = transforms_leaves # type: ignore
12701308

1271-
b_transforms = b_transforms_tree.unflatten(b_transforms_leaves)
1272-
b_ref, b_transforms = lowering._handle_transforms(
1273-
ctx, b_ref, b_transforms, handle_transposes=False, handle_reshapes=True
1274-
)
1275-
match b_transforms:
1276-
case (gpu_core.UnswizzleRef(rhs_swizzle), gpu_core.UntileRef(rhs_tiling)):
1277-
rhs_transpose = False
1278-
case (
1279-
gpu_core.UnswizzleRef(rhs_swizzle),
1280-
gpu_core.UntileRef(rhs_tiling),
1281-
gpu_core.TransposeRef((1, 0)),
1282-
):
1283-
rhs_transpose = True
1284-
case _:
1285-
raise NotImplementedError(
1286-
f"Unsupported transforms: {b_transforms}."
1287-
)
1309+
if b_transforms_tree is not None:
1310+
b_transforms = b_transforms_tree.unflatten(b_transforms_leaves)
1311+
b_ref, b_transforms = lowering._handle_transforms(
1312+
ctx, b_ref, b_transforms, handle_transposes=False, handle_reshapes=True
1313+
)
1314+
match b_transforms:
1315+
case (gpu_core.UnswizzleRef(rhs_swizzle), gpu_core.UntileRef(rhs_tiling)):
1316+
rhs_transpose = False
1317+
case (
1318+
gpu_core.UnswizzleRef(rhs_swizzle),
1319+
gpu_core.UntileRef(rhs_tiling),
1320+
gpu_core.TransposeRef((1, 0)),
1321+
):
1322+
rhs_transpose = True
1323+
case _:
1324+
raise NotImplementedError(
1325+
f"Unsupported transforms: {b_transforms}."
1326+
)
1327+
swizzle_elems = rhs_swizzle // b_aval.dtype.itemsize
1328+
if rhs_tiling != (8, swizzle_elems):
1329+
raise ValueError("MMA rhs tiling does not fit swizzle"
1330+
f" {rhs_tiling=} expected={(8, swizzle_elems)}")
1331+
1332+
if barrier_transforms_tree is not None:
1333+
barrier_transforms = barrier_transforms_tree.unflatten(
1334+
barrier_transforms_leaves)
1335+
indexer = _extract_barrier_indexer(barrier_transforms)
1336+
if indexer is not None:
1337+
barrier_ref = barrier_ref.__getitem__(*map(lowering._as_index, indexer.indices))
12881338

1289-
swizzle_elems = rhs_swizzle // b_aval.dtype.itemsize
12901339
if lhs_swizzle is None:
12911340
lhs_swizzle = rhs_swizzle
12921341
elif rhs_swizzle != lhs_swizzle:
12931342
raise ValueError("MMA rhs swizzle must match lhs swizzle."
12941343
f" {lhs_swizzle=} {rhs_swizzle=}")
1295-
if rhs_tiling != (8, swizzle_elems):
1296-
raise ValueError("MMA rhs tiling does not fit swizzle"
1297-
f" {rhs_tiling=} expected={(8, swizzle_elems)}")
12981344
if lhs_transpose:
12991345
if isinstance(a_ref, tcgen05.TMEMRef):
13001346
raise ValueError("TMEM transpose not allowed.")
@@ -1303,6 +1349,9 @@ def _tcgen05_mma_lowering(
13031349
b_ref = mgpu.memref_transpose(b_ref, (1, 0, 3, 2))
13041350
if isinstance(accumulate, bool):
13051351
accumulate = mgpu.c(accumulate, ir.IntegerType.get_signless(1))
1352+
elif isinstance(accumulate, mgpu.FragmentedArray):
1353+
accumulate = accumulate.registers.item()
1354+
assert isinstance(accumulate, ir.Value)
13061355

13071356
predicate = ctx.module_ctx.single_lane_predicate
13081357
collective = False

jax/experimental/mosaic/gpu/utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,11 @@ def _debug_scalar_ty_format(arg):
144144
return "%f", arg
145145
raise NotImplementedError(f"Can't print the type {arg.type}")
146146

147-
def debug_print(fmt, *args, uniform=True):
147+
def debug_print(fmt, *args, uniform=True, scope=None):
148+
if not uniform and scope is not None:
149+
raise ValueError("Cannot specify scope to a non-uniform debug_print.")
150+
if scope is None:
151+
scope = ThreadSubset.WARPGROUP
148152
type_formats = []
149153
new_args = []
150154
for arg in args:
@@ -168,7 +172,7 @@ def debug_print(fmt, *args, uniform=True):
168172
raise NotImplementedError(arg.type)
169173
type_formats.append(ty_format)
170174
ctx = (
171-
functools.partial(single_thread, scope=ThreadSubset.WARPGROUP)
175+
functools.partial(single_thread, scope=scope)
172176
if uniform
173177
else contextlib.nullcontext
174178
)
@@ -822,6 +826,12 @@ def arrive_expect_tx(
822826
bytes = arith.index_cast(i32, bytes)
823827
nvvm.mbarrier_arrive_expect_tx_shared(self.get_ptr(), bytes, predicate=predicate)
824828

829+
def arrive_nocomplete(self, count: int) -> None:
830+
i64 = ir.IntegerType.get_signless(64)
831+
if isinstance(count, int):
832+
count = c(count, ir.IntegerType.get_signless(32))
833+
nvvm.mbarrier_arrive_nocomplete_shared(i64, self.get_ptr(), count)
834+
825835
def get_ptr(self):
826836
ptr = ir.Type.parse(f"!llvm.ptr<{WORKGROUP_NVPTX_ADDRESS_SPACE}>")
827837
i64 = ir.IntegerType.get_signless(64)

0 commit comments

Comments
 (0)