Skip to content

Commit

Permalink
Merge pull request jax-ml#3331 from JuliusKunze/mask-split
Browse files Browse the repository at this point in the history
Allow mask(jnp.split)
  • Loading branch information
mattjj authored Jun 9, 2020
2 parents b7175a3 + ea78222 commit 307701c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
12 changes: 7 additions & 5 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,17 +1354,19 @@ def broadcast_to(arr, shape):

@_wraps(np.split)
def split(ary, indices_or_sections, axis=0):
dummy_val = np.broadcast_to(0, ary.shape) # zero strides
axis = core.concrete_or_error(int, axis, "in jax.numpy.split argument `axis`")
size = ary.shape[axis]
if isinstance(indices_or_sections, (tuple, list) + _arraylike_types):
indices_or_sections = [core.concrete_or_error(int, i_s, "in jax.numpy.split argument 1")
for i_s in indices_or_sections]
split_indices = np.concatenate([[0], indices_or_sections, [size]])
else:
indices_or_sections = core.concrete_or_error(int, indices_or_sections,
"in jax.numpy.split argument 1")
axis = core.concrete_or_error(int, axis, "in jax.numpy.split argument `axis`")

subarrays = np.split(dummy_val, indices_or_sections, axis) # shapes
split_indices = np.cumsum([0] + [np.shape(sub)[axis] for sub in subarrays])
part_size, r = _divmod(size, indices_or_sections)
if r != 0:
raise ValueError("array split does not result in an equal division")
split_indices = np.arange(indices_or_sections + 1) * part_size
starts, ends = [0] * ndim(ary), shape(ary)
_subval = lambda x, i, v: subvals(x, [(i, v)])
return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end))
Expand Down
5 changes: 4 additions & 1 deletion tests/masking_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,10 @@ def test_where(self):
{'n': 2}, [(3,)], ['float_'], jtu.rand_default(self.rng()))

def test_split(self):
raise SkipTest
self.check(lambda x: jnp.split(x, 2), ['2*n'], ['n', 'n'], dict(n=4),
[(8,)], ['float_'], rand_default(self.rng()))
self.check(lambda x: jnp.split(x, [10]), ['n'], ['10', 'n+-10'], dict(n=12),
[(12,)], ['float_'], rand_default(self.rng()))

@parameterized.named_parameters(jtu.cases_from_list([{
'testcase_name': "operator={}".format(operator.__name__), 'operator': operator}
Expand Down

0 comments on commit 307701c

Please sign in to comment.