diff --git a/tensorflow_probability/python/internal/backend/numpy/random_generators.py b/tensorflow_probability/python/internal/backend/numpy/random_generators.py index b14238caa1..004e2bf6f9 100644 --- a/tensorflow_probability/python/internal/backend/numpy/random_generators.py +++ b/tensorflow_probability/python/internal/backend/numpy/random_generators.py @@ -66,9 +66,18 @@ def _bcast_shape(base_shape, args): return bcast_shape +def _rng_from_seed(seed): + if seed is None: + return np.random + elif isinstance(seed, int): + return np.random.RandomState(seed & 0xFFFFFFFF) + else: + return np.random.RandomState(np.array(seed, dtype=np.uint32)) + + def _binomial(shape, seed, counts, probs, output_dtype=np.int32, name=None): # pylint: disable=unused-argument """Massaging dtype and nan handling of np.random.binomial.""" - rng = np.random if seed is None else np.random.RandomState(seed & 0xffffffff) + rng = _rng_from_seed(seed) invalid_count = (np.int64(counts) < 0) != (counts < 0) if np.any(invalid_count): raise ValueError('int64 overflow: {} -> {}'.format( @@ -82,7 +91,7 @@ def _binomial(shape, seed, counts, probs, output_dtype=np.int32, name=None): # def _categorical(logits, num_samples, dtype=None, seed=None, name=None): # pylint: disable=unused-argument - rng = np.random if seed is None else np.random.RandomState(seed & 0xffffffff) + rng = _rng_from_seed(seed) dtype = utils.numpy_dtype(dtype or np.int64) if not hasattr(logits, 'shape'): logits = np.array(logits, np.float32) @@ -107,7 +116,7 @@ def _categorical_jax(logits, num_samples, dtype=None, seed=None, name=None): # def _gamma(shape, alpha, beta=None, dtype=np.float32, seed=None, name=None): # pylint: disable=unused-argument - rng = np.random if seed is None else np.random.RandomState(seed & 0xffffffff) + rng = _rng_from_seed(seed) scale = 1. if beta is None else (1. / beta) shape = _ensure_shape_tuple(shape) return rng.gamma(shape=alpha, scale=scale, size=shape).astype(dtype) @@ -133,7 +142,7 @@ def _gamma_jax(shape, alpha, beta=None, dtype=np.float32, seed=None, name=None): def _normal(shape, mean=0.0, stddev=1.0, dtype=np.float32, seed=None, name=None): # pylint: disable=unused-argument - rng = np.random if seed is None else np.random.RandomState(seed & 0xffffffff) + rng = _rng_from_seed(seed) dtype = utils.common_dtype([mean, stddev], dtype_hint=dtype) shape = _bcast_shape(shape, [mean, stddev]) return rng.normal(loc=mean, scale=stddev, size=shape).astype(dtype) @@ -151,7 +160,7 @@ def _normal_jax(shape, mean=0.0, stddev=1.0, dtype=np.float32, seed=None, def _poisson(shape, lam, dtype=np.float32, seed=None, name=None): # pylint: disable=unused-argument - rng = np.random if seed is None else np.random.RandomState(seed & 0xffffffff) + rng = _rng_from_seed(seed) dtype = utils.common_dtype([lam], dtype_hint=dtype) shape = _ensure_shape_tuple(shape) return rng.poisson(lam=lam, size=shape).astype(dtype) @@ -209,7 +218,7 @@ def _poisson_jax(shape, lam, dtype=np.float32, seed=None, def _shuffle(value, seed=None, name=None): # pylint: disable=unused-argument - rng = np.random if seed is None else np.random.RandomState(seed & 0xffffffff) + rng = _rng_from_seed(seed) ret = np.array(value) rng.shuffle(ret) return ret @@ -225,7 +234,7 @@ def _shuffle_jax(value, seed=None, name=None): # pylint: disable=unused-argumen def _truncated_normal( shape, seed, means=0.0, stddevs=1.0, minvals=-2.0, maxvals=2.0, name=None): # pylint: disable=unused-argument from scipy import stats # pylint: disable=g-import-not-at-top - rng = np.random if seed is None else np.random.RandomState(seed & 0xffffffff) + rng = _rng_from_seed(seed) std_low = (minvals - means) / stddevs std_high = (maxvals - means) / stddevs std_samps = stats.truncnorm.rvs( @@ -248,7 +257,7 @@ def _truncated_normal_jax( def _uniform(shape, minval=0, maxval=None, dtype=np.float32, seed=None, name=None): # pylint: disable=unused-argument """Numpy uniform random sampler.""" - rng = np.random if seed is None else np.random.RandomState(seed & 0xffffffff) + rng = _rng_from_seed(seed) if minval is not None: minval = ops.convert_to_tensor(minval, dtype=dtype) if maxval is not None: