Skip to content

Commit 42b0625

Browse files
justinjfuGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add non-collective blackwell matmul example
PiperOrigin-RevId: 761718971
1 parent f5a9d46 commit 42b0625

File tree

2 files changed

+245
-0
lines changed

2 files changed

+245
-0
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,6 +1498,8 @@ def _slice_lowering_rule(
14981498

14991499

15001500
@register_lowering_rule(lax.select_n_p, mgpu.LoweringSemantics.Lane)
1501+
@register_lowering_rule(lax.select_n_p, mgpu.LoweringSemantics.Lane,
1502+
gpu_core.PrimitiveSemantics.Warp)
15011503
@register_lowering_rule(lax.select_n_p, mgpu.LoweringSemantics.Warpgroup)
15021504
def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases):
15031505
if len(cases) != 2:
@@ -1506,6 +1508,10 @@ def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases):
15061508
f" {len(cases)}"
15071509
)
15081510
pred_aval, *cases_avals = ctx.avals_in
1511+
if ctx.module_ctx.primitive_semantics == gpu_core.PrimitiveSemantics.Warp:
1512+
if not all(aval.shape == () for aval in ctx.avals_in):
1513+
raise NotImplementedError(
1514+
"Can only select on scalars in warp-level lowering.")
15091515
[out_aval] = ctx.avals_out
15101516
if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane:
15111517
pred = _ensure_fa(pred, pred_aval.dtype)
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Matrix Multiplication kernel for Blackwell GPUs."""
15+
16+
import dataclasses
17+
import functools
18+
import itertools
19+
import jax
20+
from jax import lax
21+
from jax._src import test_util as jtu # noqa: F401
22+
from jax.experimental.mosaic.gpu import profiler
23+
import jax.experimental.pallas as pl
24+
import jax.experimental.pallas.mosaic_gpu as plgpu
25+
import jax.numpy as jnp
26+
import numpy as np
27+
28+
29+
@dataclasses.dataclass(frozen=True)
30+
class TuningConfig:
31+
block_m: int
32+
block_n: int
33+
block_k: int
34+
max_concurrent_steps: int
35+
collective: bool
36+
37+
38+
def _find_swizzle(dim_size_bits: int):
39+
"""Finds the largest swizzle that fits the dimension size."""
40+
for swizzle_bytes in (128, 64, 32, 16):
41+
if dim_size_bits % (swizzle_bytes * 8) == 0:
42+
return swizzle_bytes
43+
raise ValueError(
44+
f"Dimension size has {dim_size_bits} bits, which is not a multiple of 128"
45+
)
46+
47+
48+
def matmul_kernel(a, b, config: TuningConfig):
49+
dtype = a.dtype
50+
if a.dtype != b.dtype:
51+
raise ValueError(
52+
f"Matmul LHS and RHS have incompatible dtypes {a.dtype} vs {b.dtype}"
53+
)
54+
m, k = a.shape
55+
k2, n = b.shape
56+
if k != k2:
57+
raise ValueError(
58+
f"Matmul LHS and RHS have incompatible shapes {a.shape} vs {b.shape}"
59+
)
60+
collective = config.collective
61+
if collective:
62+
raise ValueError("Collective matmul is not supported yet.")
63+
block_m, block_n, block_k = (config.block_m, config.block_n, config.block_k)
64+
swizzle = _find_swizzle(block_k * jnp.dtype(dtype).itemsize * 8)
65+
swizzle_elems = swizzle // jnp.dtype(dtype).itemsize
66+
transforms = (
67+
plgpu.TilingTransform((8, swizzle_elems)),
68+
plgpu.SwizzleTransform(swizzle),
69+
)
70+
block_lhs = (block_m, block_k)
71+
block_rhs = (block_k, block_n)
72+
block_out = (block_m, block_n)
73+
if m % block_m != 0:
74+
raise ValueError(f"{m=} must be divisible by {block_m=}")
75+
if n % block_n != 0:
76+
raise ValueError(f"{n=} must be divisible by {block_n=}")
77+
if k % block_k != 0:
78+
raise ValueError(f"{k=} must be divisible by {block_k=}")
79+
m_iters = m // block_m
80+
n_iters = n // block_n
81+
k_iters = k // block_k
82+
max_concurrent_steps = config.max_concurrent_steps
83+
84+
def kernel(a_gmem, b_gmem, out_gmem):
85+
m_index = lax.axis_index("m")
86+
n_index = lax.axis_index("n")
87+
slice_m = pl.ds(m_index * block_m, block_m)
88+
slice_n = pl.ds(n_index * block_n, block_n)
89+
acc_slice_m = pl.ds(m_index * block_m, block_m)
90+
acc_slice_n = pl.ds(n_index * block_n, block_n)
91+
92+
@functools.partial(
93+
pl.run_scoped,
94+
a_smem=plgpu.SMEM(
95+
(max_concurrent_steps, *block_lhs), dtype, transforms=transforms
96+
),
97+
b_smem=plgpu.SMEM(
98+
(max_concurrent_steps, *block_rhs), dtype, transforms=transforms
99+
),
100+
acc_tmem=plgpu.TMEM(block_out, jnp.float32, collective=collective),
101+
scratch_smem=plgpu.SMEM(block_out, dtype, transforms=transforms),
102+
a_tma_barrier=plgpu.Barrier(
103+
num_arrivals=1, num_barriers=max_concurrent_steps
104+
),
105+
b_tma_barrier=plgpu.Barrier(
106+
num_arrivals=1, num_barriers=max_concurrent_steps
107+
),
108+
consumed_barrier=plgpu.Barrier(
109+
num_arrivals=1,
110+
num_barriers=max_concurrent_steps + 1,
111+
for_tensor_core=True,
112+
),
113+
)
114+
def _scoped(
115+
a_smem,
116+
b_smem,
117+
acc_tmem,
118+
scratch_smem,
119+
a_tma_barrier,
120+
b_tma_barrier,
121+
consumed_barrier,
122+
):
123+
@pl.core_map(plgpu.WarpMesh(axis_name="warp"))
124+
def _per_warp():
125+
warp_id = lax.axis_index("warp")
126+
127+
@pl.when(warp_id == 0)
128+
def _memory():
129+
def _loop_body(ki, _):
130+
slot = lax.rem(ki, max_concurrent_steps)
131+
132+
@pl.when(ki >= max_concurrent_steps)
133+
def _():
134+
plgpu.barrier_wait(consumed_barrier.at[slot])
135+
136+
slice_k = pl.ds(ki * block_k, block_k)
137+
plgpu.copy_gmem_to_smem(
138+
a_gmem.at[slice_m, slice_k],
139+
a_smem.at[slot],
140+
a_tma_barrier.at[slot],
141+
)
142+
plgpu.copy_gmem_to_smem(
143+
b_gmem.at[slice_k, slice_n],
144+
b_smem.at[slot],
145+
b_tma_barrier.at[slot],
146+
)
147+
148+
lax.fori_loop(0, k_iters, _loop_body, None)
149+
150+
@pl.when(warp_id == 1)
151+
def _compute():
152+
def _loop_body(ki, _):
153+
slot = lax.rem(ki, max_concurrent_steps)
154+
plgpu.barrier_wait(a_tma_barrier.at[slot])
155+
plgpu.barrier_wait(b_tma_barrier.at[slot])
156+
is_last_iter = ki >= k_iters - 1
157+
barrier_slot = lax.select_n(is_last_iter,
158+
slot, max_concurrent_steps)
159+
plgpu.tcgen05_mma(
160+
acc_tmem,
161+
a_smem.at[slot],
162+
b_smem.at[slot],
163+
consumed_barrier.at[barrier_slot],
164+
accumulate=(ki > 0),
165+
)
166+
lax.fori_loop(0, k_iters, _loop_body, None)
167+
168+
plgpu.barrier_wait(consumed_barrier.at[max_concurrent_steps])
169+
scratch_smem[...] = acc_tmem[...].astype(dtype)
170+
plgpu.commit_smem()
171+
plgpu.copy_smem_to_gmem(
172+
scratch_smem, out_gmem.at[acc_slice_m, acc_slice_n]
173+
)
174+
plgpu.wait_smem_to_gmem(0)
175+
176+
f = plgpu.kernel(
177+
kernel,
178+
out_shape=jax.ShapeDtypeStruct((m, n), dtype),
179+
grid=(m_iters, n_iters),
180+
grid_names=("m", "n"),
181+
# TODO(justinfu): Add collective support.
182+
cluster_names=(),
183+
cluster=(),
184+
)
185+
return f(a, b)
186+
187+
188+
def main(_) -> None:
189+
problem_it = itertools.product(
190+
(1024, 4096, 8192), (1024, 4096, 8192), (1024, 8192)
191+
)
192+
for M, N, K in problem_it:
193+
print(f"==== {M=} {N=} {K=} ====")
194+
matmul_flops = 2 * M * N * K
195+
peak_flops = 2.25e15 # f16 TensorCore peak = 2250 TFLOPS
196+
a = jax.random.uniform(jax.random.key(0), (M, K), jnp.bfloat16)
197+
b = jax.random.uniform(jax.random.key(1), (K, N), jnp.bfloat16)
198+
tuning_it = itertools.product(
199+
(128,), (128, 256), (64, 128), (2, 3, 4), (False,)
200+
)
201+
best_util = -float("inf")
202+
for (block_m, block_n, block_k,
203+
max_concurrent_steps, collective) in tuning_it:
204+
config = TuningConfig(
205+
block_m=block_m,
206+
block_n=block_n,
207+
block_k=block_k,
208+
max_concurrent_steps=max_concurrent_steps,
209+
collective=collective,
210+
)
211+
try:
212+
out, runtime_ms = profiler.measure(
213+
functools.partial(matmul_kernel, config=config)
214+
)(a, b)
215+
except ValueError as e:
216+
if "exceeds available shared memory" in e.args[0]:
217+
continue
218+
raise
219+
if M * N * K <= 1024 * 1024 * 1024:
220+
expected = a @ b
221+
np.testing.assert_allclose(out, expected)
222+
runtime_us = float(runtime_ms) * 1e3
223+
optimal_time = matmul_flops / peak_flops * 1e6 # us
224+
achieved_tc_util = optimal_time / runtime_us * 100
225+
if achieved_tc_util > best_util:
226+
best_util = achieved_tc_util
227+
print(
228+
f"{block_m=} {block_n=} {block_k=} {max_concurrent_steps=}: "
229+
f"{runtime_us:<7.1f}us"
230+
f" = {achieved_tc_util:4.1f}% TC utilization"
231+
)
232+
print(f"\tBest utilization: {best_util:4.1f}%")
233+
234+
235+
if __name__ == "__main__":
236+
from absl import app
237+
238+
jax.config.config_with_absl()
239+
app.run(main)

0 commit comments

Comments
 (0)