diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index bd7aac66ce77..05ae62aac03d 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -47,7 +47,7 @@ from .. import lazy from ..abstract_arrays import (ConcreteArray, ShapedArray, array_types, raise_to_shaped) -from ..util import (partial, unzip2, prod, safe_map, safe_zip, +from ..util import (partial, unzip2, unzip3, prod, safe_map, safe_zip, extend_name_stack, wrap_name) from ..lib import xla_bridge as xb from ..lib import xla_client as xc @@ -211,7 +211,38 @@ def _shard_array(x, devices, indices): return [xla.device_put(x[i], d) for (i, d) in zip(indices, devices)] for _t in array_types: shard_arg_handlers[_t] = _shard_array -shard_arg_handlers[xla.DeviceArray] = _shard_array + +def _shard_device_array(x, devices, indices): + start_indices, limit_indices, removed_dims = map(tuple, unzip3( + _as_slice_indices(x, idx) for idx in indices)) + shards = x._multi_slice(start_indices, limit_indices, removed_dims) + return [xla.device_put(s, d) for s, d in zip(shards, devices)] +shard_arg_handlers[xla.DeviceArray] = _shard_device_array + +# NOTE(skye): we could refactor to generate _multi_slice parameters directly +# from the input ShardingSpec, rather than the indices. However, this would +# require duplicating the ordering logic of spec_to_indices, which is more +# subtle and more likely to change than the index logic we have to support here. +def _as_slice_indices(arr: xla.DeviceArray, idx: Index) -> Tuple[ + Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]: + """Returns start_indices, limit_indices, removed_dims""" + start_indices = [0] * arr.ndim + limit_indices = list(arr.shape) + removed_dims = [] + + tuple_idx = idx if isinstance(idx, tuple) else (idx,) + for dim, sub_idx in enumerate(tuple_idx): + if isinstance(sub_idx, int): + start_indices[dim] = sub_idx + limit_indices[dim] = sub_idx + 1 + removed_dims.append(dim) + else: + assert isinstance(sub_idx, slice) + start_indices[dim] = sub_idx.start + limit_indices[dim] = sub_idx.stop + + return tuple(start_indices), tuple(limit_indices), tuple(removed_dims) + def shard_aval(size, aval): try: diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 5dd39df35769..50a47d5450f8 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -48,7 +48,8 @@ from ..interpreters.masking import Poly from .. import lax from .. import ops -from ..util import partial, get_module_functions, unzip2, prod as _prod, subvals +from ..util import (partial, get_module_functions, unzip2, prod as _prod, + subvals, safe_zip) from ..lib import pytree from ..lib import xla_client @@ -390,7 +391,7 @@ def fmax(x1, x2): return where((x1 > x2) | isnan(x2), x1, x2) @_wraps(onp.finfo) -def finfo(dtype): +def finfo(dtype): return dtypes.finfo(dtype) @_wraps(onp.issubdtype) @@ -724,7 +725,7 @@ def _conv(x, y, mode, op, precision): if ndim(x) != 1 or ndim(y) != 1: raise ValueError(f"{op}() only support 1-dimensional inputs.") x, y = _promote_dtypes_inexact(x, y) - + out_order = slice(None) if len(x) < len(y): x, y = y, x @@ -3967,12 +3968,24 @@ def _operator_round(number, ndigits=None): setattr(DeviceArray, "broadcast_in_dim", lax.broadcast_in_dim) setattr(DeviceArray, "split", split) -@jit -def _unstack(x): - if x.ndim == 0: - raise ValueError("Argument to _unstack must be non-scalar") - return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])] -setattr(DeviceArray, "_unstack", _unstack) +@partial(jit, static_argnums=(1,2,3)) +def _multi_slice(arr: DeviceArray, + start_indices: Tuple[Tuple[int, ...]], + limit_indices: Tuple[Tuple[int, ...]], + removed_dims: Tuple[Tuple[int, ...]]): + """Extracts multiple slices from `arr`. + + This is used to shard DeviceArray arguments to pmap. It's implemented as a + DeviceArray method here to avoid circular imports. + """ + results = [] + for starts, limits, removed in safe_zip(start_indices, limit_indices, removed_dims): + sliced = lax.slice(arr, starts, limits) + if removed_dims: + sliced = sliced.reshape(onp.delete(arr.shape, removed_dims)) + results.append(sliced) + return results +setattr(DeviceArray, "_multi_slice", _multi_slice) # Syntactic sugar for scatter operations.