Skip to content

Commit 5a448b8

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add support for async copies to peer devices
PiperOrigin-RevId: 762370447
1 parent dc0cdf7 commit 5a448b8

File tree

4 files changed

+177
-1
lines changed

4 files changed

+177
-1
lines changed

jax/experimental/mosaic/gpu/launch_context.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ def _get_tma_desc(
400400
self,
401401
gmem_ref,
402402
gmem_transform: tuple[MemRefTransform, ...],
403+
gmem_peer_id: int | ir.Value | None,
403404
transformed_slice_shape: tuple[int, ...],
404405
swizzle: int | None,
405406
reduction_op: Literal[
@@ -408,6 +409,7 @@ def _get_tma_desc(
408409
):
409410
tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform)
410411
if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None:
412+
i32 = ir.IntegerType.get_signless(32)
411413
i64 = ir.IntegerType.get_signless(64)
412414
ptr_ty = ir.Type.parse("!llvm.ptr")
413415
def init_tma_desc(host_ptr):
@@ -432,6 +434,25 @@ def init_tma_desc(host_ptr):
432434
base_ptr = llvm.getelementptr(
433435
ptr_ty, alloc_ptr, [as_i64(offset)], [llvm_dyn], ref_ty.element_type, llvm.GEPNoWrapFlags.none,
434436
)
437+
if gmem_peer_id is not None:
438+
if not isinstance(gmem_peer_id, ir.Value):
439+
peer_id = c(gmem_peer_id, i32)
440+
else:
441+
try:
442+
# We try to reproduce the gmem_peer_id computation on the host.
443+
peer_id = _recompute_peer_id(gmem_peer_id)
444+
except ReplicationError as e:
445+
raise ValueError(
446+
"Failed to recompute the async_copy peer id on the host"
447+
) from e
448+
self._ensure_nvshmem_decls()
449+
base_ptr = llvm.call(
450+
base_ptr.type,
451+
[base_ptr, peer_id],
452+
[],
453+
[],
454+
callee="nvshmem_ptr",
455+
)
435456
rank = ref_ty.rank
436457
assert rank * 2 == len(sizes_and_strides)
437458
swizzle_arg = (
@@ -507,6 +528,7 @@ def async_copy(
507528
dst_ref,
508529
gmem_slice: Any = (),
509530
gmem_transform: MemRefTransform | tuple[MemRefTransform, ...] = (),
531+
gmem_peer_id: int | ir.Value | None = None,
510532
barrier: utils.BarrierRef | None = None,
511533
swizzle: int | None = None,
512534
arrive: bool | None = None,
@@ -750,7 +772,8 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int):
750772
multicast_mask = None
751773

752774
tma_desc = self._get_tma_desc(
753-
gmem_ref, gmem_transform, tuple(slice_shape), swizzle, reduction_op,
775+
gmem_ref, gmem_transform, gmem_peer_id,
776+
tuple(slice_shape), swizzle, reduction_op,
754777
)
755778

756779
# We constuct TMA descriptors in column-major order.
@@ -893,3 +916,33 @@ def device_id(self) -> ir.Value:
893916
self._ensure_nvshmem_decls()
894917
i32 = ir.IntegerType.get_signless(32)
895918
return llvm.call(i32, [], [], [], callee="nvshmem_my_pe")
919+
920+
921+
class ReplicationError(Exception):
922+
pass
923+
924+
def _recompute_peer_id(peer_id: ir.Value, fuel=8) -> ir.Value:
925+
if fuel == 0:
926+
raise ReplicationError(
927+
"gmem_peer_id computation is too complicated to recompute on the host"
928+
)
929+
if isinstance(peer_id, ir.BlockArgument):
930+
raise ReplicationError("Can't recompute a value that's a block argument")
931+
op = peer_id.owner.opview
932+
# We accept all arith ops
933+
if op.OPERATION_NAME.startswith("arith."):
934+
new_operands = [_recompute_peer_id(x, fuel - 1) for x in op.operands]
935+
result_types = [r.type for r in op.results]
936+
new_attributes = {na.name: na.attr for na in op.attributes}
937+
new_op = ir.Operation.create(
938+
op.OPERATION_NAME, result_types, new_operands, new_attributes
939+
)
940+
return new_op.results if len(new_op.results) > 1 else new_op.result
941+
# nvshmem_my_pe queries the device id of the current process and works on both
942+
# the host and the device.
943+
if isinstance(op, llvm.CallOp) and op.callee.value == "nvshmem_my_pe":
944+
i32 = ir.IntegerType.get_signless(32)
945+
return llvm.call(i32, [], [], [], callee="nvshmem_my_pe")
946+
raise ReplicationError(
947+
f"Unrecognized op can't be recomputed on the host: {op}"
948+
)

jaxlib/mosaic/gpu/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ cc_library(
122122
# Linker may prune these symbols if they are not explicitly exported.
123123
linkopts = [
124124
"-Wl,--export-dynamic-symbol='mosaic_gpu_*'",
125+
"-Wl,--export-dynamic-symbol='nvshmem_my_pe'",
126+
"-Wl,--export-dynamic-symbol='nvshmem_ptr'",
125127
"-Wl,--export-dynamic-symbol='nvshmemx_barrier_all_on_stream'",
126128
"-Wl,--export-dynamic-symbol='nvshmemx_cumodule_init'",
127129
"-Wl,--export-dynamic-symbol='nvshmemx_init_status'",

tests/mosaic/BUILD

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,27 @@ jax_multiplatform_test(
6363
]),
6464
)
6565

66+
jax_multiplatform_test(
67+
name = "gpu_test_distributed",
68+
srcs = ["gpu_test_distributed.py"],
69+
args = [
70+
"--num_processes=2",
71+
"--gpus_per_process=1",
72+
],
73+
enable_backends = [],
74+
enable_configs = ["gpu_h100x2"],
75+
env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0 --xla_gpu_experimental_enable_nvshmem=true"},
76+
tags = ["multiaccelerator"],
77+
deps = [
78+
"//jax:experimental",
79+
"//jax:mosaic_gpu",
80+
"//jax:test_multiprocess",
81+
] + py_deps([
82+
"absl/testing",
83+
"numpy",
84+
]),
85+
)
86+
6687
jax_py_test(
6788
name = "gpu_dialect_test",
6889
srcs = ["gpu_dialect_test.py"],

tests/mosaic/gpu_test_distributed.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2025 The JAX Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
from absl.testing import parameterized
17+
import jax
18+
from jax._src import config
19+
from jax._src import test_util as jtu
20+
from jax._src import test_multiprocess as jt_multiprocess
21+
from jax._src.interpreters import mlir
22+
from jax._src.lib.mlir import ir
23+
from jax._src.lib.mlir.dialects import arith
24+
from jax.experimental.mosaic.gpu import dialect as mgpu_dialect # pylint: disable=g-importing-member
25+
from jax.experimental import shard
26+
from jax.experimental import multihost_utils
27+
import jax.numpy as jnp
28+
import numpy as np
29+
try:
30+
import jax._src.lib.mosaic_gpu # noqa: F401
31+
HAS_MOSAIC_GPU = True
32+
except ImportError:
33+
HAS_MOSAIC_GPU = False
34+
else:
35+
import jax.experimental.mosaic.gpu as mgpu
36+
37+
38+
# ruff: noqa: F405
39+
# pylint: disable=g-complex-comprehension
40+
P = jax.sharding.PartitionSpec
41+
42+
43+
class TestCase(parameterized.TestCase):
44+
45+
def setUp(self):
46+
if not HAS_MOSAIC_GPU:
47+
self.skipTest("jaxlib built without Mosaic GPU")
48+
if (not jtu.test_device_matches(["cuda"]) or
49+
not jtu.is_cuda_compute_capability_at_least("9.0")):
50+
self.skipTest("Only works on GPU with capability >= sm90")
51+
if not mgpu.supports_cross_device_collectives():
52+
self.skipTest("NVSHMEM library unavailable.")
53+
if jax.process_count() == 1:
54+
self.skipTest("Test requires multiple processes.")
55+
if jax.device_count() != jax.process_count():
56+
self.skipTest("Need 1 device per process")
57+
super().setUp()
58+
self.prng = np.random.default_rng(1234)
59+
self.context = mlir.make_ir_context()
60+
if mgpu_dialect is not None:
61+
mgpu_dialect.register_dialect(self.context)
62+
self.enter_context(config.traceback_filtering("off"))
63+
self.enter_context(self.context)
64+
self.enter_context(ir.Location.unknown())
65+
66+
67+
class ProfilerTest(TestCase):
68+
69+
def test_remote_async_copy(self):
70+
i32 = ir.IntegerType.get_signless(32)
71+
def kernel(ctx, src, dst, scratch):
72+
tmp, barrier = scratch
73+
other_device = arith.subi(arith.constant(i32, 1), ctx.device_id())
74+
ctx.async_copy(src_ref=src, dst_ref=tmp, barrier=barrier)
75+
barrier.wait()
76+
ctx.async_copy(src_ref=tmp, dst_ref=dst, gmem_peer_id=other_device)
77+
ctx.await_async_copy(0)
78+
mesh = jax.make_mesh(
79+
(2,), ("x",), axis_types=(jax.sharding.AxisType.Explicit,)
80+
)
81+
with jax.sharding.use_mesh(mesh):
82+
x_np = np.arange(64 * 64, dtype=jnp.float32).reshape(64, 64)
83+
x = shard.reshard(x_np, P("x"))
84+
y = jax.jit(
85+
jax.shard_map(
86+
lambda x: mgpu.as_gpu_kernel(
87+
kernel, (1, 1, 1), (128, 1, 1), x, x, (x, mgpu.TMABarrier())
88+
)(x),
89+
out_specs=P("x"),
90+
check_vma=False,
91+
)
92+
)(x)
93+
y_np = multihost_utils.process_allgather(y, tiled=True)
94+
np.testing.assert_array_equal(
95+
y_np, np.concatenate(np.split(x_np, 2)[::-1], axis=0)
96+
)
97+
98+
99+
if __name__ == "__main__":
100+
jt_multiprocess.main()

0 commit comments

Comments
 (0)