From 3216f5ca4647a00a5a53663846a56dc0faf1fc01 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Thu, 30 Apr 2020 08:31:48 -0700 Subject: [PATCH] err on empty operand in numpy argmin and argmax fixes #2899 --- jax/numpy/lax_numpy.py | 8 +++++--- tests/lax_numpy_test.py | 10 ++++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index a29f90f51a72..bd7c4bb440e9 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -2827,7 +2827,7 @@ def argmax(a, axis=None): if axis is None: a = ravel(a) axis = 0 - return _argminmax(max, a, axis) + return _argminmax("argmax", max, a, axis) _NANARG_DOC = """\ @@ -2850,7 +2850,7 @@ def argmin(a, axis=None): if axis is None: a = ravel(a) axis = 0 - return _argminmax(min, a, axis) + return _argminmax("argmin", min, a, axis) @_wraps(onp.nanargmin, lax_description=_NANARG_DOC.format("min")) @@ -2864,7 +2864,9 @@ def nanargmin(a, axis=None): # TODO(mattjj): redo this lowering with a call to variadic lax.reduce -def _argminmax(op, a, axis): +def _argminmax(name, op, a, axis): + if size(a) == 0: + raise ValueError("attempt to get {} of an empty sequence".format(name)) shape = [1] * a.ndim shape[axis] = a.shape[axis] idxs = lax.tie_in(a, arange(a.shape[axis])).reshape(shape) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index a68d30225367..446534ec1a6a 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -696,6 +696,16 @@ def jnp_fun(array_to_reduce): raise self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": rec.test_name.capitalize(), + "name": rec.name, "jnp_op": getattr(jnp, rec.name)} + for rec in JAX_ARGMINMAX_RECORDS)) + def testArgMinMaxEmpty(self, name, jnp_op): + name = name[3:] if name.startswith("nan") else name + msg = "attempt to get {} of an empty sequence".format(name) + with self.assertRaises(ValueError, msg=msg): + jnp_op(onp.array([])) + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_{}_{}".format( jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),