Skip to content

Commit 9366503

Browse files
Updates LLVM usage to match [917d1f20aecf](llvm/llvm-project@917d1f20aecf) PiperOrigin-RevId: 823542980
1 parent b815692 commit 9366503

File tree

5 files changed

+207
-94
lines changed

5 files changed

+207
-94
lines changed

jax/_src/lib/mlir/dialects/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Any, TYPE_CHECKING
1818

1919
if TYPE_CHECKING:
20+
from jaxlib.mlir.dialects import _gpu_ops_gen as _gpu_ops_gen
2021
from jaxlib.mlir.dialects import arith as arith
2122
from jaxlib.mlir.dialects import builtin as builtin
2223
from jaxlib.mlir.dialects import cf as cf
@@ -35,6 +36,7 @@
3536
else:
3637
from jax._src import lazy_loader as _lazy
3738
__getattr__, __dir__, __all__ = _lazy.attach("jaxlib.mlir.dialects", [
39+
"_gpu_ops_gen",
3840
"arith",
3941
"builtin",
4042
"cf",

jax/experimental/mosaic/gpu/core.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,14 @@
3232
from jax._src import core as jax_core
3333
from jax._src import dtypes
3434
from jax._src import lib
35-
from jax._src import sharding_impls
3635
from jax._src import mesh as mesh_lib
36+
from jax._src import sharding_impls
3737
from jax._src import util as jax_util
3838
from jax._src.interpreters import mlir
3939
from jax._src.lib import mosaic_gpu_dialect as dialect
4040
from jaxlib.mlir import ir
4141
from jaxlib.mlir import passmanager
42+
from jaxlib.mlir.dialects import _gpu_ops_gen
4243
from jaxlib.mlir.dialects import arith
4344
from jaxlib.mlir.dialects import builtin
4445
from jaxlib.mlir.dialects import func
@@ -579,9 +580,14 @@ def _launch(
579580
)
580581
else:
581582
cluster_kwargs = {}
582-
launch_op = gpu.LaunchOp(
583-
token.type, [token], *grid_vals, *block_vals,
584-
dynamicSharedMemorySize=c(smem_bytes, i32), **cluster_kwargs)
583+
launch_op = _gpu_ops_gen.LaunchOp(
584+
token.type,
585+
[token],
586+
*grid_vals,
587+
*block_vals,
588+
dynamicSharedMemorySize=c(smem_bytes, i32),
589+
**cluster_kwargs,
590+
)
585591
launch_op.body.blocks.append(*([index] * (12 + 2 * len(cluster_kwargs)))) # Append an empty block
586592
with ir.InsertionPoint(launch_op.body.blocks[0]):
587593
dynamic_smem = gpu.dynamic_shared_memory(

jax/experimental/mosaic/gpu/dialect_lowering.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from jax._src.interpreters import mlir as mlir_interpreter
2626
from jax._src.lib import mosaic_gpu_dialect as mgpu
2727
from jax._src.lib.mlir import ir
28+
from jax._src.lib.mlir.dialects import _gpu_ops_gen
2829
from jax._src.lib.mlir.dialects import arith
2930
from jax._src.lib.mlir.dialects import builtin
3031
from jax._src.lib.mlir.dialects import func
@@ -2062,7 +2063,7 @@ def _index_switch_op_lowering_rule(
20622063

20632064

20642065
@_register_lowering(func.FuncOp)
2065-
@_register_lowering(gpu.LaunchOp)
2066+
@_register_lowering(_gpu_ops_gen.LaunchOp)
20662067
def _traverse_op_lowering_rule(
20672068
ctx: LoweringContext, op: ir.OpView
20682069
) -> MlirLoweringRuleResult:

jax/experimental/mosaic/gpu/launch_context.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect
2525
from jaxlib.mlir import ir
26+
from jaxlib.mlir.dialects import _gpu_ops_gen
2627
from jaxlib.mlir.dialects import arith
2728
from jaxlib.mlir.dialects import builtin
2829
from jaxlib.mlir.dialects import func
@@ -32,9 +33,9 @@
3233
from jaxlib.mlir.dialects import nvvm
3334
import numpy as np
3435

36+
from . import fragmented_array as fa
3537
from . import profiler
3638
from . import utils
37-
from . import fragmented_array as fa
3839

3940
TMA_DESCRIPTOR_BYTES = 128
4041
TMA_DESCRIPTOR_ALIGNMENT = 64
@@ -304,7 +305,7 @@ class Scratch:
304305
: (!llvm.array<256 x i8>) -> !llvm.ptr
305306
306307
"""
307-
def __init__(self, gpu_launch_op: gpu.LaunchOp):
308+
def __init__(self, gpu_launch_op: _gpu_ops_gen.LaunchOp):
308309
self.next_offset: int = 0
309310
self.host_init: list[Callable[[ir.Value], None]] = []
310311
self._ops_created = False

0 commit comments

Comments
 (0)