Skip to content

Commit 1e381d1

Browse files
authored
UniformFloat: allow inclusion of high in all cases (#1462)
Fix #1299 by removing logic specific to ensuring that we emulate a closed range by excluding `high` from the result.
1 parent 2584f48 commit 1e381d1

File tree

3 files changed

+59
-129
lines changed

3 files changed

+59
-129
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ You may also find the [Upgrade Guide](https://rust-random.github.io/book/update.
1616
- Move all benchmarks to new `benches` crate (#1439)
1717
- Annotate panicking methods with `#[track_caller]` (#1442, #1447)
1818
- Enable feature `small_rng` by default (#1455)
19+
- Allow `UniformFloat::new` samples and `UniformFloat::sample_single` to yield `high` (#1462)
1920

2021
## [0.9.0-alpha.1] - 2024-03-18
2122
- Add the `Slice::num_choices` method to the Slice distribution (#1402)

src/distributions/uniform.rs

Lines changed: 58 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@
5151
//! Those methods should include an assertion to check the range is valid (i.e.
5252
//! `low < high`). The example below merely wraps another back-end.
5353
//!
54-
//! The `new`, `new_inclusive` and `sample_single` functions use arguments of
54+
//! The `new`, `new_inclusive`, `sample_single` and `sample_single_inclusive`
55+
//! functions use arguments of
5556
//! type `SampleBorrow<X>` to support passing in values by reference or
5657
//! by value. In the implementation of these functions, you can choose to
5758
//! simply use the reference returned by [`SampleBorrow::borrow`], or you can choose
@@ -207,6 +208,11 @@ impl<X: SampleUniform> Uniform<X> {
207208
/// Create a new `Uniform` instance, which samples uniformly from the half
208209
/// open range `[low, high)` (excluding `high`).
209210
///
211+
/// For discrete types (e.g. integers), samples will always be strictly less
212+
/// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`),
213+
/// samples may equal `high` due to loss of precision but may not be
214+
/// greater than `high`.
215+
///
210216
/// Fails if `low >= high`, or if `low`, `high` or the range `high - low` is
211217
/// non-finite. In release mode, only the range is checked.
212218
pub fn new<B1, B2>(low: B1, high: B2) -> Result<Uniform<X>, Error>
@@ -265,6 +271,11 @@ pub trait UniformSampler: Sized {
265271

266272
/// Construct self, with inclusive lower bound and exclusive upper bound `[low, high)`.
267273
///
274+
/// For discrete types (e.g. integers), samples will always be strictly less
275+
/// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`),
276+
/// samples may equal `high` due to loss of precision but may not be
277+
/// greater than `high`.
278+
///
268279
/// Usually users should not call this directly but prefer to use
269280
/// [`Uniform::new`].
270281
fn new<B1, B2>(low: B1, high: B2) -> Result<Self, Error>
@@ -287,6 +298,11 @@ pub trait UniformSampler: Sized {
287298
/// Sample a single value uniformly from a range with inclusive lower bound
288299
/// and exclusive upper bound `[low, high)`.
289300
///
301+
/// For discrete types (e.g. integers), samples will always be strictly less
302+
/// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`),
303+
/// samples may equal `high` due to loss of precision but may not be
304+
/// greater than `high`.
305+
///
290306
/// By default this is implemented using
291307
/// `UniformSampler::new(low, high).sample(rng)`. However, for some types
292308
/// more optimal implementations for single usage may be provided via this
@@ -908,6 +924,33 @@ pub struct UniformFloat<X> {
908924

909925
macro_rules! uniform_float_impl {
910926
($($meta:meta)?, $ty:ty, $uty:ident, $f_scalar:ident, $u_scalar:ident, $bits_to_discard:expr) => {
927+
$(#[cfg($meta)])?
928+
impl UniformFloat<$ty> {
929+
/// Construct, reducing `scale` as required to ensure that rounding
930+
/// can never yield values greater than `high`.
931+
///
932+
/// Note: though it may be tempting to use a variant of this method
933+
/// to ensure that samples from `[low, high)` are always strictly
934+
/// less than `high`, this approach may be very slow where
935+
/// `scale.abs()` is much smaller than `high.abs()`
936+
/// (example: `low=0.99999999997819644, high=1.`).
937+
fn new_bounded(low: $ty, high: $ty, mut scale: $ty) -> Self {
938+
let max_rand = <$ty>::splat(1.0 as $f_scalar - $f_scalar::EPSILON);
939+
940+
loop {
941+
let mask = (scale * max_rand + low).gt_mask(high);
942+
if !mask.any() {
943+
break;
944+
}
945+
scale = scale.decrease_masked(mask);
946+
}
947+
948+
debug_assert!(<$ty>::splat(0.0).all_le(scale));
949+
950+
UniformFloat { low, scale }
951+
}
952+
}
953+
911954
$(#[cfg($meta)])?
912955
impl SampleUniform for $ty {
913956
type Sampler = UniformFloat<$ty>;
@@ -931,26 +974,13 @@ macro_rules! uniform_float_impl {
931974
if !(low.all_lt(high)) {
932975
return Err(Error::EmptyRange);
933976
}
934-
let max_rand = <$ty>::splat(
935-
($u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0,
936-
);
937977

938-
let mut scale = high - low;
978+
let scale = high - low;
939979
if !(scale.all_finite()) {
940980
return Err(Error::NonFinite);
941981
}
942982

943-
loop {
944-
let mask = (scale * max_rand + low).ge_mask(high);
945-
if !mask.any() {
946-
break;
947-
}
948-
scale = scale.decrease_masked(mask);
949-
}
950-
951-
debug_assert!(<$ty>::splat(0.0).all_le(scale));
952-
953-
Ok(UniformFloat { low, scale })
983+
Ok(Self::new_bounded(low, high, scale))
954984
}
955985

956986
fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Result<Self, Error>
@@ -967,26 +997,14 @@ macro_rules! uniform_float_impl {
967997
if !low.all_le(high) {
968998
return Err(Error::EmptyRange);
969999
}
970-
let max_rand = <$ty>::splat(
971-
($u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0,
972-
);
9731000

974-
let mut scale = (high - low) / max_rand;
1001+
let max_rand = <$ty>::splat(1.0 as $f_scalar - $f_scalar::EPSILON);
1002+
let scale = (high - low) / max_rand;
9751003
if !scale.all_finite() {
9761004
return Err(Error::NonFinite);
9771005
}
9781006

979-
loop {
980-
let mask = (scale * max_rand + low).gt_mask(high);
981-
if !mask.any() {
982-
break;
983-
}
984-
scale = scale.decrease_masked(mask);
985-
}
986-
987-
debug_assert!(<$ty>::splat(0.0).all_le(scale));
988-
989-
Ok(UniformFloat { low, scale })
1007+
Ok(Self::new_bounded(low, high, scale))
9901008
}
9911009

9921010
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
@@ -1010,72 +1028,7 @@ macro_rules! uniform_float_impl {
10101028
B1: SampleBorrow<Self::X> + Sized,
10111029
B2: SampleBorrow<Self::X> + Sized,
10121030
{
1013-
let low = *low_b.borrow();
1014-
let high = *high_b.borrow();
1015-
#[cfg(debug_assertions)]
1016-
if !low.all_finite() || !high.all_finite() {
1017-
return Err(Error::NonFinite);
1018-
}
1019-
if !low.all_lt(high) {
1020-
return Err(Error::EmptyRange);
1021-
}
1022-
let mut scale = high - low;
1023-
if !scale.all_finite() {
1024-
return Err(Error::NonFinite);
1025-
}
1026-
1027-
loop {
1028-
// Generate a value in the range [1, 2)
1029-
let value1_2 =
1030-
(rng.random::<$uty>() >> $uty::splat($bits_to_discard)).into_float_with_exponent(0);
1031-
1032-
// Get a value in the range [0, 1) to avoid overflow when multiplying by scale
1033-
let value0_1 = value1_2 - <$ty>::splat(1.0);
1034-
1035-
// Doing multiply before addition allows some architectures
1036-
// to use a single instruction.
1037-
let res = value0_1 * scale + low;
1038-
1039-
debug_assert!(low.all_le(res) || !scale.all_finite());
1040-
if res.all_lt(high) {
1041-
return Ok(res);
1042-
}
1043-
1044-
// This handles a number of edge cases.
1045-
// * `low` or `high` is NaN. In this case `scale` and
1046-
// `res` are going to end up as NaN.
1047-
// * `low` is negative infinity and `high` is finite.
1048-
// `scale` is going to be infinite and `res` will be
1049-
// NaN.
1050-
// * `high` is positive infinity and `low` is finite.
1051-
// `scale` is going to be infinite and `res` will
1052-
// be infinite or NaN (if value0_1 is 0).
1053-
// * `low` is negative infinity and `high` is positive
1054-
// infinity. `scale` will be infinite and `res` will
1055-
// be NaN.
1056-
// * `low` and `high` are finite, but `high - low`
1057-
// overflows to infinite. `scale` will be infinite
1058-
// and `res` will be infinite or NaN (if value0_1 is 0).
1059-
// So if `high` or `low` are non-finite, we are guaranteed
1060-
// to fail the `res < high` check above and end up here.
1061-
//
1062-
// While we technically should check for non-finite `low`
1063-
// and `high` before entering the loop, by doing the checks
1064-
// here instead, we allow the common case to avoid these
1065-
// checks. But we are still guaranteed that if `low` or
1066-
// `high` are non-finite we'll end up here and can do the
1067-
// appropriate checks.
1068-
//
1069-
// Likewise, `high - low` overflowing to infinity is also
1070-
// rare, so handle it here after the common case.
1071-
let mask = !scale.finite_mask();
1072-
if mask.any() {
1073-
if !(low.all_finite() && high.all_finite()) {
1074-
return Err(Error::NonFinite);
1075-
}
1076-
scale = scale.decrease_masked(mask);
1077-
}
1078-
}
1031+
Self::sample_single_inclusive(low_b, high_b, rng)
10791032
}
10801033

10811034
#[inline]
@@ -1465,14 +1418,14 @@ mod tests {
14651418
let my_incl_uniform = Uniform::new_inclusive(low, high).unwrap();
14661419
for _ in 0..100 {
14671420
let v = rng.sample(my_uniform).extract(lane);
1468-
assert!(low_scalar <= v && v < high_scalar);
1421+
assert!(low_scalar <= v && v <= high_scalar);
14691422
let v = rng.sample(my_incl_uniform).extract(lane);
14701423
assert!(low_scalar <= v && v <= high_scalar);
14711424
let v =
14721425
<$ty as SampleUniform>::Sampler::sample_single(low, high, &mut rng)
14731426
.unwrap()
14741427
.extract(lane);
1475-
assert!(low_scalar <= v && v < high_scalar);
1428+
assert!(low_scalar <= v && v <= high_scalar);
14761429
let v = <$ty as SampleUniform>::Sampler::sample_single_inclusive(
14771430
low, high, &mut rng,
14781431
)
@@ -1510,12 +1463,12 @@ mod tests {
15101463
low_scalar
15111464
);
15121465

1513-
assert!(max_rng.sample(my_uniform).extract(lane) < high_scalar);
1466+
assert!(max_rng.sample(my_uniform).extract(lane) <= high_scalar);
15141467
assert!(max_rng.sample(my_incl_uniform).extract(lane) <= high_scalar);
15151468
// sample_single cannot cope with max_rng:
15161469
// assert!(<$ty as SampleUniform>::Sampler
15171470
// ::sample_single(low, high, &mut max_rng).unwrap()
1518-
// .extract(lane) < high_scalar);
1471+
// .extract(lane) <= high_scalar);
15191472
assert!(
15201473
<$ty as SampleUniform>::Sampler::sample_single_inclusive(
15211474
low,
@@ -1543,7 +1496,7 @@ mod tests {
15431496
)
15441497
.unwrap()
15451498
.extract(lane)
1546-
< high_scalar
1499+
<= high_scalar
15471500
);
15481501
}
15491502
}
@@ -1590,10 +1543,9 @@ mod tests {
15901543
#[cfg(all(feature = "std", panic = "unwind"))]
15911544
fn test_float_assertions() {
15921545
use super::SampleUniform;
1593-
use std::panic::catch_unwind;
1594-
fn range<T: SampleUniform>(low: T, high: T) {
1546+
fn range<T: SampleUniform>(low: T, high: T) -> Result<T, Error> {
15951547
let mut rng = crate::test::rng(253);
1596-
T::Sampler::sample_single(low, high, &mut rng).unwrap();
1548+
T::Sampler::sample_single(low, high, &mut rng)
15971549
}
15981550

15991551
macro_rules! t {
@@ -1616,10 +1568,9 @@ mod tests {
16161568
for lane in 0..<$ty>::LEN {
16171569
let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar);
16181570
let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar);
1619-
assert!(catch_unwind(|| range(low, high)).is_err());
1571+
assert!(range(low, high).is_err());
16201572
assert!(Uniform::new(low, high).is_err());
16211573
assert!(Uniform::new_inclusive(low, high).is_err());
1622-
assert!(catch_unwind(|| range(low, low)).is_err());
16231574
assert!(Uniform::new(low, low).is_err());
16241575
}
16251576
}

src/distributions/utils.rs

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,7 @@ pub(crate) trait FloatSIMDUtils {
218218
fn all_finite(self) -> bool;
219219

220220
type Mask;
221-
fn finite_mask(self) -> Self::Mask;
222221
fn gt_mask(self, other: Self) -> Self::Mask;
223-
fn ge_mask(self, other: Self) -> Self::Mask;
224222

225223
// Decrease all lanes where the mask is `true` to the next lower value
226224
// representable by the floating-point type. At least one of the lanes
@@ -292,21 +290,11 @@ macro_rules! scalar_float_impl {
292290
self.is_finite()
293291
}
294292

295-
#[inline(always)]
296-
fn finite_mask(self) -> Self::Mask {
297-
self.is_finite()
298-
}
299-
300293
#[inline(always)]
301294
fn gt_mask(self, other: Self) -> Self::Mask {
302295
self > other
303296
}
304297

305-
#[inline(always)]
306-
fn ge_mask(self, other: Self) -> Self::Mask {
307-
self >= other
308-
}
309-
310298
#[inline(always)]
311299
fn decrease_masked(self, mask: Self::Mask) -> Self {
312300
debug_assert!(mask, "At least one lane must be set");
@@ -368,21 +356,11 @@ macro_rules! simd_impl {
368356
self.is_finite().all()
369357
}
370358

371-
#[inline(always)]
372-
fn finite_mask(self) -> Self::Mask {
373-
self.is_finite()
374-
}
375-
376359
#[inline(always)]
377360
fn gt_mask(self, other: Self) -> Self::Mask {
378361
self.simd_gt(other)
379362
}
380363

381-
#[inline(always)]
382-
fn ge_mask(self, other: Self) -> Self::Mask {
383-
self.simd_ge(other)
384-
}
385-
386364
#[inline(always)]
387365
fn decrease_masked(self, mask: Self::Mask) -> Self {
388366
// Casting a mask into ints will produce all bits set for

0 commit comments

Comments
 (0)