51
51
//! Those methods should include an assertion to check the range is valid (i.e.
52
52
//! `low < high`). The example below merely wraps another back-end.
53
53
//!
54
- //! The `new`, `new_inclusive` and `sample_single` functions use arguments of
54
+ //! The `new`, `new_inclusive`, `sample_single` and `sample_single_inclusive`
55
+ //! functions use arguments of
55
56
//! type `SampleBorrow<X>` to support passing in values by reference or
56
57
//! by value. In the implementation of these functions, you can choose to
57
58
//! simply use the reference returned by [`SampleBorrow::borrow`], or you can choose
@@ -207,6 +208,11 @@ impl<X: SampleUniform> Uniform<X> {
207
208
/// Create a new `Uniform` instance, which samples uniformly from the half
208
209
/// open range `[low, high)` (excluding `high`).
209
210
///
211
+ /// For discrete types (e.g. integers), samples will always be strictly less
212
+ /// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`),
213
+ /// samples may equal `high` due to loss of precision but may not be
214
+ /// greater than `high`.
215
+ ///
210
216
/// Fails if `low >= high`, or if `low`, `high` or the range `high - low` is
211
217
/// non-finite. In release mode, only the range is checked.
212
218
pub fn new < B1 , B2 > ( low : B1 , high : B2 ) -> Result < Uniform < X > , Error >
@@ -265,6 +271,11 @@ pub trait UniformSampler: Sized {
265
271
266
272
/// Construct self, with inclusive lower bound and exclusive upper bound `[low, high)`.
267
273
///
274
+ /// For discrete types (e.g. integers), samples will always be strictly less
275
+ /// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`),
276
+ /// samples may equal `high` due to loss of precision but may not be
277
+ /// greater than `high`.
278
+ ///
268
279
/// Usually users should not call this directly but prefer to use
269
280
/// [`Uniform::new`].
270
281
fn new < B1 , B2 > ( low : B1 , high : B2 ) -> Result < Self , Error >
@@ -287,6 +298,11 @@ pub trait UniformSampler: Sized {
287
298
/// Sample a single value uniformly from a range with inclusive lower bound
288
299
/// and exclusive upper bound `[low, high)`.
289
300
///
301
+ /// For discrete types (e.g. integers), samples will always be strictly less
302
+ /// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`),
303
+ /// samples may equal `high` due to loss of precision but may not be
304
+ /// greater than `high`.
305
+ ///
290
306
/// By default this is implemented using
291
307
/// `UniformSampler::new(low, high).sample(rng)`. However, for some types
292
308
/// more optimal implementations for single usage may be provided via this
@@ -908,6 +924,33 @@ pub struct UniformFloat<X> {
908
924
909
925
macro_rules! uniform_float_impl {
910
926
( $( $meta: meta) ?, $ty: ty, $uty: ident, $f_scalar: ident, $u_scalar: ident, $bits_to_discard: expr) => {
927
+ $( #[ cfg( $meta) ] ) ?
928
+ impl UniformFloat <$ty> {
929
+ /// Construct, reducing `scale` as required to ensure that rounding
930
+ /// can never yield values greater than `high`.
931
+ ///
932
+ /// Note: though it may be tempting to use a variant of this method
933
+ /// to ensure that samples from `[low, high)` are always strictly
934
+ /// less than `high`, this approach may be very slow where
935
+ /// `scale.abs()` is much smaller than `high.abs()`
936
+ /// (example: `low=0.99999999997819644, high=1.`).
937
+ fn new_bounded( low: $ty, high: $ty, mut scale: $ty) -> Self {
938
+ let max_rand = <$ty>:: splat( 1.0 as $f_scalar - $f_scalar:: EPSILON ) ;
939
+
940
+ loop {
941
+ let mask = ( scale * max_rand + low) . gt_mask( high) ;
942
+ if !mask. any( ) {
943
+ break ;
944
+ }
945
+ scale = scale. decrease_masked( mask) ;
946
+ }
947
+
948
+ debug_assert!( <$ty>:: splat( 0.0 ) . all_le( scale) ) ;
949
+
950
+ UniformFloat { low, scale }
951
+ }
952
+ }
953
+
911
954
$( #[ cfg( $meta) ] ) ?
912
955
impl SampleUniform for $ty {
913
956
type Sampler = UniformFloat <$ty>;
@@ -931,26 +974,13 @@ macro_rules! uniform_float_impl {
931
974
if !( low. all_lt( high) ) {
932
975
return Err ( Error :: EmptyRange ) ;
933
976
}
934
- let max_rand = <$ty>:: splat(
935
- ( $u_scalar:: MAX >> $bits_to_discard) . into_float_with_exponent( 0 ) - 1.0 ,
936
- ) ;
937
977
938
- let mut scale = high - low;
978
+ let scale = high - low;
939
979
if !( scale. all_finite( ) ) {
940
980
return Err ( Error :: NonFinite ) ;
941
981
}
942
982
943
- loop {
944
- let mask = ( scale * max_rand + low) . ge_mask( high) ;
945
- if !mask. any( ) {
946
- break ;
947
- }
948
- scale = scale. decrease_masked( mask) ;
949
- }
950
-
951
- debug_assert!( <$ty>:: splat( 0.0 ) . all_le( scale) ) ;
952
-
953
- Ok ( UniformFloat { low, scale } )
983
+ Ok ( Self :: new_bounded( low, high, scale) )
954
984
}
955
985
956
986
fn new_inclusive<B1 , B2 >( low_b: B1 , high_b: B2 ) -> Result <Self , Error >
@@ -967,26 +997,14 @@ macro_rules! uniform_float_impl {
967
997
if !low. all_le( high) {
968
998
return Err ( Error :: EmptyRange ) ;
969
999
}
970
- let max_rand = <$ty>:: splat(
971
- ( $u_scalar:: MAX >> $bits_to_discard) . into_float_with_exponent( 0 ) - 1.0 ,
972
- ) ;
973
1000
974
- let mut scale = ( high - low) / max_rand;
1001
+ let max_rand = <$ty>:: splat( 1.0 as $f_scalar - $f_scalar:: EPSILON ) ;
1002
+ let scale = ( high - low) / max_rand;
975
1003
if !scale. all_finite( ) {
976
1004
return Err ( Error :: NonFinite ) ;
977
1005
}
978
1006
979
- loop {
980
- let mask = ( scale * max_rand + low) . gt_mask( high) ;
981
- if !mask. any( ) {
982
- break ;
983
- }
984
- scale = scale. decrease_masked( mask) ;
985
- }
986
-
987
- debug_assert!( <$ty>:: splat( 0.0 ) . all_le( scale) ) ;
988
-
989
- Ok ( UniformFloat { low, scale } )
1007
+ Ok ( Self :: new_bounded( low, high, scale) )
990
1008
}
991
1009
992
1010
fn sample<R : Rng + ?Sized >( & self , rng: & mut R ) -> Self :: X {
@@ -1010,72 +1028,7 @@ macro_rules! uniform_float_impl {
1010
1028
B1 : SampleBorrow <Self :: X > + Sized ,
1011
1029
B2 : SampleBorrow <Self :: X > + Sized ,
1012
1030
{
1013
- let low = * low_b. borrow( ) ;
1014
- let high = * high_b. borrow( ) ;
1015
- #[ cfg( debug_assertions) ]
1016
- if !low. all_finite( ) || !high. all_finite( ) {
1017
- return Err ( Error :: NonFinite ) ;
1018
- }
1019
- if !low. all_lt( high) {
1020
- return Err ( Error :: EmptyRange ) ;
1021
- }
1022
- let mut scale = high - low;
1023
- if !scale. all_finite( ) {
1024
- return Err ( Error :: NonFinite ) ;
1025
- }
1026
-
1027
- loop {
1028
- // Generate a value in the range [1, 2)
1029
- let value1_2 =
1030
- ( rng. random:: <$uty>( ) >> $uty:: splat( $bits_to_discard) ) . into_float_with_exponent( 0 ) ;
1031
-
1032
- // Get a value in the range [0, 1) to avoid overflow when multiplying by scale
1033
- let value0_1 = value1_2 - <$ty>:: splat( 1.0 ) ;
1034
-
1035
- // Doing multiply before addition allows some architectures
1036
- // to use a single instruction.
1037
- let res = value0_1 * scale + low;
1038
-
1039
- debug_assert!( low. all_le( res) || !scale. all_finite( ) ) ;
1040
- if res. all_lt( high) {
1041
- return Ok ( res) ;
1042
- }
1043
-
1044
- // This handles a number of edge cases.
1045
- // * `low` or `high` is NaN. In this case `scale` and
1046
- // `res` are going to end up as NaN.
1047
- // * `low` is negative infinity and `high` is finite.
1048
- // `scale` is going to be infinite and `res` will be
1049
- // NaN.
1050
- // * `high` is positive infinity and `low` is finite.
1051
- // `scale` is going to be infinite and `res` will
1052
- // be infinite or NaN (if value0_1 is 0).
1053
- // * `low` is negative infinity and `high` is positive
1054
- // infinity. `scale` will be infinite and `res` will
1055
- // be NaN.
1056
- // * `low` and `high` are finite, but `high - low`
1057
- // overflows to infinite. `scale` will be infinite
1058
- // and `res` will be infinite or NaN (if value0_1 is 0).
1059
- // So if `high` or `low` are non-finite, we are guaranteed
1060
- // to fail the `res < high` check above and end up here.
1061
- //
1062
- // While we technically should check for non-finite `low`
1063
- // and `high` before entering the loop, by doing the checks
1064
- // here instead, we allow the common case to avoid these
1065
- // checks. But we are still guaranteed that if `low` or
1066
- // `high` are non-finite we'll end up here and can do the
1067
- // appropriate checks.
1068
- //
1069
- // Likewise, `high - low` overflowing to infinity is also
1070
- // rare, so handle it here after the common case.
1071
- let mask = !scale. finite_mask( ) ;
1072
- if mask. any( ) {
1073
- if !( low. all_finite( ) && high. all_finite( ) ) {
1074
- return Err ( Error :: NonFinite ) ;
1075
- }
1076
- scale = scale. decrease_masked( mask) ;
1077
- }
1078
- }
1031
+ Self :: sample_single_inclusive( low_b, high_b, rng)
1079
1032
}
1080
1033
1081
1034
#[ inline]
@@ -1465,14 +1418,14 @@ mod tests {
1465
1418
let my_incl_uniform = Uniform :: new_inclusive( low, high) . unwrap( ) ;
1466
1419
for _ in 0 ..100 {
1467
1420
let v = rng. sample( my_uniform) . extract( lane) ;
1468
- assert!( low_scalar <= v && v < high_scalar) ;
1421
+ assert!( low_scalar <= v && v <= high_scalar) ;
1469
1422
let v = rng. sample( my_incl_uniform) . extract( lane) ;
1470
1423
assert!( low_scalar <= v && v <= high_scalar) ;
1471
1424
let v =
1472
1425
<$ty as SampleUniform >:: Sampler :: sample_single( low, high, & mut rng)
1473
1426
. unwrap( )
1474
1427
. extract( lane) ;
1475
- assert!( low_scalar <= v && v < high_scalar) ;
1428
+ assert!( low_scalar <= v && v <= high_scalar) ;
1476
1429
let v = <$ty as SampleUniform >:: Sampler :: sample_single_inclusive(
1477
1430
low, high, & mut rng,
1478
1431
)
@@ -1510,12 +1463,12 @@ mod tests {
1510
1463
low_scalar
1511
1464
) ;
1512
1465
1513
- assert!( max_rng. sample( my_uniform) . extract( lane) < high_scalar) ;
1466
+ assert!( max_rng. sample( my_uniform) . extract( lane) <= high_scalar) ;
1514
1467
assert!( max_rng. sample( my_incl_uniform) . extract( lane) <= high_scalar) ;
1515
1468
// sample_single cannot cope with max_rng:
1516
1469
// assert!(<$ty as SampleUniform>::Sampler
1517
1470
// ::sample_single(low, high, &mut max_rng).unwrap()
1518
- // .extract(lane) < high_scalar);
1471
+ // .extract(lane) <= high_scalar);
1519
1472
assert!(
1520
1473
<$ty as SampleUniform >:: Sampler :: sample_single_inclusive(
1521
1474
low,
@@ -1543,7 +1496,7 @@ mod tests {
1543
1496
)
1544
1497
. unwrap( )
1545
1498
. extract( lane)
1546
- < high_scalar
1499
+ <= high_scalar
1547
1500
) ;
1548
1501
}
1549
1502
}
@@ -1590,10 +1543,9 @@ mod tests {
1590
1543
#[ cfg( all( feature = "std" , panic = "unwind" ) ) ]
1591
1544
fn test_float_assertions ( ) {
1592
1545
use super :: SampleUniform ;
1593
- use std:: panic:: catch_unwind;
1594
- fn range < T : SampleUniform > ( low : T , high : T ) {
1546
+ fn range < T : SampleUniform > ( low : T , high : T ) -> Result < T , Error > {
1595
1547
let mut rng = crate :: test:: rng ( 253 ) ;
1596
- T :: Sampler :: sample_single ( low, high, & mut rng) . unwrap ( ) ;
1548
+ T :: Sampler :: sample_single ( low, high, & mut rng)
1597
1549
}
1598
1550
1599
1551
macro_rules! t {
@@ -1616,10 +1568,9 @@ mod tests {
1616
1568
for lane in 0 ..<$ty>:: LEN {
1617
1569
let low = <$ty>:: splat( 0.0 as $f_scalar) . replace( lane, low_scalar) ;
1618
1570
let high = <$ty>:: splat( 1.0 as $f_scalar) . replace( lane, high_scalar) ;
1619
- assert!( catch_unwind ( || range( low, high) ) . is_err( ) ) ;
1571
+ assert!( range( low, high) . is_err( ) ) ;
1620
1572
assert!( Uniform :: new( low, high) . is_err( ) ) ;
1621
1573
assert!( Uniform :: new_inclusive( low, high) . is_err( ) ) ;
1622
- assert!( catch_unwind( || range( low, low) ) . is_err( ) ) ;
1623
1574
assert!( Uniform :: new( low, low) . is_err( ) ) ;
1624
1575
}
1625
1576
}
0 commit comments