@@ -633,7 +633,8 @@ def choice(key: ArrayLike,
633
633
shape : Shape = (),
634
634
replace : bool = True ,
635
635
p : RealArray | None = None ,
636
- axis : int = 0 ) -> Array :
636
+ axis : int = 0 ,
637
+ mode : str | None = None ) -> Array :
637
638
"""Generates a random sample from a given array.
638
639
639
640
.. warning::
@@ -656,6 +657,12 @@ def choice(key: ArrayLike,
656
657
entries in a.
657
658
axis: int, optional. The axis along which the selection is performed.
658
659
The default, 0, selects by row.
660
+ mode: optional, "high" or "low" for how many bits to use in the gumbel sampler
661
+ when `p is None` and `replace = False`. The default is determined by the
662
+ ``use_high_dynamic_range_gumbel`` config, which defaults to "low". With mode="low",
663
+ in float32 sampling will be biased for choices with probability less than about
664
+ 1E-7; with mode="high" this limit is pushed down to about 1E-14. mode="high"
665
+ approximately doubles the cost of sampling.
659
666
660
667
Returns:
661
668
An array of shape `shape` containing samples from `a`.
@@ -701,7 +708,7 @@ def choice(key: ArrayLike,
701
708
ind = jnp .searchsorted (p_cuml , r ).astype (int )
702
709
else :
703
710
# Gumbel top-k trick: https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/
704
- g = gumbel (key , (n_inputs ,), dtype = p_arr .dtype ) + jnp .log (p_arr )
711
+ g = gumbel (key , (n_inputs ,), dtype = p_arr .dtype , mode = mode ) + jnp .log (p_arr )
705
712
ind = lax .top_k (g , k = n_draws )[1 ].astype (int )
706
713
result = ind if arr .ndim == 0 else jnp .take (arr , ind , axis )
707
714
@@ -940,7 +947,8 @@ def bernoulli(key: ArrayLike,
940
947
mode: optional, "high" or "low" for how many bits to use when sampling.
941
948
default='low'. Set to "high" for correct sampling at small values of
942
949
`p`. When sampling in float32, bernoulli samples with mode='low' produce
943
- incorrect results for p < ~1E-7.
950
+ incorrect results for p < ~1E-7. mode="high" approximately doubles the
951
+ cost of sampling.
944
952
945
953
Returns:
946
954
A random array with boolean dtype and shape given by ``shape`` if ``shape``
@@ -1544,7 +1552,7 @@ def poisson(key: ArrayLike,
1544
1552
def gumbel (key : ArrayLike ,
1545
1553
shape : Shape = (),
1546
1554
dtype : DTypeLikeFloat = float ,
1547
- mode : str | None = None ) -> Array :
1555
+ mode : str | None = None ) -> Array :
1548
1556
"""Sample Gumbel random values with given shape and float dtype.
1549
1557
1550
1558
The values are distributed according to the probability density function:
@@ -1559,6 +1567,11 @@ def gumbel(key: ArrayLike,
1559
1567
dtype: optional, a float dtype for the returned values (default float64 if
1560
1568
jax_enable_x64 is true, otherwise float32).
1561
1569
mode: optional, "high" or "low" for how many bits to use when sampling.
1570
+ The default is determined by the ``use_high_dynamic_range_gumbel`` config,
1571
+ which defaults to "low". When drawing float32 samples, with mode="low" the
1572
+ uniform resolution is such that the largest possible gumbel logit is ~16;
1573
+ with mode="high" this is increased to ~32, at approximately double the
1574
+ computational cost.
1562
1575
1563
1576
Returns:
1564
1577
A random array with the specified shape and dtype.
@@ -1599,6 +1612,7 @@ def categorical(
1599
1612
axis : int = - 1 ,
1600
1613
shape : Shape | None = None ,
1601
1614
replace : bool = True ,
1615
+ mode : str | None = None ,
1602
1616
) -> Array :
1603
1617
"""Sample random values from categorical distributions.
1604
1618
@@ -1615,6 +1629,12 @@ def categorical(
1615
1629
The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``.
1616
1630
replace: If True (default), perform sampling with replacement. If False, perform
1617
1631
sampling without replacement.
1632
+ mode: optional, "high" or "low" for how many bits to use in the gumbel sampler.
1633
+ The default is determined by the ``use_high_dynamic_range_gumbel`` config,
1634
+ which defaults to "low". With mode="low", in float32 sampling will be biased
1635
+ for events with probability less than about 1E-7; with mode="high" this limit
1636
+ is pushed down to about 1E-14. mode="high" approximately doubles the cost of
1637
+ sampling.
1618
1638
1619
1639
Returns:
1620
1640
A random array with int dtype and shape given by ``shape`` if ``shape``
@@ -1644,11 +1664,11 @@ def categorical(
1644
1664
logits_shape = list (shape [len (shape ) - len (batch_shape ):])
1645
1665
logits_shape .insert (axis % len (logits_arr .shape ), logits_arr .shape [axis ])
1646
1666
return jnp .argmax (
1647
- gumbel (key , (* shape_prefix , * logits_shape ), logits_arr .dtype ) +
1667
+ gumbel (key , (* shape_prefix , * logits_shape ), logits_arr .dtype , mode = mode ) +
1648
1668
lax .expand_dims (logits_arr , tuple (range (len (shape_prefix )))),
1649
1669
axis = axis )
1650
1670
else :
1651
- logits_arr += gumbel (key , logits_arr .shape , logits_arr .dtype )
1671
+ logits_arr += gumbel (key , logits_arr .shape , logits_arr .dtype , mode = mode )
1652
1672
k = math .prod (shape_prefix )
1653
1673
if k > logits_arr .shape [axis ]:
1654
1674
raise ValueError (
0 commit comments