Skip to content

Commit d4ab826

Browse files
justinjfuGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add support for copy_gmem_to_smem in Warp semantics.
PiperOrigin-RevId: 762475094
1 parent 9928409 commit d4ab826

File tree

3 files changed

+56
-12
lines changed

3 files changed

+56
-12
lines changed

jax/_src/pallas/mosaic_gpu/primitives.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import jax.numpy as jnp
5050

5151

52+
WARP_SIZE = 32
5253
WARPGROUP_SIZE = 128
5354

5455

@@ -464,7 +465,7 @@ def _copy_gmem_to_smem_lowering(
464465
dst_transforms_treedef,
465466
barrier_transforms_treedef,
466467
collective_axes,
467-
warpgroup_sync: bool = True,
468+
for_warpgroup: bool = True,
468469
):
469470
flat_src_transforms, flat_dst_transforms, flat_barrier_transforms = (
470471
util.split_list(
@@ -505,15 +506,23 @@ def _copy_gmem_to_smem_lowering(
505506
if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane:
506507
if bytes % WARPGROUP_SIZE:
507508
raise NotImplementedError("Only aligned copies are supported")
508-
# We arrive uniformly from each thread in the WG, so we need to divide the
509-
# number of bytes by the number of threads in the WG.
510-
# TODO: apaszke - Relax this. We can just select the WG leader and have it
511-
# arrive with the whole transfer size, while everyone else arrives with 0.
512-
# But we should continue using this scheme as it's likely to be faster.
513-
bytes //= WARPGROUP_SIZE
514-
if warpgroup_sync:
509+
if for_warpgroup:
510+
# We arrive uniformly from each thread in the WG, so we need to divide the
511+
# number of bytes by the number of threads in the WG.
512+
# TODO: apaszke - Relax this. We can just select the WG leader and have it
513+
# arrive with the whole transfer size, while everyone else arrives with 0.
514+
# But we should continue using this scheme as it's likely to be faster.
515+
bytes //= WARPGROUP_SIZE
515516
mgpu.warpgroup_barrier() # Make sure all reads have completed.
516-
barrier.arrive_expect_tx(bytes)
517+
barrier.arrive_expect_tx(bytes)
518+
else:
519+
# In Warp-level lowering, we arrive on each CUDA thread in a warp, but
520+
# the barrier still expects a full 128 arrivals so we arrive 4 times
521+
# on each CUDA thread instead.
522+
bytes //= WARP_SIZE
523+
barrier.arrive(arrival_count=3, can_complete=False)
524+
barrier.arrive_expect_tx(bytes)
525+
517526
ctx.launch_ctx.async_copy(
518527
src_ref=src,
519528
dst_ref=dst,
@@ -549,7 +558,7 @@ def _copy_gmem_to_smem_lowering(
549558
copy_gmem_to_smem_p,
550559
mgpu.LoweringSemantics.Lane,
551560
primitive_semantics=gpu_core.PrimitiveSemantics.Warp,
552-
)(functools.partial(_copy_gmem_to_smem_lowering, warpgroup_sync=False))
561+
)(functools.partial(_copy_gmem_to_smem_lowering, for_warpgroup=False))
553562

554563

555564
def copy_gmem_to_smem(
@@ -713,6 +722,8 @@ def _barrier_wait_pp_eqn(
713722

714723

715724
@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Lane)
725+
@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Lane,
726+
gpu_core.PrimitiveSemantics.Warp)
716727
@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Warpgroup)
717728
def _barrier_wait_lowering(
718729
ctx: lowering.LoweringRuleContext,

jax/experimental/mosaic/gpu/utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -816,9 +816,16 @@ def update_parities(self, parities: ir.Value) -> tuple[ir.Value, ir.Value]:
816816
)
817817
return parity, arith.xori(parities, bitmask)
818818

819-
def arrive(self):
819+
def arrive(self, arrival_count: int = 1, can_complete: bool = True):
820820
i64 = ir.IntegerType.get_signless(64)
821-
nvvm.mbarrier_arrive_shared(i64, self.get_ptr())
821+
if can_complete:
822+
if arrival_count > 1:
823+
count = c(arrival_count - 1, ir.IntegerType.get_signless(32))
824+
nvvm.mbarrier_arrive_nocomplete_shared(i64, self.get_ptr(), count)
825+
nvvm.mbarrier_arrive_shared(i64, self.get_ptr())
826+
else:
827+
count = c(arrival_count, ir.IntegerType.get_signless(32))
828+
nvvm.mbarrier_arrive_nocomplete_shared(i64, self.get_ptr(), count)
822829

823830
def arrive_expect_tx(
824831
self, bytes: int | ir.Value, predicate: ir.Value | None = None

tests/pallas/mosaic_gpu_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1878,6 +1878,32 @@ def _():
18781878
},
18791879
)
18801880

1881+
def test_copy_gmem_to_smem_from_different_warps(self):
1882+
# In this test, we issue a copy from from warp 0 and await it in warp 1.
1883+
warp_mesh = plgpu.WarpMesh(axis_name="warp")
1884+
@functools.partial(plgpu.kernel,
1885+
out_shape=jax.ShapeDtypeStruct((32, 32), jnp.float32))
1886+
def kernel(x_ref, y_ref):
1887+
def scope(smem_ref, tma_barrier):
1888+
@pl.core_map(warp_mesh)
1889+
def _():
1890+
warp_id = lax.axis_index("warp")
1891+
@pl.when(warp_id == 0)
1892+
def _():
1893+
plgpu.copy_gmem_to_smem(x_ref.at[32:64], smem_ref, tma_barrier)
1894+
1895+
@pl.when(warp_id == 1)
1896+
def _():
1897+
plgpu.barrier_wait(tma_barrier)
1898+
plgpu.copy_smem_to_gmem(smem_ref, y_ref)
1899+
plgpu.wait_smem_to_gmem(0)
1900+
pl.run_scoped(scope,
1901+
smem_ref=plgpu.SMEM((32, 32), jnp.float32),
1902+
tma_barrier=plgpu.Barrier(num_arrivals=1))
1903+
x = jax.random.uniform(jax.random.key(42), (64, 32), jnp.float32)
1904+
result = kernel(x)
1905+
np.testing.assert_array_equal(result, x[32:64])
1906+
18811907

18821908
class PallasCallWGTest(
18831909
PallasCallTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup

0 commit comments

Comments
 (0)