Skip to content

Commit 617afc3

Browse files
Jake VanderPlasGoogle-ML-Automation
Jake VanderPlas
authored andcommitted
Move jax/_src/interpreters/batching.py into its own BUILD rule
Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times. Unfortunately this is not a clean build refactor, because batching depends on jax.lax, which in turn depends on batching. However, the problematic functions are only called within contexts where jax.lax is available for import. We have a few options here: 1. Continue to bundle the batching.py source with the main build. 2. Build separately, but do the local import workaround in this CL (a pattern we use elsewhere). 3. Build this separately, but move some batching definitions into jax.lax for a more strict dependency graph. Or pass the `lax` namespace explicitly to the function at the call site. I opted for (2) here because I judged the benefits of a refactored build to be worth the cost of localized impure dependencies, and the kind of refactoring in (3) would affect some downstream users. PiperOrigin-RevId: 762110930
1 parent e71d5d5 commit 617afc3

File tree

2 files changed

+38
-10
lines changed

2 files changed

+38
-10
lines changed

jax/BUILD

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,6 @@ py_library_providing_imports_info(
316316
"_src/ffi.py",
317317
"_src/flatten_util.py",
318318
"_src/interpreters/__init__.py",
319-
"_src/interpreters/batching.py",
320319
"_src/interpreters/pxla.py",
321320
"_src/pjit.py",
322321
"_src/prng.py",
@@ -384,6 +383,7 @@ py_library_providing_imports_info(
384383
":ad_util",
385384
":api_util",
386385
":basearray",
386+
":batching",
387387
":cloud_tpu_init",
388388
":compilation_cache_internal",
389389
":compiler",
@@ -688,6 +688,24 @@ pytype_strict_library(
688688
],
689689
)
690690

691+
pytype_strict_library(
692+
name = "batching",
693+
srcs = ["_src/interpreters/batching.py"],
694+
deps = [
695+
":ad_util",
696+
":config",
697+
":core",
698+
":mesh",
699+
":partial_eval",
700+
":partition_spec",
701+
":sharding_impls",
702+
":source_info_util",
703+
":tree_util",
704+
":typing",
705+
":util",
706+
] + py_deps("numpy"),
707+
)
708+
691709
pytype_strict_library(
692710
name = "mlir",
693711
srcs = ["_src/interpreters/mlir.py"],

jax/_src/interpreters/batching.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
import numpy as np
2323

24-
import jax
2524
from jax._src import config
2625
from jax._src import core
2726
from jax._src import source_info_util
@@ -301,11 +300,14 @@ def _cont(axis_size, elt, axis):
301300
from_elt_handlers: dict[type, FromEltHandler] = {}
302301

303302
def make_iota(axis_size: AxisSize) -> Array:
303+
# Callers of this utility, via batch() or vtile(), must be in a context
304+
# where lax is importable.
305+
from jax import lax # pytype: disable=import-error
304306
handler = make_iota_handlers.get(type(axis_size))
305307
if handler:
306308
return handler(axis_size)
307309
else:
308-
return jax.lax.iota('int32', int(axis_size))
310+
return lax.iota('int32', int(axis_size))
309311
make_iota_handlers: dict[type, MakeIotaHandler] = {}
310312

311313
def register_vmappable(data_type: type, spec_type: type, axis_size_type: type,
@@ -1019,10 +1021,13 @@ def broadcast_batcher(prim, args, dims, **params):
10191021
return (out, (0,) * len(out)) if prim.multiple_results else (out, 0)
10201022

10211023
def _handle_scalar_broadcasting(nd, x, d):
1024+
# Callers of this utility, via broadcast_batcher() or defbroadcasting(),
1025+
# must be in a context where lax is importable.
1026+
from jax import lax # pytype: disable=import-error
10221027
if d is not_mapped or nd == np.ndim(x):
10231028
return x
10241029
else:
1025-
return jax.lax.expand_dims(x, tuple(range(np.ndim(x), nd)))
1030+
return lax.expand_dims(x, tuple(range(np.ndim(x), nd)))
10261031

10271032
def defreducer(prim, ident):
10281033
primitive_batchers[prim] = partial(reducer_batcher, prim, ident)
@@ -1078,17 +1083,20 @@ def mask_ragged_axes(operand: Array, ident, axis_spec: RaggedAxis) -> Array:
10781083

10791084
def _mask_one_ragged_axis(
10801085
operand: Array, ident, axis_spec: RaggedAxis) -> Array:
1086+
# Callers of this utility, via reducer_batcher() or defreducer(),
1087+
# must be in a context where lax is importable.
1088+
from jax import lax # pytype: disable=import-error
10811089
assert len(axis_spec.ragged_axes) == 1, "Mask just one ragged axis at a time"
10821090
ragged_axis, segment_lengths = axis_spec.ragged_axes[0]
10831091
value = ident(operand.dtype)
1084-
positions = jax.lax.broadcasted_iota('int32', operand.shape, ragged_axis)
1092+
positions = lax.broadcasted_iota('int32', operand.shape, ragged_axis)
10851093
# TODO(mattjj, axch) can't get ._data, need to convert it
1086-
# lengths = jax.lax.convert_element_type(segment_lengths._data, 'int32')
1087-
lengths = jax.lax.convert_element_type(segment_lengths, 'int32')
1088-
limits = jax.lax.broadcast_in_dim(
1094+
# lengths = lax.convert_element_type(segment_lengths._data, 'int32')
1095+
lengths = lax.convert_element_type(segment_lengths, 'int32')
1096+
limits = lax.broadcast_in_dim(
10891097
lengths, operand.shape, [axis_spec.stacked_axis])
10901098
mask = positions < limits
1091-
return jax.lax.select(mask, operand, jax.lax.broadcast(value, operand.shape))
1099+
return lax.select(mask, operand, lax.broadcast(value, operand.shape))
10921100

10931101
def move_stacked_axis(operand, bdim, dst):
10941102
dst = canonicalize_axis(dst, operand.ndim)
@@ -1103,6 +1111,8 @@ def move_stacked_axis(operand, bdim, dst):
11031111
### general utilities for manipulating axes on jaxpr types (not vmappables)
11041112

11051113
def broadcast(x, sz, axis, mesh_axis=None):
1114+
# Callers of this utility must be in a context where lax is importable.
1115+
from jax import lax # pytype: disable=import-error
11061116
shape = list(np.shape(x))
11071117
shape.insert(axis, sz)
11081118
broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis))
@@ -1114,7 +1124,7 @@ def broadcast(x, sz, axis, mesh_axis=None):
11141124
# TODO(dougalm, yashkatariya): Delete this context manager once we figure
11151125
# out how to ensure jaxpr arguments always have the context mesh.
11161126
with mesh_lib.use_abstract_mesh(sharding.mesh):
1117-
x = jax.lax.broadcast_in_dim(x, shape, broadcast_dims, out_sharding=sharding)
1127+
x = lax.broadcast_in_dim(x, shape, broadcast_dims, out_sharding=sharding)
11181128
if config._check_vma.value:
11191129
# TODO(yashkatariya,parkers): don't do this, fix during fixit week 2026
11201130
spmd_names = core.get_axis_env().spmd_axis_names

0 commit comments

Comments
 (0)