@@ -554,6 +554,7 @@ def _copy_gmem_to_smem_lowering(
554
554
)
555
555
return ()
556
556
557
+
557
558
lowering .register_lowering_rule (
558
559
copy_gmem_to_smem_p ,
559
560
mgpu .LoweringSemantics .Lane ,
@@ -722,9 +723,14 @@ def _barrier_wait_pp_eqn(
722
723
723
724
724
725
@lowering .register_lowering_rule (barrier_wait_p , mgpu .LoweringSemantics .Lane )
725
- @lowering .register_lowering_rule (barrier_wait_p , mgpu .LoweringSemantics .Lane ,
726
- gpu_core .PrimitiveSemantics .Warp )
727
- @lowering .register_lowering_rule (barrier_wait_p , mgpu .LoweringSemantics .Warpgroup )
726
+ @lowering .register_lowering_rule (
727
+ barrier_wait_p ,
728
+ mgpu .LoweringSemantics .Lane ,
729
+ gpu_core .PrimitiveSemantics .Warp ,
730
+ )
731
+ @lowering .register_lowering_rule (
732
+ barrier_wait_p , mgpu .LoweringSemantics .Warpgroup
733
+ )
728
734
def _barrier_wait_lowering (
729
735
ctx : lowering .LoweringRuleContext ,
730
736
barrier ,
@@ -1198,18 +1204,31 @@ def tcgen05_mma(acc: _Ref,
1198
1204
else :
1199
1205
b_transforms_leaves , b_transforms_tree = [], None
1200
1206
1207
+ if isinstance (barrier , pallas_core .TransformedRef ):
1208
+ barrier_transforms_leaves , barrier_transforms_tree = jax .tree .flatten (
1209
+ barrier .transforms
1210
+ )
1211
+ barrier = barrier .ref
1212
+ else :
1213
+ barrier_transforms_leaves , barrier_transforms_tree = [], None
1214
+
1201
1215
tcgen05_mma_p .bind (acc , a , b , barrier , accumulate ,
1202
1216
* a_transforms_leaves , * b_transforms_leaves ,
1217
+ * barrier_transforms_leaves ,
1203
1218
a_transforms_tree = a_transforms_tree ,
1204
1219
b_transforms_tree = b_transforms_tree ,
1220
+ barrier_transforms_tree = barrier_transforms_tree ,
1205
1221
collective_axis = collective_axis )
1206
1222
1223
+
1207
1224
@tcgen05_mma_p .def_abstract_eval
1208
1225
def _tcgen05_mma_abstract_eval (acc , a , b , barrier , accumulate ,
1209
1226
* transforms_leaves ,
1210
1227
a_transforms_tree , b_transforms_tree ,
1228
+ barrier_transforms_tree ,
1211
1229
collective_axis ):
1212
- del (accumulate , transforms_leaves , a_transforms_tree , b_transforms_tree )
1230
+ del (accumulate , transforms_leaves , a_transforms_tree , b_transforms_tree ,
1231
+ barrier_transforms_tree )
1213
1232
1214
1233
if acc .memory_space != gpu_core .TMEM :
1215
1234
raise ValueError ("Accumulator must be a TMEM Ref." )
@@ -1233,6 +1252,7 @@ def _tcgen05_mma_abstract_eval(acc, a, b, barrier, accumulate,
1233
1252
1234
1253
return []
1235
1254
1255
+
1236
1256
@lowering .register_lowering_rule (tcgen05_mma_p , * gpu_core .LANExWG_SEMANTICS )
1237
1257
@lowering .register_lowering_rule (tcgen05_mma_p , * gpu_core .LANExWARP_SEMANTICS )
1238
1258
def _tcgen05_mma_lowering (
@@ -1245,16 +1265,26 @@ def _tcgen05_mma_lowering(
1245
1265
* transforms_leaves ,
1246
1266
a_transforms_tree ,
1247
1267
b_transforms_tree ,
1268
+ barrier_transforms_tree ,
1248
1269
collective_axis ,
1249
1270
):
1250
1271
_ , a_aval , b_aval , * _ = ctx .avals_in
1251
1272
lhs_swizzle : int | None = None
1252
1273
lhs_transpose : bool = False
1253
- if a_transforms_tree is not None :
1254
- a_transforms_leaves , b_transforms_leaves = util .split_list (
1255
- transforms_leaves , [a_transforms_tree .num_leaves ]
1256
- )
1257
1274
1275
+ transforms_trees = (
1276
+ a_transforms_tree ,
1277
+ b_transforms_tree ,
1278
+ barrier_transforms_tree ,
1279
+ )
1280
+ (a_transforms_leaves , b_transforms_leaves , barrier_transforms_leaves , _ ) = (
1281
+ util .split_list (
1282
+ transforms_leaves ,
1283
+ [getattr (tree , "num_leaves" , 0 ) for tree in transforms_trees ],
1284
+ )
1285
+ )
1286
+
1287
+ if a_transforms_tree is not None :
1258
1288
a_transforms = a_transforms_tree .unflatten (a_transforms_leaves )
1259
1289
a_ref , a_transforms = lowering ._handle_transforms (
1260
1290
ctx , a_ref , a_transforms , handle_transposes = False , handle_reshapes = True
@@ -1276,9 +1306,8 @@ def _tcgen05_mma_lowering(
1276
1306
if lhs_tiling != (8 , swizzle_elems ):
1277
1307
raise ValueError ("MMA lhs tiling does not fit swizzle. "
1278
1308
f"{ lhs_tiling = } expected={ (8 , swizzle_elems )} " )
1279
- else :
1280
- b_transforms_leaves = transforms_leaves # type: ignore
1281
1309
1310
+ assert b_transforms_tree is not None
1282
1311
b_transforms = b_transforms_tree .unflatten (b_transforms_leaves )
1283
1312
b_ref , b_transforms = lowering ._handle_transforms (
1284
1313
ctx , b_ref , b_transforms , handle_transposes = False , handle_reshapes = True
@@ -1296,16 +1325,28 @@ def _tcgen05_mma_lowering(
1296
1325
raise NotImplementedError (
1297
1326
f"Unsupported transforms: { b_transforms } ."
1298
1327
)
1299
-
1300
1328
swizzle_elems = rhs_swizzle // b_aval .dtype .itemsize
1329
+ if rhs_tiling != (8 , swizzle_elems ):
1330
+ raise ValueError (
1331
+ "MMA rhs tiling does not fit swizzle"
1332
+ f" { rhs_tiling = } expected={ (8 , swizzle_elems )} "
1333
+ )
1334
+
1335
+ if barrier_transforms_tree is not None :
1336
+ barrier_transforms = barrier_transforms_tree .unflatten (
1337
+ barrier_transforms_leaves
1338
+ )
1339
+ indexer = _extract_barrier_indexer (barrier_transforms )
1340
+ if indexer is not None :
1341
+ barrier_ref = barrier_ref .__getitem__ (
1342
+ * map (lowering ._as_index , indexer .indices )
1343
+ )
1344
+
1301
1345
if lhs_swizzle is None :
1302
1346
lhs_swizzle = rhs_swizzle
1303
1347
elif rhs_swizzle != lhs_swizzle :
1304
1348
raise ValueError ("MMA rhs swizzle must match lhs swizzle."
1305
1349
f" { lhs_swizzle = } { rhs_swizzle = } " )
1306
- if rhs_tiling != (8 , swizzle_elems ):
1307
- raise ValueError ("MMA rhs tiling does not fit swizzle"
1308
- f" { rhs_tiling = } expected={ (8 , swizzle_elems )} " )
1309
1350
if lhs_transpose :
1310
1351
if isinstance (a_ref , tcgen05 .TMEMRef ):
1311
1352
raise ValueError ("TMEM transpose not allowed." )
@@ -1314,6 +1355,9 @@ def _tcgen05_mma_lowering(
1314
1355
b_ref = mgpu .memref_transpose (b_ref , (1 , 0 , 3 , 2 ))
1315
1356
if isinstance (accumulate , bool ):
1316
1357
accumulate = mgpu .c (accumulate , ir .IntegerType .get_signless (1 ))
1358
+ elif isinstance (accumulate , mgpu .FragmentedArray ):
1359
+ accumulate = accumulate .registers .item ()
1360
+ assert isinstance (accumulate , ir .Value )
1317
1361
1318
1362
predicate = ctx .module_ctx .single_lane_predicate
1319
1363
collective = False
@@ -1341,8 +1385,8 @@ def _tcgen05_mma_lowering(
1341
1385
acc ,
1342
1386
a_ref ,
1343
1387
b_ref ,
1344
- a_swizzle = lhs_swizzle ,
1345
- b_swizzle = rhs_swizzle ,
1388
+ a_swizzle = int ( lhs_swizzle ) ,
1389
+ b_swizzle = int ( rhs_swizzle ) ,
1346
1390
accumulate = accumulate ,
1347
1391
collective = collective ,
1348
1392
)
0 commit comments