Skip to content

Commit

Permalink
Improve JAX test PRNG APIs to fix correlations between test cases. (j…
Browse files Browse the repository at this point in the history
…ax-ml#2957)

* Improve JAX test PRNG APIs to fix correlations between test cases.

In jax-ml#2863, we observed that we were missing gradient problems because the random test cases being generated were too similar because they were formed with identically seeded PRNGs. This change updates the test_util.rand_...() functions to take an explicit numpy.random.RandomState, and adds a rng() method to JaxTestCase to form a RandomState seeded on the test case name.

This gives the following properties:
* different test cases receive different seeds
* PRNG seeding is deterministic and independent of execution order and sharding.
* PRNG seeding is deterministic across runs.

* Fix some failing tests.

* Fix more test failures.

Simplify ediff1d implementation and make it more permissive when casting.

* Relax test tolerance of laplace CDF test.
  • Loading branch information
hawkinsp authored May 5, 2020
1 parent 3cd409e commit 7116cc5
Show file tree
Hide file tree
Showing 16 changed files with 518 additions and 555 deletions.
49 changes: 11 additions & 38 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,47 +1021,20 @@ def diff(a, n=1, axis=-1,):

return a

@_wraps(onp.ediff1d)
_EDIFF1D_DOC = """\
Unlike NumPy's implementation of ediff1d, :py:func:`jax.numpy.ediff1d` will not
issue an error if casting ``to_end`` or ``to_begin`` to the type of ``ary``
loses precision.
"""

@_wraps(onp.ediff1d, lax_description=_EDIFF1D_DOC)
def ediff1d(ary, to_end=None, to_begin=None):
# convert into 1d array
ary = ravel(asarray(ary))

# default case
if to_begin is None and to_end is None:
return lax.sub(ary[1:], ary[:-1])

# enforce propagation of the dtype of input ary to returned array
dtype_req = ary.dtype

if to_begin is None:
l_begin = 0
else:
# check if to_begin can be cast to ary dtype
if not can_cast(asarray(to_begin), dtype_req):
raise ValueError("cannot convert 'to_begin' to array with dtype "
"'%r' as required for input array operand" % dtype_req)
# convert to_begin to flat array
to_begin = ravel(asarray(to_begin, dtype=dtype_req))
l_begin = len(to_begin)

if to_end is None:
l_end = 0
else:
# check if to_end can be cast to ary dtype
if not can_cast(asarray(to_end), dtype_req):
raise ValueError("cannot convert 'to_end' to array with dtype "
"'%r' as required for input array operand" % dtype_req)
# convert to_end to flat array
to_end = ravel(asarray(to_end, dtype=dtype_req))
l_end = len(to_end)

# calculate difference and copy to_begin and to_end
l_diff = _max(len(ary) - 1, 0)
result = lax.sub(ary[1:], ary[:-1])
if l_begin > 0:
result = concatenate((to_begin, result))
if l_end > 0:
result = concatenate((result, to_end))
if to_begin is not None:
result = concatenate((ravel(asarray(to_begin, dtype=ary.dtype)), result))
if to_end is not None:
result = concatenate((result, ravel(asarray(to_end, dtype=ary.dtype))))
return result


Expand Down
83 changes: 38 additions & 45 deletions jax/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import sys
import unittest
import warnings
import zlib

from absl.testing import absltest
from absl.testing import parameterized
Expand Down Expand Up @@ -470,48 +471,39 @@ def _rand_dtype(rand, shape, dtype, scale=1., post=lambda x: x):
return _cast_to_shape(onp.asarray(post(vals), dtype), shape, dtype)


def rand_default(scale=3):
randn = npr.RandomState(0).randn
return partial(_rand_dtype, randn, scale=scale)
def rand_default(rng, scale=3):
return partial(_rand_dtype, rng.randn, scale=scale)


def rand_nonzero():
def rand_nonzero(rng):
post = lambda x: onp.where(x == 0, onp.array(1, dtype=x.dtype), x)
randn = npr.RandomState(0).randn
return partial(_rand_dtype, randn, scale=3, post=post)
return partial(_rand_dtype, rng.randn, scale=3, post=post)


def rand_positive():
def rand_positive(rng):
post = lambda x: x + 1
rand = npr.RandomState(0).rand
return partial(_rand_dtype, rand, scale=2, post=post)
return partial(_rand_dtype, rng.rand, scale=2, post=post)


def rand_small():
randn = npr.RandomState(0).randn
return partial(_rand_dtype, randn, scale=1e-3)
def rand_small(rng):
return partial(_rand_dtype, rng.randn, scale=1e-3)


def rand_not_small(offset=10.):
def rand_not_small(rng, offset=10.):
post = lambda x: x + onp.where(x > 0, offset, -offset)
randn = npr.RandomState(0).randn
return partial(_rand_dtype, randn, scale=3., post=post)
return partial(_rand_dtype, rng.randn, scale=3., post=post)


def rand_small_positive():
rand = npr.RandomState(0).rand
return partial(_rand_dtype, rand, scale=2e-5)
def rand_small_positive(rng):
return partial(_rand_dtype, rng.rand, scale=2e-5)

def rand_uniform(low=0.0, high=1.0):
def rand_uniform(rng, low=0.0, high=1.0):
assert low < high
rand = npr.RandomState(0).rand
post = lambda x: x * (high - low) + low
return partial(_rand_dtype, rand, post=post)
return partial(_rand_dtype, rng.rand, post=post)


def rand_some_equal():
randn = npr.RandomState(0).randn
rng = npr.RandomState(0)
def rand_some_equal(rng):

def post(x):
x_ravel = x.ravel()
Expand All @@ -520,13 +512,12 @@ def post(x):
flips = rng.rand(*onp.shape(x)) < 0.5
return onp.where(flips, x_ravel[0], x)

return partial(_rand_dtype, randn, scale=100., post=post)
return partial(_rand_dtype, rng.randn, scale=100., post=post)


def rand_some_inf():
def rand_some_inf(rng):
"""Return a random sampler that produces infinities in floating types."""
rng = npr.RandomState(1)
base_rand = rand_default()
base_rand = rand_default(rng)

"""
TODO: Complex numbers are not correctly tested
Expand Down Expand Up @@ -556,10 +547,9 @@ def rand(shape, dtype):

return rand

def rand_some_nan():
def rand_some_nan(rng):
"""Return a random sampler that produces nans in floating types."""
rng = npr.RandomState(1)
base_rand = rand_default()
base_rand = rand_default(rng)

def rand(shape, dtype):
"""The random sampler function."""
Expand All @@ -583,10 +573,9 @@ def rand(shape, dtype):

return rand

def rand_some_inf_and_nan():
def rand_some_inf_and_nan(rng):
"""Return a random sampler that produces infinities in floating types."""
rng = npr.RandomState(1)
base_rand = rand_default()
base_rand = rand_default(rng)

"""
TODO: Complex numbers are not correctly tested
Expand Down Expand Up @@ -619,10 +608,9 @@ def rand(shape, dtype):
return rand

# TODO(mattjj): doesn't handle complex types
def rand_some_zero():
def rand_some_zero(rng):
"""Return a random sampler that produces some zeros."""
rng = npr.RandomState(1)
base_rand = rand_default()
base_rand = rand_default(rng)

def rand(shape, dtype):
"""The random sampler function."""
Expand All @@ -637,21 +625,18 @@ def rand(shape, dtype):
return rand


def rand_int(low, high=None):
randint = npr.RandomState(0).randint
def rand_int(rng, low=0, high=None):
def fn(shape, dtype):
return randint(low, high=high, size=shape, dtype=dtype)
return rng.randint(low, high=high, size=shape, dtype=dtype)
return fn

def rand_unique_int(high=None):
randchoice = npr.RandomState(0).choice
def rand_unique_int(rng, high=None):
def fn(shape, dtype):
return randchoice(onp.arange(high or onp.prod(shape), dtype=dtype),
return rng.choice(onp.arange(high or onp.prod(shape), dtype=dtype),
size=shape, replace=False)
return fn

def rand_bool():
rng = npr.RandomState(0)
def rand_bool(rng):
def generator(shape, dtype):
return _cast_to_shape(rng.rand(*_dims_of_shape(shape)) < 0.5, shape, dtype)
return generator
Expand Down Expand Up @@ -718,7 +703,15 @@ class JaxTestCase(parameterized.TestCase):
# assert core.reset_trace_state()

def setUp(self):
super(JaxTestCase, self).setUp()
core.skip_checks = False
# We use the adler32 hash for two reasons.
# a) it is deterministic run to run, unlike hash() which is randomized.
# b) it returns values in int32 range, which RandomState requires.
self._rng = npr.RandomState(zlib.adler32(self._testMethodName.encode()))

def rng(self):
return self._rng

def assertArraysAllClose(self, x, y, check_dtypes, atol=None, rtol=None):
"""Assert that x and y are close (up to numerical tolerances)."""
Expand Down
10 changes: 6 additions & 4 deletions tests/array_interoperability_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@

class DLPackTest(jtu.JaxTestCase):
def setUp(self):
super(DLPackTest, self).setUp()
if jtu.device_under_test() == "tpu":
self.skipTest("DLPack not supported on TPU")

Expand All @@ -64,7 +65,7 @@ def setUp(self):
for shape in all_shapes
for dtype in dlpack_dtypes))
def testJaxRoundTrip(self, shape, dtype):
rng = jtu.rand_default()
rng = jtu.rand_default(self.rng())
np = rng(shape, dtype)
x = jnp.array(np)
dlpack = jax.dlpack.to_dlpack(x)
Expand All @@ -83,7 +84,7 @@ def testJaxRoundTrip(self, shape, dtype):
for dtype in torch_dtypes))
@unittest.skipIf(not torch, "Test requires PyTorch")
def testTorchToJax(self, shape, dtype):
rng = jtu.rand_default()
rng = jtu.rand_default(self.rng())
np = rng(shape, dtype)
x = torch.from_numpy(np)
x = x.cuda() if jtu.device_under_test() == "gpu" else x
Expand All @@ -99,7 +100,7 @@ def testTorchToJax(self, shape, dtype):
for dtype in torch_dtypes))
@unittest.skipIf(not torch, "Test requires PyTorch")
def testJaxToTorch(self, shape, dtype):
rng = jtu.rand_default()
rng = jtu.rand_default(self.rng())
np = rng(shape, dtype)
x = jnp.array(np)
dlpack = jax.dlpack.to_dlpack(x)
Expand All @@ -110,6 +111,7 @@ def testJaxToTorch(self, shape, dtype):
class CudaArrayInterfaceTest(jtu.JaxTestCase):

def setUp(self):
super(CudaArrayInterfaceTest, self).setUp()
if jtu.device_under_test() != "gpu":
self.skipTest("__cuda_array_interface__ is only supported on GPU")

Expand All @@ -121,7 +123,7 @@ def setUp(self):
for dtype in dlpack_dtypes))
@unittest.skipIf(not cupy, "Test requires CuPy")
def testJaxToCuPy(self, shape, dtype):
rng = jtu.rand_default()
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
y = jnp.array(x)
z = cupy.asarray(y)
Expand Down
36 changes: 18 additions & 18 deletions tests/batching_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,12 +673,12 @@ def testLaxLinalgTriangularSolve(self):
start_index_map=(0, 1)),
(1, 3)),
]
for rng_idx_factory in [partial(jtu.rand_int, max(shape))]
for rng_idx_factory in [partial(jtu.rand_int, high=max(shape))]
for rng_factory in [jtu.rand_default])
def testGatherBatchedOperand(self, axis, shape, dtype, idxs, dnums,
slice_sizes, rng_factory, rng_idx_factory):
rng = rng_factory()
rng_idx = rng_idx_factory()
rng = rng_factory(self.rng())
rng_idx = rng_idx_factory(self.rng())
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
operand = rng(shape, dtype)
ans = vmap(fun, (axis, None))(operand, idxs)
Expand Down Expand Up @@ -709,12 +709,12 @@ def testGatherBatchedOperand(self, axis, shape, dtype, idxs, dnums,
offset_dims=(1,), collapsed_slice_dims=(0,),
start_index_map=(0, 1)),
(1, 3)), ]
for rng_idx_factory in [partial(jtu.rand_int, max(shape))]
for rng_idx_factory in [partial(jtu.rand_int, high=max(shape))]
for rng_factory in [jtu.rand_default])
def testGatherGradBatchedOperand(self, axis, shape, dtype, idxs, dnums,
slice_sizes, rng_factory, rng_idx_factory):
rng = rng_factory()
rng_idx = rng_idx_factory()
rng = rng_factory(self.rng())
rng_idx = rng_idx_factory(self.rng())
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx))))
operand = rng(shape, dtype)
Expand Down Expand Up @@ -744,12 +744,12 @@ def testGatherGradBatchedOperand(self, axis, shape, dtype, idxs, dnums,
[[1, 0], [2, 3]]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)),
]
for rng_idx_factory in [partial(jtu.rand_int, max(shape))]
for rng_idx_factory in [partial(jtu.rand_int, high=max(shape))]
for rng_factory in [jtu.rand_default])
def testGatherBatchedIndices(self, axis, shape, dtype, idxs, dnums,
slice_sizes, rng_factory, rng_idx_factory):
rng = rng_factory()
rng_idx = rng_idx_factory()
rng = rng_factory(self.rng())
rng_idx = rng_idx_factory(self.rng())
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
operand = rng(shape, dtype)
ans = vmap(fun, (None, axis))(operand, idxs)
Expand Down Expand Up @@ -778,12 +778,12 @@ def testGatherBatchedIndices(self, axis, shape, dtype, idxs, dnums,
[[1, 0], [2, 3]]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)),
]
for rng_idx_factory in [partial(jtu.rand_int, max(shape))]
for rng_idx_factory in [partial(jtu.rand_int, high=max(shape))]
for rng_factory in [jtu.rand_default])
def testGatherGradBatchedIndices(self, axis, shape, dtype, idxs, dnums,
slice_sizes, rng_factory, rng_idx_factory):
rng = rng_factory()
rng_idx = rng_idx_factory()
rng = rng_factory(self.rng())
rng_idx = rng_idx_factory(self.rng())
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx))))
operand = rng(shape, dtype)
Expand Down Expand Up @@ -819,12 +819,12 @@ def testGatherGradBatchedIndices(self, axis, shape, dtype, idxs, dnums,
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
(1, 3)),
]
for rng_idx_factory in [partial(jtu.rand_int, max(shape))]
for rng_idx_factory in [partial(jtu.rand_int, high=max(shape))]
for rng_factory in [jtu.rand_default])
def testGatherBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums,
slice_sizes, rng_factory, rng_idx_factory):
rng = rng_factory()
rng_idx = rng_idx_factory()
rng = rng_factory(self.rng())
rng_idx = rng_idx_factory(self.rng())
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
operand = rng(shape, dtype)
assert operand.shape[op_axis] == idxs.shape[idxs_axis]
Expand Down Expand Up @@ -861,12 +861,12 @@ def testGatherBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums,
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
(1, 3)),
]
for rng_idx_factory in [partial(jtu.rand_int, max(shape))]
for rng_idx_factory in [partial(jtu.rand_int, high=max(shape))]
for rng_factory in [jtu.rand_default])
def testGatherGradBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums,
slice_sizes, rng_factory, rng_idx_factory):
rng = rng_factory()
rng_idx = rng_idx_factory()
rng = rng_factory(self.rng())
rng_idx = rng_idx_factory(self.rng())
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx))))
operand = rng(shape, dtype)
Expand Down
Loading

0 comments on commit 7116cc5

Please sign in to comment.