Skip to content

Commit eb54e7f

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add support for async copies to peer devices
PiperOrigin-RevId: 761977946
1 parent 2302a2e commit eb54e7f

File tree

7 files changed

+200
-18
lines changed

7 files changed

+200
-18
lines changed

jax/experimental/mosaic/gpu/core.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -677,14 +677,25 @@ def as_gpu_kernel(
677677
if launch_ctx.is_device_collective and not supports_cross_device_collectives():
678678
raise RuntimeError("Kernel is a cross-device collective but no support is available.")
679679

680-
expected_arg_treedef = jax.tree.structure(in_shape)
680+
expected_arg_tys, expected_arg_treedef = jax.tree.flatten(in_shape)
681681
def _check_args(*args):
682682
arg_treedef = jax.tree.structure(args)
683683
if arg_treedef != expected_arg_treedef:
684684
raise ValueError(
685685
f"Invalid argument structure: expected {expected_arg_treedef}, got"
686686
f" {arg_treedef}, ({args=})"
687687
)
688+
for arg, expected_ty in zip(args, expected_arg_tys):
689+
if arg.shape != expected_ty.shape:
690+
raise ValueError(
691+
f"Argument shape mismatch: expected {expected_ty.shape}, got"
692+
f" {arg.shape}"
693+
)
694+
if arg.dtype != expected_ty.dtype:
695+
raise ValueError(
696+
f"Argument dtype mismatch: expected {expected_ty.dtype}, got"
697+
f" {arg.dtype}"
698+
)
688699

689700
def bind(*args) -> Any:
690701
return mosaic_gpu_p.bind(*args, module=module, out_types=out_shape)

jax/experimental/mosaic/gpu/launch_context.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ def _get_tma_desc(
400400
self,
401401
gmem_ref,
402402
gmem_transform: tuple[MemRefTransform, ...],
403+
gmem_peer_id: int | ir.Value | None,
403404
transformed_slice_shape: tuple[int, ...],
404405
swizzle: int | None,
405406
reduction_op: Literal[
@@ -408,6 +409,7 @@ def _get_tma_desc(
408409
):
409410
tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform)
410411
if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None:
412+
i32 = ir.IntegerType.get_signless(32)
411413
i64 = ir.IntegerType.get_signless(64)
412414
ptr_ty = ir.Type.parse("!llvm.ptr")
413415
def init_tma_desc(host_ptr):
@@ -432,6 +434,24 @@ def init_tma_desc(host_ptr):
432434
base_ptr = llvm.getelementptr(
433435
ptr_ty, alloc_ptr, [as_i64(offset)], [llvm_dyn], ref_ty.element_type, llvm.GEPNoWrapFlags.none,
434436
)
437+
if gmem_peer_id is not None:
438+
if not isinstance(gmem_peer_id, ir.Value):
439+
peer_id = c(gmem_peer_id, i32)
440+
else:
441+
try:
442+
peer_id = _replicate_peer_id_computation(gmem_peer_id)
443+
except ReplicationError as e:
444+
raise ValueError(
445+
"Failed to reproduce the gmem_peer_id computation on the host"
446+
) from e
447+
self._ensure_nvshmem_decls()
448+
base_ptr = llvm.call(
449+
base_ptr.type,
450+
[base_ptr, peer_id],
451+
[],
452+
[],
453+
callee="nvshmem_ptr",
454+
)
435455
rank = ref_ty.rank
436456
assert rank * 2 == len(sizes_and_strides)
437457
swizzle_arg = (
@@ -507,6 +527,7 @@ def async_copy(
507527
dst_ref,
508528
gmem_slice: Any = (),
509529
gmem_transform: MemRefTransform | tuple[MemRefTransform, ...] = (),
530+
gmem_peer_id: int | ir.Value | None = None,
510531
barrier: utils.BarrierRef | None = None,
511532
swizzle: int | None = None,
512533
arrive: bool | None = None,
@@ -750,7 +771,8 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int):
750771
multicast_mask = None
751772

752773
tma_desc = self._get_tma_desc(
753-
gmem_ref, gmem_transform, tuple(slice_shape), swizzle, reduction_op,
774+
gmem_ref, gmem_transform, gmem_peer_id,
775+
tuple(slice_shape), swizzle, reduction_op,
754776
)
755777

756778
# We constuct TMA descriptors in column-major order.
@@ -893,3 +915,33 @@ def device_id(self) -> ir.Value:
893915
self._ensure_nvshmem_decls()
894916
i32 = ir.IntegerType.get_signless(32)
895917
return llvm.call(i32, [], [], [], callee="nvshmem_my_pe")
918+
919+
920+
class ReplicationError(Exception):
921+
pass
922+
923+
def _replicate_peer_id_computation(peer_id: ir.Value, fuel=8) -> ir.Value:
924+
if fuel == 0:
925+
raise ReplicationError(
926+
"gmem_peer_id computation is too complicated to recompute on the host"
927+
)
928+
if isinstance(peer_id, ir.BlockArgument):
929+
raise ReplicationError("Can't recompute a value that's a block argument")
930+
op = peer_id.owner.opview
931+
# We accept all arith ops
932+
if op.OPERATION_NAME.startswith("arith."):
933+
new_operands = [
934+
_replicate_peer_id_computation(x, fuel - 1) for x in op.operands
935+
]
936+
result_types = [r.type for r in op.results]
937+
new_attributes = {na.name: na.attr for na in op.attributes}
938+
new_op = ir.Operation.create(
939+
op.OPERATION_NAME, result_types, new_operands, new_attributes
940+
)
941+
return new_op.results if len(new_op.results) > 1 else new_op.result
942+
if isinstance(op, llvm.CallOp) and op.callee.value == "nvshmem_my_pe":
943+
i32 = ir.IntegerType.get_signless(32)
944+
return llvm.call(i32, [], [], [], callee="nvshmem_my_pe")
945+
raise ReplicationError(
946+
f"Unrecognized op can't be recomputed on the host: {op}"
947+
)

jaxlib/mosaic/gpu/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ cc_library(
122122
# Linker may prune these symbols if they are not explicitly exported.
123123
linkopts = [
124124
"-Wl,--export-dynamic-symbol='mosaic_gpu_*'",
125+
"-Wl,--export-dynamic-symbol='nvshmem_my_pe'",
126+
"-Wl,--export-dynamic-symbol='nvshmem_ptr'",
125127
"-Wl,--export-dynamic-symbol='nvshmemx_barrier_all_on_stream'",
126128
"-Wl,--export-dynamic-symbol='nvshmemx_cumodule_init'",
127129
"-Wl,--export-dynamic-symbol='nvshmemx_init_status'",

tests/mosaic/BUILD

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,27 @@ jax_multiplatform_test(
6363
]),
6464
)
6565

66+
jax_multiplatform_test(
67+
name = "gpu_test_distributed",
68+
srcs = ["gpu_test_distributed.py"],
69+
args = [
70+
"--num_processes=2",
71+
"--gpus_per_process=1",
72+
],
73+
enable_backends = [],
74+
enable_configs = ["gpu_h100x2"],
75+
env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0 --xla_gpu_experimental_enable_nvshmem=true"},
76+
tags = ["multiaccelerator"],
77+
deps = [
78+
"//jax:experimental",
79+
"//jax:mosaic_gpu",
80+
"//jax:test_multiprocess",
81+
] + py_deps([
82+
"absl/testing",
83+
"numpy",
84+
]),
85+
)
86+
6687
jax_py_test(
6788
name = "gpu_dialect_test",
6889
srcs = ["gpu_dialect_test.py"],

tests/mosaic/gpu_test_distributed.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2025 The JAX Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
from absl.testing import parameterized
17+
import jax
18+
from jax._src import config
19+
from jax._src import test_util as jtu
20+
from jax._src import test_multiprocess as jt_multiprocess
21+
from jax._src.interpreters import mlir
22+
from jax._src.lib.mlir import ir
23+
from jax._src.lib.mlir.dialects import arith
24+
from jax.experimental.mosaic.gpu import dialect as mgpu_dialect # pylint: disable=g-importing-member
25+
from jax.experimental import shard
26+
from jax.experimental import multihost_utils
27+
import jax.numpy as jnp
28+
import numpy as np
29+
try:
30+
import jax._src.lib.mosaic_gpu # noqa: F401
31+
HAS_MOSAIC_GPU = True
32+
except ImportError:
33+
HAS_MOSAIC_GPU = False
34+
else:
35+
import jax.experimental.mosaic.gpu as mgpu
36+
37+
38+
# ruff: noqa: F405
39+
# pylint: disable=g-complex-comprehension
40+
P = jax.sharding.PartitionSpec
41+
42+
43+
class TestCase(parameterized.TestCase):
44+
45+
def setUp(self):
46+
if not HAS_MOSAIC_GPU:
47+
self.skipTest("jaxlib built without Mosaic GPU")
48+
if (not jtu.test_device_matches(["cuda"]) or
49+
not jtu.is_cuda_compute_capability_at_least("9.0")):
50+
self.skipTest("Only works on GPU with capability >= sm90")
51+
if not mgpu.supports_cross_device_collectives():
52+
self.skipTest("NVSHMEM library unavailable.")
53+
if jax.process_count() == 1:
54+
self.skipTest("Test requires multiple processes.")
55+
if jax.device_count() != jax.process_count():
56+
self.skipTest("Need 1 device per process")
57+
super().setUp()
58+
self.prng = np.random.default_rng(1234)
59+
self.context = mlir.make_ir_context()
60+
if mgpu_dialect is not None:
61+
mgpu_dialect.register_dialect(self.context)
62+
self.enter_context(config.traceback_filtering("off"))
63+
self.enter_context(self.context)
64+
self.enter_context(ir.Location.unknown())
65+
66+
67+
class ProfilerTest(TestCase):
68+
69+
def test_remote_async_copy(self):
70+
i32 = ir.IntegerType.get_signless(32)
71+
def kernel(ctx, src, dst, scratch):
72+
tmp, barrier = scratch
73+
other_device = arith.subi(arith.constant(i32, 1), ctx.device_id())
74+
ctx.async_copy(src_ref=src, dst_ref=tmp, barrier=barrier)
75+
barrier.wait()
76+
ctx.async_copy(src_ref=tmp, dst_ref=dst, gmem_peer_id=other_device)
77+
ctx.await_async_copy(0)
78+
mesh = jax.make_mesh(
79+
(2,), ("x",), axis_types=(jax.sharding.AxisType.Explicit,)
80+
)
81+
with jax.sharding.use_mesh(mesh):
82+
x_np = np.arange(64 * 64, dtype=jnp.float32).reshape(64, 64)
83+
x = shard.reshard(x_np, P("x"))
84+
y = jax.jit(
85+
jax.shard_map(
86+
lambda x: mgpu.as_gpu_kernel(
87+
kernel, (1, 1, 1), (128, 1, 1), x, x, (x, mgpu.TMABarrier())
88+
)(x),
89+
out_specs=P("x"),
90+
check_vma=False,
91+
)
92+
)(x)
93+
y_np = multihost_utils.process_allgather(y, tiled=True)
94+
np.testing.assert_array_equal(
95+
y_np, np.concatenate(np.split(x_np, 2)[::-1], axis=0)
96+
)
97+
98+
99+
if __name__ == "__main__":
100+
jt_multiprocess.main()

tests/pallas/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,7 @@ jax_multiplatform_test(
864864
"notap",
865865
],
866866
deps = [
867+
"//jax:experimental",
867868
"//jax:pallas",
868869
"//jax:pallas_experimental_gpu_ops",
869870
"//jax:pallas_mosaic_gpu",

tests/pallas/mgpu_collective_matmul_test.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from jax._src.pallas import pallas_call
2828
from jax.experimental.mosaic import gpu as mgpu
2929
from jax.experimental.pallas.ops.gpu import collective_matmul_mgpu
30+
from jax.experimental import shard
3031
import jax.numpy as jnp
3132
import numpy as np
3233

@@ -51,8 +52,13 @@ def setUp(self):
5152
if os.environ.get("XLA_PYTHON_CLIENT_ALLOCATOR", "") == "platform":
5253
self.skipTest("NVSHMEM doesn't work with the platform allocator.")
5354
context_stack = contextlib.ExitStack()
54-
context_stack.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True))
5555
self.addCleanup(context_stack.close)
56+
context_stack.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True))
57+
num_devices = jax.device_count()
58+
mesh = jax.make_mesh(
59+
(num_devices,), ("x",), axis_types=(jax.sharding.AxisType.Explicit,)
60+
)
61+
context_stack.enter_context(jax.sharding.use_mesh(mesh))
5662

5763
@parameterized.product(
5864
m_shard=(1024, 8192),
@@ -90,28 +96,17 @@ def test_all_gather_lhs_matmul(
9096
k1, k2 = random.split(random.key(1234), num=2)
9197
lhs = random.normal(k1, (num_devices * m_shard, k), dtype)
9298
rhs = random.normal(k2, (k, num_devices * n_shard), dtype)
93-
94-
mesh = jax.sharding.Mesh(jax.devices(), ["x"])
95-
lhs = jax.device_put(lhs, jax.sharding.NamedSharding(mesh, P("x", None)))
96-
rhs = jax.device_put(rhs, jax.sharding.NamedSharding(mesh, P(None, "x")))
99+
lhs = shard.reshard(lhs, P("x", None))
100+
rhs = shard.reshard(rhs, P(None, "x"))
97101

98102
def run(body):
99103
out = jax.jit(
100-
jax.shard_map(
101-
body,
102-
mesh=mesh,
103-
in_specs=(P("x", None), P(None, "x")),
104-
out_specs=P(None, "x"),
105-
check_vma=False,
106-
)
104+
jax.shard_map(body, out_specs=P(None, "x"), check_vma=False)
107105
)(lhs, rhs)
108106
# Gather output, for NumPy comparison on the host.
109107
out = jax.shard_map(
110108
lambda x: lax.all_gather(x, "x", axis=1, tiled=True),
111-
mesh=mesh,
112-
in_specs=P(None, "x"),
113-
out_specs=P(None),
114-
check_vma=False,
109+
out_specs=P(None), check_vma=False,
115110
)(out)
116111
return out
117112

0 commit comments

Comments
 (0)