Skip to content

Commit

Permalink
New and improved _shard_device_array function. (jax-ml#2958)
Browse files Browse the repository at this point in the history
This gets the performance of sharding DeviceArray arguments to pmap roughly back to what it was prior to jax-ml@07571ae. It does so by re-introducing a _shard_device_array function that can handle arbitrary array slices.

Benchmark results compared to jax-ml@87d9590 (i.e. just prior to the regression):
```
---------Benchmark summary for pmap_shard_device_array---------
  nargs    nshards       mean      %std    relative    mean/baseline
-------  ---------  ---------  --------  ----------  ---------------
     10          8  0.0479975  12.0865      1                1.09631
    100          8  0.32916     5.7446      6.85786          1.10263
    500          8  1.5563      2.68041    32.4246           1.10066
    100          2  0.136431    8.33826     2.84245          1.15886
    100          4  0.198815    5.91716     4.1422           1.11409
    100          8  0.31788     4.80559     6.62285          1.06637
```

This still seems a bit slower than it was before, but gets most of the performance back. We can further optimize in future changes if needed.

Fixes jax-ml#2958 (hopefully)
  • Loading branch information
skye authored May 5, 2020
1 parent 61a34f5 commit 0eba939
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 11 deletions.
35 changes: 33 additions & 2 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from .. import lazy
from ..abstract_arrays import (ConcreteArray, ShapedArray, array_types,
raise_to_shaped)
from ..util import (partial, unzip2, prod, safe_map, safe_zip,
from ..util import (partial, unzip2, unzip3, prod, safe_map, safe_zip,
extend_name_stack, wrap_name)
from ..lib import xla_bridge as xb
from ..lib import xla_client as xc
Expand Down Expand Up @@ -211,7 +211,38 @@ def _shard_array(x, devices, indices):
return [xla.device_put(x[i], d) for (i, d) in zip(indices, devices)]
for _t in array_types:
shard_arg_handlers[_t] = _shard_array
shard_arg_handlers[xla.DeviceArray] = _shard_array

def _shard_device_array(x, devices, indices):
start_indices, limit_indices, removed_dims = map(tuple, unzip3(
_as_slice_indices(x, idx) for idx in indices))
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
return [xla.device_put(s, d) for s, d in zip(shards, devices)]
shard_arg_handlers[xla.DeviceArray] = _shard_device_array

# NOTE(skye): we could refactor to generate _multi_slice parameters directly
# from the input ShardingSpec, rather than the indices. However, this would
# require duplicating the ordering logic of spec_to_indices, which is more
# subtle and more likely to change than the index logic we have to support here.
def _as_slice_indices(arr: xla.DeviceArray, idx: Index) -> Tuple[
Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]:
"""Returns start_indices, limit_indices, removed_dims"""
start_indices = [0] * arr.ndim
limit_indices = list(arr.shape)
removed_dims = []

tuple_idx = idx if isinstance(idx, tuple) else (idx,)
for dim, sub_idx in enumerate(tuple_idx):
if isinstance(sub_idx, int):
start_indices[dim] = sub_idx
limit_indices[dim] = sub_idx + 1
removed_dims.append(dim)
else:
assert isinstance(sub_idx, slice)
start_indices[dim] = sub_idx.start
limit_indices[dim] = sub_idx.stop

return tuple(start_indices), tuple(limit_indices), tuple(removed_dims)


def shard_aval(size, aval):
try:
Expand Down
31 changes: 22 additions & 9 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@
from ..interpreters.masking import Poly
from .. import lax
from .. import ops
from ..util import partial, get_module_functions, unzip2, prod as _prod, subvals
from ..util import (partial, get_module_functions, unzip2, prod as _prod,
subvals, safe_zip)
from ..lib import pytree
from ..lib import xla_client

Expand Down Expand Up @@ -390,7 +391,7 @@ def fmax(x1, x2):
return where((x1 > x2) | isnan(x2), x1, x2)

@_wraps(onp.finfo)
def finfo(dtype):
def finfo(dtype):
return dtypes.finfo(dtype)

@_wraps(onp.issubdtype)
Expand Down Expand Up @@ -724,7 +725,7 @@ def _conv(x, y, mode, op, precision):
if ndim(x) != 1 or ndim(y) != 1:
raise ValueError(f"{op}() only support 1-dimensional inputs.")
x, y = _promote_dtypes_inexact(x, y)

out_order = slice(None)
if len(x) < len(y):
x, y = y, x
Expand Down Expand Up @@ -3967,12 +3968,24 @@ def _operator_round(number, ndigits=None):
setattr(DeviceArray, "broadcast_in_dim", lax.broadcast_in_dim)
setattr(DeviceArray, "split", split)

@jit
def _unstack(x):
if x.ndim == 0:
raise ValueError("Argument to _unstack must be non-scalar")
return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])]
setattr(DeviceArray, "_unstack", _unstack)
@partial(jit, static_argnums=(1,2,3))
def _multi_slice(arr: DeviceArray,
start_indices: Tuple[Tuple[int, ...]],
limit_indices: Tuple[Tuple[int, ...]],
removed_dims: Tuple[Tuple[int, ...]]):
"""Extracts multiple slices from `arr`.
This is used to shard DeviceArray arguments to pmap. It's implemented as a
DeviceArray method here to avoid circular imports.
"""
results = []
for starts, limits, removed in safe_zip(start_indices, limit_indices, removed_dims):
sliced = lax.slice(arr, starts, limits)
if removed_dims:
sliced = sliced.reshape(onp.delete(arr.shape, removed_dims))
results.append(sliced)
return results
setattr(DeviceArray, "_multi_slice", _multi_slice)


# Syntactic sugar for scatter operations.
Expand Down

0 comments on commit 0eba939

Please sign in to comment.