diff --git a/jax/BUILD b/jax/BUILD index e7f1fad3121d..7c8847e2e94b 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -315,7 +315,6 @@ py_library_providing_imports_info( "_src/ffi.py", "_src/flatten_util.py", "_src/interpreters/__init__.py", - "_src/interpreters/batching.py", "_src/interpreters/pxla.py", "_src/pjit.py", "_src/prng.py", @@ -383,6 +382,7 @@ py_library_providing_imports_info( ":ad_util", ":api_util", ":basearray", + ":batching", ":cloud_tpu_init", ":compilation_cache_internal", ":compiler", @@ -707,6 +707,24 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "batching", + srcs = ["_src/interpreters/batching.py"], + deps = [ + ":ad_util", + ":config", + ":core", + ":mesh", + ":partial_eval", + ":partition_spec", + ":sharding_impls", + ":source_info_util", + ":tree_util", + ":typing", + ":util", + ] + py_deps("numpy"), +) + pytype_strict_library( name = "mlir", srcs = ["_src/interpreters/mlir.py"], diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 0fbe54a30672..55769aa307fc 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -21,7 +21,6 @@ import numpy as np -import jax from jax._src import config from jax._src import core from jax._src import source_info_util @@ -301,11 +300,14 @@ def _cont(axis_size, elt, axis): from_elt_handlers: dict[type, FromEltHandler] = {} def make_iota(axis_size: AxisSize) -> Array: + # Callers of this utility, via batch() or vtile(), must be in a context + # where lax is importable. + from jax import lax # pytype: disable=import-error handler = make_iota_handlers.get(type(axis_size)) if handler: return handler(axis_size) else: - return jax.lax.iota('int32', int(axis_size)) + return lax.iota('int32', int(axis_size)) make_iota_handlers: dict[type, MakeIotaHandler] = {} def register_vmappable(data_type: type, spec_type: type, axis_size_type: type, @@ -1019,10 +1021,13 @@ def broadcast_batcher(prim, args, dims, **params): return (out, (0,) * len(out)) if prim.multiple_results else (out, 0) def _handle_scalar_broadcasting(nd, x, d): + # Callers of this utility, via broadcast_batcher() or defbroadcasting(), + # must be in a context where lax is importable. + from jax import lax # pytype: disable=import-error if d is not_mapped or nd == np.ndim(x): return x else: - return jax.lax.expand_dims(x, tuple(range(np.ndim(x), nd))) + return lax.expand_dims(x, tuple(range(np.ndim(x), nd))) def defreducer(prim, ident): primitive_batchers[prim] = partial(reducer_batcher, prim, ident) @@ -1078,17 +1083,20 @@ def mask_ragged_axes(operand: Array, ident, axis_spec: RaggedAxis) -> Array: def _mask_one_ragged_axis( operand: Array, ident, axis_spec: RaggedAxis) -> Array: + # Callers of this utility, via reducer_batcher() or defreducer(), + # must be in a context where lax is importable. + from jax import lax # pytype: disable=import-error assert len(axis_spec.ragged_axes) == 1, "Mask just one ragged axis at a time" ragged_axis, segment_lengths = axis_spec.ragged_axes[0] value = ident(operand.dtype) - positions = jax.lax.broadcasted_iota('int32', operand.shape, ragged_axis) + positions = lax.broadcasted_iota('int32', operand.shape, ragged_axis) # TODO(mattjj, axch) can't get ._data, need to convert it - # lengths = jax.lax.convert_element_type(segment_lengths._data, 'int32') - lengths = jax.lax.convert_element_type(segment_lengths, 'int32') - limits = jax.lax.broadcast_in_dim( + # lengths = lax.convert_element_type(segment_lengths._data, 'int32') + lengths = lax.convert_element_type(segment_lengths, 'int32') + limits = lax.broadcast_in_dim( lengths, operand.shape, [axis_spec.stacked_axis]) mask = positions < limits - return jax.lax.select(mask, operand, jax.lax.broadcast(value, operand.shape)) + return lax.select(mask, operand, lax.broadcast(value, operand.shape)) def move_stacked_axis(operand, bdim, dst): dst = canonicalize_axis(dst, operand.ndim) @@ -1103,6 +1111,8 @@ def move_stacked_axis(operand, bdim, dst): ### general utilities for manipulating axes on jaxpr types (not vmappables) def broadcast(x, sz, axis, mesh_axis=None): + # Callers of this utility must be in a context where lax is importable. + from jax import lax # pytype: disable=import-error shape = list(np.shape(x)) shape.insert(axis, sz) broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis)) @@ -1114,7 +1124,7 @@ def broadcast(x, sz, axis, mesh_axis=None): # TODO(dougalm, yashkatariya): Delete this context manager once we figure # out how to ensure jaxpr arguments always have the context mesh. with mesh_lib.use_abstract_mesh(sharding.mesh): - x = jax.lax.broadcast_in_dim(x, shape, broadcast_dims, out_sharding=sharding) + x = lax.broadcast_in_dim(x, shape, broadcast_dims, out_sharding=sharding) if config._check_vma.value: # TODO(yashkatariya,parkers): don't do this, fix during fixit week 2026 spmd_names = core.get_axis_env().spmd_axis_names