@@ -15,7 +15,7 @@ pub struct SimdGamma {
1515
1616impl SimdGamma {
1717 pub fn new ( alpha : f32 , scale : f32 ) -> Self {
18- assert ! ( alpha >= 1 .0 && scale > 0.0 ) ;
18+ assert ! ( alpha > 0 .0 && scale > 0.0 ) ;
1919 Self {
2020 alpha,
2121 scale,
@@ -25,26 +25,40 @@ impl SimdGamma {
2525 }
2626 }
2727
28- /// Bulk fill using Marsaglia–Tsang; uses scalar acceptance per sample but reduces per-call overhead.
28+ /// Bulk fill using Marsaglia–Tsang (for alpha >= 1) or Ahrens-Dieter (for alpha < 1)
2929 pub fn fill_slice < R : Rng + ?Sized > ( & self , rng : & mut R , out : & mut [ f32 ] ) {
30- let d = self . alpha - 1.0 / 3.0 ;
31- let c = 1.0 / ( 3.0 * d) . sqrt ( ) ;
32- for x in out. iter_mut ( ) {
33- let val = loop {
34- let z = self . normal . sample ( rng) ;
35- let v = ( 1.0 + c * z) . powi ( 3 ) ;
36- if v <= 0.0 {
37- continue ;
38- }
30+ if self . alpha < 1.0 {
31+ // For alpha < 1, use the transformation: if X ~ Gamma(alpha+1, scale), then X*U^(1/alpha) ~ Gamma(alpha, scale)
32+ let gamma_plus_one = SimdGamma :: new ( self . alpha + 1.0 , self . scale ) ;
33+ for x in out. iter_mut ( ) {
34+ let g = gamma_plus_one. sample ( rng) ;
3935 let u: f32 = rng. gen_range ( 0.0 ..1.0 ) ;
40- if u < 1.0 - 0.0331 * z. powi ( 4 ) {
41- break self . scale * d * v;
42- }
43- if u. ln ( ) < 0.5 * z * z + d * ( 1.0 - v + v. ln ( ) ) {
44- break self . scale * d * v;
45- }
46- } ;
47- * x = val;
36+ * x = g * u. powf ( 1.0 / self . alpha ) ;
37+ }
38+ } else {
39+ // Marsaglia-Tsang for alpha >= 1
40+ let d = self . alpha - 1.0 / 3.0 ;
41+ let c = 1.0 / ( 9.0 * d) . sqrt ( ) ;
42+ for x in out. iter_mut ( ) {
43+ let val = loop {
44+ let z = self . normal . sample ( rng) ;
45+ let v = ( 1.0 + c * z) . powi ( 3 ) ;
46+ if v <= 0.0 {
47+ continue ;
48+ }
49+ let u: f32 = rng. gen_range ( 0.0 ..1.0 ) ;
50+ let z2 = z * z;
51+ // Quick acceptance
52+ if u < 1.0 - 0.0331 * z2 * z2 {
53+ break d * v;
54+ }
55+ // Log acceptance
56+ if u. ln ( ) < 0.5 * z2 + d * ( 1.0 - v + v. ln ( ) ) {
57+ break d * v;
58+ }
59+ } ;
60+ * x = self . scale * val;
61+ }
4862 }
4963 }
5064
0 commit comments