Skip to content

Commit 7014bde

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Add a first prototype of an all_gather collective matmul kernel
It's not very optimized at the moment and is unlikely to outperform the baseline of raw all_gather + matmul, but it computes the right numbers. We are already aware of a few places that could be optimized and we'll start rolling them out soon. PiperOrigin-RevId: 761939624
1 parent 2232201 commit 7014bde

File tree

4 files changed

+342
-1
lines changed

4 files changed

+342
-1
lines changed

jax/_src/pallas/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,7 @@ def _convert_block_spec_to_block_mapping(
971971
class ScratchShape(Protocol):
972972
def get_array_aval(self) -> jax_core.AbstractValue:
973973
...
974-
def get_ref_aval(self) -> state.AbstractRef:
974+
def get_ref_aval(self) -> state.AbstractRef | TransformedRef:
975975
...
976976

977977

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""A collective matmul kernel implemented using Mosaic GPU."""
16+
17+
import functools
18+
import jax
19+
from jax import lax
20+
from jax.experimental import pallas as pl
21+
from jax.experimental.pallas import mosaic_gpu as plgpu
22+
import jax.numpy as jnp
23+
24+
25+
def _find_swizzle(dim_size_bits: int, what: str):
26+
for swizzle_bytes in (128, 64, 32, 16):
27+
if dim_size_bits % (swizzle_bytes * 8) == 0:
28+
return swizzle_bytes
29+
raise ValueError(
30+
f"No valid out swizzle for {what}: its minor dimension has"
31+
f" {dim_size_bits} bits, which is not a multiple of 128"
32+
)
33+
34+
35+
# TODO(apaszke): Add grid tiling
36+
def all_gather_lhs_matmul(
37+
lhs: jax.Array,
38+
rhs: jax.Array,
39+
axis_name,
40+
*,
41+
block_m: int,
42+
block_n: int,
43+
block_k: int,
44+
max_concurrent_steps: int,
45+
) -> jax.Array:
46+
if (num_devices := jax.device_count()) != jax.process_count():
47+
raise ValueError("The kernel only supports one device per process")
48+
if (axis_size := lax.axis_size(axis_name)) != num_devices:
49+
raise ValueError("The kernel can only work over all devices in a Mesh.")
50+
if max_concurrent_steps < 2:
51+
raise ValueError("max_concurrent_steps must be >= 2")
52+
53+
num_sms = 132 # There are 132 SMs on a H100 SXM GPU.
54+
55+
m_shard, k = lhs.shape
56+
k2, n_shard = rhs.shape
57+
if k != k2:
58+
raise ValueError(
59+
f"lhs and rhs must have the same contraction size, got {k} and {k2}."
60+
)
61+
if (element_type := lhs.dtype) != rhs.dtype:
62+
raise ValueError(
63+
f"lhs and rhs must have the same element type, got {element_type} and"
64+
f" {rhs.dtype}."
65+
)
66+
if k % block_k != 0:
67+
raise NotImplementedError(f"k={k} must be a multiple of block_k={block_k}")
68+
if m_shard % block_m != 0:
69+
raise NotImplementedError(f"m_shard={m_shard} must be a multiple of block_m={block_m}")
70+
if n_shard % block_n != 0:
71+
raise NotImplementedError(f"n_shard={n_shard} must be a multiple of block_n={block_n}")
72+
if n_shard != block_n:
73+
raise NotImplementedError(
74+
f"n_shard={n_shard} must be equal to block_n={block_n}"
75+
)
76+
77+
swizzle = min(
78+
_find_swizzle(block_k * jnp.finfo(element_type).bits, "lhs"),
79+
_find_swizzle(block_n * jnp.finfo(element_type).bits, "rhs"),
80+
)
81+
transforms = (
82+
plgpu.TilingTransform((8, swizzle // jnp.dtype(element_type).itemsize)),
83+
plgpu.SwizzleTransform(swizzle),
84+
)
85+
86+
def kernel_body(lhs_ref, rhs_ref, out_ref, scratch_ref, capacity_sem, received_sem):
87+
sm_id = lax.axis_index('sm')
88+
scratch_ref = scratch_ref.at[sm_id]
89+
90+
dev_id = lax.axis_index(axis_name)
91+
send_dev_id = lax.rem(dev_id + axis_size - 1, axis_size)
92+
recv_dev_id = lax.rem(dev_id + 1, axis_size)
93+
# NOTE: Technically we should signal the recv_dev_id (and our signal would
94+
# be received from send_dev_id), but if everyone signals in a ring after a
95+
# barrier then it's equivalent to a local signal.
96+
pl.semaphore_signal(capacity_sem)
97+
send_scratch_ref = plgpu.remote_ref(
98+
scratch_ref, send_dev_id, device_id_type=pl.DeviceIdType.LOGICAL
99+
)
100+
101+
def m_loop(mi, _):
102+
mi = mi * lax.axis_size('sm') + sm_id
103+
m_tile_slice = pl.ds(mi * block_m, block_m)
104+
105+
# For some reason ptxas spills if we unroll the loop over k
106+
copy_block = 32
107+
def k_copy_loop(ki, _):
108+
k_slice = pl.ds(ki * copy_block, copy_block)
109+
scratch_ref[0, :, k_slice] = lhs_ref[m_tile_slice, k_slice]
110+
jax.lax.fori_loop(0, k // copy_block, k_copy_loop, None)
111+
112+
def device_loop(device_offset, _):
113+
# Loop invariant: scratch_ref.at[scratch_slot] is ready to be used
114+
# We're double buffering the scratch space. At each step, we read from
115+
# scratch_ref.at[scratch_slot] and write to scratch_ref.at[next_scratch_slot]
116+
# located on the send_dev_id. We swap the slots after completing a step,
117+
# which lets us overlap the copy with compute.
118+
scratch_slot = lax.rem(device_offset, 2)
119+
next_scratch_slot = 1 - scratch_slot
120+
121+
@functools.partial(
122+
pl.run_scoped,
123+
acc_ref=plgpu.ACC((block_m, block_n)),
124+
out_smem=plgpu.SMEM((block_m, block_n), jnp.float16, transforms=transforms),
125+
)
126+
def _(acc_ref, out_smem):
127+
pl.semaphore_wait(capacity_sem)
128+
@functools.partial(
129+
plgpu.emit_pipeline,
130+
grid=(k // block_k,),
131+
in_specs=[
132+
plgpu.BlockSpec((block_m, block_k), lambda k: (0, k), transforms=transforms),
133+
plgpu.BlockSpec((block_k, block_n), lambda k: (k, 0), transforms=transforms),
134+
],
135+
max_concurrent_steps=max_concurrent_steps,
136+
delay_release=1,
137+
)
138+
def k_loop(idxs, lhs_smem, rhs_smem):
139+
(ki,) = idxs
140+
plgpu.wgmma(acc_ref, lhs_smem, rhs_smem)
141+
k_slice = pl.ds(ki * block_k, block_k)
142+
# TODO(apaszke): No need to send on the last step
143+
# TODO(apaszke): Use an async copy. This is uncoalesced.
144+
send_scratch_ref[next_scratch_slot, :, k_slice] = lhs_smem[...]
145+
k_loop(scratch_ref.at[scratch_slot], rhs_ref)
146+
# TODO(apaszke): Both of those semaphores perform a .sys release.
147+
# This is very expensive and we should only do a single .sys fence.
148+
pl.semaphore_signal(capacity_sem, device_id=recv_dev_id, device_id_type=pl.DeviceIdType.LOGICAL)
149+
pl.semaphore_signal(received_sem, device_id=send_dev_id, device_id_type=pl.DeviceIdType.LOGICAL)
150+
# Make sure all TMAs have read SMEM before we overwrite it.
151+
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
152+
out_smem[...] = acc_ref[...].astype(out_smem.dtype)
153+
plgpu.commit_smem()
154+
device_m_slice = pl.ds(
155+
lax.rem(device_offset + dev_id, num_devices) * m_shard, block_m
156+
)
157+
plgpu.copy_smem_to_gmem(
158+
out_smem, out_ref.at[device_m_slice].at[m_tile_slice]
159+
)
160+
# Wait for the next scratch to arrive --- see the loop invariant.
161+
pl.semaphore_wait(received_sem)
162+
jax.lax.fori_loop(0, num_devices, device_loop, None)
163+
grid_size = m_shard // block_m
164+
m_steps = grid_size // num_sms + jnp.int32(sm_id < grid_size % num_sms)
165+
# TODO(apaszke): Use the ND-loop helper.
166+
jax.lax.fori_loop(0, m_steps, m_loop, None)
167+
168+
result, _ = plgpu.kernel(
169+
kernel_body,
170+
out_shape=[jax.ShapeDtypeStruct((axis_size * m_shard, n_shard), jnp.float16),
171+
jax.ShapeDtypeStruct((num_sms, 2, block_m, k), jnp.float16)],
172+
scratch_shapes=[
173+
plgpu.SemaphoreType.REGULAR, plgpu.SemaphoreType.REGULAR,
174+
],
175+
grid=(num_sms,),
176+
grid_names=('sm',),
177+
)(lhs, rhs)
178+
return result

tests/pallas/BUILD

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -842,6 +842,35 @@ jax_multiplatform_test(
842842
]),
843843
)
844844

845+
jax_multiplatform_test(
846+
name = "mgpu_collective_matmul_test",
847+
srcs = ["mgpu_collective_matmul_test.py"],
848+
args = [
849+
"--num_processes=2",
850+
"--gpus_per_process=1",
851+
],
852+
enable_backends = [],
853+
enable_configs = [
854+
"gpu_h100x2",
855+
],
856+
env = {
857+
"XLA_FLAGS": "--xla_gpu_experimental_enable_nvshmem=true",
858+
"JAX_PALLAS_USE_MOSAIC_GPU": "1",
859+
},
860+
shard_count = 4,
861+
tags = [
862+
"manual",
863+
"multiaccelerator",
864+
"notap",
865+
],
866+
deps = [
867+
"//jax:pallas",
868+
"//jax:pallas_experimental_gpu_ops",
869+
"//jax:pallas_mosaic_gpu",
870+
"//jax:test_multiprocess",
871+
] + py_deps("absl/testing") + py_deps("numpy"),
872+
)
873+
845874
jax_multiplatform_test(
846875
name = "fuser_block_spec_test",
847876
srcs = [
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright 2025 The JAX Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Test different parameterizations of our Mosaic GPU collective matmul."""
16+
17+
import contextlib
18+
import functools
19+
import os
20+
21+
from absl.testing import parameterized # pylint: disable=g-multiple-import
22+
import jax
23+
from jax import lax
24+
from jax import random
25+
from jax._src import test_multiprocess as jt_multiprocess
26+
from jax._src import test_util as jtu
27+
from jax._src.pallas import pallas_call
28+
from jax.experimental.mosaic import gpu as mgpu
29+
from jax.experimental.pallas.ops.gpu import collective_matmul_mgpu
30+
import jax.numpy as jnp
31+
import numpy as np
32+
33+
34+
P = jax.sharding.PartitionSpec
35+
36+
37+
@jtu.with_config(jax_traceback_filtering="off")
38+
class CollectiveMatmulTestCase(jtu.JaxTestCase):
39+
40+
def setUp(self):
41+
super().setUp()
42+
if collective_matmul_mgpu is None:
43+
self.skipTest("Mosaic GPU not available.")
44+
if (not jtu.test_device_matches(["cuda"]) or
45+
not jtu.is_cuda_compute_capability_equal("9.0")):
46+
self.skipTest("Only works on GPU with capability sm90a")
47+
if not mgpu.supports_cross_device_collectives():
48+
self.skipTest("NVSHMEM library unavailable.")
49+
if jax.process_count() == 1:
50+
self.skipTest("Test requires multiple processes.")
51+
context_stack = contextlib.ExitStack()
52+
context_stack.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True))
53+
self.addCleanup(context_stack.close)
54+
55+
@parameterized.product(
56+
m_shard=(1024, 8192),
57+
n_shard=(64, 128, 192),
58+
k=(256, 8192),
59+
block_m=(64, 128, 192),
60+
block_n=(64, 128, 192),
61+
block_k=(64, 128),
62+
max_concurrent_steps=(2, 4),
63+
)
64+
def test_all_gather_lhs_matmul(
65+
self,
66+
m_shard,
67+
n_shard,
68+
k,
69+
block_m,
70+
block_n,
71+
block_k,
72+
max_concurrent_steps,
73+
):
74+
num_devices = jax.device_count()
75+
dtype = jnp.float16
76+
lhs_smem_size = block_m * block_k * max_concurrent_steps * 2
77+
rhs_smem_size = block_k * block_n * max_concurrent_steps * 2
78+
# H100 SMEM limit is 228kB.
79+
if lhs_smem_size + rhs_smem_size > 228_000:
80+
self.skipTest("This configuration requires too much SMEM.")
81+
if n_shard != block_n:
82+
self.skipTest("n_shard must be equal to block_n for now.")
83+
if n_shard % block_n:
84+
self.skipTest("n_shard must be divisble by block_n for now.")
85+
if m_shard % block_m:
86+
self.skipTest("m_shard must be divisible by block_m for now.")
87+
88+
k1, k2 = random.split(random.key(1234), num=2)
89+
lhs = random.normal(k1, (num_devices * m_shard, k), dtype)
90+
rhs = random.normal(k2, (k, num_devices * n_shard), dtype)
91+
92+
mesh = jax.sharding.Mesh(jax.devices(), ["x"])
93+
lhs = jax.device_put(lhs, jax.sharding.NamedSharding(mesh, P("x", None)))
94+
rhs = jax.device_put(rhs, jax.sharding.NamedSharding(mesh, P(None, "x")))
95+
96+
def run(body):
97+
out = jax.jit(
98+
jax.shard_map(
99+
body,
100+
mesh=mesh,
101+
in_specs=(P("x", None), P(None, "x")),
102+
out_specs=P(None, "x"),
103+
check_vma=False,
104+
)
105+
)(lhs, rhs)
106+
# Gather output, for NumPy comparison on the host.
107+
out = jax.shard_map(
108+
lambda x: lax.all_gather(x, "x", axis=1, tiled=True),
109+
mesh=mesh,
110+
in_specs=P(None, "x"),
111+
out_specs=P(None),
112+
check_vma=False,
113+
)(out)
114+
return out
115+
116+
out = run(
117+
functools.partial(
118+
collective_matmul_mgpu.all_gather_lhs_matmul,
119+
axis_name="x",
120+
block_m=block_m,
121+
block_n=block_n,
122+
block_k=block_k,
123+
max_concurrent_steps=max_concurrent_steps,
124+
)
125+
)
126+
ref_out = run(lambda x, y: lax.all_gather(x, "x", axis=0, tiled=True) @ y)
127+
np.testing.assert_allclose(out, ref_out)
128+
129+
130+
if __name__ == "__main__":
131+
os.environ["XLA_FLAGS"] = (
132+
os.environ.get("XLA_FLAGS", "") + " --xla_gpu_autotune_level=0"
133+
)
134+
jt_multiprocess.main()

0 commit comments

Comments
 (0)