|
| 1 | +use crate::RngCore; |
| 2 | + |
| 3 | +pub(crate) struct CoinFlipper<R: RngCore> { |
| 4 | + pub rng: R, |
| 5 | + chunk: u32, //TODO(opt): this should depend on RNG word size |
| 6 | + chunk_remaining: u32, |
| 7 | +} |
| 8 | + |
| 9 | +impl<R: RngCore> CoinFlipper<R> { |
| 10 | + pub fn new(rng: R) -> Self { |
| 11 | + Self { |
| 12 | + rng, |
| 13 | + chunk: 0, |
| 14 | + chunk_remaining: 0, |
| 15 | + } |
| 16 | + } |
| 17 | + |
| 18 | + #[inline] |
| 19 | + /// Returns true with a probability of 1 / d |
| 20 | + /// Uses an expected two bits of randomness |
| 21 | + /// Panics if d == 0 |
| 22 | + pub fn gen_ratio_one_over(&mut self, d: usize) -> bool { |
| 23 | + debug_assert_ne!(d, 0); |
| 24 | + // This uses the same logic as `gen_ratio` but is optimized for the case that |
| 25 | + // the starting numerator is one (which it always is for `Sequence::Choose()`) |
| 26 | + |
| 27 | + // In this case (but not `gen_ratio`), this way of calculating c is always accurate |
| 28 | + let c = (usize::BITS - 1 - d.leading_zeros()).min(32); |
| 29 | + |
| 30 | + if self.flip_c_heads(c) { |
| 31 | + let numerator = 1 << c; |
| 32 | + return self.gen_ratio(numerator, d); |
| 33 | + } else { |
| 34 | + return false; |
| 35 | + } |
| 36 | + } |
| 37 | + |
| 38 | + #[inline] |
| 39 | + /// Returns true with a probability of n / d |
| 40 | + /// Uses an expected two bits of randomness |
| 41 | + fn gen_ratio(&mut self, mut n: usize, d: usize) -> bool { |
| 42 | + // Explanation: |
| 43 | + // We are trying to return true with a probability of n / d |
| 44 | + // If n >= d, we can just return true |
| 45 | + // Otherwise there are two possibilities 2n < d and 2n >= d |
| 46 | + // In either case we flip a coin. |
| 47 | + // If 2n < d |
| 48 | + // If it comes up tails, return false |
| 49 | + // If it comes up heads, double n and start again |
| 50 | + // This is fair because (0.5 * 0) + (0.5 * 2n / d) = n / d and 2n is less than d |
| 51 | + // (if 2n was greater than d we would effectively round it down to 1 |
| 52 | + // by returning true) |
| 53 | + // If 2n >= d |
| 54 | + // If it comes up tails, set n to 2n - d and start again |
| 55 | + // If it comes up heads, return true |
| 56 | + // This is fair because (0.5 * 1) + (0.5 * (2n - d) / d) = n / d |
| 57 | + // Note that if 2n = d and the coin comes up tails, n will be set to 0 |
| 58 | + // before restarting which is equivalent to returning false. |
| 59 | + |
| 60 | + // As a performance optimization we can flip multiple coins at once |
| 61 | + // This is efficient because we can use the `lzcnt` intrinsic |
| 62 | + // We can check up to 32 flips at once but we only receive one bit of information |
| 63 | + // - all heads or at least one tail. |
| 64 | + |
| 65 | + // Let c be the number of coins to flip. 1 <= c <= 32 |
| 66 | + // If 2n < d, n * 2^c < d |
| 67 | + // If the result is all heads, then set n to n * 2^c |
| 68 | + // If there was at least one tail, return false |
| 69 | + // If 2n >= d, the order of results matters so we flip one coin at a time so c = 1 |
| 70 | + // Ideally, c will be as high as possible within these constraints |
| 71 | + |
| 72 | + while n < d { |
| 73 | + // Find a good value for c by counting leading zeros |
| 74 | + // This will either give the highest possible c, or 1 less than that |
| 75 | + let c = n |
| 76 | + .leading_zeros() |
| 77 | + .saturating_sub(d.leading_zeros() + 1) |
| 78 | + .clamp(1, 32); |
| 79 | + |
| 80 | + if self.flip_c_heads(c) { |
| 81 | + // All heads |
| 82 | + // Set n to n * 2^c |
| 83 | + // If 2n >= d, the while loop will exit and we will return `true` |
| 84 | + // If n * 2^c > `usize::MAX` we always return `true` anyway |
| 85 | + n = n.saturating_mul(2_usize.pow(c)); |
| 86 | + } else { |
| 87 | + //At least one tail |
| 88 | + if c == 1 { |
| 89 | + // Calculate 2n - d. |
| 90 | + // We need to use wrapping as 2n might be greater than `usize::MAX` |
| 91 | + let next_n = n.wrapping_add(n).wrapping_sub(d); |
| 92 | + if next_n == 0 || next_n > n { |
| 93 | + // This will happen if 2n < d |
| 94 | + return false; |
| 95 | + } |
| 96 | + n = next_n; |
| 97 | + } else { |
| 98 | + // c > 1 so 2n < d so we can return false |
| 99 | + return false; |
| 100 | + } |
| 101 | + } |
| 102 | + } |
| 103 | + true |
| 104 | + } |
| 105 | + |
| 106 | + /// If the next `c` bits of randomness all represent heads, consume them, return true |
| 107 | + /// Otherwise return false and consume the number of heads plus one. |
| 108 | + /// Generates new bits of randomness when necessary (in 32 bit chunks) |
| 109 | + /// Has a 1 in 2 to the `c` chance of returning true |
| 110 | + /// `c` must be less than or equal to 32 |
| 111 | + fn flip_c_heads(&mut self, mut c: u32) -> bool { |
| 112 | + debug_assert!(c <= 32); |
| 113 | + // Note that zeros on the left of the chunk represent heads. |
| 114 | + // It needs to be this way round because zeros are filled in when left shifting |
| 115 | + loop { |
| 116 | + let zeros = self.chunk.leading_zeros(); |
| 117 | + |
| 118 | + if zeros < c { |
| 119 | + // The happy path - we found a 1 and can return false |
| 120 | + // Note that because a 1 bit was detected, |
| 121 | + // We cannot have run out of random bits so we don't need to check |
| 122 | + |
| 123 | + // First consume all of the bits read |
| 124 | + // Using shl seems to give worse performance for size-hinted iterators |
| 125 | + self.chunk = self.chunk.wrapping_shl(zeros + 1); |
| 126 | + |
| 127 | + self.chunk_remaining = self.chunk_remaining.saturating_sub(zeros + 1); |
| 128 | + return false; |
| 129 | + } else { |
| 130 | + // The number of zeros is larger than `c` |
| 131 | + // There are two possibilities |
| 132 | + if let Some(new_remaining) = self.chunk_remaining.checked_sub(c) { |
| 133 | + // Those zeroes were all part of our random chunk, |
| 134 | + // throw away `c` bits of randomness and return true |
| 135 | + self.chunk_remaining = new_remaining; |
| 136 | + self.chunk <<= c; |
| 137 | + return true; |
| 138 | + } else { |
| 139 | + // Some of those zeroes were part of the random chunk |
| 140 | + // and some were part of the space behind it |
| 141 | + // We need to take into account only the zeroes that were random |
| 142 | + c -= self.chunk_remaining; |
| 143 | + |
| 144 | + // Generate a new chunk |
| 145 | + self.chunk = self.rng.next_u32(); |
| 146 | + self.chunk_remaining = 32; |
| 147 | + // Go back to start of loop |
| 148 | + } |
| 149 | + } |
| 150 | + } |
| 151 | + } |
| 152 | +} |
0 commit comments