49
49
import jax .numpy as jnp
50
50
51
51
52
+ WARP_SIZE = 32
52
53
WARPGROUP_SIZE = 128
53
54
54
55
@@ -464,7 +465,7 @@ def _copy_gmem_to_smem_lowering(
464
465
dst_transforms_treedef ,
465
466
barrier_transforms_treedef ,
466
467
collective_axes ,
467
- warpgroup_sync : bool = True ,
468
+ for_warpgroup : bool = True ,
468
469
):
469
470
flat_src_transforms , flat_dst_transforms , flat_barrier_transforms = (
470
471
util .split_list (
@@ -505,15 +506,23 @@ def _copy_gmem_to_smem_lowering(
505
506
if ctx .module_ctx .lowering_semantics == mgpu .LoweringSemantics .Lane :
506
507
if bytes % WARPGROUP_SIZE :
507
508
raise NotImplementedError ("Only aligned copies are supported" )
508
- # We arrive uniformly from each thread in the WG, so we need to divide the
509
- # number of bytes by the number of threads in the WG.
510
- # TODO: apaszke - Relax this. We can just select the WG leader and have it
511
- # arrive with the whole transfer size, while everyone else arrives with 0.
512
- # But we should continue using this scheme as it's likely to be faster .
513
- bytes //= WARPGROUP_SIZE
514
- if warpgroup_sync :
509
+ if for_warpgroup :
510
+ # We arrive uniformly from each thread in the WG, so we need to divide the
511
+ # number of bytes by the number of threads in the WG.
512
+ # TODO: apaszke - Relax this. We can just select the WG leader and have it
513
+ # arrive with the whole transfer size, while everyone else arrives with 0 .
514
+ # But we should continue using this scheme as it's likely to be faster.
515
+ bytes //= WARPGROUP_SIZE
515
516
mgpu .warpgroup_barrier () # Make sure all reads have completed.
516
- barrier .arrive_expect_tx (bytes )
517
+ barrier .arrive_expect_tx (bytes )
518
+ else :
519
+ # In Warp-level lowering, we arrive on each CUDA thread in a warp, but
520
+ # the barrier still expects a full 128 arrivals so we arrive 4 times
521
+ # on each CUDA thread.
522
+ bytes //= WARP_SIZE
523
+ barrier .arrive_nocomplete (3 )
524
+ barrier .arrive_expect_tx (bytes )
525
+
517
526
ctx .launch_ctx .async_copy (
518
527
src_ref = src ,
519
528
dst_ref = dst ,
@@ -549,7 +558,7 @@ def _copy_gmem_to_smem_lowering(
549
558
copy_gmem_to_smem_p ,
550
559
mgpu .LoweringSemantics .Lane ,
551
560
primitive_semantics = gpu_core .PrimitiveSemantics .Warp ,
552
- )(functools .partial (_copy_gmem_to_smem_lowering , warpgroup_sync = False ))
561
+ )(functools .partial (_copy_gmem_to_smem_lowering , for_warpgroup = False ))
553
562
554
563
555
564
def copy_gmem_to_smem (
@@ -713,6 +722,8 @@ def _barrier_wait_pp_eqn(
713
722
714
723
715
724
@lowering .register_lowering_rule (barrier_wait_p , mgpu .LoweringSemantics .Lane )
725
+ @lowering .register_lowering_rule (barrier_wait_p , mgpu .LoweringSemantics .Lane ,
726
+ gpu_core .PrimitiveSemantics .Warp )
716
727
@lowering .register_lowering_rule (barrier_wait_p , mgpu .LoweringSemantics .Warpgroup )
717
728
def _barrier_wait_lowering (
718
729
ctx : lowering .LoweringRuleContext ,
@@ -1187,18 +1198,29 @@ def tcgen05_mma(acc: _Ref,
1187
1198
else :
1188
1199
b_transforms_leaves , b_transforms_tree = [], None
1189
1200
1201
+ if isinstance (barrier , pallas_core .TransformedRef ):
1202
+ barrier_transforms_leaves , barrier_transforms_tree = jax .tree .flatten (
1203
+ barrier .transforms )
1204
+ barrier = barrier .ref
1205
+ else :
1206
+ barrier_transforms_leaves , barrier_transforms_tree = [], None
1207
+
1190
1208
tcgen05_mma_p .bind (acc , a , b , barrier , accumulate ,
1191
1209
* a_transforms_leaves , * b_transforms_leaves ,
1210
+ * barrier_transforms_leaves ,
1192
1211
a_transforms_tree = a_transforms_tree ,
1193
1212
b_transforms_tree = b_transforms_tree ,
1213
+ barrier_transforms_tree = barrier_transforms_tree ,
1194
1214
collective_axis = collective_axis )
1195
1215
1196
1216
@tcgen05_mma_p .def_abstract_eval
1197
1217
def _tcgen05_mma_abstract_eval (acc , a , b , barrier , accumulate ,
1198
1218
* transforms_leaves ,
1199
1219
a_transforms_tree , b_transforms_tree ,
1220
+ barrier_transforms_tree ,
1200
1221
collective_axis ):
1201
- del (accumulate , transforms_leaves , a_transforms_tree , b_transforms_tree )
1222
+ del (accumulate , transforms_leaves , a_transforms_tree , b_transforms_tree ,
1223
+ barrier_transforms_tree )
1202
1224
1203
1225
if acc .memory_space != gpu_core .TMEM :
1204
1226
raise ValueError ("Accumulator must be a TMEM Ref." )
@@ -1222,6 +1244,20 @@ def _tcgen05_mma_abstract_eval(acc, a, b, barrier, accumulate,
1222
1244
1223
1245
return []
1224
1246
1247
+
1248
+ def _split_transforms (all_transforms_leaves , transforms_trees ) -> list [Any ]:
1249
+ transform_leaves = []
1250
+ for transforms_tree in transforms_trees :
1251
+ if transforms_tree is None :
1252
+ transform_leaves .append ([])
1253
+ continue
1254
+ current_leaves , all_transforms_leaves = util .split_list (
1255
+ all_transforms_leaves , [transforms_tree .num_leaves ]
1256
+ )
1257
+ transform_leaves .append (current_leaves )
1258
+ return transform_leaves
1259
+
1260
+
1225
1261
@lowering .register_lowering_rule (tcgen05_mma_p , * gpu_core .LANExWG_SEMANTICS )
1226
1262
@lowering .register_lowering_rule (tcgen05_mma_p , * gpu_core .LANExWARP_SEMANTICS )
1227
1263
def _tcgen05_mma_lowering (
@@ -1234,16 +1270,20 @@ def _tcgen05_mma_lowering(
1234
1270
* transforms_leaves ,
1235
1271
a_transforms_tree ,
1236
1272
b_transforms_tree ,
1273
+ barrier_transforms_tree ,
1237
1274
collective_axis ,
1238
1275
):
1239
1276
_ , a_aval , b_aval , * _ = ctx .avals_in
1240
1277
lhs_swizzle : int | None = None
1278
+ rhs_swizzle : int | None = None
1241
1279
lhs_transpose : bool = False
1242
- if a_transforms_tree is not None :
1243
- a_transforms_leaves , b_transforms_leaves = util .split_list (
1244
- transforms_leaves , [a_transforms_tree .num_leaves ]
1245
- )
1280
+ rhs_transpose : bool = False
1246
1281
1282
+ a_transforms_leaves , b_transforms_leaves , barrier_transforms_leaves = (
1283
+ _split_transforms (transforms_leaves ,
1284
+ [a_transforms_tree , b_transforms_tree , barrier_transforms_tree ])
1285
+ )
1286
+ if a_transforms_tree is not None :
1247
1287
a_transforms = a_transforms_tree .unflatten (a_transforms_leaves )
1248
1288
a_ref , a_transforms = lowering ._handle_transforms (
1249
1289
ctx , a_ref , a_transforms , handle_transposes = False , handle_reshapes = True
@@ -1265,36 +1305,42 @@ def _tcgen05_mma_lowering(
1265
1305
if lhs_tiling != (8 , swizzle_elems ):
1266
1306
raise ValueError ("MMA lhs tiling does not fit swizzle. "
1267
1307
f"{ lhs_tiling = } expected={ (8 , swizzle_elems )} " )
1268
- else :
1269
- b_transforms_leaves = transforms_leaves # type: ignore
1270
1308
1271
- b_transforms = b_transforms_tree .unflatten (b_transforms_leaves )
1272
- b_ref , b_transforms = lowering ._handle_transforms (
1273
- ctx , b_ref , b_transforms , handle_transposes = False , handle_reshapes = True
1274
- )
1275
- match b_transforms :
1276
- case (gpu_core .UnswizzleRef (rhs_swizzle ), gpu_core .UntileRef (rhs_tiling )):
1277
- rhs_transpose = False
1278
- case (
1279
- gpu_core .UnswizzleRef (rhs_swizzle ),
1280
- gpu_core .UntileRef (rhs_tiling ),
1281
- gpu_core .TransposeRef ((1 , 0 )),
1282
- ):
1283
- rhs_transpose = True
1284
- case _:
1285
- raise NotImplementedError (
1286
- f"Unsupported transforms: { b_transforms } ."
1287
- )
1309
+ if b_transforms_tree is not None :
1310
+ b_transforms = b_transforms_tree .unflatten (b_transforms_leaves )
1311
+ b_ref , b_transforms = lowering ._handle_transforms (
1312
+ ctx , b_ref , b_transforms , handle_transposes = False , handle_reshapes = True
1313
+ )
1314
+ match b_transforms :
1315
+ case (gpu_core .UnswizzleRef (rhs_swizzle ), gpu_core .UntileRef (rhs_tiling )):
1316
+ rhs_transpose = False
1317
+ case (
1318
+ gpu_core .UnswizzleRef (rhs_swizzle ),
1319
+ gpu_core .UntileRef (rhs_tiling ),
1320
+ gpu_core .TransposeRef ((1 , 0 )),
1321
+ ):
1322
+ rhs_transpose = True
1323
+ case _:
1324
+ raise NotImplementedError (
1325
+ f"Unsupported transforms: { b_transforms } ."
1326
+ )
1327
+ swizzle_elems = rhs_swizzle // b_aval .dtype .itemsize
1328
+ if rhs_tiling != (8 , swizzle_elems ):
1329
+ raise ValueError ("MMA rhs tiling does not fit swizzle"
1330
+ f" { rhs_tiling = } expected={ (8 , swizzle_elems )} " )
1331
+
1332
+ if barrier_transforms_tree is not None :
1333
+ barrier_transforms = barrier_transforms_tree .unflatten (
1334
+ barrier_transforms_leaves )
1335
+ indexer = _extract_barrier_indexer (barrier_transforms )
1336
+ if indexer is not None :
1337
+ barrier_ref = barrier_ref .__getitem__ (* map (lowering ._as_index , indexer .indices ))
1288
1338
1289
- swizzle_elems = rhs_swizzle // b_aval .dtype .itemsize
1290
1339
if lhs_swizzle is None :
1291
1340
lhs_swizzle = rhs_swizzle
1292
1341
elif rhs_swizzle != lhs_swizzle :
1293
1342
raise ValueError ("MMA rhs swizzle must match lhs swizzle."
1294
1343
f" { lhs_swizzle = } { rhs_swizzle = } " )
1295
- if rhs_tiling != (8 , swizzle_elems ):
1296
- raise ValueError ("MMA rhs tiling does not fit swizzle"
1297
- f" { rhs_tiling = } expected={ (8 , swizzle_elems )} " )
1298
1344
if lhs_transpose :
1299
1345
if isinstance (a_ref , tcgen05 .TMEMRef ):
1300
1346
raise ValueError ("TMEM transpose not allowed." )
@@ -1303,6 +1349,9 @@ def _tcgen05_mma_lowering(
1303
1349
b_ref = mgpu .memref_transpose (b_ref , (1 , 0 , 3 , 2 ))
1304
1350
if isinstance (accumulate , bool ):
1305
1351
accumulate = mgpu .c (accumulate , ir .IntegerType .get_signless (1 ))
1352
+ elif isinstance (accumulate , mgpu .FragmentedArray ):
1353
+ accumulate = accumulate .registers .item ()
1354
+ assert isinstance (accumulate , ir .Value )
1306
1355
1307
1356
predicate = ctx .module_ctx .single_lane_predicate
1308
1357
collective = False
0 commit comments