|
| 1 | +# Copyright 2025 The JAX Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""A collective matmul kernel implemented using Mosaic GPU.""" |
| 16 | + |
| 17 | +import functools |
| 18 | +import jax |
| 19 | +from jax import lax |
| 20 | +from jax.experimental import pallas as pl |
| 21 | +from jax.experimental.pallas import mosaic_gpu as plgpu |
| 22 | +import jax.numpy as jnp |
| 23 | + |
| 24 | + |
| 25 | +def _find_swizzle(dim_size_bits: int, what: str): |
| 26 | + for swizzle_bytes in (128, 64, 32, 16): |
| 27 | + if dim_size_bits % (swizzle_bytes * 8) == 0: |
| 28 | + return swizzle_bytes |
| 29 | + raise ValueError( |
| 30 | + f"No valid out swizzle for {what}: its minor dimension has" |
| 31 | + f" {dim_size_bits} bits, which is not a multiple of 128" |
| 32 | + ) |
| 33 | + |
| 34 | + |
| 35 | +# TODO(apaszke): Add grid tiling |
| 36 | +def all_gather_lhs_matmul( |
| 37 | + lhs: jax.Array, |
| 38 | + rhs: jax.Array, |
| 39 | + axis_name, |
| 40 | + *, |
| 41 | + block_m: int, |
| 42 | + block_n: int, |
| 43 | + block_k: int, |
| 44 | + max_concurrent_steps: int, |
| 45 | +) -> jax.Array: |
| 46 | + if (num_devices := jax.device_count()) != jax.process_count(): |
| 47 | + raise ValueError("The kernel only supports one device per process") |
| 48 | + if (axis_size := lax.axis_size(axis_name)) != num_devices: |
| 49 | + raise ValueError("The kernel can only work over all devices in a Mesh.") |
| 50 | + if max_concurrent_steps < 2: |
| 51 | + raise ValueError("max_concurrent_steps must be >= 2") |
| 52 | + |
| 53 | + num_sms = 132 # There are 132 SMs on a H100 SXM GPU. |
| 54 | + |
| 55 | + m_shard, k = lhs.shape |
| 56 | + k2, n_shard = rhs.shape |
| 57 | + if k != k2: |
| 58 | + raise ValueError( |
| 59 | + f"lhs and rhs must have the same contraction size, got {k} and {k2}." |
| 60 | + ) |
| 61 | + if (element_type := lhs.dtype) != rhs.dtype: |
| 62 | + raise ValueError( |
| 63 | + f"lhs and rhs must have the same element type, got {element_type} and" |
| 64 | + f" {rhs.dtype}." |
| 65 | + ) |
| 66 | + if k % block_k != 0: |
| 67 | + raise NotImplementedError(f"k={k} must be a multiple of block_k={block_k}") |
| 68 | + if m_shard % block_m != 0: |
| 69 | + raise NotImplementedError(f"m_shard={m_shard} must be a multiple of block_m={block_m}") |
| 70 | + if n_shard % block_n != 0: |
| 71 | + raise NotImplementedError(f"n_shard={n_shard} must be a multiple of block_n={block_n}") |
| 72 | + if n_shard != block_n: |
| 73 | + raise NotImplementedError( |
| 74 | + f"n_shard={n_shard} must be equal to block_n={block_n}" |
| 75 | + ) |
| 76 | + |
| 77 | + swizzle = min( |
| 78 | + _find_swizzle(block_k * jnp.finfo(element_type).bits, "lhs"), |
| 79 | + _find_swizzle(block_n * jnp.finfo(element_type).bits, "rhs"), |
| 80 | + ) |
| 81 | + transforms = ( |
| 82 | + plgpu.TilingTransform((8, swizzle // jnp.dtype(element_type).itemsize)), |
| 83 | + plgpu.SwizzleTransform(swizzle), |
| 84 | + ) |
| 85 | + |
| 86 | + def kernel_body(lhs_ref, rhs_ref, out_ref, scratch_ref, capacity_sem, received_sem): |
| 87 | + sm_id = lax.axis_index('sm') |
| 88 | + scratch_ref = scratch_ref.at[sm_id] |
| 89 | + |
| 90 | + dev_id = lax.axis_index(axis_name) |
| 91 | + send_dev_id = lax.rem(dev_id + axis_size - 1, axis_size) |
| 92 | + recv_dev_id = lax.rem(dev_id + 1, axis_size) |
| 93 | + # NOTE: Technically we should signal the recv_dev_id (and our signal would |
| 94 | + # be received from send_dev_id), but if everyone signals in a ring after a |
| 95 | + # barrier then it's equivalent to a local signal. |
| 96 | + pl.semaphore_signal(capacity_sem) |
| 97 | + send_scratch_ref = plgpu.remote_ref( |
| 98 | + scratch_ref, send_dev_id, device_id_type=pl.DeviceIdType.LOGICAL |
| 99 | + ) |
| 100 | + |
| 101 | + def m_loop(mi, _): |
| 102 | + mi = mi * lax.axis_size('sm') + sm_id |
| 103 | + m_tile_slice = pl.ds(mi * block_m, block_m) |
| 104 | + |
| 105 | + # For some reason ptxas spills if we unroll the loop over k |
| 106 | + copy_block = 32 |
| 107 | + def k_copy_loop(ki, _): |
| 108 | + k_slice = pl.ds(ki * copy_block, copy_block) |
| 109 | + scratch_ref[0, :, k_slice] = lhs_ref[m_tile_slice, k_slice] |
| 110 | + jax.lax.fori_loop(0, k // copy_block, k_copy_loop, None) |
| 111 | + |
| 112 | + def device_loop(device_offset, _): |
| 113 | + # Loop invariant: scratch_ref.at[scratch_slot] is ready to be used |
| 114 | + # We're double buffering the scratch space. At each step, we read from |
| 115 | + # scratch_ref.at[scratch_slot] and write to scratch_ref.at[next_scratch_slot] |
| 116 | + # located on the send_dev_id. We swap the slots after completing a step, |
| 117 | + # which lets us overlap the copy with compute. |
| 118 | + scratch_slot = lax.rem(device_offset, 2) |
| 119 | + next_scratch_slot = 1 - scratch_slot |
| 120 | + |
| 121 | + @functools.partial( |
| 122 | + pl.run_scoped, |
| 123 | + acc_ref=plgpu.ACC((block_m, block_n)), |
| 124 | + out_smem=plgpu.SMEM((block_m, block_n), jnp.float16, transforms=transforms), |
| 125 | + ) |
| 126 | + def _(acc_ref, out_smem): |
| 127 | + pl.semaphore_wait(capacity_sem) |
| 128 | + @functools.partial( |
| 129 | + plgpu.emit_pipeline, |
| 130 | + grid=(k // block_k,), |
| 131 | + in_specs=[ |
| 132 | + plgpu.BlockSpec((block_m, block_k), lambda k: (0, k), transforms=transforms), |
| 133 | + plgpu.BlockSpec((block_k, block_n), lambda k: (k, 0), transforms=transforms), |
| 134 | + ], |
| 135 | + max_concurrent_steps=max_concurrent_steps, |
| 136 | + delay_release=1, |
| 137 | + ) |
| 138 | + def k_loop(idxs, lhs_smem, rhs_smem): |
| 139 | + (ki,) = idxs |
| 140 | + plgpu.wgmma(acc_ref, lhs_smem, rhs_smem) |
| 141 | + k_slice = pl.ds(ki * block_k, block_k) |
| 142 | + # TODO(apaszke): No need to send on the last step |
| 143 | + # TODO(apaszke): Use an async copy. This is uncoalesced. |
| 144 | + send_scratch_ref[next_scratch_slot, :, k_slice] = lhs_smem[...] |
| 145 | + k_loop(scratch_ref.at[scratch_slot], rhs_ref) |
| 146 | + # TODO(apaszke): Both of those semaphores perform a .sys release. |
| 147 | + # This is very expensive and we should only do a single .sys fence. |
| 148 | + pl.semaphore_signal(capacity_sem, device_id=recv_dev_id, device_id_type=pl.DeviceIdType.LOGICAL) |
| 149 | + pl.semaphore_signal(received_sem, device_id=send_dev_id, device_id_type=pl.DeviceIdType.LOGICAL) |
| 150 | + # Make sure all TMAs have read SMEM before we overwrite it. |
| 151 | + plgpu.wait_smem_to_gmem(0, wait_read_only=True) |
| 152 | + out_smem[...] = acc_ref[...].astype(out_smem.dtype) |
| 153 | + plgpu.commit_smem() |
| 154 | + device_m_slice = pl.ds( |
| 155 | + lax.rem(device_offset + dev_id, num_devices) * m_shard, block_m |
| 156 | + ) |
| 157 | + plgpu.copy_smem_to_gmem( |
| 158 | + out_smem, out_ref.at[device_m_slice].at[m_tile_slice] |
| 159 | + ) |
| 160 | + # Wait for the next scratch to arrive --- see the loop invariant. |
| 161 | + pl.semaphore_wait(received_sem) |
| 162 | + jax.lax.fori_loop(0, num_devices, device_loop, None) |
| 163 | + grid_size = m_shard // block_m |
| 164 | + m_steps = grid_size // num_sms + jnp.int32(sm_id < grid_size % num_sms) |
| 165 | + # TODO(apaszke): Use the ND-loop helper. |
| 166 | + jax.lax.fori_loop(0, m_steps, m_loop, None) |
| 167 | + |
| 168 | + result, _ = plgpu.kernel( |
| 169 | + kernel_body, |
| 170 | + out_shape=[jax.ShapeDtypeStruct((axis_size * m_shard, n_shard), jnp.float16), |
| 171 | + jax.ShapeDtypeStruct((num_sms, 2, block_m, k), jnp.float16)], |
| 172 | + scratch_shapes=[ |
| 173 | + plgpu.SemaphoreType.REGULAR, plgpu.SemaphoreType.REGULAR, |
| 174 | + ], |
| 175 | + grid=(num_sms,), |
| 176 | + grid_names=('sm',), |
| 177 | + )(lhs, rhs) |
| 178 | + return result |
0 commit comments