Skip to content

Commit

Permalink
err on empty operand in numpy argmin and argmax
Browse files Browse the repository at this point in the history
  • Loading branch information
froystig committed Apr 30, 2020
1 parent 8d4b685 commit 3216f5c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
8 changes: 5 additions & 3 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2827,7 +2827,7 @@ def argmax(a, axis=None):
if axis is None:
a = ravel(a)
axis = 0
return _argminmax(max, a, axis)
return _argminmax("argmax", max, a, axis)


_NANARG_DOC = """\
Expand All @@ -2850,7 +2850,7 @@ def argmin(a, axis=None):
if axis is None:
a = ravel(a)
axis = 0
return _argminmax(min, a, axis)
return _argminmax("argmin", min, a, axis)


@_wraps(onp.nanargmin, lax_description=_NANARG_DOC.format("min"))
Expand All @@ -2864,7 +2864,9 @@ def nanargmin(a, axis=None):


# TODO(mattjj): redo this lowering with a call to variadic lax.reduce
def _argminmax(op, a, axis):
def _argminmax(name, op, a, axis):
if size(a) == 0:
raise ValueError("attempt to get {} of an empty sequence".format(name))
shape = [1] * a.ndim
shape[axis] = a.shape[axis]
idxs = lax.tie_in(a, arange(a.shape[axis])).reshape(shape)
Expand Down
10 changes: 10 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,16 @@ def jnp_fun(array_to_reduce):
raise
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": rec.test_name.capitalize(),
"name": rec.name, "jnp_op": getattr(jnp, rec.name)}
for rec in JAX_ARGMINMAX_RECORDS))
def testArgMinMaxEmpty(self, name, jnp_op):
name = name[3:] if name.startswith("nan") else name
msg = "attempt to get {} of an empty sequence".format(name)
with self.assertRaises(ValueError, msg=msg):
jnp_op(onp.array([]))

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_{}".format(
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
Expand Down

0 comments on commit 3216f5c

Please sign in to comment.