|
| 1 | +# Copyright 2025 The JAX Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +"""Matrix Multiplication kernel for Blackwell GPUs.""" |
| 15 | + |
| 16 | +import dataclasses |
| 17 | +import functools |
| 18 | +import itertools |
| 19 | +import jax |
| 20 | +from jax import lax |
| 21 | +from jax._src import test_util as jtu # noqa: F401 |
| 22 | +from jax.experimental.mosaic.gpu import profiler |
| 23 | +import jax.experimental.pallas as pl |
| 24 | +import jax.experimental.pallas.mosaic_gpu as plgpu |
| 25 | +import jax.numpy as jnp |
| 26 | +import numpy as np |
| 27 | + |
| 28 | + |
| 29 | +@dataclasses.dataclass(frozen=True) |
| 30 | +class TuningConfig: |
| 31 | + block_m: int |
| 32 | + block_n: int |
| 33 | + block_k: int |
| 34 | + max_concurrent_steps: int |
| 35 | + collective: bool |
| 36 | + |
| 37 | + |
| 38 | +def _find_swizzle(dim_size_bits: int): |
| 39 | + """Finds the largest swizzle that fits the dimension size.""" |
| 40 | + for swizzle_bytes in (128, 64, 32, 16): |
| 41 | + if dim_size_bits % (swizzle_bytes * 8) == 0: |
| 42 | + return swizzle_bytes |
| 43 | + raise ValueError( |
| 44 | + f"Dimension size has {dim_size_bits} bits, which is not a multiple of 128" |
| 45 | + ) |
| 46 | + |
| 47 | + |
| 48 | +def matmul_kernel(a, b, config: TuningConfig): |
| 49 | + dtype = a.dtype |
| 50 | + if a.dtype != b.dtype: |
| 51 | + raise ValueError( |
| 52 | + f"Matmul LHS and RHS have incompatible dtypes {a.dtype} vs {b.dtype}" |
| 53 | + ) |
| 54 | + m, k = a.shape |
| 55 | + k2, n = b.shape |
| 56 | + if k != k2: |
| 57 | + raise ValueError( |
| 58 | + f"Matmul LHS and RHS have incompatible shapes {a.shape} vs {b.shape}" |
| 59 | + ) |
| 60 | + collective = config.collective |
| 61 | + if collective: |
| 62 | + raise ValueError("Collective matmul is not supported yet.") |
| 63 | + block_m, block_n, block_k = (config.block_m, config.block_n, config.block_k) |
| 64 | + swizzle = _find_swizzle(block_k * jnp.dtype(dtype).itemsize * 8) |
| 65 | + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize |
| 66 | + transforms = ( |
| 67 | + plgpu.TilingTransform((8, swizzle_elems)), |
| 68 | + plgpu.SwizzleTransform(swizzle), |
| 69 | + ) |
| 70 | + block_lhs = (block_m, block_k) |
| 71 | + block_rhs = (block_k, block_n) |
| 72 | + block_out = (block_m, block_n) |
| 73 | + if m % block_m != 0: |
| 74 | + raise ValueError(f"{m=} must be divisible by {block_m=}") |
| 75 | + if n % block_n != 0: |
| 76 | + raise ValueError(f"{n=} must be divisible by {block_n=}") |
| 77 | + if k % block_k != 0: |
| 78 | + raise ValueError(f"{k=} must be divisible by {block_k=}") |
| 79 | + m_iters = m // block_m |
| 80 | + n_iters = n // block_n |
| 81 | + k_iters = k // block_k |
| 82 | + max_concurrent_steps = config.max_concurrent_steps |
| 83 | + |
| 84 | + def kernel(a_gmem, b_gmem, out_gmem): |
| 85 | + m_index = lax.axis_index("m") |
| 86 | + n_index = lax.axis_index("n") |
| 87 | + slice_m = pl.ds(m_index * block_m, block_m) |
| 88 | + slice_n = pl.ds(n_index * block_n, block_n) |
| 89 | + acc_slice_m = pl.ds(m_index * block_m, block_m) |
| 90 | + acc_slice_n = pl.ds(n_index * block_n, block_n) |
| 91 | + |
| 92 | + @functools.partial( |
| 93 | + pl.run_scoped, |
| 94 | + a_smem=plgpu.SMEM( |
| 95 | + (max_concurrent_steps, *block_lhs), dtype, transforms=transforms |
| 96 | + ), |
| 97 | + b_smem=plgpu.SMEM( |
| 98 | + (max_concurrent_steps, *block_rhs), dtype, transforms=transforms |
| 99 | + ), |
| 100 | + acc_tmem=plgpu.TMEM(block_out, jnp.float32, collective=collective), |
| 101 | + scratch_smem=plgpu.SMEM(block_out, dtype, transforms=transforms), |
| 102 | + a_tma_barrier=plgpu.Barrier( |
| 103 | + num_arrivals=1, num_barriers=max_concurrent_steps |
| 104 | + ), |
| 105 | + b_tma_barrier=plgpu.Barrier( |
| 106 | + num_arrivals=1, num_barriers=max_concurrent_steps |
| 107 | + ), |
| 108 | + consumed_barrier=plgpu.Barrier( |
| 109 | + num_arrivals=1, |
| 110 | + num_barriers=max_concurrent_steps + 1, |
| 111 | + for_tensor_core=True, |
| 112 | + ), |
| 113 | + ) |
| 114 | + def _scoped( |
| 115 | + a_smem, |
| 116 | + b_smem, |
| 117 | + acc_tmem, |
| 118 | + scratch_smem, |
| 119 | + a_tma_barrier, |
| 120 | + b_tma_barrier, |
| 121 | + consumed_barrier, |
| 122 | + ): |
| 123 | + @pl.core_map(plgpu.WarpMesh(axis_name="warp")) |
| 124 | + def _per_warp(): |
| 125 | + warp_id = lax.axis_index("warp") |
| 126 | + |
| 127 | + @pl.when(warp_id == 0) |
| 128 | + def _memory(): |
| 129 | + def _loop_body(ki, _): |
| 130 | + slot = lax.rem(ki, max_concurrent_steps) |
| 131 | + |
| 132 | + @pl.when(ki >= max_concurrent_steps) |
| 133 | + def _(): |
| 134 | + plgpu.barrier_wait(consumed_barrier.at[slot]) |
| 135 | + |
| 136 | + slice_k = pl.ds(ki * block_k, block_k) |
| 137 | + plgpu.copy_gmem_to_smem( |
| 138 | + a_gmem.at[slice_m, slice_k], |
| 139 | + a_smem.at[slot], |
| 140 | + a_tma_barrier.at[slot], |
| 141 | + ) |
| 142 | + plgpu.copy_gmem_to_smem( |
| 143 | + b_gmem.at[slice_k, slice_n], |
| 144 | + b_smem.at[slot], |
| 145 | + b_tma_barrier.at[slot], |
| 146 | + ) |
| 147 | + |
| 148 | + lax.fori_loop(0, k_iters, _loop_body, None) |
| 149 | + |
| 150 | + @pl.when(warp_id == 1) |
| 151 | + def _compute(): |
| 152 | + def _loop_body(ki, _): |
| 153 | + slot = lax.rem(ki, max_concurrent_steps) |
| 154 | + plgpu.barrier_wait(a_tma_barrier.at[slot]) |
| 155 | + plgpu.barrier_wait(b_tma_barrier.at[slot]) |
| 156 | + is_last_iter = ki >= k_iters - 1 |
| 157 | + barrier_slot = lax.select_n(is_last_iter, |
| 158 | + slot, max_concurrent_steps) |
| 159 | + plgpu.tcgen05_mma( |
| 160 | + acc_tmem, |
| 161 | + a_smem.at[slot], |
| 162 | + b_smem.at[slot], |
| 163 | + consumed_barrier.at[barrier_slot], |
| 164 | + accumulate=(ki > 0), |
| 165 | + ) |
| 166 | + lax.fori_loop(0, k_iters, _loop_body, None) |
| 167 | + |
| 168 | + plgpu.barrier_wait(consumed_barrier.at[max_concurrent_steps]) |
| 169 | + scratch_smem[...] = acc_tmem[...].astype(dtype) |
| 170 | + plgpu.commit_smem() |
| 171 | + plgpu.copy_smem_to_gmem( |
| 172 | + scratch_smem, out_gmem.at[acc_slice_m, acc_slice_n] |
| 173 | + ) |
| 174 | + plgpu.wait_smem_to_gmem(0) |
| 175 | + |
| 176 | + f = plgpu.kernel( |
| 177 | + kernel, |
| 178 | + out_shape=jax.ShapeDtypeStruct((m, n), dtype), |
| 179 | + grid=(m_iters, n_iters), |
| 180 | + grid_names=("m", "n"), |
| 181 | + # TODO(justinfu): Add collective support. |
| 182 | + cluster_names=(), |
| 183 | + cluster=(), |
| 184 | + ) |
| 185 | + return f(a, b) |
| 186 | + |
| 187 | + |
| 188 | +def main(_) -> None: |
| 189 | + problem_it = itertools.product( |
| 190 | + (1024, 4096, 8192), (1024, 4096, 8192), (1024, 8192) |
| 191 | + ) |
| 192 | + for M, N, K in problem_it: |
| 193 | + print(f"==== {M=} {N=} {K=} ====") |
| 194 | + matmul_flops = 2 * M * N * K |
| 195 | + peak_flops = 2.25e15 # f16 TensorCore peak = 2250 TFLOPS |
| 196 | + a = jax.random.uniform(jax.random.key(0), (M, K), jnp.bfloat16) |
| 197 | + b = jax.random.uniform(jax.random.key(1), (K, N), jnp.bfloat16) |
| 198 | + tuning_it = itertools.product( |
| 199 | + (128,), (128, 256), (64, 128), (2, 3, 4), (False,) |
| 200 | + ) |
| 201 | + best_util = -float("inf") |
| 202 | + for (block_m, block_n, block_k, |
| 203 | + max_concurrent_steps, collective) in tuning_it: |
| 204 | + config = TuningConfig( |
| 205 | + block_m=block_m, |
| 206 | + block_n=block_n, |
| 207 | + block_k=block_k, |
| 208 | + max_concurrent_steps=max_concurrent_steps, |
| 209 | + collective=collective, |
| 210 | + ) |
| 211 | + try: |
| 212 | + out, runtime_ms = profiler.measure( |
| 213 | + functools.partial(matmul_kernel, config=config) |
| 214 | + )(a, b) |
| 215 | + except ValueError as e: |
| 216 | + if "exceeds available shared memory" in e.args[0]: |
| 217 | + continue |
| 218 | + raise |
| 219 | + if M * N * K <= 1024 * 1024 * 1024: |
| 220 | + expected = a @ b |
| 221 | + np.testing.assert_allclose(out, expected) |
| 222 | + runtime_us = float(runtime_ms) * 1e3 |
| 223 | + optimal_time = matmul_flops / peak_flops * 1e6 # us |
| 224 | + achieved_tc_util = optimal_time / runtime_us * 100 |
| 225 | + if achieved_tc_util > best_util: |
| 226 | + best_util = achieved_tc_util |
| 227 | + print( |
| 228 | + f"{block_m=} {block_n=} {block_k=} {max_concurrent_steps=}: " |
| 229 | + f"{runtime_us:<7.1f}us" |
| 230 | + f" = {achieved_tc_util:4.1f}% TC utilization" |
| 231 | + ) |
| 232 | + print(f"\tBest utilization: {best_util:4.1f}%") |
| 233 | + |
| 234 | + |
| 235 | +if __name__ == "__main__": |
| 236 | + from absl import app |
| 237 | + |
| 238 | + jax.config.config_with_absl() |
| 239 | + app.run(main) |
0 commit comments