Skip to content

Commit 7fe350c

Browse files
Hypergeo fix (#1510)
1 parent ad67294 commit 7fe350c

File tree

4 files changed

+22
-4
lines changed

4 files changed

+22
-4
lines changed

rand_distr/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313
- Mark `WeightError`, `PoissonError`, `BinomialError` as `#[non_exhaustive]` (#1480).
1414
- Remove support for generating `isize` and `usize` values with `Standard`, `Uniform` and `Fill` and usage as a `WeightedAliasIndex` weight (#1487)
1515
- Limit the maximal acceptable lambda for `Poisson` to solve (#1312) (#1498)
16+
- Fix bug in `Hypergeometric`, this is a Value-breaking change (#1510)
1617
- Change parameter type of `Zipf::new`: `n` is now floating-point (#1518)
1718

1819
### Added

rand_distr/src/hypergeometric.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,17 @@ fn fraction_of_products_of_factorials(numerator: (u64, u64), denominator: (u64,
131131
result
132132
}
133133

134+
const LOGSQRT2PI: f64 = 0.91893853320467274178; // log(sqrt(2*pi))
135+
134136
fn ln_of_factorial(v: f64) -> f64 {
135137
// the paper calls for ln(v!), but also wants to pass in fractions,
136138
// so we need to use Stirling's approximation to fill in the gaps:
137-
v * v.ln() - v
139+
140+
// shift v by 3, because Stirling is bad for small values
141+
let v_3 = v + 3.0;
142+
let ln_fac = (v_3 + 0.5) * v_3.ln() - v_3 + LOGSQRT2PI + 1.0 / (12.0 * v_3);
143+
// make the correction for the shift
144+
ln_fac - ((v + 3.0) * (v + 2.0) * (v + 1.0)).ln()
138145
}
139146

140147
impl Hypergeometric {
@@ -359,7 +366,7 @@ impl Distribution<u64> for Hypergeometric {
359366
} else {
360367
for i in (y as u64 + 1)..=(m as u64) {
361368
f *= i as f64 * (n2 - k + i) as f64;
362-
f /= (n1 - i) as f64 * (k - i) as f64;
369+
f /= (n1 - i + 1) as f64 * (k - i + 1) as f64;
363370
}
364371
}
365372

@@ -441,6 +448,7 @@ impl Distribution<u64> for Hypergeometric {
441448

442449
#[cfg(test)]
443450
mod test {
451+
444452
use super::*;
445453

446454
#[test]
@@ -494,4 +502,13 @@ mod test {
494502
fn hypergeometric_distributions_can_be_compared() {
495503
assert_eq!(Hypergeometric::new(1, 2, 3), Hypergeometric::new(1, 2, 3));
496504
}
505+
506+
#[test]
507+
fn stirling() {
508+
let test = [0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
509+
for &v in test.iter() {
510+
let ln_fac = ln_of_factorial(v);
511+
assert!((special::Gamma::ln_gamma(v + 1.0).0 - ln_fac).abs() < 1e-4);
512+
}
513+
}
497514
}

rand_distr/tests/cdf.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ fn hypergeometric() {
598598
(60, 10, 7),
599599
(70, 20, 50),
600600
(100, 50, 10),
601-
// (100, 50, 49), // Fail case
601+
(100, 50, 49),
602602
];
603603

604604
for (seed, (n, k, n_)) in parameters.into_iter().enumerate() {

rand_distr/tests/value_stability.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ fn hypergeometric_stability() {
105105
test_samples(
106106
7221,
107107
Hypergeometric::new(100, 50, 50).unwrap(),
108-
&[23, 27, 26, 27, 22, 24, 31, 22],
108+
&[23, 27, 26, 27, 22, 25, 31, 25],
109109
); // Algorithm H2PE
110110
}
111111

0 commit comments

Comments
 (0)