Skip to content

Commit 5ee3ca3

Browse files
teryrorvks
authored andcommitted
Move setup computations into constructor + misc. cleanup
1 parent b6a5f08 commit 5ee3ca3

File tree

2 files changed

+35
-19
lines changed

2 files changed

+35
-19
lines changed

rand_distr/src/geometric.rs

+25-14
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ use core::fmt;
2828
#[derive(Copy, Clone, Debug)]
2929
pub struct Geometric
3030
{
31-
p: f64
31+
p: f64,
32+
pi: f64,
33+
k: u64
3234
}
3335

3436
/// Error type returned from `Geometric::new`.
@@ -55,8 +57,21 @@ impl Geometric {
5557
pub fn new(p: f64) -> Result<Self, Error> {
5658
if !p.is_finite() || p < 0.0 || p > 1.0 {
5759
Err(Error::InvalidProbability)
60+
} else if p == 0.0 || p >= 2.0 / 3.0 {
61+
Ok(Geometric { p, pi: p, k: 0 })
5862
} else {
59-
Ok(Geometric { p })
63+
let (pi, k) = {
64+
// choose smallest k such that pi = (1 - p)^(2^k) <= 0.5
65+
let mut k = 1;
66+
let mut pi = (1.0 - p).powi(2);
67+
while pi > 0.5 {
68+
k += 1;
69+
pi = pi * pi;
70+
}
71+
(pi, k)
72+
};
73+
74+
Ok(Geometric { p, pi, k })
6075
}
6176
}
6277
}
@@ -77,21 +92,14 @@ impl Distribution<u64> for Geometric
7792

7893
if self.p == 0.0 { return core::u64::MAX; }
7994

95+
let Geometric { p, pi, k } = *self;
96+
8097
// Based on the algorithm presented in section 3 of
8198
// Karl Bringmann and Tobias Friedrich (July 2013) - Exact and Efficient
8299
// Generation of Geometric Random Variates and Random Graphs, published
83100
// in International Colloquium on Automata, Languages and Programming
84101
// (pp.267-278)
85-
let (pi, k) = {
86-
// choose smallest k such that pi = (1 - p)^(2^k) <= 0.5
87-
let mut k = 1;
88-
let mut pi = (1.0 - self.p).powi(2);
89-
while pi > 0.5 {
90-
k += 1;
91-
pi = pi * pi;
92-
}
93-
(pi, k)
94-
};
102+
// https://people.mpi-inf.mpg.de/~kbringma/paper/2013ICALP-1.pdf
95103

96104
// Use the trivial algorithm to sample D from Geo(pi) = Geo(p) / 2^k:
97105
let d = {
@@ -104,12 +112,15 @@ impl Distribution<u64> for Geometric
104112

105113
// Use rejection sampling for the remainder M from Geo(p) % 2^k:
106114
// choose M uniformly from [0, 2^k), but reject with probability (1 - p)^M
115+
// NOTE: The paper suggests using bitwise sampling here, which is
116+
// currently unsupported, but should improve performance by requiring
117+
// fewer iterations on average. ~ October 28, 2020
107118
let m = loop {
108119
let m = rng.gen::<u64>() & ((1 << k) - 1);
109120
let p_reject = if m <= core::i32::MAX as u64 {
110-
(1.0 - self.p).powi(m as i32)
121+
(1.0 - p).powi(m as i32)
111122
} else {
112-
(1.0 - self.p).powf(m as f64)
123+
(1.0 - p).powf(m as f64)
113124
};
114125

115126
let u = rng.gen::<f64>();

rand_distr/src/hypergeometric.rs

+10-5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
use crate::Distribution;
44
use rand::Rng;
5+
use rand::distributions::uniform::Uniform;
56
use core::fmt;
67

78
#[derive(Clone, Copy, Debug)]
@@ -112,7 +113,10 @@ fn ln_of_factorial(v: f64) -> f64 {
112113
}
113114

114115
impl Hypergeometric {
115-
/// Constructs a new `Hypergeometric` with the given shape parameters.
116+
/// Constructs a new `Hypergeometric` with the shape parameters
117+
/// `N = total_population_size`,
118+
/// `K = population_with_feature`,
119+
/// `n = sample_size`.
116120
#[allow(clippy::many_single_char_names)] // Same names as in the reference.
117121
pub fn new(total_population_size: u64, population_with_feature: u64, sample_size: u64) -> Result<Self, Error> {
118122
if population_with_feature > total_population_size {
@@ -234,13 +238,14 @@ impl Distribution<u64> for Hypergeometric {
234238
x
235239
},
236240
RejectionAcceptance { m, a, lambda_l, lambda_r, x_l, x_r, p1, p2, p3 } => {
241+
let distr_region_select = Uniform::new(0.0, p3);
237242
loop {
238243
let (y, v) = loop {
239-
let u = rng.gen::<f64>() * p3; // for selecting the region
244+
let u = distr_region_select.sample(rng);
240245
let v = rng.gen::<f64>(); // for the accept/reject decision
241246

242247
if u <= p1 {
243-
// Region 1, centrel bell
248+
// Region 1, central bell
244249
let y = (x_l + u).floor();
245250
break (y, v);
246251
} else if u <= p2 {
@@ -276,7 +281,7 @@ impl Distribution<u64> for Hypergeometric {
276281
}
277282
}
278283

279-
if v < f { break y as i64; }
284+
if v <= f { break y as i64; }
280285
} else {
281286
// Step 4.2: Squeezing
282287
let y1 = y + 1.0;
@@ -292,7 +297,7 @@ impl Distribution<u64> for Hypergeometric {
292297
let dg = if g < 0.0 {
293298
1.0 + g
294299
} else {
295-
g
300+
1.0
296301
};
297302
let gu = g * (1.0 + g * (-0.5 + g / 3.0));
298303
let gl = gu - g.powi(4) / (4.0 * dg);

0 commit comments

Comments
 (0)