Skip to content

Commit 53a7e85

Browse files
committed
Implement Fast Loaded Dice Roller
1 parent bbb0dff commit 53a7e85

File tree

2 files changed

+167
-0
lines changed

2 files changed

+167
-0
lines changed

rand_distr/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ pub use num_traits;
115115

116116
#[cfg(feature = "alloc")]
117117
pub mod weighted_alias;
118+
#[cfg(feature = "alloc")]
119+
pub mod weighted_fldr;
118120

119121
mod binomial;
120122
mod cauchy;

rand_distr/src/weighted_fldr.rs

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
//! Implementation of the sampling algorithm in
2+
//!
3+
//! > Feras A. Saad, Cameron E. Freer, Martin C. Rinard, and Vikash K. Mansinghka.
4+
//! > The Fast Loaded Dice Roller: A Near-Optimal Exact Sampler for Discrete
5+
//! > Probability Distributions. In AISTATS 2020: Proceedings of the 23rd
6+
//! > International Conference on Artificial Intelligence and Statistics,
7+
//! > Proceedings of Machine Learning Research 108, Palermo, Sicily, Italy, 2020.
8+
use alloc::vec::Vec;
9+
use alloc::vec;
10+
11+
use super::WeightedError;
12+
13+
use crate::Distribution;
14+
use rand::Rng;
15+
16+
fn bit_length(x: i32) -> i32 {
17+
(32 - x.leading_zeros()) as i32
18+
}
19+
20+
/// Distribution of weighted indices with Fast Loaded Dice Roller method.
21+
#[derive(Debug)]
22+
pub struct WeightedIndex {
23+
n: i32, m: i32, k: i32, r: i32,
24+
h1: Vec<i32>, h2: Vec<i32>,
25+
}
26+
27+
impl WeightedIndex {
28+
/// Preprocess weights.
29+
pub fn new(weights: Vec<i32>) -> Result<Self, WeightedError> {
30+
let n = weights.len();
31+
if n == 0 {
32+
return Err(WeightedError::NoItem);
33+
} else if n > ::core::i32::MAX as usize {
34+
return Err(WeightedError::TooMany);
35+
}
36+
let n = n as i32;
37+
let mut m = 0;
38+
for &w in &weights {
39+
if w < 0 {
40+
return Err(WeightedError::InvalidWeight);
41+
}
42+
m += w;
43+
}
44+
if m == 0 {
45+
return Err(WeightedError::AllWeightsZero);
46+
}
47+
let k = bit_length(m - 1);
48+
let r = (1 << k) - m;
49+
50+
let mut h1 = vec![0; k as usize];
51+
let mut h2 = vec![-1; ((n + 1) * k) as usize];
52+
53+
let mut d;
54+
for j in 0..k {
55+
d = 0;
56+
for i in 0..n {
57+
let w = (weights[i as usize] >> ((k-1) - j)) & 1;
58+
if w > 0 {
59+
h1[j as usize] += 1;
60+
h2[(d*k + j) as usize] = i;
61+
d += 1;
62+
}
63+
}
64+
let w = (r >> ((k - 1) - j)) & 1;
65+
if w > 0 {
66+
h1[j as usize] += 1;
67+
h2[(d*k + j) as usize] = n;
68+
}
69+
}
70+
71+
Ok(WeightedIndex { n, m, k, r, h1, h2 })
72+
}
73+
}
74+
75+
impl Distribution<i32> for WeightedIndex {
76+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> i32 {
77+
let n = self.n;
78+
let k = self.k;
79+
let h1 = &self.h1;
80+
let h2 = &self.h2;
81+
let mut c: i32 = 0;
82+
let mut d: i32 = 0;
83+
84+
loop {
85+
let b: bool = rng.gen();
86+
let b = b as i32;
87+
d = 2*d + (1 - b);
88+
if d < h1[c as usize] {
89+
let z = h2[(d*k + c) as usize];
90+
if z < n {
91+
return z;
92+
} else {
93+
d = 0;
94+
c = 0;
95+
}
96+
} else {
97+
d -= h1[c as usize];
98+
c += 1;
99+
}
100+
}
101+
}
102+
}
103+
104+
#[cfg(test)]
105+
mod test {
106+
use super::*;
107+
use rand::distributions::Uniform;
108+
109+
#[test]
110+
fn test_weighted_fldr() {
111+
const NUM_WEIGHTS: i32 = 10;
112+
const ZERO_WEIGHT_INDEX: i32 = 3;
113+
const NUM_SAMPLES: i32 = 15000;
114+
let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
115+
116+
let weights = {
117+
let mut weights = Vec::with_capacity(NUM_WEIGHTS as usize);
118+
let random_weight_distribution = Uniform::new_inclusive(
119+
0, NUM_WEIGHTS,
120+
);
121+
for _ in 0..NUM_WEIGHTS {
122+
weights.push(rng.sample(&random_weight_distribution));
123+
}
124+
weights[ZERO_WEIGHT_INDEX as usize] = 0;
125+
weights
126+
};
127+
let weight_sum = weights.iter().map(|w| *w).sum::<i32>();
128+
let expected_counts = weights
129+
.iter()
130+
.map(|&w| (w as f64) / (weight_sum as f64) * NUM_SAMPLES as f64)
131+
.collect::<Vec<f64>>();
132+
let weight_distribution = WeightedIndex::new(weights).unwrap();
133+
134+
let mut counts = vec![0; NUM_WEIGHTS as usize];
135+
for _ in 0..NUM_SAMPLES {
136+
counts[rng.sample(&weight_distribution) as usize] += 1;
137+
}
138+
139+
assert_eq!(counts[ZERO_WEIGHT_INDEX as usize], 0);
140+
for (count, expected_count) in counts.into_iter().zip(expected_counts) {
141+
let difference = (count as f64 - expected_count).abs();
142+
let max_allowed_difference = NUM_SAMPLES as f64 / NUM_WEIGHTS as f64 * 0.1;
143+
assert!(difference <= max_allowed_difference);
144+
}
145+
146+
assert_eq!(
147+
WeightedIndex::new(vec![]).unwrap_err(),
148+
WeightedError::NoItem
149+
);
150+
assert_eq!(
151+
WeightedIndex::new(vec![0]).unwrap_err(),
152+
WeightedError::AllWeightsZero
153+
);
154+
155+
// Signed integer special cases
156+
assert_eq!(
157+
WeightedIndex::new(vec![-1]).unwrap_err(),
158+
WeightedError::InvalidWeight
159+
);
160+
assert_eq!(
161+
WeightedIndex::new(vec![core::i32::MIN]).unwrap_err(),
162+
WeightedError::InvalidWeight
163+
);
164+
}
165+
}

0 commit comments

Comments
 (0)