Skip to content

Commit 347481b

Browse files
author
jax authors
committed
Merge pull request jax-ml#5768 from jakevdp:fix-take
PiperOrigin-RevId: 358049345
2 parents ba269b4 + 28a19ff commit 347481b

File tree

3 files changed

+12
-4
lines changed

3 files changed

+12
-4
lines changed

docs/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
1717
from JAX ([#5627](https://github.com/google/jax/pull/5627)
1818
and [README](https://github.com/google/jax/blob/master/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax)).
1919
* Bug fixes:
20+
* {func}`jax.numpy.take` properly handles negative indices
21+
([#5768](https://github.com/google/jax/pull/5768))
2022
* Breaking changes:
2123
* JAX's promotion rules were adjusted to make promotion more consistent and
2224
invariant to JIT. In particular, binary operations can now result in weakly-typed

jax/_src/numpy/lax_numpy.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4142,12 +4142,15 @@ def take(a, indices, axis: Optional[int] = None, out=None, mode=None):
41424142
else:
41434143
axis_idx = _canonicalize_axis(axis, ndim(a))
41444144

4145-
if mode == "raise":
4145+
if mode is None:
4146+
# lax.gather() does not support negative indices, so we wrap them here
4147+
indices = where(indices < 0, indices + a.shape[axis_idx], indices)
4148+
elif mode == "raise":
41464149
# TODO(phawkins): we have no way to report out of bounds errors yet.
41474150
raise NotImplementedError("The 'raise' mode to jnp.take is not supported.")
41484151
elif mode == "wrap":
41494152
indices = mod(indices, _constant_like(indices, a.shape[axis_idx]))
4150-
elif mode != "clip" and mode is not None:
4153+
elif mode != "clip":
41514154
raise ValueError("Invalid mode '{}' for np.take".format(mode))
41524155

41534156
index_dims = len(shape(indices))

tests/lax_numpy_test.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3621,15 +3621,18 @@ def testUnpackbits(self, shape, dtype, axis, bitorder, count):
36213621
[cast(Optional[int], None)])
36223622
for dtype in all_dtypes
36233623
for index_dtype in int_dtypes
3624-
for mode in ['wrap', 'clip']))
3624+
for mode in [None, 'wrap', 'clip']))
36253625
def testTake(self, shape, dtype, index_shape, index_dtype, axis, mode):
36263626
def args_maker():
36273627
x = rng(shape, dtype)
36283628
i = rng_indices(index_shape, index_dtype)
36293629
return x, i
36303630

36313631
rng = jtu.rand_default(self.rng())
3632-
rng_indices = jtu.rand_int(self.rng(), -5, 5)
3632+
if mode is None:
3633+
rng_indices = jtu.rand_int(self.rng(), -shape[axis or 0], shape[axis or 0])
3634+
else:
3635+
rng_indices = jtu.rand_int(self.rng(), -5, 5)
36333636
jnp_op = lambda x, i: jnp.take(x, i, axis=axis, mode=mode)
36343637
np_op = lambda x, i: np.take(x, i, axis=axis, mode=mode)
36353638
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)

0 commit comments

Comments
 (0)