Skip to content

Commit fdf6e1f

Browse files
Merge pull request #28946 from jakevdp:random-mode
PiperOrigin-RevId: 762092464
2 parents 1aaec81 + 437e32b commit fdf6e1f

File tree

2 files changed

+32
-10
lines changed

2 files changed

+32
-10
lines changed

jax/_src/random.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,8 @@ def choice(key: ArrayLike,
633633
shape: Shape = (),
634634
replace: bool = True,
635635
p: RealArray | None = None,
636-
axis: int = 0) -> Array:
636+
axis: int = 0,
637+
mode: str | None = None) -> Array:
637638
"""Generates a random sample from a given array.
638639
639640
.. warning::
@@ -656,6 +657,12 @@ def choice(key: ArrayLike,
656657
entries in a.
657658
axis: int, optional. The axis along which the selection is performed.
658659
The default, 0, selects by row.
660+
mode: optional, "high" or "low" for how many bits to use in the gumbel sampler
661+
when `p is None` and `replace = False`. The default is determined by the
662+
``use_high_dynamic_range_gumbel`` config, which defaults to "low". With mode="low",
663+
in float32 sampling will be biased for choices with probability less than about
664+
1E-7; with mode="high" this limit is pushed down to about 1E-14. mode="high"
665+
approximately doubles the cost of sampling.
659666
660667
Returns:
661668
An array of shape `shape` containing samples from `a`.
@@ -701,7 +708,7 @@ def choice(key: ArrayLike,
701708
ind = jnp.searchsorted(p_cuml, r).astype(int)
702709
else:
703710
# Gumbel top-k trick: https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/
704-
g = gumbel(key, (n_inputs,), dtype=p_arr.dtype) + jnp.log(p_arr)
711+
g = gumbel(key, (n_inputs,), dtype=p_arr.dtype, mode=mode) + jnp.log(p_arr)
705712
ind = lax.top_k(g, k=n_draws)[1].astype(int)
706713
result = ind if arr.ndim == 0 else jnp.take(arr, ind, axis)
707714

@@ -940,7 +947,8 @@ def bernoulli(key: ArrayLike,
940947
mode: optional, "high" or "low" for how many bits to use when sampling.
941948
default='low'. Set to "high" for correct sampling at small values of
942949
`p`. When sampling in float32, bernoulli samples with mode='low' produce
943-
incorrect results for p < ~1E-7.
950+
incorrect results for p < ~1E-7. mode="high" approximately doubles the
951+
cost of sampling.
944952
945953
Returns:
946954
A random array with boolean dtype and shape given by ``shape`` if ``shape``
@@ -1544,7 +1552,7 @@ def poisson(key: ArrayLike,
15441552
def gumbel(key: ArrayLike,
15451553
shape: Shape = (),
15461554
dtype: DTypeLikeFloat = float,
1547-
mode: str | None =None) -> Array:
1555+
mode: str | None = None) -> Array:
15481556
"""Sample Gumbel random values with given shape and float dtype.
15491557
15501558
The values are distributed according to the probability density function:
@@ -1559,6 +1567,11 @@ def gumbel(key: ArrayLike,
15591567
dtype: optional, a float dtype for the returned values (default float64 if
15601568
jax_enable_x64 is true, otherwise float32).
15611569
mode: optional, "high" or "low" for how many bits to use when sampling.
1570+
The default is determined by the ``use_high_dynamic_range_gumbel`` config,
1571+
which defaults to "low". When drawing float32 samples, with mode="low" the
1572+
uniform resolution is such that the largest possible gumbel logit is ~16;
1573+
with mode="high" this is increased to ~32, at approximately double the
1574+
computational cost.
15621575
15631576
Returns:
15641577
A random array with the specified shape and dtype.
@@ -1599,6 +1612,7 @@ def categorical(
15991612
axis: int = -1,
16001613
shape: Shape | None = None,
16011614
replace: bool = True,
1615+
mode: str | None = None,
16021616
) -> Array:
16031617
"""Sample random values from categorical distributions.
16041618
@@ -1615,6 +1629,12 @@ def categorical(
16151629
The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``.
16161630
replace: If True (default), perform sampling with replacement. If False, perform
16171631
sampling without replacement.
1632+
mode: optional, "high" or "low" for how many bits to use in the gumbel sampler.
1633+
The default is determined by the ``use_high_dynamic_range_gumbel`` config,
1634+
which defaults to "low". With mode="low", in float32 sampling will be biased
1635+
for events with probability less than about 1E-7; with mode="high" this limit
1636+
is pushed down to about 1E-14. mode="high" approximately doubles the cost of
1637+
sampling.
16181638
16191639
Returns:
16201640
A random array with int dtype and shape given by ``shape`` if ``shape``
@@ -1644,11 +1664,11 @@ def categorical(
16441664
logits_shape = list(shape[len(shape) - len(batch_shape):])
16451665
logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis])
16461666
return jnp.argmax(
1647-
gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype) +
1667+
gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype, mode=mode) +
16481668
lax.expand_dims(logits_arr, tuple(range(len(shape_prefix)))),
16491669
axis=axis)
16501670
else:
1651-
logits_arr += gumbel(key, logits_arr.shape, logits_arr.dtype)
1671+
logits_arr += gumbel(key, logits_arr.shape, logits_arr.dtype, mode=mode)
16521672
k = math.prod(shape_prefix)
16531673
if k > logits_arr.shape[axis]:
16541674
raise ValueError(

tests/random_lax_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,9 @@ def testTruncatedNormal(self, dtype):
286286
],
287287
dtype=jtu.dtypes.floating + jtu.dtypes.integer,
288288
weighted=[True, False],
289+
mode=[None, 'low', 'high']
289290
)
290-
def testChoice(self, dtype, input_range_or_shape, shape, replace, weighted, axis):
291+
def testChoice(self, dtype, input_range_or_shape, shape, replace, weighted, axis, mode):
291292
# This is the function API that we test against (note that self.rng().choice differs)
292293
np_choice = np.random.default_rng(0).choice
293294
p_dtype = dtypes.to_inexact_dtype(dtype)
@@ -303,7 +304,7 @@ def testChoice(self, dtype, input_range_or_shape, shape, replace, weighted, axis
303304
p /= p.sum()
304305
else:
305306
p = None
306-
rand = lambda key, x: random.choice(key, x, shape, replace, p, axis)
307+
rand = lambda key, x: random.choice(key, x, shape, replace, p, axis, mode=mode)
307308
sample = rand(key(), x)
308309
if not is_range:
309310
self.assertEqual(dtype, sample.dtype)
@@ -397,15 +398,16 @@ def testBernoulli(self, p, dtype, mode):
397398
]
398399
],
399400
sample_shape=[(10000,), (5000, 2)],
401+
mode=[None, 'low', 'high'],
400402
dtype=jtu.dtypes.floating,
401403
)
402-
def testCategorical(self, p, axis, dtype, sample_shape):
404+
def testCategorical(self, p, axis, dtype, sample_shape, mode):
403405
key = lambda: self.make_key(0)
404406
p = np.array(p, dtype=dtype)
405407
logits = np.log(p) - 42 # test unnormalized
406408
out_shape = tuple(np.delete(logits.shape, axis))
407409
shape = sample_shape + out_shape
408-
rand = partial(random.categorical, shape=shape, axis=axis)
410+
rand = partial(random.categorical, shape=shape, axis=axis, mode=mode)
409411
crand = jax.jit(rand)
410412

411413
uncompiled_samples = rand(key(), logits)

0 commit comments

Comments
 (0)