diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 379a972be9b0..1ec22bff3f6d 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -554,6 +554,7 @@ def _copy_gmem_to_smem_lowering( ) return () + lowering.register_lowering_rule( copy_gmem_to_smem_p, mgpu.LoweringSemantics.Lane, @@ -722,9 +723,14 @@ def _barrier_wait_pp_eqn( @lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Lane) -@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Lane, - gpu_core.PrimitiveSemantics.Warp) -@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Warpgroup) +@lowering.register_lowering_rule( + barrier_wait_p, + mgpu.LoweringSemantics.Lane, + gpu_core.PrimitiveSemantics.Warp, +) +@lowering.register_lowering_rule( + barrier_wait_p, mgpu.LoweringSemantics.Warpgroup +) def _barrier_wait_lowering( ctx: lowering.LoweringRuleContext, barrier, @@ -1198,18 +1204,31 @@ def tcgen05_mma(acc: _Ref, else: b_transforms_leaves, b_transforms_tree = [], None + if isinstance(barrier, pallas_core.TransformedRef): + barrier_transforms_leaves, barrier_transforms_tree = jax.tree.flatten( + barrier.transforms + ) + barrier = barrier.ref + else: + barrier_transforms_leaves, barrier_transforms_tree = [], None + tcgen05_mma_p.bind(acc, a, b, barrier, accumulate, *a_transforms_leaves, *b_transforms_leaves, + *barrier_transforms_leaves, a_transforms_tree=a_transforms_tree, b_transforms_tree=b_transforms_tree, + barrier_transforms_tree=barrier_transforms_tree, collective_axis=collective_axis) + @tcgen05_mma_p.def_abstract_eval def _tcgen05_mma_abstract_eval(acc, a, b, barrier, accumulate, *transforms_leaves, a_transforms_tree, b_transforms_tree, + barrier_transforms_tree, collective_axis): - del (accumulate, transforms_leaves, a_transforms_tree, b_transforms_tree) + del (accumulate, transforms_leaves, a_transforms_tree, b_transforms_tree, + barrier_transforms_tree) if acc.memory_space != gpu_core.TMEM: raise ValueError("Accumulator must be a TMEM Ref.") @@ -1233,6 +1252,7 @@ def _tcgen05_mma_abstract_eval(acc, a, b, barrier, accumulate, return [] + @lowering.register_lowering_rule(tcgen05_mma_p, *gpu_core.LANExWG_SEMANTICS) @lowering.register_lowering_rule(tcgen05_mma_p, *gpu_core.LANExWARP_SEMANTICS) def _tcgen05_mma_lowering( @@ -1245,16 +1265,26 @@ def _tcgen05_mma_lowering( *transforms_leaves, a_transforms_tree, b_transforms_tree, + barrier_transforms_tree, collective_axis, ): _, a_aval, b_aval, *_ = ctx.avals_in lhs_swizzle: int | None = None lhs_transpose: bool = False - if a_transforms_tree is not None: - a_transforms_leaves, b_transforms_leaves = util.split_list( - transforms_leaves, [a_transforms_tree.num_leaves] - ) + transforms_trees = ( + a_transforms_tree, + b_transforms_tree, + barrier_transforms_tree, + ) + (a_transforms_leaves, b_transforms_leaves, barrier_transforms_leaves, _) = ( + util.split_list( + transforms_leaves, + [getattr(tree, "num_leaves", 0) for tree in transforms_trees], + ) + ) + + if a_transforms_tree is not None: a_transforms = a_transforms_tree.unflatten(a_transforms_leaves) a_ref, a_transforms = lowering._handle_transforms( ctx, a_ref, a_transforms, handle_transposes=False, handle_reshapes=True @@ -1276,9 +1306,8 @@ def _tcgen05_mma_lowering( if lhs_tiling != (8, swizzle_elems): raise ValueError("MMA lhs tiling does not fit swizzle. " f"{lhs_tiling=} expected={(8, swizzle_elems)}") - else: - b_transforms_leaves = transforms_leaves # type: ignore + assert b_transforms_tree is not None b_transforms = b_transforms_tree.unflatten(b_transforms_leaves) b_ref, b_transforms = lowering._handle_transforms( ctx, b_ref, b_transforms, handle_transposes=False, handle_reshapes=True @@ -1296,16 +1325,28 @@ def _tcgen05_mma_lowering( raise NotImplementedError( f"Unsupported transforms: {b_transforms}." ) - swizzle_elems = rhs_swizzle // b_aval.dtype.itemsize + if rhs_tiling != (8, swizzle_elems): + raise ValueError( + "MMA rhs tiling does not fit swizzle" + f" {rhs_tiling=} expected={(8, swizzle_elems)}" + ) + + if barrier_transforms_tree is not None: + barrier_transforms = barrier_transforms_tree.unflatten( + barrier_transforms_leaves + ) + indexer = _extract_barrier_indexer(barrier_transforms) + if indexer is not None: + barrier_ref = barrier_ref.__getitem__( + *map(lowering._as_index, indexer.indices) + ) + if lhs_swizzle is None: lhs_swizzle = rhs_swizzle elif rhs_swizzle != lhs_swizzle: raise ValueError("MMA rhs swizzle must match lhs swizzle." f" {lhs_swizzle=} {rhs_swizzle=}") - if rhs_tiling != (8, swizzle_elems): - raise ValueError("MMA rhs tiling does not fit swizzle" - f" {rhs_tiling=} expected={(8, swizzle_elems)}") if lhs_transpose: if isinstance(a_ref, tcgen05.TMEMRef): raise ValueError("TMEM transpose not allowed.") @@ -1314,6 +1355,9 @@ def _tcgen05_mma_lowering( b_ref = mgpu.memref_transpose(b_ref, (1, 0, 3, 2)) if isinstance(accumulate, bool): accumulate = mgpu.c(accumulate, ir.IntegerType.get_signless(1)) + elif isinstance(accumulate, mgpu.FragmentedArray): + accumulate = accumulate.registers.item() + assert isinstance(accumulate, ir.Value) predicate = ctx.module_ctx.single_lane_predicate collective = False @@ -1341,8 +1385,8 @@ def _tcgen05_mma_lowering( acc, a_ref, b_ref, - a_swizzle=lhs_swizzle, - b_swizzle=rhs_swizzle, + a_swizzle=int(lhs_swizzle), + b_swizzle=int(rhs_swizzle), accumulate=accumulate, collective=collective, ) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 029445143a0f..3c0b463ba1c7 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2496,6 +2496,52 @@ def _scoped(a_smem, b_smem, expected = x @ y np.testing.assert_allclose(result, expected, rtol=1e-3) + @parameterized.parameters((0,), (1,)) + def test_mma_barrier_indexing( + self, barrier_index, shape=(128, 128), swizzle=128, dtype=jnp.float16 + ): + self.skip_if_wg_semantics() + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(swizzle), + ) + + def kernel(a_smem, b_smem, out_ref, acc_tmem, scratch_smem, barrier_ref): + plgpu.tcgen05_mma( + acc_tmem, + a_smem, + b_smem, + barrier_ref.at[barrier_index], + accumulate=False, + ) + plgpu.barrier_wait(barrier_ref.at[barrier_index]) + scratch_smem[...] = acc_tmem[...].astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(scratch_smem, out_ref) + plgpu.wait_smem_to_gmem(0) + + scratch_shapes = [ + plgpu.TMEM(shape, jnp.float32, packed=False), + plgpu.SMEM(shape, dtype, transforms=transforms), + plgpu.Barrier(num_arrivals=1, num_barriers=2, for_tensor_core=True), + ] + f = self.pallas_call( + kernel, + in_specs=( + plgpu.BlockSpec(transforms=transforms, memory_space=plgpu.SMEM), + plgpu.BlockSpec(transforms=transforms, memory_space=plgpu.SMEM), + ), + out_specs=plgpu.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(shape, dtype), + scratch_shapes=scratch_shapes, + ) + x = jax.random.uniform(jax.random.key(0), shape=shape, dtype=dtype) + y = jax.random.uniform(jax.random.key(1), shape=shape, dtype=dtype) + result = f(x, y) + expected = x @ y + np.testing.assert_allclose(result, expected, rtol=1e-3) + class PallasCallSm100AWGTest( PallasCallSm100ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup