Skip to content

Commit b183cd5

Browse files
committed
fix: distributions beta,gamma,geometric,nig
1 parent f590c74 commit b183cd5

4 files changed

Lines changed: 53 additions & 26 deletions

File tree

src/stats/distr/beta.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pub struct SimdBeta {
1616

1717
impl SimdBeta {
1818
pub fn new(alpha: f32, beta: f32) -> Self {
19-
assert!(alpha >= 1.0 && beta >= 1.0);
19+
assert!(alpha > 0.0 && beta > 0.0);
2020
Self {
2121
alpha,
2222
beta,

src/stats/distr/gamma.rs

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ pub struct SimdGamma {
1515

1616
impl SimdGamma {
1717
pub fn new(alpha: f32, scale: f32) -> Self {
18-
assert!(alpha >= 1.0 && scale > 0.0);
18+
assert!(alpha > 0.0 && scale > 0.0);
1919
Self {
2020
alpha,
2121
scale,
@@ -25,26 +25,40 @@ impl SimdGamma {
2525
}
2626
}
2727

28-
/// Bulk fill using Marsaglia–Tsang; uses scalar acceptance per sample but reduces per-call overhead.
28+
/// Bulk fill using Marsaglia–Tsang (for alpha >= 1) or Ahrens-Dieter (for alpha < 1)
2929
pub fn fill_slice<R: Rng + ?Sized>(&self, rng: &mut R, out: &mut [f32]) {
30-
let d = self.alpha - 1.0 / 3.0;
31-
let c = 1.0 / (3.0 * d).sqrt();
32-
for x in out.iter_mut() {
33-
let val = loop {
34-
let z = self.normal.sample(rng);
35-
let v = (1.0 + c * z).powi(3);
36-
if v <= 0.0 {
37-
continue;
38-
}
30+
if self.alpha < 1.0 {
31+
// For alpha < 1, use the transformation: if X ~ Gamma(alpha+1, scale), then X*U^(1/alpha) ~ Gamma(alpha, scale)
32+
let gamma_plus_one = SimdGamma::new(self.alpha + 1.0, self.scale);
33+
for x in out.iter_mut() {
34+
let g = gamma_plus_one.sample(rng);
3935
let u: f32 = rng.gen_range(0.0..1.0);
40-
if u < 1.0 - 0.0331 * z.powi(4) {
41-
break self.scale * d * v;
42-
}
43-
if u.ln() < 0.5 * z * z + d * (1.0 - v + v.ln()) {
44-
break self.scale * d * v;
45-
}
46-
};
47-
*x = val;
36+
*x = g * u.powf(1.0 / self.alpha);
37+
}
38+
} else {
39+
// Marsaglia-Tsang for alpha >= 1
40+
let d = self.alpha - 1.0 / 3.0;
41+
let c = 1.0 / (9.0 * d).sqrt();
42+
for x in out.iter_mut() {
43+
let val = loop {
44+
let z = self.normal.sample(rng);
45+
let v = (1.0 + c * z).powi(3);
46+
if v <= 0.0 {
47+
continue;
48+
}
49+
let u: f32 = rng.gen_range(0.0..1.0);
50+
let z2 = z * z;
51+
// Quick acceptance
52+
if u < 1.0 - 0.0331 * z2 * z2 {
53+
break d * v;
54+
}
55+
// Log acceptance
56+
if u.ln() < 0.5 * z2 + d * (1.0 - v + v.ln()) {
57+
break d * v;
58+
}
59+
};
60+
*x = self.scale * val;
61+
}
4862
}
4963
}
5064

src/stats/distr/geometric.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ impl SimdGeometric {
2222
use crate::stats::distr::fill_f32_zero_one;
2323
use wide::f32x8;
2424

25+
// rand_distr Geometric returns number of failures before first success (starts at 0)
26+
// Formula: floor(ln(U) / ln(1-p)) where U ~ Uniform(0,1)
2527
let ln1p = (1.0 - self.p).ln();
2628
let inv_ln1p = f32x8::splat(1.0 / ln1p);
2729
let mut u = [0.0f32; 8];
@@ -30,10 +32,11 @@ impl SimdGeometric {
3032
for chunk in &mut chunks {
3133
fill_f32_zero_one(rng, &mut u);
3234
let v = f32x8::from(u);
33-
let g = (v.ln() * inv_ln1p).floor() + f32x8::splat(1.0);
35+
// Number of failures before success (can be 0)
36+
let g = (v.ln() * inv_ln1p).floor();
3437
let mut tmp = g.to_array();
3538
for t in &mut tmp {
36-
*t = (*t).max(1.0);
39+
*t = (*t).max(0.0);
3740
}
3841
for (o, t) in chunk.iter_mut().zip(tmp.iter()) {
3942
*o = *t as u32;
@@ -43,10 +46,10 @@ impl SimdGeometric {
4346
if !rem.is_empty() {
4447
fill_f32_zero_one(rng, &mut u);
4548
let v = f32x8::from(u);
46-
let g = (v.ln() * inv_ln1p).floor() + f32x8::splat(1.0);
49+
let g = (v.ln() * inv_ln1p).floor();
4750
let tmp = g.to_array();
4851
for i in 0..rem.len() {
49-
let val = tmp[i].max(1.0);
52+
let val = tmp[i].max(0.0);
5053
rem[i] = val as u32;
5154
}
5255
}

src/stats/distr/normal_inverse_gauss.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,19 @@ pub struct SimdNormalInverseGauss {
1818
}
1919

2020
impl SimdNormalInverseGauss {
21+
/// Create a NIG with 4 parameters (alpha, beta, delta, mu)
22+
/// This is the full parameterization.
23+
/// NIG(α, β, δ, μ) where α > |β|, δ > 0
2124
pub fn new(alpha: f32, beta: f32, delta: f32, mu: f32) -> Self {
22-
// Typically alpha> |beta|, delta>0, etc.
23-
let ig = SimdInverseGauss::new(delta * (alpha * alpha - beta * beta).sqrt(), delta);
25+
assert!(
26+
alpha > 0.0 && alpha > beta.abs(),
27+
"NIG: alpha must be > |beta|"
28+
);
29+
assert!(delta > 0.0, "NIG: delta must be positive");
30+
let gamma = (alpha * alpha - beta * beta).sqrt();
31+
let ig_mean = delta / gamma;
32+
let ig_shape = delta * delta;
33+
let ig = SimdInverseGauss::new(ig_mean, ig_shape);
2434
let normal = SimdNormal::new(0.0, 1.0);
2535
Self {
2636
alpha,

0 commit comments

Comments
 (0)