diff --git a/tests/random_test.py b/tests/random_test.py index 33027da6f5a4..a1e6e4e4c013 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -246,8 +246,8 @@ def testBernoulli(self, p, dtype): for (p, axis) in [ ([.25] * 4, -1), ([.1, .2, .3, .4], -1), - ([[.25, .25], [.1, .9]], 1), - ([[.25, .1], [.25, .9]], 0), + ([[.5, .5], [.1, .9]], 1), + ([[.5, .1], [.5, .9]], 0), ] for sample_shape in [(10000,), (5000, 2)] for dtype in [onp.float32, onp.float64])) @@ -255,25 +255,24 @@ def testCategorical(self, p, axis, dtype, sample_shape): key = random.PRNGKey(0) p = onp.array(p, dtype=dtype) logits = onp.log(p) - 42 # test unnormalized - shape = sample_shape + tuple(onp.delete(logits.shape, axis)) + out_shape = tuple(onp.delete(logits.shape, axis)) + shape = sample_shape + out_shape rand = lambda key, p: random.categorical(key, logits, shape=shape, axis=axis) crand = api.jit(rand) uncompiled_samples = rand(key, p) compiled_samples = crand(key, p) - if p.ndim > 1: - self.skipTest("multi-dimensional categorical tests are currently broken!") + if axis < 0: + axis += len(logits.shape) for samples in [uncompiled_samples, compiled_samples]: - if axis < 0: - axis += len(logits.shape) - assert samples.shape == shape - + samples = np.reshape(samples, (10000,) + out_shape) if len(p.shape[:-1]) > 0: - for cat_index, p_ in enumerate(p): - self._CheckChiSquared(samples[:, cat_index], pmf=lambda x: p_[x]) + ps = onp.transpose(p, (1, 0)) if axis == 0 else p + for cat_samples, cat_p in zip(samples.transpose(), ps): + self._CheckChiSquared(cat_samples, pmf=lambda x: cat_p[x]) else: self._CheckChiSquared(samples, pmf=lambda x: p[x])