Skip to content

Commit

Permalink
Future-proof view test against signaling NaNs (jax-ml#3178)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp authored May 21, 2020
1 parent c459280 commit bb2127c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion jax/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def _rand_dtype(rand, shape, dtype, scale=1., post=lambda x: x):
return _cast_to_shape(np.asarray(post(vals), dtype), shape, dtype)


def rand_fullrange(rng, standardize_nans=True):
def rand_fullrange(rng, standardize_nans=False):
"""Random numbers that span the full range of available bits."""
def gen(shape, dtype, post=lambda x: x):
dtype = np.dtype(dtype)
Expand Down
8 changes: 5 additions & 3 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2229,12 +2229,14 @@ def testView(self, shape, a_dtype, dtype):
if not FLAGS.jax_enable_x64:
if jnp.dtype(a_dtype).itemsize == 8 or jnp.dtype(dtype).itemsize == 8:
self.skipTest("x64 types are disabled by jax_enable_x64")
rng = jtu.rand_fullrange(self.rng(), standardize_nans=True)
rng = jtu.rand_fullrange(self.rng())
args_maker = lambda: [rng(shape, a_dtype)]
np_op = lambda x: np.asarray(x).view(dtype)
jnp_op = lambda x: jnp.asarray(x).view(dtype)
self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
# Above may produce signaling nans; ignore warnings from invalid values.
with np.errstate(invalid='ignore'):
self._CheckAgainstNumpy(jnp_op, np_op, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)

def testPathologicalFloats(self):
args_maker = lambda: [np.array([
Expand Down

0 comments on commit bb2127c

Please sign in to comment.