Skip to content

Move jax/_src/interpreters/batching.py into its own BUILD rule #28957

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,6 @@ py_library_providing_imports_info(
"_src/ffi.py",
"_src/flatten_util.py",
"_src/interpreters/__init__.py",
"_src/interpreters/batching.py",
"_src/interpreters/pxla.py",
"_src/pjit.py",
"_src/prng.py",
Expand Down Expand Up @@ -383,6 +382,7 @@ py_library_providing_imports_info(
":ad_util",
":api_util",
":basearray",
":batching",
":cloud_tpu_init",
":compilation_cache_internal",
":compiler",
Expand Down Expand Up @@ -707,6 +707,24 @@ pytype_strict_library(
],
)

pytype_strict_library(
name = "batching",
srcs = ["_src/interpreters/batching.py"],
deps = [
":ad_util",
":config",
":core",
":mesh",
":partial_eval",
":partition_spec",
":sharding_impls",
":source_info_util",
":tree_util",
":typing",
":util",
] + py_deps("numpy"),
)

pytype_strict_library(
name = "mlir",
srcs = ["_src/interpreters/mlir.py"],
Expand Down
28 changes: 19 additions & 9 deletions jax/_src/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import numpy as np

import jax
from jax._src import config
from jax._src import core
from jax._src import source_info_util
Expand Down Expand Up @@ -301,11 +300,14 @@ def _cont(axis_size, elt, axis):
from_elt_handlers: dict[type, FromEltHandler] = {}

def make_iota(axis_size: AxisSize) -> Array:
# Callers of this utility, via batch() or vtile(), must be in a context
# where lax is importable.
from jax import lax # pytype: disable=import-error
handler = make_iota_handlers.get(type(axis_size))
if handler:
return handler(axis_size)
else:
return jax.lax.iota('int32', int(axis_size))
return lax.iota('int32', int(axis_size))
make_iota_handlers: dict[type, MakeIotaHandler] = {}

def register_vmappable(data_type: type, spec_type: type, axis_size_type: type,
Expand Down Expand Up @@ -1019,10 +1021,13 @@ def broadcast_batcher(prim, args, dims, **params):
return (out, (0,) * len(out)) if prim.multiple_results else (out, 0)

def _handle_scalar_broadcasting(nd, x, d):
# Callers of this utility, via broadcast_batcher() or defbroadcasting(),
# must be in a context where lax is importable.
from jax import lax # pytype: disable=import-error
if d is not_mapped or nd == np.ndim(x):
return x
else:
return jax.lax.expand_dims(x, tuple(range(np.ndim(x), nd)))
return lax.expand_dims(x, tuple(range(np.ndim(x), nd)))

def defreducer(prim, ident):
primitive_batchers[prim] = partial(reducer_batcher, prim, ident)
Expand Down Expand Up @@ -1078,17 +1083,20 @@ def mask_ragged_axes(operand: Array, ident, axis_spec: RaggedAxis) -> Array:

def _mask_one_ragged_axis(
operand: Array, ident, axis_spec: RaggedAxis) -> Array:
# Callers of this utility, via reducer_batcher() or defreducer(),
# must be in a context where lax is importable.
from jax import lax # pytype: disable=import-error
assert len(axis_spec.ragged_axes) == 1, "Mask just one ragged axis at a time"
ragged_axis, segment_lengths = axis_spec.ragged_axes[0]
value = ident(operand.dtype)
positions = jax.lax.broadcasted_iota('int32', operand.shape, ragged_axis)
positions = lax.broadcasted_iota('int32', operand.shape, ragged_axis)
# TODO(mattjj, axch) can't get ._data, need to convert it
# lengths = jax.lax.convert_element_type(segment_lengths._data, 'int32')
lengths = jax.lax.convert_element_type(segment_lengths, 'int32')
limits = jax.lax.broadcast_in_dim(
# lengths = lax.convert_element_type(segment_lengths._data, 'int32')
lengths = lax.convert_element_type(segment_lengths, 'int32')
limits = lax.broadcast_in_dim(
lengths, operand.shape, [axis_spec.stacked_axis])
mask = positions < limits
return jax.lax.select(mask, operand, jax.lax.broadcast(value, operand.shape))
return lax.select(mask, operand, lax.broadcast(value, operand.shape))

def move_stacked_axis(operand, bdim, dst):
dst = canonicalize_axis(dst, operand.ndim)
Expand All @@ -1103,6 +1111,8 @@ def move_stacked_axis(operand, bdim, dst):
### general utilities for manipulating axes on jaxpr types (not vmappables)

def broadcast(x, sz, axis, mesh_axis=None):
# Callers of this utility must be in a context where lax is importable.
from jax import lax # pytype: disable=import-error
shape = list(np.shape(x))
shape.insert(axis, sz)
broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis))
Expand All @@ -1114,7 +1124,7 @@ def broadcast(x, sz, axis, mesh_axis=None):
# TODO(dougalm, yashkatariya): Delete this context manager once we figure
# out how to ensure jaxpr arguments always have the context mesh.
with mesh_lib.use_abstract_mesh(sharding.mesh):
x = jax.lax.broadcast_in_dim(x, shape, broadcast_dims, out_sharding=sharding)
x = lax.broadcast_in_dim(x, shape, broadcast_dims, out_sharding=sharding)
if config._check_vma.value:
# TODO(yashkatariya,parkers): don't do this, fix during fixit week 2026
spmd_names = core.get_axis_env().spmd_axis_names
Expand Down
Loading