Skip to content

Commit 1e96eb4

Browse files
wainwrightmarkvks
andauthored
Added new versions of choose and choose_stable (#1268)
* Added new versions of choose and choose_stable * Removed coin_flipper tests which were unnecessary and not building on ci * Performance optimizations in coin_flipper * Clippy fixes and more documentation * Added a correctness fix for coin_flipper * Update benches/seq.rs Co-authored-by: Vinzent Steinberg <[email protected]> * Update benches/seq.rs Co-authored-by: Vinzent Steinberg <[email protected]> * Removed old version of choose and choose stable and updated value stability tests * Moved sequence choose benchmarks to their own file * Reworked coin_flipper * Use criterion for seq_choose benches * Removed an old comment * Change how c is estimated in coin_flipper * Revert "Use criterion for seq_choose benches" This reverts commit 2339539. * Added seq_choose benches for smaller numbers * Removed some unneeded lines from seq_choose * Improvements in coin_flipper.rs * Small refactor of coin_flipper * Tidied comments in coin_flipper * Use criterion for seq_choose benchmarks * Made choose not generate a random number if len=1 * small change to IteratorRandom::choose * Made it easier to change seq_choose benchmarks RNG * Added Pcg64 benchmarks for seq_choose * Added TODO to coin_flipper * Changed criterion settings in seq_choose Co-authored-by: Vinzent Steinberg <[email protected]>
1 parent 3107a54 commit 1e96eb4

File tree

5 files changed

+460
-185
lines changed

5 files changed

+460
-185
lines changed

Cargo.toml

+5
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,8 @@ rand_pcg = { path = "rand_pcg", version = "0.4.0" }
7575
bincode = "1.2.1"
7676
rayon = "1.5.3"
7777
criterion = { version = "0.4" }
78+
79+
[[bench]]
80+
name = "seq_choose"
81+
path = "benches/seq_choose.rs"
82+
harness = false

benches/seq.rs

+1-71
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ extern crate test;
1313

1414
use test::Bencher;
1515

16+
use core::mem::size_of;
1617
use rand::prelude::*;
1718
use rand::seq::*;
18-
use core::mem::size_of;
1919

2020
// We force use of 32-bit RNG since seq code is optimised for use with 32-bit
2121
// generators on all platforms.
@@ -74,76 +74,6 @@ seq_slice_choose_multiple!(seq_slice_choose_multiple_950_of_1000, 950, 1000);
7474
seq_slice_choose_multiple!(seq_slice_choose_multiple_10_of_100, 10, 100);
7575
seq_slice_choose_multiple!(seq_slice_choose_multiple_90_of_100, 90, 100);
7676

77-
#[bench]
78-
fn seq_iter_choose_from_1000(b: &mut Bencher) {
79-
let mut rng = SmallRng::from_rng(thread_rng()).unwrap();
80-
let x: &mut [usize] = &mut [1; 1000];
81-
for (i, r) in x.iter_mut().enumerate() {
82-
*r = i;
83-
}
84-
b.iter(|| {
85-
let mut s = 0;
86-
for _ in 0..RAND_BENCH_N {
87-
s += x.iter().choose(&mut rng).unwrap();
88-
}
89-
s
90-
});
91-
b.bytes = size_of::<usize>() as u64 * crate::RAND_BENCH_N;
92-
}
93-
94-
#[derive(Clone)]
95-
struct UnhintedIterator<I: Iterator + Clone> {
96-
iter: I,
97-
}
98-
impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> {
99-
type Item = I::Item;
100-
101-
fn next(&mut self) -> Option<Self::Item> {
102-
self.iter.next()
103-
}
104-
}
105-
106-
#[derive(Clone)]
107-
struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
108-
iter: I,
109-
window_size: usize,
110-
}
111-
impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> {
112-
type Item = I::Item;
113-
114-
fn next(&mut self) -> Option<Self::Item> {
115-
self.iter.next()
116-
}
117-
118-
fn size_hint(&self) -> (usize, Option<usize>) {
119-
(core::cmp::min(self.iter.len(), self.window_size), None)
120-
}
121-
}
122-
123-
#[bench]
124-
fn seq_iter_unhinted_choose_from_1000(b: &mut Bencher) {
125-
let mut rng = SmallRng::from_rng(thread_rng()).unwrap();
126-
let x: &[usize] = &[1; 1000];
127-
b.iter(|| {
128-
UnhintedIterator { iter: x.iter() }
129-
.choose(&mut rng)
130-
.unwrap()
131-
})
132-
}
133-
134-
#[bench]
135-
fn seq_iter_window_hinted_choose_from_1000(b: &mut Bencher) {
136-
let mut rng = SmallRng::from_rng(thread_rng()).unwrap();
137-
let x: &[usize] = &[1; 1000];
138-
b.iter(|| {
139-
WindowHintedIterator {
140-
iter: x.iter(),
141-
window_size: 7,
142-
}
143-
.choose(&mut rng)
144-
})
145-
}
146-
14777
#[bench]
14878
fn seq_iter_choose_multiple_10_of_100(b: &mut Bencher) {
14979
let mut rng = SmallRng::from_rng(thread_rng()).unwrap();

benches/seq_choose.rs

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
// Copyright 2018-2022 Developers of the Rand project.
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
use criterion::{black_box, criterion_group, criterion_main, Criterion};
9+
use rand::prelude::*;
10+
use rand::SeedableRng;
11+
12+
criterion_group!(
13+
name = benches;
14+
config = Criterion::default();
15+
targets = bench
16+
);
17+
criterion_main!(benches);
18+
19+
pub fn bench(c: &mut Criterion) {
20+
bench_rng::<rand_chacha::ChaCha20Rng>(c, "ChaCha20");
21+
bench_rng::<rand_pcg::Pcg32>(c, "Pcg32");
22+
bench_rng::<rand_pcg::Pcg64>(c, "Pcg64");
23+
}
24+
25+
fn bench_rng<Rng: RngCore + SeedableRng>(c: &mut Criterion, rng_name: &'static str) {
26+
for length in [1, 2, 3, 10, 100, 1000].map(|x| black_box(x)) {
27+
c.bench_function(
28+
format!("choose_size-hinted_from_{length}_{rng_name}").as_str(),
29+
|b| {
30+
let mut rng = Rng::seed_from_u64(123);
31+
b.iter(|| choose_size_hinted(length, &mut rng))
32+
},
33+
);
34+
35+
c.bench_function(
36+
format!("choose_stable_from_{length}_{rng_name}").as_str(),
37+
|b| {
38+
let mut rng = Rng::seed_from_u64(123);
39+
b.iter(|| choose_stable(length, &mut rng))
40+
},
41+
);
42+
43+
c.bench_function(
44+
format!("choose_unhinted_from_{length}_{rng_name}").as_str(),
45+
|b| {
46+
let mut rng = Rng::seed_from_u64(123);
47+
b.iter(|| choose_unhinted(length, &mut rng))
48+
},
49+
);
50+
51+
c.bench_function(
52+
format!("choose_windowed_from_{length}_{rng_name}").as_str(),
53+
|b| {
54+
let mut rng = Rng::seed_from_u64(123);
55+
b.iter(|| choose_windowed(length, 7, &mut rng))
56+
},
57+
);
58+
}
59+
}
60+
61+
fn choose_size_hinted<R: Rng>(max: usize, rng: &mut R) -> Option<usize> {
62+
let iterator = 0..max;
63+
iterator.choose(rng)
64+
}
65+
66+
fn choose_stable<R: Rng>(max: usize, rng: &mut R) -> Option<usize> {
67+
let iterator = 0..max;
68+
iterator.choose_stable(rng)
69+
}
70+
71+
fn choose_unhinted<R: Rng>(max: usize, rng: &mut R) -> Option<usize> {
72+
let iterator = UnhintedIterator { iter: (0..max) };
73+
iterator.choose(rng)
74+
}
75+
76+
fn choose_windowed<R: Rng>(max: usize, window_size: usize, rng: &mut R) -> Option<usize> {
77+
let iterator = WindowHintedIterator {
78+
iter: (0..max),
79+
window_size,
80+
};
81+
iterator.choose(rng)
82+
}
83+
84+
#[derive(Clone)]
85+
struct UnhintedIterator<I: Iterator + Clone> {
86+
iter: I,
87+
}
88+
impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> {
89+
type Item = I::Item;
90+
91+
fn next(&mut self) -> Option<Self::Item> {
92+
self.iter.next()
93+
}
94+
}
95+
96+
#[derive(Clone)]
97+
struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
98+
iter: I,
99+
window_size: usize,
100+
}
101+
impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> {
102+
type Item = I::Item;
103+
104+
fn next(&mut self) -> Option<Self::Item> {
105+
self.iter.next()
106+
}
107+
108+
fn size_hint(&self) -> (usize, Option<usize>) {
109+
(core::cmp::min(self.iter.len(), self.window_size), None)
110+
}
111+
}

src/seq/coin_flipper.rs

+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
use crate::RngCore;
2+
3+
pub(crate) struct CoinFlipper<R: RngCore> {
4+
pub rng: R,
5+
chunk: u32, //TODO(opt): this should depend on RNG word size
6+
chunk_remaining: u32,
7+
}
8+
9+
impl<R: RngCore> CoinFlipper<R> {
10+
pub fn new(rng: R) -> Self {
11+
Self {
12+
rng,
13+
chunk: 0,
14+
chunk_remaining: 0,
15+
}
16+
}
17+
18+
#[inline]
19+
/// Returns true with a probability of 1 / d
20+
/// Uses an expected two bits of randomness
21+
/// Panics if d == 0
22+
pub fn gen_ratio_one_over(&mut self, d: usize) -> bool {
23+
debug_assert_ne!(d, 0);
24+
// This uses the same logic as `gen_ratio` but is optimized for the case that
25+
// the starting numerator is one (which it always is for `Sequence::Choose()`)
26+
27+
// In this case (but not `gen_ratio`), this way of calculating c is always accurate
28+
let c = (usize::BITS - 1 - d.leading_zeros()).min(32);
29+
30+
if self.flip_c_heads(c) {
31+
let numerator = 1 << c;
32+
return self.gen_ratio(numerator, d);
33+
} else {
34+
return false;
35+
}
36+
}
37+
38+
#[inline]
39+
/// Returns true with a probability of n / d
40+
/// Uses an expected two bits of randomness
41+
fn gen_ratio(&mut self, mut n: usize, d: usize) -> bool {
42+
// Explanation:
43+
// We are trying to return true with a probability of n / d
44+
// If n >= d, we can just return true
45+
// Otherwise there are two possibilities 2n < d and 2n >= d
46+
// In either case we flip a coin.
47+
// If 2n < d
48+
// If it comes up tails, return false
49+
// If it comes up heads, double n and start again
50+
// This is fair because (0.5 * 0) + (0.5 * 2n / d) = n / d and 2n is less than d
51+
// (if 2n was greater than d we would effectively round it down to 1
52+
// by returning true)
53+
// If 2n >= d
54+
// If it comes up tails, set n to 2n - d and start again
55+
// If it comes up heads, return true
56+
// This is fair because (0.5 * 1) + (0.5 * (2n - d) / d) = n / d
57+
// Note that if 2n = d and the coin comes up tails, n will be set to 0
58+
// before restarting which is equivalent to returning false.
59+
60+
// As a performance optimization we can flip multiple coins at once
61+
// This is efficient because we can use the `lzcnt` intrinsic
62+
// We can check up to 32 flips at once but we only receive one bit of information
63+
// - all heads or at least one tail.
64+
65+
// Let c be the number of coins to flip. 1 <= c <= 32
66+
// If 2n < d, n * 2^c < d
67+
// If the result is all heads, then set n to n * 2^c
68+
// If there was at least one tail, return false
69+
// If 2n >= d, the order of results matters so we flip one coin at a time so c = 1
70+
// Ideally, c will be as high as possible within these constraints
71+
72+
while n < d {
73+
// Find a good value for c by counting leading zeros
74+
// This will either give the highest possible c, or 1 less than that
75+
let c = n
76+
.leading_zeros()
77+
.saturating_sub(d.leading_zeros() + 1)
78+
.clamp(1, 32);
79+
80+
if self.flip_c_heads(c) {
81+
// All heads
82+
// Set n to n * 2^c
83+
// If 2n >= d, the while loop will exit and we will return `true`
84+
// If n * 2^c > `usize::MAX` we always return `true` anyway
85+
n = n.saturating_mul(2_usize.pow(c));
86+
} else {
87+
//At least one tail
88+
if c == 1 {
89+
// Calculate 2n - d.
90+
// We need to use wrapping as 2n might be greater than `usize::MAX`
91+
let next_n = n.wrapping_add(n).wrapping_sub(d);
92+
if next_n == 0 || next_n > n {
93+
// This will happen if 2n < d
94+
return false;
95+
}
96+
n = next_n;
97+
} else {
98+
// c > 1 so 2n < d so we can return false
99+
return false;
100+
}
101+
}
102+
}
103+
true
104+
}
105+
106+
/// If the next `c` bits of randomness all represent heads, consume them, return true
107+
/// Otherwise return false and consume the number of heads plus one.
108+
/// Generates new bits of randomness when necessary (in 32 bit chunks)
109+
/// Has a 1 in 2 to the `c` chance of returning true
110+
/// `c` must be less than or equal to 32
111+
fn flip_c_heads(&mut self, mut c: u32) -> bool {
112+
debug_assert!(c <= 32);
113+
// Note that zeros on the left of the chunk represent heads.
114+
// It needs to be this way round because zeros are filled in when left shifting
115+
loop {
116+
let zeros = self.chunk.leading_zeros();
117+
118+
if zeros < c {
119+
// The happy path - we found a 1 and can return false
120+
// Note that because a 1 bit was detected,
121+
// We cannot have run out of random bits so we don't need to check
122+
123+
// First consume all of the bits read
124+
// Using shl seems to give worse performance for size-hinted iterators
125+
self.chunk = self.chunk.wrapping_shl(zeros + 1);
126+
127+
self.chunk_remaining = self.chunk_remaining.saturating_sub(zeros + 1);
128+
return false;
129+
} else {
130+
// The number of zeros is larger than `c`
131+
// There are two possibilities
132+
if let Some(new_remaining) = self.chunk_remaining.checked_sub(c) {
133+
// Those zeroes were all part of our random chunk,
134+
// throw away `c` bits of randomness and return true
135+
self.chunk_remaining = new_remaining;
136+
self.chunk <<= c;
137+
return true;
138+
} else {
139+
// Some of those zeroes were part of the random chunk
140+
// and some were part of the space behind it
141+
// We need to take into account only the zeroes that were random
142+
c -= self.chunk_remaining;
143+
144+
// Generate a new chunk
145+
self.chunk = self.rng.next_u32();
146+
self.chunk_remaining = 32;
147+
// Go back to start of loop
148+
}
149+
}
150+
}
151+
}
152+
}

0 commit comments

Comments
 (0)