From 492178926e059249c9c07345430f573d024b37fa Mon Sep 17 00:00:00 2001 From: Amir Samani Date: Sun, 11 May 2025 23:25:31 -0700 Subject: [PATCH 1/2] add binding to get gpu SM count --- jaxlib/mosaic/gpu/mosaic_gpu_ext.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc index decdbaef28e1..f75154445797 100644 --- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -262,6 +262,16 @@ NB_MODULE(_mosaic_gpu_ext, m) { return profiler_state.timings; }, nb::arg("finalize") = true); + m.def("_get_gpu_sm_count", []() { + int dev = 0; + gpuDeviceProp deviceProp; + gpuError_t err = gpuGetDeviceProperties(&deviceProp, dev); + if (err != gpuSuccess) { + throw std::runtime_error("Failed get GPU properties!"); + } + int sm_count = deviceProp.multiProcessorCount; + return sm_count; + }); } } // namespace From a0ef68fa962e32b2f425c64d053d55a240c50635 Mon Sep 17 00:00:00 2001 From: Amir Samani Date: Sun, 11 May 2025 23:27:23 -0700 Subject: [PATCH 2/2] add mc ptr support to tma with overlapped gemm and all reduce examples --- jax/_src/pallas/mosaic_gpu/primitives.py | 5 + jax/experimental/mosaic/gpu/core.py | 19 +- .../mosaic/gpu/examples/gemm_ar_one_shot.py | 495 ++++++++++++++++ .../mosaic/gpu/examples/gemm_ar_two_shot.py | 530 ++++++++++++++++++ jax/experimental/mosaic/gpu/launch_context.py | 26 +- jax/experimental/mosaic/gpu/utils.py | 77 +++ jaxlib/mosaic/gpu/nvshmem.h | 8 + jaxlib/mosaic/gpu/runtime.cc | 12 +- 8 files changed, 1164 insertions(+), 8 deletions(-) create mode 100644 jax/experimental/mosaic/gpu/examples/gemm_ar_one_shot.py create mode 100644 jax/experimental/mosaic/gpu/examples/gemm_ar_two_shot.py diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 40eccca7c711..f9be9507195e 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -237,6 +237,7 @@ def _copy_smem_to_gmem_lowering( has_user_predicate, commit_group, reduction_op, + team_id, ): if has_user_predicate: flat_args, user_predicate = flat_args[:-1], flat_args[-1] @@ -268,6 +269,7 @@ def _copy_smem_to_gmem_lowering( predicate=predicate, arrive=commit_group, reduction_op=reduction_op, + team_id=team_id, **copy_params, ) return () @@ -347,6 +349,7 @@ def copy_smem_to_gmem( *, commit_group: bool = True, reduction_op: mgpu.ReductionOp | None = None, + team_id: int | None = None, ) -> None: """Asynchronously copies a SMEM reference to a GMEM reference. @@ -361,6 +364,7 @@ def copy_smem_to_gmem( reduction_op: If set, perform the specified reduction operation when storing to GMEM. For example, using ``"add"`` is conceptually equivalent to doing ``src += dst``. + team_id: if set, dst ref would be translated to a multicast memory addr See also: :func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem` @@ -389,6 +393,7 @@ def copy_smem_to_gmem( has_user_predicate=predicate is not None, commit_group=commit_group, reduction_op=reduction_op, + team_id = team_id, ) return None diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index c20c5252a27f..659f701d6713 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -116,8 +116,14 @@ def supports_cross_device_collectives(): @mosaic_gpu_p.def_abstract_eval -def _mosaic_gpu_abstract_eval(*_, module, out_types): +def _mosaic_gpu_abstract_eval( + *_, + module, + out_types, + input_output_aliases, +): del module # Unused. + del input_output_aliases # Unused. return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types] # TODO(apaszke): Implement a proper system for managing kernel lifetimes @@ -618,8 +624,9 @@ def _run_serde_pass( def _declare_runtime_functions(): """Declares the runtime functions that can be used by the generated code.""" ptr_ty = ir.Type.parse("!llvm.ptr") + i32 = ir.IntegerType.get_signless(32) i64 = ir.IntegerType.get_signless(64) - arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty] + arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty, i32] init_tma_desc_type = ir.FunctionType.get(arg_tys, []) func.FuncOp( "mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private" @@ -639,6 +646,7 @@ def as_gpu_kernel( kernel_name: str | None = None, ir_version: int | None = None, thread_semantics: LoweringSemantics = LoweringSemantics.Lane, + input_output_aliases: tuple[tuple[int, int], ...] = (), ): if isinstance(in_shape, list): in_shape = tuple(in_shape) @@ -680,7 +688,12 @@ def _check_args(*args): ) def bind(*args) -> Any: - return mosaic_gpu_p.bind(*args, module=module, out_types=out_shape) + return mosaic_gpu_p.bind( + *args, + module=module, + out_types=out_shape, + input_output_aliases=input_output_aliases, + ) if prof_spec is not None: @jax.jit diff --git a/jax/experimental/mosaic/gpu/examples/gemm_ar_one_shot.py b/jax/experimental/mosaic/gpu/examples/gemm_ar_one_shot.py new file mode 100644 index 000000000000..169be73bda7a --- /dev/null +++ b/jax/experimental/mosaic/gpu/examples/gemm_ar_one_shot.py @@ -0,0 +1,495 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""simple one-shot distributed persistent GEMM overlapped with all reduce for H100.""" + +import dataclasses +import itertools +from typing import Any +from functools import partial + +import jax +from jax import random +from jax import lax +from jax._src import test_util as jtu # noqa: F401 +from jax._src.interpreters import mlir +import jax._src.lib.mosaic_gpu as mosaic_gpu_lib +from jax.experimental import mesh_utils, shard_map +import jax.experimental.mosaic.gpu as mgpu +from jax.experimental.mosaic.gpu import profiler +from jax.experimental.mosaic.gpu import * # noqa: F403 +from jax.experimental.mosaic.gpu import utils as mgpu_utils +from jax.experimental.mosaic.gpu import core as mgpu_core +import jax.numpy as jnp +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import arith +from jaxlib.mlir.dialects import gpu +from jaxlib.mlir.dialects import nvvm +from jaxlib.mlir.dialects import scf +import numpy as np + +# mypy: ignore-errors +# ruff: noqa: F405 +# pylint: disable=line-too-long, wildcard-import, missing-function-docstring, bad-continuation, g-bad-todo, protected-access + +SmemRef = ir.Value + +P = jax.sharding.PartitionSpec + + +@dataclasses.dataclass(frozen=True) +class Tiling: + m: int + n: int + k: int + + # Allow access by .mk, .kn, .mn, etc. + def __getattr__(self, name): + if len(name) == 1: + return super().__getattribute__(name) + return tuple(getattr(self, d) for d in name) + + +class WGMMADefaultImpl: + """Default WGMMA implementation. + + The kernel can accept any class that satisfies the same interface as this + class. + """ + + @staticmethod + def zero_accs(tile_m: int, tile_n: int) -> WGMMAAccumulator: + return WGMMAAccumulator.zero(tile_m, tile_n) + + @staticmethod + def smem_shape_extra( + block_tiling: Tiling, + tma_tiling: Tiling, + lhs_dtype: jnp.dtype, rhs_dtype: jnp.dtype, + rhs_transpose: bool, + ) -> dict[str, jax.ShapeDtypeStruct]: + del block_tiling, tma_tiling, lhs_dtype, rhs_dtype, rhs_transpose # Unused. + return () + + @staticmethod + def get_result(acc: WGMMAAccumulator) -> FragmentedArray: + return acc.value + + @staticmethod + def wgmma( + smem_scratch: Any, # pylint: disable=unused-argument + acc: WGMMAAccumulator, + a_slice: SmemRef, + b_slice: SmemRef, + swizzle: int, + ) -> dict[str, WGMMAAccumulator]: + """Perform a matrix multiplication. + + This function must guarantee that all WGMMA operations queued before it was + called have completed before returning. + """ + acc = wgmma(acc, a_slice, b_slice, swizzle=swizzle) + nvvm.wgmma_commit_group_sync_aligned() + nvvm.wgmma_wait_group_sync_aligned(1) + return acc + + +def mlir_context(f): + def wrap(*args, **kw): + with mlir.make_ir_context(), ir.Location.unknown(): + return f(*args, **kw) + + return wrap + + +def worker_for(worker_id, worker_count, work_count): + def wrapper(f): + for_op = scf.ForOp(worker_id, work_count, worker_count, []) + with ir.InsertionPoint(for_op.body): + f(for_op.induction_variable) + scf.yield_([]) + return wrapper + + +@mlir_context +def build_kernel( + m, n, k, + lhs_dtype, rhs_dtype, out_dtype, + cta_count: int = 132, # 132 SMs for Hopper + num_gpus: int = 8, + stages: int = 2, + tile_m: int = 128, + tile_n: int = 128, + swizzle: int = 128, + rhs_transpose: bool = False, + wgmma_impl=WGMMADefaultImpl, + profiler_spec: profiler.ProfilerSpec | None = None, +): + if tile_m % 64 != 0: + raise ValueError(f"{tile_m=} must be divisible by 64") + if n % tile_n != 0: + raise ValueError(f"{n=} must be divisible by {tile_n=}") + if stages < 2: + raise ValueError(f"Need at least 2 stages, but got {stages=}") + if not rhs_transpose and jnp.dtype(rhs_dtype).itemsize != 2: + raise ValueError(f"Transpose only supported for 16bit types (got: {rhs_transpose=}, {rhs_dtype=})") + if swizzle not in {32, 64, 128}: + raise ValueError(f"swizzle must be 32, 64, or 128, but got {swizzle=}") + + out_mlir_dtype = mlir.dtype_to_ir_type(out_dtype) + out_swizzle = swizzle + if bytewidth(out_mlir_dtype) == 4: + if tile_n % 32 == 0: + out_swizzle = 128 + elif tile_n % 16 == 0: + out_swizzle = 64 + else: + raise NotImplementedError( + f"{tile_n=} must by divisible by 16 for 32-bit output" + ) + out_swizzle_elems = out_swizzle // bytewidth(out_mlir_dtype) + out_tiling = (64, out_swizzle_elems) + out_tile = jax.ShapeDtypeStruct(tile_shape((tile_m, tile_n), out_tiling), out_dtype) + + lhs_elem_bytes = bytewidth(mlir.dtype_to_ir_type(lhs_dtype)) + rhs_elem_bytes = bytewidth(mlir.dtype_to_ir_type(rhs_dtype)) + lhs_swizzle_elems = swizzle // lhs_elem_bytes + rhs_swizzle_elems = swizzle // rhs_elem_bytes + tile_k = max(lhs_swizzle_elems, rhs_swizzle_elems) + + if tile_n % rhs_swizzle_elems != 0: + raise ValueError( + f"{tile_n=} must be divisible by {swizzle} bytes =" + f" {((lhs_swizzle_elems, lhs_dtype), (rhs_swizzle_elems, rhs_dtype))}" + ) + + if k % tile_k != 0: + raise ValueError(f"k must be divisible by {tile_k=}, but got {k=}") + + block_tiling = Tiling(m=tile_m, n=tile_n, k=tile_k) + tma_tiling = Tiling(m=64, n=rhs_swizzle_elems, k=lhs_swizzle_elems) + k_steps = k // block_tiling.k + stages = min(stages, k_steps) + + def safe_div(x, y): + assert x % y == 0, (x, y) + return x // y + + def partial_tile_div(x, y): + if x % y == 0: + return x // y + else: + return (x // y) + 1 + + m_tile_count = partial_tile_div(m, block_tiling.m) + n_tile_count = n // tile_n + tile_count = m_tile_count * n_tile_count + + block = (128, 1, 1) + + c = arith.ConstantOp.create_index + + compute_scratch_shape = ( + jax.ShapeDtypeStruct((stages, *tile_shape(block_tiling.mk, tma_tiling.mk)), lhs_dtype), + jax.ShapeDtypeStruct((stages, *tile_shape(block_tiling.kn, tma_tiling.kn)), rhs_dtype), + wgmma_impl.smem_shape_extra(block_tiling, tma_tiling, lhs_dtype, rhs_dtype, rhs_transpose), + ) + epilogue_scratch_shape = jax.ShapeDtypeStruct(out_tile.shape, out_tile.dtype) + smem_shape = Union([compute_scratch_shape, epilogue_scratch_shape]) + + def _main(ctx, a_device, b_device, c_device, start_sem, done_sem, c_dev_alias, start_sem_alias, done_sem_alias, smem): + ((lhs_smem, rhs_smem, impl_smem), epilogue_smem), tma_barriers = smem + + i64_ty = ir.IntegerType.get_signless(64) + index = ir.IndexType.get() + + cta = gpu.block_id(gpu.Dimension.x) + cta_idx = arith.index_cast(i64_ty,gpu.block_id(gpu.Dimension.x)) + + # sync to begin the kernel + bsem_uc_memref = memref_slice( + start_sem, arith.index_cast(ir.IndexType.get(), cta_idx) + ) + bsem_mc_ptr = mgpu_utils.to_mc_ptr( + bsem_uc_memref, + team_id=0, + ) + bsem_uc_ptr = mgpu.utils.memref_ptr( + memref_slice(start_sem, arith.index_cast(ir.IndexType.get(), cta_idx)) + ) + with ctx.named_region("sync to begin"): + mgpu_utils.warpgroup_barrier() + with single_thread(per_block=True): + mgpu_utils.signal_with_red(bsem_mc_ptr, is_relaxed=True) + mgpu_utils.wait_loop(bsem_uc_ptr, num_gpus, is_relaxed=True) + @worker_for(worker_id=cta, worker_count=mgpu_core.c(cta_count, index), work_count=mgpu_core.c(tile_count, index)) + def body(work_id): + m_idx = arith.divui(work_id, mgpu_core.c(n_tile_count, index)) + n_idx = arith.remui(work_id, mgpu_core.c(n_tile_count, index)) + m_start = arith.muli(m_idx, mgpu_core.c(tile_m, index)) + n_start = arith.muli(n_idx, mgpu_core.c(tile_n, index)) + + def fetch(slot, ki): + barrier = tma_barriers[slot] + k_start = arith.muli(c(block_tiling.k), ki) + lhs_tma_tile_bytes = int(np.prod(block_tiling.mk) * lhs_elem_bytes) + rhs_tma_tile_bytes = int(np.prod(block_tiling.kn) * rhs_elem_bytes) + txcount = lhs_tma_tile_bytes + rhs_tma_tile_bytes + common_copy_args = dict( + swizzle=swizzle, barrier=barrier, arrive=False, uniform=False, + ) + with single_thread(): + barrier.arrive_expect_tx(txcount) + ctx.async_copy( + src_ref=a_device, + dst_ref=memref_slice(lhs_smem, slot), + gmem_slice=(ds(m_start, block_tiling.m), ds(k_start, block_tiling.k)), + gmem_transform=TileTransform(tma_tiling.mk), + collective=(gpu.Dimension.x, gpu.Dimension.z), + **common_copy_args, + ) + rhs_slice = (ds(k_start, block_tiling.k), ds(n_start, block_tiling.n)) + rhs_transform = (TileTransform(tma_tiling.kn),) + if rhs_transpose: + rhs_slice = rhs_slice[::-1] + rhs_transform += (TransposeTransform((1, 0, 2, 3)),) + assert tma_tiling.n == tma_tiling.k, block_tiling # No need to flip the tiling. + ctx.async_copy( + src_ref=b_device, + dst_ref=memref_slice(rhs_smem, slot), + gmem_slice=rhs_slice, + gmem_transform=rhs_transform, + collective=gpu.Dimension.y, + **common_copy_args, + ) + + accs = wgmma_impl.zero_accs(block_tiling.m, block_tiling.n) + + with ctx.named_region("TMA warmup"): + for i in range(stages): + fetch(c(i), c(i)) + + @fori(c(k_steps), accs) + def stage_loop_body(ki, accs): + si = arith.remui(ki, c(stages)) + + with ctx.named_region("TMA wait"): + tma_barriers[si].wait() + + with ctx.named_region("WGMMA"): + a_slice = memref_slice(lhs_smem, si) + b_slice = memref_slice(rhs_smem, si) + if rhs_transpose: + b_slice = memref_transpose(b_slice, (0, 1, 3, 2)) + accs = wgmma_impl.wgmma( + impl_smem, accs, a_slice, b_slice, swizzle=swizzle + ) + + with ctx.named_region("TMA start"): + tma_ki = arith.addi(ki, c(stages - 1)) + tma_si = arith.remui(tma_ki, c(stages)) + not_first_step = arith.cmpi(arith.CmpIPredicate.ne, ki, c(0)) + tma_ki_in_bounds = arith.cmpi( + arith.CmpIPredicate.slt, tma_ki, c(k_steps) + ) + do_tma = arith.andi(not_first_step, tma_ki_in_bounds) + with ir.InsertionPoint(scf.IfOp(do_tma).then_block): + fetch(tma_si, tma_ki) + scf.yield_([]) + + return accs + + # Wait until WGMMA is complete and we can safely read the accumulator. + with ctx.named_region("WGMMA drain"): + nvvm.wgmma_wait_group_sync_aligned(0) + + with ctx.named_region("SMEM store"): + acc_val = wgmma_impl.get_result(stage_loop_body.result) + acc_val.astype(out_mlir_dtype).store_tiled(epilogue_smem, swizzle=out_swizzle) + commit_shared() # Make sure the stores are visible to TMA. + + with ctx.named_region("GMEM store"): + ctx.async_copy( + src_ref=epilogue_smem, + dst_ref=c_device, + gmem_slice=(ds(m_start, tile_m), ds(n_start, tile_n)), + gmem_transform=TileTransform(out_tiling), + swizzle=out_swizzle, + reduction_op="add", + team_id=0, + ) + ctx.await_async_copy(0) + + # sync to end the kernel + sem_uc_memref = memref_slice( + done_sem, arith.index_cast(ir.IndexType.get(), cta_idx) + ) + sem_mc_ptr = mgpu_utils.to_mc_ptr( + sem_uc_memref, + team_id=0, + ) + sem_uc_ptr = mgpu.utils.memref_ptr( + memref_slice(done_sem, arith.index_cast(ir.IndexType.get(), cta_idx)) + ) + + with ctx.named_region("sync to end kernel"): + mgpu_utils.warpgroup_barrier() + with single_thread(per_block=True): + mgpu_utils.signal_with_red(sem_mc_ptr) + mgpu_utils.wait_loop(sem_uc_ptr, num_gpus) + + return as_gpu_kernel( + _main, + (cta_count, 1, 1), + block, + ( + jax.ShapeDtypeStruct((m, k), lhs_dtype), + jax.ShapeDtypeStruct((n, k) if rhs_transpose else (k, n), rhs_dtype), + jax.ShapeDtypeStruct((m, n), out_dtype), + jax.ShapeDtypeStruct((cta_count,), jnp.int32), + jax.ShapeDtypeStruct((cta_count,), jnp.int32), + ), + ( + jax.ShapeDtypeStruct((m, n), out_dtype), + jax.ShapeDtypeStruct((cta_count,), jnp.int32), + jax.ShapeDtypeStruct((cta_count,), jnp.int32), + ), + ( + smem_shape, + TMABarrier(num_barriers=stages), + ), + profiler_spec, + input_output_aliases=((2,0),(3,1),(4,2),), + ) + + +def verify(x, y, actual_output): + dimension_numbers = (((1,), (0,)), ((), ())) + def lax_dot_general_psum(x, y): + matmul_result = jax.lax.dot_general( + x, + y, + dimension_numbers=dimension_numbers, + preferred_element_type=dtype, + ) + #Sum the result from all devices along the "x" axis + ar_result = jax.lax.psum(matmul_result, axis_name='x') + return ar_result.astype(dtype) + + jitted_ref_f = jax.jit( + shard_map.shard_map( + lax_dot_general_psum, + mesh=mesh, + in_specs=(P(None, 'x'), P('x', None)), + out_specs=P(None, None), + ) + ) + + desired_output = jitted_ref_f(x, y) + np.testing.assert_allclose( + actual_output.astype(jnp.float32), desired_output.astype(jnp.float32), atol=1e-3, rtol=1e-3 + ) + + +if __name__ == "__main__": + jax.distributed.initialize() + num_gpus = jax.device_count() + assert num_gpus == 8, f"Expected 8 devices, but got {num_gpus}." + devices = mesh_utils.create_device_mesh((num_gpus,)) + mesh = jax.sharding.Mesh(devices, ("x",)) + sharding_x = jax.sharding.NamedSharding(mesh, P(None, 'x')) + sharding_y = jax.sharding.NamedSharding(mesh, P('x', None)) + + dtype = jnp.dtype(jnp.float16) + m, k, n = 64, 32768, 8192 + kx, ky = random.split(random.key(1234)) + x = random.uniform(kx, (m, k), dtype=dtype) * 0.001 + x = jax.device_put(x, sharding_x) + y = random.uniform(ky, (k, n), dtype=dtype) * 0.001 + y = jax.device_put(y, sharding_y) + assert k % 8 == 0, f"Expected k to be divisible by {num_gpus} got {k}." + local_k = k//8 + + tile_m = tile_n = (64, 128) + swizzle = (128,) + stages = (2, 4, 5, 6) + cta_count = mosaic_gpu_lib._mosaic_gpu_ext._get_gpu_sm_count() + configs = itertools.product(tile_m, tile_n, stages, swizzle) + names = ("tile_m", "tile_n", "stages", "swizzle") + best_runtime = float("inf") + best_kwargs = {} + + for config in configs: + kwargs = dict(zip(names, config)) + if n < kwargs["tile_n"]: + continue + try: + @partial(jax.jit, static_argnums=(2, 3, 4, 5, 6, 7, 8, 9, 10, 11)) + def gemm_ar(inp_a, inp_b, dtype, m, n, k, cta_count, num_gpus, stages, tile_m, tile_n, swizzle): + + gemm_ar_kernel = build_kernel( + m, n, k, dtype, dtype, dtype, cta_count = cta_count, num_gpus=num_gpus, stages=stages, + tile_m=tile_m, tile_n=tile_n, swizzle=swizzle, wgmma_impl=WGMMADefaultImpl, + ) + + sharded_gemm_ar_kernel = shard_map.shard_map( + gemm_ar_kernel, + mesh=mesh, + in_specs=(P(None, 'x'), P('x', None), P(None, None), P(None,), P(None,)), + out_specs=P(None, None), + check_rep=False, + ) + + def kernel_call(i, init_val): + z = init_val + z = jnp.zeros(m * n,dtype=dtype).reshape(m, n) + sem_in = jnp.zeros(cta_count,dtype=jnp.int32).reshape(cta_count,) + sem_out = jnp.zeros(cta_count,dtype=jnp.int32).reshape(cta_count,) + z, sem_in, sem_out = sharded_gemm_ar_kernel(inp_a,inp_b, z, sem_in, sem_out) + return z + + z = jnp.zeros(m * n,dtype=dtype).reshape(m, n) + z = lax.fori_loop(0, 5, kernel_call, (z)) + return z + + # warm up call + jax.experimental.multihost_utils.sync_global_devices('barrier') + z = jax.block_until_ready(gemm_ar( + x, y, dtype, m, n, local_k, cta_count, num_gpus, kwargs["stages"], + kwargs["tile_m"], kwargs["tile_n"], kwargs["swizzle"], + ) + ) + jax.experimental.multihost_utils.sync_global_devices('barrier') + + # profile gemm+ar kernel + jax.experimental.multihost_utils.sync_global_devices('barrier') + (z), mgpu_runtime_ms = profiler.measure( + gemm_ar,mode='cupti',aggregate=False)( + x, y, dtype, m, n, local_k, cta_count, num_gpus, kwargs["stages"], + kwargs["tile_m"],kwargs["tile_n"], kwargs["swizzle"] + ) + jax.experimental.multihost_utils.sync_global_devices('barrier') + last_mosaic_gpu_runtime = None + for det_runtime in mgpu_runtime_ms: + if 'mosaic_gpu__main_kernel' in det_runtime[0]: + last_mosaic_gpu_runtime = det_runtime[1] + if last_mosaic_gpu_runtime < best_runtime: + best_runtime = last_mosaic_gpu_runtime + best_kwargs = kwargs + verify(x,y,z) + except ValueError as e: + if "Mosaic GPU kernel exceeds available shared memory" not in str(e): + raise + print("Best parameters for GEMM+AR: ", " ".join(f"{k}={v}" for k, v in best_kwargs.items())) + print(f"GEMM+AR mosaic-gpu kernel time={best_runtime:.4f}") diff --git a/jax/experimental/mosaic/gpu/examples/gemm_ar_two_shot.py b/jax/experimental/mosaic/gpu/examples/gemm_ar_two_shot.py new file mode 100644 index 000000000000..67f7fd2e9e8a --- /dev/null +++ b/jax/experimental/mosaic/gpu/examples/gemm_ar_two_shot.py @@ -0,0 +1,530 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""simple two-shot distributed persistent GEMM overlapped with all reduce for H100.""" + +import dataclasses +import itertools +from typing import Any +from functools import partial + +import jax +from jax import random +from jax import lax +from jax._src import test_util as jtu # noqa: F401 +from jax._src.interpreters import mlir +import jax._src.lib.mosaic_gpu as mosaic_gpu_lib +from jax.experimental import mesh_utils, shard_map +import jax.experimental.mosaic.gpu as mgpu +from jax.experimental.mosaic.gpu import profiler +from jax.experimental.mosaic.gpu import * # noqa: F403 +from jax.experimental.mosaic.gpu import utils as mgpu_utils +from jax.experimental.mosaic.gpu import core as mgpu_core +import jax.numpy as jnp +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import arith +from jaxlib.mlir.dialects import gpu +from jaxlib.mlir.dialects import nvvm +from jaxlib.mlir.dialects import scf +import numpy as np + +# mypy: ignore-errors +# ruff: noqa: F405 +# pylint: disable=line-too-long, wildcard-import, missing-function-docstring, bad-continuation, g-bad-todo, protected-access + +SmemRef = ir.Value + +P = jax.sharding.PartitionSpec + + +@dataclasses.dataclass(frozen=True) +class Tiling: + m: int + n: int + k: int + + # Allow access by .mk, .kn, .mn, etc. + def __getattr__(self, name): + if len(name) == 1: + return super().__getattribute__(name) + return tuple(getattr(self, d) for d in name) + + +class WGMMADefaultImpl: + """Default WGMMA implementation. + + The kernel can accept any class that satisfies the same interface as this + class. + """ + + @staticmethod + def zero_accs(tile_m: int, tile_n: int) -> WGMMAAccumulator: + return WGMMAAccumulator.zero(tile_m, tile_n) + + @staticmethod + def smem_shape_extra( + block_tiling: Tiling, + tma_tiling: Tiling, + lhs_dtype: jnp.dtype, rhs_dtype: jnp.dtype, + rhs_transpose: bool, + ) -> dict[str, jax.ShapeDtypeStruct]: + del block_tiling, tma_tiling, lhs_dtype, rhs_dtype, rhs_transpose # Unused. + return () + + @staticmethod + def get_result(acc: WGMMAAccumulator) -> FragmentedArray: + return acc.value + + @staticmethod + def wgmma( + smem_scratch: Any, # pylint: disable=unused-argument + acc: WGMMAAccumulator, + a_slice: SmemRef, + b_slice: SmemRef, + swizzle: int, + ) -> dict[str, WGMMAAccumulator]: + """Perform a matrix multiplication. + + This function must guarantee that all WGMMA operations queued before it was + called have completed before returning. + """ + acc = wgmma(acc, a_slice, b_slice, swizzle=swizzle) + nvvm.wgmma_commit_group_sync_aligned() + nvvm.wgmma_wait_group_sync_aligned(1) + return acc + + +def mlir_context(f): + def wrap(*args, **kw): + with mlir.make_ir_context(), ir.Location.unknown(): + return f(*args, **kw) + + return wrap + + +def worker_for(worker_id, worker_count, work_count, init_state): + def wrapper(f): + for_op = scf.ForOp(worker_id, work_count, worker_count, [init_state]) + with ir.InsertionPoint(for_op.body): + new_state = f(for_op.induction_variable, for_op.inner_iter_args[0]) + scf.yield_([new_state]) + return wrapper + + +@mlir_context +def build_kernel( + m, n, k, + lhs_dtype, rhs_dtype, out_dtype, + cta_count: int = 132, # 132 SMs for Hopper + num_gpus: int = 8, + stages: int = 2, + tile_m: int = 128, + tile_n: int = 128, + swizzle: int = 128, + rhs_transpose: bool = False, + wgmma_impl=WGMMADefaultImpl, + profiler_spec: profiler.ProfilerSpec | None = None, +): + if tile_m % 64 != 0: + raise ValueError(f"{tile_m=} must be divisible by 64") + if n % tile_n != 0: + raise ValueError(f"{n=} must be divisible by {tile_n=}") + if stages < 2: + raise ValueError(f"Need at least 2 stages, but got {stages=}") + if not rhs_transpose and jnp.dtype(rhs_dtype).itemsize != 2: + raise ValueError(f"Transpose only supported for 16bit types (got: {rhs_transpose=}, {rhs_dtype=})") + if swizzle not in {32, 64, 128}: + raise ValueError(f"swizzle must be 32, 64, or 128, but got {swizzle=}") + + out_mlir_dtype = mlir.dtype_to_ir_type(out_dtype) + out_swizzle = swizzle + if bytewidth(out_mlir_dtype) == 4: + if tile_n % 32 == 0: + out_swizzle = 128 + elif tile_n % 16 == 0: + out_swizzle = 64 + else: + raise NotImplementedError( + f"{tile_n=} must by divisible by 16 for 32-bit output" + ) + out_swizzle_elems = out_swizzle // bytewidth(out_mlir_dtype) + out_tiling = (64, out_swizzle_elems) + out_tile = jax.ShapeDtypeStruct(tile_shape((tile_m, tile_n), out_tiling), out_dtype) + + lhs_elem_bytes = bytewidth(mlir.dtype_to_ir_type(lhs_dtype)) + rhs_elem_bytes = bytewidth(mlir.dtype_to_ir_type(rhs_dtype)) + lhs_swizzle_elems = swizzle // lhs_elem_bytes + rhs_swizzle_elems = swizzle // rhs_elem_bytes + tile_k = max(lhs_swizzle_elems, rhs_swizzle_elems) + + if tile_n % rhs_swizzle_elems != 0: + raise ValueError( + f"{tile_n=} must be divisible by {swizzle} bytes =" + f" {((lhs_swizzle_elems, lhs_dtype), (rhs_swizzle_elems, rhs_dtype))}" + ) + + if k % tile_k != 0: + raise ValueError(f"k must be divisible by {tile_k=}, but got {k=}") + + block_tiling = Tiling(m=tile_m, n=tile_n, k=tile_k) + tma_tiling = Tiling(m=64, n=rhs_swizzle_elems, k=lhs_swizzle_elems) + k_steps = k // block_tiling.k + stages = min(stages, k_steps) + + def safe_div(x, y): + assert x % y == 0, (x, y) + return x // y + + def partial_tile_div(x, y): + if x % y == 0: + return x // y + else: + return (x // y) + 1 + + m_tile_count = partial_tile_div(m, block_tiling.m) + n_tile_count = n // tile_n + tile_count = m_tile_count * n_tile_count + + block = (128, 1, 1) + + c = arith.ConstantOp.create_index + + compute_scratch_shape = ( + jax.ShapeDtypeStruct((stages, *tile_shape(block_tiling.mk, tma_tiling.mk)), lhs_dtype), + jax.ShapeDtypeStruct((stages, *tile_shape(block_tiling.kn, tma_tiling.kn)), rhs_dtype), + wgmma_impl.smem_shape_extra(block_tiling, tma_tiling, lhs_dtype, rhs_dtype, rhs_transpose), + ) + epilogue_scratch_shape = jax.ShapeDtypeStruct(out_tile.shape, out_tile.dtype) + smem_shape = Union([compute_scratch_shape, epilogue_scratch_shape]) + + def _main(ctx, a_device, b_device, c_device, start_sem, done_sem, c_dev_alias, start_sem_alias, done_sem_alias, smem): + ((lhs_smem, rhs_smem, impl_smem), epilogue_smem), tma_barriers = smem + + i64_ty = ir.IntegerType.get_signless(64) + i32_ty = ir.IntegerType.get_signless(32) + index = ir.IndexType.get() + + cta = gpu.block_id(gpu.Dimension.x) + cta_idx = arith.index_cast(i64_ty,gpu.block_id(gpu.Dimension.x)) + sync_flip_flop = mgpu_core.c(0, ir.IndexType.get()) + + @worker_for(worker_id=cta, worker_count=mgpu_core.c(cta_count, index), work_count=mgpu_core.c(tile_count, index), init_state=sync_flip_flop) + def body(work_id, sync_flip_flop): + m_idx = arith.divui(work_id, mgpu_core.c(n_tile_count, index)) + n_idx = arith.remui(work_id, mgpu_core.c(n_tile_count, index)) + m_start = arith.muli(m_idx, mgpu_core.c(tile_m, index)) + n_start = arith.muli(n_idx, mgpu_core.c(tile_n, index)) + + def fetch(slot, ki): + barrier = tma_barriers[slot] + k_start = arith.muli(c(block_tiling.k), ki) + lhs_tma_tile_bytes = int(np.prod(block_tiling.mk) * lhs_elem_bytes) + rhs_tma_tile_bytes = int(np.prod(block_tiling.kn) * rhs_elem_bytes) + txcount = lhs_tma_tile_bytes + rhs_tma_tile_bytes + common_copy_args = dict( + swizzle=swizzle, barrier=barrier, arrive=False, uniform=False, + ) + with single_thread(): + barrier.arrive_expect_tx(txcount) + ctx.async_copy( + src_ref=a_device, + dst_ref=memref_slice(lhs_smem, slot), + gmem_slice=(ds(m_start, block_tiling.m), ds(k_start, block_tiling.k)), + gmem_transform=TileTransform(tma_tiling.mk), + collective=(gpu.Dimension.x, gpu.Dimension.z), + **common_copy_args, + ) + rhs_slice = (ds(k_start, block_tiling.k), ds(n_start, block_tiling.n)) + rhs_transform = (TileTransform(tma_tiling.kn),) + if rhs_transpose: + rhs_slice = rhs_slice[::-1] + rhs_transform += (TransposeTransform((1, 0, 2, 3)),) + assert tma_tiling.n == tma_tiling.k, block_tiling # No need to flip the tiling. + ctx.async_copy( + src_ref=b_device, + dst_ref=memref_slice(rhs_smem, slot), + gmem_slice=rhs_slice, + gmem_transform=rhs_transform, + collective=gpu.Dimension.y, + **common_copy_args, + ) + + accs = wgmma_impl.zero_accs(block_tiling.m, block_tiling.n) + + with ctx.named_region("TMA warmup"): + for i in range(stages): + fetch(c(i), c(i)) + + @fori(c(k_steps), accs) + def stage_loop_body(ki, accs): + si = arith.remui(ki, c(stages)) + + with ctx.named_region("TMA wait"): + tma_barriers[si].wait() + + with ctx.named_region("WGMMA"): + a_slice = memref_slice(lhs_smem, si) + b_slice = memref_slice(rhs_smem, si) + if rhs_transpose: + b_slice = memref_transpose(b_slice, (0, 1, 3, 2)) + accs = wgmma_impl.wgmma( + impl_smem, accs, a_slice, b_slice, swizzle=swizzle + ) + + with ctx.named_region("TMA start"): + tma_ki = arith.addi(ki, c(stages - 1)) + tma_si = arith.remui(tma_ki, c(stages)) + not_first_step = arith.cmpi(arith.CmpIPredicate.ne, ki, c(0)) + tma_ki_in_bounds = arith.cmpi( + arith.CmpIPredicate.slt, tma_ki, c(k_steps) + ) + do_tma = arith.andi(not_first_step, tma_ki_in_bounds) + with ir.InsertionPoint(scf.IfOp(do_tma).then_block): + fetch(tma_si, tma_ki) + scf.yield_([]) + + return accs + + # Wait until WGMMA is complete and we can safely read the accumulator. + with ctx.named_region("WGMMA drain"): + nvvm.wgmma_wait_group_sync_aligned(0) + + with ctx.named_region("SMEM store"): + acc_val = wgmma_impl.get_result(stage_loop_body.result) + acc_val.astype(out_mlir_dtype).store_tiled(epilogue_smem, swizzle=out_swizzle) + commit_shared() # Make sure the stores are visible to TMA. + + with ctx.named_region("GMEM store"): + ctx.async_copy( + src_ref=epilogue_smem, + dst_ref=c_device, + gmem_slice=(ds(m_start, tile_m), ds(n_start, tile_n)), + gmem_transform=TileTransform(out_tiling), + swizzle=out_swizzle, + ) + # ensure all write are done + ctx.await_async_copy(0) + utils.global_membar() + + # sync all blocks + bsem_uc_memref = memref_slice( + start_sem, arith.addi( + arith.index_cast(ir.IndexType.get(), cta_idx), + arith.muli(sync_flip_flop, mgpu_core.c(cta_count, ir.IndexType.get())) + ) + ) + bsem_mc_ptr = ctx.to_remote_mc_ptr( + bsem_uc_memref, + team=mgpu.utils.c(0, i32_ty), + ) + bsem_uc_ptr = mgpu.utils.memref_ptr( + memref_slice(start_sem, arith.index_cast(ir.IndexType.get(), cta_idx)) + ) + with ctx.named_region("sync all blocks before reduction"): + mgpu_utils.warpgroup_barrier() + with single_thread(scope=ThreadSubset.BLOCK): + mgpu_utils.signal_with_red(bsem_mc_ptr, is_relaxed=True) + mgpu_utils.wait_loop(bsem_uc_ptr, num_gpus, is_relaxed=True) + sync_flip_flop = arith.xori(sync_flip_flop, mgpu_core.c(1, ir.IndexType.get())) + # multimem load reduce and multimem store + with ctx.named_region("ld red and st"): + mgpu_utils.warpgroup_barrier() + device_idx = arith.index_cast(index, ctx.device_id()) + num_red_elements = mgpu.utils.c(8, index) + world_size = mgpu.utils.c(8, index) + num_m = arith.minui(mgpu.utils.c(m, index), mgpu.utils.c(tile_m, index)) + num_rows_per_gpu = arith.ceildivui(num_m, world_size) + num_threads_per_gpu = arith.divui(mgpu_core.c(tile_n, index), num_red_elements) + thread_idx = arith.index_cast(index, utils.thread_idx()) + if_in_bound = scf.IfOp(arith.cmpi(arith.CmpIPredicate.ult, thread_idx, num_threads_per_gpu),hasElse=False) + with ir.InsertionPoint(if_in_bound.then_block): + for_op = scf.ForOp(mgpu.utils.c(0, index),num_rows_per_gpu,mgpu.utils.c(1, index)) + with ir.InsertionPoint(for_op.body): + m_offset = arith.addi(arith.muli(for_op.induction_variable,world_size), device_idx) + if_in_bound_m = scf.IfOp(arith.cmpi(arith.CmpIPredicate.ult, m_offset, num_m),hasElse=False) + with ir.InsertionPoint(if_in_bound_m.then_block): + n_offset = arith.muli(thread_idx, num_red_elements) + m_idx = arith.addi(m_start, m_offset) + n_idx = arith.addi(n_start, n_offset) + uc_memref = memref_slice(c_device, (ds(m_idx, 1), ds(n_idx, 1))) + mc_ptr = ctx.to_remote_mc_ptr(uc_memref,team=mgpu.utils.c(0, i32_ty)) + x, y, z, w = utils.multimem_ld_reduce_128(mc_ptr) + utils.multimem_st_128(mc_ptr,x, y, z, w) + scf.yield_([]) + scf.yield_([]) + scf.yield_([]) + return sync_flip_flop + + # sync to end the kernel + sem_uc_memref = memref_slice( + done_sem, arith.index_cast(ir.IndexType.get(), cta_idx) + ) + sem_mc_ptr = ctx.to_remote_mc_ptr( + sem_uc_memref, + team=mgpu.utils.c(0, i32_ty), + ) + sem_uc_ptr = mgpu.utils.memref_ptr( + memref_slice(done_sem, arith.index_cast(ir.IndexType.get(), cta_idx)) + ) + + with ctx.named_region("sync to end kernel"): + mgpu_utils.warpgroup_barrier() + with single_thread(scope=ThreadSubset.BLOCK): + mgpu_utils.signal_with_red(sem_mc_ptr) + mgpu_utils.wait_loop(sem_uc_ptr, num_gpus) + + return as_gpu_kernel( + _main, + (cta_count, 1, 1), + block, + ( + jax.ShapeDtypeStruct((m, k), lhs_dtype), + jax.ShapeDtypeStruct((n, k) if rhs_transpose else (k, n), rhs_dtype), + jax.ShapeDtypeStruct((m, n), out_dtype), + jax.ShapeDtypeStruct((2*cta_count,), jnp.int32), + jax.ShapeDtypeStruct((cta_count,), jnp.int32), + ), + ( + jax.ShapeDtypeStruct((m, n), out_dtype), + jax.ShapeDtypeStruct((cta_count,), jnp.int32), + jax.ShapeDtypeStruct((cta_count,), jnp.int32), + ), + ( + smem_shape, + TMABarrier(num_barriers=stages), + ), + profiler_spec, + input_output_aliases=((2,0),(3,1),(4,2),), + ) + + +def verify(x, y, actual_output): + dimension_numbers = (((1,), (0,)), ((), ())) + def lax_dot_general_psum(x, y): + matmul_result = jax.lax.dot_general( + x, + y, + dimension_numbers=dimension_numbers, + preferred_element_type=dtype, + ) + #Sum the result from all devices along the "x" axis + ar_result = jax.lax.psum(matmul_result, axis_name='x') + return ar_result.astype(dtype) + + jitted_ref_f = jax.jit( + shard_map.shard_map( + lax_dot_general_psum, + mesh=mesh, + in_specs=(P(None, 'x'), P('x', None)), + out_specs=P(None, None), + ) + ) + + desired_output = jitted_ref_f(x, y) + np.testing.assert_allclose( + actual_output.astype(jnp.float32), desired_output.astype(jnp.float32), atol=1e-3, rtol=1e-3 + ) + + +if __name__ == "__main__": + jax.distributed.initialize() + num_gpus = jax.device_count() + assert num_gpus == 8, f"Expected 8 devices, but got {num_gpus}." + devices = mesh_utils.create_device_mesh((num_gpus,)) + mesh = jax.sharding.Mesh(devices, ("x",)) + sharding_x = jax.sharding.NamedSharding(mesh, P(None, 'x')) + sharding_y = jax.sharding.NamedSharding(mesh, P('x', None)) + + dtype = jnp.dtype(jnp.float16) + m, k, n = 64, 32768, 8192 + kx, ky = random.split(random.key(1234)) + x = random.uniform(kx, (m, k), dtype=dtype) * 0.001 + x = jax.device_put(x, sharding_x) + y = random.uniform(ky, (k, n), dtype=dtype) * 0.001 + y = jax.device_put(y, sharding_y) + assert k % 8 == 0, f"Expected k to be divisible by {num_gpus} got {k}." + local_k = k//8 + + tile_m = tile_n = (64, 128) + swizzle = (128,) + stages = (2, 4, 5, 6) + cta_count = mosaic_gpu_lib._mosaic_gpu_ext._get_gpu_sm_count() + configs = itertools.product(tile_m, tile_n, stages, swizzle) + names = ("tile_m", "tile_n", "stages", "swizzle") + best_runtime = float("inf") + best_kwargs = {} + + for config in configs: + kwargs = dict(zip(names, config)) + if n < kwargs["tile_n"]: + continue + try: + @partial(jax.jit, static_argnums=(2, 3, 4, 5, 6, 7, 8, 9, 10, 11)) + def gemm_ar(inp_a, inp_b, dtype, m, n, k, cta_count, num_gpus, stages, tile_m, tile_n, swizzle): + + gemm_ar_kernel = build_kernel( + m, n, k, dtype, dtype, dtype, cta_count = cta_count, num_gpus=num_gpus, stages=stages, + tile_m=tile_m, tile_n=tile_n, swizzle=swizzle, wgmma_impl=WGMMADefaultImpl, + ) + + sharded_gemm_ar_kernel = shard_map.shard_map( + gemm_ar_kernel, + mesh=mesh, + in_specs=(P(None, 'x'), P('x', None), P(None, None), P(None,), P(None,)), + out_specs=P(None, None), + check_rep=False, + ) + + def kernel_call(i, init_val): + z = init_val + z = jnp.zeros(m * n,dtype=dtype).reshape(m, n) + sem_in = jnp.zeros(cta_count,dtype=jnp.int32).reshape(cta_count,) + sem_out = jnp.zeros(cta_count,dtype=jnp.int32).reshape(cta_count,) + z, sem_in, sem_out = sharded_gemm_ar_kernel(inp_a,inp_b, z, sem_in, sem_out) + return z + + z = jnp.zeros(m * n,dtype=dtype).reshape(m, n) + z = lax.fori_loop(0, 5, kernel_call, (z)) + return z + + # warm up call + jax.experimental.multihost_utils.sync_global_devices('barrier') + z = jax.block_until_ready(gemm_ar( + x, y, dtype, m, n, local_k, cta_count, num_gpus, kwargs["stages"], + kwargs["tile_m"], kwargs["tile_n"], kwargs["swizzle"], + ) + ) + jax.experimental.multihost_utils.sync_global_devices('barrier') + + # profile gemm+ar kernel + jax.experimental.multihost_utils.sync_global_devices('barrier') + (z), mgpu_runtime_ms = profiler.measure( + gemm_ar,mode='cupti',aggregate=False)( + x, y, dtype, m, n, local_k, cta_count, num_gpus, kwargs["stages"], + kwargs["tile_m"],kwargs["tile_n"], kwargs["swizzle"] + ) + jax.experimental.multihost_utils.sync_global_devices('barrier') + last_mosaic_gpu_runtime = None + for det_runtime in mgpu_runtime_ms: + if 'mosaic_gpu__main_kernel' in det_runtime[0]: + last_mosaic_gpu_runtime = det_runtime[1] + if last_mosaic_gpu_runtime < best_runtime: + best_runtime = last_mosaic_gpu_runtime + best_kwargs = kwargs + verify(x,y,z) + except ValueError as e: + if "Mosaic GPU kernel exceeds available shared memory" not in str(e): + raise + print("Best parameters for GEMM+AR: ", " ".join(f"{k}={v}" for k, v in best_kwargs.items())) + print(f"GEMM+AR mosaic-gpu kernel time={best_runtime:.4f}") diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index d169c448a80e..fe1f68f4f933 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -401,10 +401,12 @@ def _get_tma_desc( reduction_op: Literal[ "add","min","max","inc","dec","and","or","xor" ] | None, + team_id, ): - tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform) + tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform, team_id) if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None: i64 = ir.IntegerType.get_signless(64) + i32 = ir.IntegerType.get_signless(32) ptr_ty = ir.Type.parse("!llvm.ptr") def init_tma_desc(host_ptr): ref = gmem_ref @@ -474,6 +476,7 @@ def init_tma_desc(host_ptr): utils.pack_array([as_i64(i) for i in sizes_and_strides[rank:]]), c(swizzle_arg, i64), utils.pack_array([c(v, i64) for v in transformed_slice_shape]), + c(team_id, i32), ] func.call([], "mosaic_gpu_init_tma_desc", args) def cast_tma_desc(device_ptr): @@ -506,6 +509,7 @@ def async_copy( ir.Value | None ) = None, # Should select 0 or 1 threads from the WG. reduction_op: ReductionOp | None = None, + team_id: int | None = None, ): """Initiates an async copy between GMEM and SMEM. @@ -734,9 +738,14 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): else: multicast_mask = None - tma_desc = self._get_tma_desc( - gmem_ref, gmem_transform, tuple(slice_shape), swizzle, reduction_op, - ) + if team_id is None: + tma_desc = self._get_tma_desc( + gmem_ref, gmem_transform, tuple(slice_shape), swizzle, reduction_op, -1, + ) + else: + tma_desc = self._get_tma_desc( + gmem_ref, gmem_transform, tuple(slice_shape), swizzle, reduction_op, team_id, + ) # We constuct TMA descriptors in column-major order. rev_dyn_base_indices = [ @@ -861,6 +870,10 @@ def _ensure_nvshmem_decls(self): ir.Type.parse("!llvm.func") ) llvm.LLVMFuncOp("nvshmem_ptr", nvshmem_ptr_type, sym_visibility="private") + nvshmemx_mc_ptr_type = ir.TypeAttr.get( + ir.Type.parse(f'!llvm.func') + ) + llvm.LLVMFuncOp("nvshmemx_mc_ptr", nvshmemx_mc_ptr_type, sym_visibility="private") def to_remote(self, ref: ir.Value, peer: ir.Value): self._ensure_nvshmem_decls() @@ -874,6 +887,11 @@ def to_remote(self, ref: ir.Value, peer: ir.Value): raise ValueError(f"peer index must be an i32, got {peer.type}") return llvm.call(ref.type, [ref, peer], [], [], callee="nvshmem_ptr") + def to_remote_mc_ptr(self, ref, team): + self._ensure_nvshmem_decls() + ref_ptr = utils.memref_ptr(ref) + return llvm.call(ir.Type.parse("!llvm.ptr"), [team, ref_ptr], [], [], callee="nvshmemx_mc_ptr") + def device_id(self) -> ir.Value: self._ensure_nvshmem_decls() i32 = ir.IntegerType.get_signless(32) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index bd11c3a07544..4b2a308a2850 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -1365,3 +1365,80 @@ def vector_concat(vectors: Sequence[ir.Value]) -> ir.Value: result = vector.insertelement(elem, result, position=c(offset + i, index)) offset += vty.shape[0] return result + + +def signal_with_red(mc_ptr, is_relaxed=False): + mode = "relaxed" if is_relaxed else "release" + asm_instr = f""" + {{ + multimem.red.{mode}.sys.global.add.u32 [$0], 1; + fence.proxy.alias; + }}""" + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [mc_ptr], + asm_instr, + "l", + has_side_effects=True, + asm_dialect=0, + ) + + +def wait_loop(uc_ptr, num_gpus=8, is_relaxed=False): + mode = "relaxed" if is_relaxed else "acquire" + asm_instr = f""" + {{ + .reg .u32 %tmp32_<1>; + .reg .pred %p<1>; + + wait_signal: + atom.global.sys.{mode}.cas.b32 %tmp32_0, [$0], {num_gpus}, 0; + setp.eq.u32 %p0, %tmp32_0, 8; + @!%p0 bra wait_signal; + }}""" + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [uc_ptr], + asm_instr, + "l", + has_side_effects=True, + asm_dialect=0, + ) + + +def global_membar(): + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [], + "membar.gl;", + "", + has_side_effects=True, + ) + + +def multimem_ld_reduce_128(mc_ptr): + i32 = ir.IntegerType.get_signless(32) + return_struct = llvm.inline_asm( + ir.Type.parse("!llvm.struct<(i32,i32,i32,i32)>"), + [mc_ptr], + "multimem.ld_reduce.relaxed.sys.global.add.v4.f16x2 {$0, $1, $2, $3}, [$4];", + "=r,=r,=r,=r,l", + has_side_effects=True, + asm_dialect=0, + ) + return_regs = [ + llvm.extractvalue(i32, return_struct, [i]) for i in range(4) + ] + return return_regs[0], return_regs[1], return_regs[2], return_regs[3] + + +def multimem_st_128(mc_ptr, x, y, z, w): + i32 = ir.IntegerType.get_signless(32) + llvm.inline_asm( + i32, + [mc_ptr, x, y, z, w], + "multimem.st.relaxed.sys.global.v4.f32 [$1], {$2, $3, $4, $5};", + "=r,l,r,r,r,r", + has_side_effects=True, + asm_dialect=0, + ) diff --git a/jaxlib/mosaic/gpu/nvshmem.h b/jaxlib/mosaic/gpu/nvshmem.h index dbd11aa1d373..ce4421db7287 100644 --- a/jaxlib/mosaic/gpu/nvshmem.h +++ b/jaxlib/mosaic/gpu/nvshmem.h @@ -36,6 +36,8 @@ namespace gpu { fprintf(stderr, #FnName " not available in this library."); \ } +typedef int32_t nvshmem_team_t; + class NvshmemApi { public: // Returns a default NvshmemApi for a current process. @@ -61,6 +63,10 @@ class NvshmemApi { bool is_loaded() { return nvshmemx_init_status != nullptr && nvshmemx_init_status() == 2; } + + void* mc_ptr(nvshmem_team_t team, void* addr){ + return nvshmemx_mc_ptr(team,addr); + } NvshmemApi(NvshmemApi const&) = delete; void operator=(NvshmemApi const&) = delete; @@ -79,11 +85,13 @@ class NvshmemApi { NVSHMEM_SET_FN(nvshmemx_barrier_all_on_stream) NVSHMEM_SET_FN(nvshmemx_cumodule_init) NVSHMEM_SET_FN(nvshmemx_init_status) + NVSHMEM_SET_FN(nvshmemx_mc_ptr) } int (*nvshmemx_barrier_all_on_stream)(cudaStream_t); int (*nvshmemx_cumodule_init)(CUmodule); int (*nvshmemx_init_status)(); + void* (*nvshmemx_mc_ptr)(int, void*); std::mutex mutex_; }; diff --git a/jaxlib/mosaic/gpu/runtime.cc b/jaxlib/mosaic/gpu/runtime.cc index cb48a20dc3d5..cc6709306aef 100644 --- a/jaxlib/mosaic/gpu/runtime.cc +++ b/jaxlib/mosaic/gpu/runtime.cc @@ -25,7 +25,8 @@ extern "C" { void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, int64_t elem_type, int64_t rank, int64_t *sizes, int64_t *strides, - int64_t swizzle_bytes, int64_t *window_shape) { + int64_t swizzle_bytes, int64_t *window_shape, + mosaic::gpu::nvshmem_team_t team_id) { if (((uintptr_t)tma_desc) % 64 != 0) { fprintf(stderr, "TMA descriptor address must be 64 byte aligned, but got: %p\n", @@ -155,6 +156,15 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, fprintf(stderr, "Unsupported swizzle: %ld\n", swizzle_bytes); abort(); } + + if (team_id != -1){ + base_addr = mosaic::gpu::NvshmemApi::Default().mc_ptr(team_id, base_addr); + if (base_addr == nullptr) { + fprintf(stderr, "Failed to translate base_addr to multicast addr for TMA transfer.\n"); + abort(); + } + } + CUresult result = cuTensorMapEncodeTiled( tma_desc, data_type, rank, base_addr, tma_sizes, tma_strides, tma_window_shape, element_strides, CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle,