From bb2127cebd3ade161109ee4919a92aaff5c788c1 Mon Sep 17 00:00:00 2001 From: Jake Vanderplas Date: Thu, 21 May 2020 09:20:59 -0700 Subject: [PATCH] Future-proof view test against signaling NaNs (#3178) --- jax/test_util.py | 2 +- tests/lax_numpy_test.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/jax/test_util.py b/jax/test_util.py index e34885fc8050..c8b85a341317 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -463,7 +463,7 @@ def _rand_dtype(rand, shape, dtype, scale=1., post=lambda x: x): return _cast_to_shape(np.asarray(post(vals), dtype), shape, dtype) -def rand_fullrange(rng, standardize_nans=True): +def rand_fullrange(rng, standardize_nans=False): """Random numbers that span the full range of available bits.""" def gen(shape, dtype, post=lambda x: x): dtype = np.dtype(dtype) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index e0290678cd47..cd5637550073 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2229,12 +2229,14 @@ def testView(self, shape, a_dtype, dtype): if not FLAGS.jax_enable_x64: if jnp.dtype(a_dtype).itemsize == 8 or jnp.dtype(dtype).itemsize == 8: self.skipTest("x64 types are disabled by jax_enable_x64") - rng = jtu.rand_fullrange(self.rng(), standardize_nans=True) + rng = jtu.rand_fullrange(self.rng()) args_maker = lambda: [rng(shape, a_dtype)] np_op = lambda x: np.asarray(x).view(dtype) jnp_op = lambda x: jnp.asarray(x).view(dtype) - self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True) + # Above may produce signaling nans; ignore warnings from invalid values. + with np.errstate(invalid='ignore'): + self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True) def testPathologicalFloats(self): args_maker = lambda: [np.array([