diff --git a/CHANGELOG.md b/CHANGELOG.md index 293072c..0e2dd97 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.6.0] - unreleased + +### API changes +- `Dirichlet` no longer uses `const` generics, which means that its size is not required at compile time. Essentially a revert of #1292 +- Add `Dirichlet::new_with_size` constructor + ## [0.5.1] ### Testing diff --git a/src/dirichlet.rs b/src/dirichlet.rs index ac17fa2..0e0f85d 100644 --- a/src/dirichlet.rs +++ b/src/dirichlet.rs @@ -21,14 +21,14 @@ use alloc::{boxed::Box, vec, vec::Vec}; #[derive(Clone, Debug, PartialEq)] #[cfg_attr(feature = "serde", serde_as)] -struct DirichletFromGamma +struct DirichletFromGamma where F: Float, StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution, { - samplers: [Gamma; N], + samplers: Vec>, } /// Error type returned from [`DirchletFromGamma::new`]. @@ -36,12 +36,9 @@ where enum DirichletFromGammaError { /// Gamma::new(a, 1) failed. GammmaNewFailed, - - /// gamma_dists.try_into() failed (in theory, this should not happen). - GammaArrayCreationFailed, } -impl DirichletFromGamma +impl DirichletFromGamma where F: Float, StandardNormal: Distribution, @@ -53,30 +50,28 @@ where /// This function is part of a private implementation detail. /// It assumes that the input is correct, so no validation of alpha is done. #[inline] - fn new(alpha: [F; N]) -> Result, DirichletFromGammaError> { + fn new(alpha: &[F]) -> Result, DirichletFromGammaError> { let mut gamma_dists = Vec::new(); for a in alpha { let dist = - Gamma::new(a, F::one()).map_err(|_| DirichletFromGammaError::GammmaNewFailed)?; + Gamma::new(*a, F::one()).map_err(|_| DirichletFromGammaError::GammmaNewFailed)?; gamma_dists.push(dist); } Ok(DirichletFromGamma { - samplers: gamma_dists - .try_into() - .map_err(|_| DirichletFromGammaError::GammaArrayCreationFailed)?, + samplers: gamma_dists, }) } } -impl Distribution<[F; N]> for DirichletFromGamma +impl Distribution> for DirichletFromGamma where F: Float, StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution, { - fn sample(&self, rng: &mut R) -> [F; N] { - let mut samples = [F::zero(); N]; + fn sample(&self, rng: &mut R) -> Vec { + let mut samples = vec![F::zero(); self.samplers.len()]; let mut sum = F::zero(); for (s, g) in samples.iter_mut().zip(self.samplers.iter()) { @@ -93,7 +88,7 @@ where #[derive(Clone, Debug, PartialEq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -struct DirichletFromBeta +struct DirichletFromBeta where F: Float, StandardNormal: Distribution, @@ -110,7 +105,7 @@ enum DirichletFromBetaError { BetaNewFailed, } -impl DirichletFromBeta +impl DirichletFromBeta where F: Float, StandardNormal: Distribution, @@ -122,15 +117,16 @@ where /// This function is part of a private implementation detail. /// It assumes that the input is correct, so no validation of alpha is done. #[inline] - fn new(alpha: [F; N]) -> Result, DirichletFromBetaError> { + fn new(alpha: &[F]) -> Result, DirichletFromBetaError> { // `alpha_rev_csum` is the reverse of the cumulative sum of the // reverse of `alpha[1..]`. E.g. if `alpha = [a0, a1, a2, a3]`, then // `alpha_rev_csum` is `[a1 + a2 + a3, a2 + a3, a3]`. // Note that instances of DirichletFromBeta will always have N >= 2, // so the subtractions of 1, 2 and 3 from N in the following are safe. - let mut alpha_rev_csum = vec![alpha[N - 1]; N - 1]; - for k in 0..(N - 2) { - alpha_rev_csum[N - 3 - k] = alpha_rev_csum[N - 2 - k] + alpha[N - 2 - k]; + let n = alpha.len(); + let mut alpha_rev_csum = vec![alpha[n - 1]; n - 1]; + for k in 0..(n - 2) { + alpha_rev_csum[n - 3 - k] = alpha_rev_csum[n - 2 - k] + alpha[n - 2 - k]; } // Zip `alpha[..(N-1)]` and `alpha_rev_csum`; for the example @@ -139,7 +135,7 @@ where // Then pass each tuple to `Beta::new()` to create the `Beta` // instances. let mut beta_dists = Vec::new(); - for (&a, &b) in alpha[..(N - 1)].iter().zip(alpha_rev_csum.iter()) { + for (&a, &b) in alpha[..(n - 1)].iter().zip(alpha_rev_csum.iter()) { let dist = Beta::new(a, b).map_err(|_| DirichletFromBetaError::BetaNewFailed)?; beta_dists.push(dist); } @@ -149,15 +145,16 @@ where } } -impl Distribution<[F; N]> for DirichletFromBeta +impl Distribution> for DirichletFromBeta where F: Float, StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution, { - fn sample(&self, rng: &mut R) -> [F; N] { - let mut samples = [F::zero(); N]; + fn sample(&self, rng: &mut R) -> Vec { + let n = self.samplers.len() + 1; + let mut samples = vec![F::zero(); n]; let mut acc = F::one(); for (s, beta) in samples.iter_mut().zip(self.samplers.iter()) { @@ -165,14 +162,14 @@ where *s = acc * beta_sample; acc = acc * (F::one() - beta_sample); } - samples[N - 1] = acc; + samples[n - 1] = acc; samples } } #[derive(Clone, Debug, PartialEq)] #[cfg_attr(feature = "serde", serde_as)] -enum DirichletRepr +enum DirichletRepr where F: Float, StandardNormal: Distribution, @@ -180,10 +177,10 @@ where Open01: Distribution, { /// Dirichlet distribution that generates samples using the Gamma distribution. - FromGamma(DirichletFromGamma), + FromGamma(DirichletFromGamma), /// Dirichlet distribution that generates samples using the Beta distribution. - FromBeta(DirichletFromBeta), + FromBeta(DirichletFromBeta), } /// The [Dirichlet distribution](https://en.wikipedia.org/wiki/Dirichlet_distribution) `Dirichlet(α₁, α₂, ..., αₖ)`. @@ -210,20 +207,20 @@ where /// use rand::prelude::*; /// use rand_distr::Dirichlet; /// -/// let dirichlet = Dirichlet::new([1.0, 2.0, 3.0]).unwrap(); +/// let dirichlet = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap(); /// let samples = dirichlet.sample(&mut rand::rng()); -/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples); +/// println!("{:?} is from a Dirichlet(&[1.0, 2.0, 3.0]) distribution", samples); /// ``` #[cfg_attr(feature = "serde", serde_as)] #[derive(Clone, Debug, PartialEq)] -pub struct Dirichlet +pub struct Dirichlet where F: Float, StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution, { - repr: DirichletRepr, + repr: DirichletRepr, } /// Error type returned from [`Dirichlet::new`]. @@ -268,7 +265,7 @@ impl fmt::Display for Error { #[cfg(feature = "std")] impl std::error::Error for Error {} -impl Dirichlet +impl Dirichlet where F: Float, StandardNormal: Distribution, @@ -280,8 +277,8 @@ where /// Requires `alpha.len() >= 2`, and each value in `alpha` must be positive, /// finite and not subnormal. #[inline] - pub fn new(alpha: [F; N]) -> Result, Error> { - if N < 2 { + pub fn new(alpha: &[F]) -> Result, Error> { + if alpha.len() < 2 { return Err(Error::AlphaTooShort); } for &ai in alpha.iter() { @@ -313,16 +310,46 @@ where }) } } + + /// Construct a new `Dirichlet` with the given shape parameter `alpha` and `size`. + /// + /// Requires `size >= 2`. + #[inline] + pub fn new_with_size(alpha: F, size: usize) -> Result, Error> { + if !(alpha > F::zero()) { + return Err(Error::AlphaTooSmall); + } + if size < 2 { + return Err(Error::SizeTooSmall); + } + if alpha <= NumCast::from(0.1).unwrap() { + // Use the Beta method when alpha is less than 0.1 This + // threshold provides a reasonable compromise between using the faster + // Gamma method for as wide a range as possible while ensuring that + // the probability of generating nans is negligibly small. + let dist = DirichletFromBeta::new(&vec![alpha; size]) + .map_err(|_| Error::FailedToCreateBeta)?; + Ok(Dirichlet { + repr: DirichletRepr::FromBeta(dist), + }) + } else { + let dist = DirichletFromGamma::new(&vec![alpha; size]) + .map_err(|_| Error::FailedToCreateGamma)?; + Ok(Dirichlet { + repr: DirichletRepr::FromGamma(dist), + }) + } + } } -impl Distribution<[F; N]> for Dirichlet +impl Distribution> for Dirichlet where F: Float, StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution, { - fn sample(&self, rng: &mut R) -> [F; N] { + fn sample(&self, rng: &mut R) -> Vec { match &self.repr { DirichletRepr::FromGamma(dirichlet) => dirichlet.sample(rng), DirichletRepr::FromBeta(dirichlet) => dirichlet.sample(rng), @@ -336,7 +363,7 @@ mod test { #[test] fn test_dirichlet() { - let d = Dirichlet::new([1.0, 2.0, 3.0]).unwrap(); + let d = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap(); let mut rng = crate::test::rng(221); let samples = d.sample(&mut rng); assert!(samples.into_iter().all(|x: f64| x > 0.0)); @@ -345,42 +372,42 @@ mod test { #[test] #[should_panic] fn test_dirichlet_invalid_length() { - Dirichlet::new([0.5]).unwrap(); + Dirichlet::new(&[0.5]).unwrap(); } #[test] #[should_panic] fn test_dirichlet_alpha_zero() { - Dirichlet::new([0.1, 0.0, 0.3]).unwrap(); + Dirichlet::new(&[0.1, 0.0, 0.3]).unwrap(); } #[test] #[should_panic] fn test_dirichlet_alpha_negative() { - Dirichlet::new([0.1, -1.5, 0.3]).unwrap(); + Dirichlet::new(&[0.1, -1.5, 0.3]).unwrap(); } #[test] #[should_panic] fn test_dirichlet_alpha_nan() { - Dirichlet::new([0.5, f64::NAN, 0.25]).unwrap(); + Dirichlet::new(&[0.5, f64::NAN, 0.25]).unwrap(); } #[test] #[should_panic] fn test_dirichlet_alpha_subnormal() { - Dirichlet::new([0.5, 1.5e-321, 0.25]).unwrap(); + Dirichlet::new(&[0.5, 1.5e-321, 0.25]).unwrap(); } #[test] #[should_panic] fn test_dirichlet_alpha_inf() { - Dirichlet::new([0.5, f64::INFINITY, 0.25]).unwrap(); + Dirichlet::new(&[0.5, f64::INFINITY, 0.25]).unwrap(); } #[test] fn dirichlet_distributions_can_be_compared() { - assert_eq!(Dirichlet::new([1.0, 2.0]), Dirichlet::new([1.0, 2.0])); + assert_eq!(Dirichlet::new(&[1.0, 2.0]), Dirichlet::new(&[1.0, 2.0])); } /// Check that the means of the components of n samples from @@ -390,7 +417,7 @@ mod test { /// This is a crude statistical test, but it will catch egregious /// mistakes. It will also also fail if any samples contain nan. fn check_dirichlet_means(alpha: [f64; N], n: i32, rtol: f64, seed: u64) { - let d = Dirichlet::new(alpha).unwrap(); + let d = Dirichlet::new(&alpha).unwrap(); let mut rng = crate::test::rng(seed); let mut sums = [0.0; N]; for _ in 0..n { diff --git a/tests/value_stability.rs b/tests/value_stability.rs index 2eb263e..b833c43 100644 --- a/tests/value_stability.rs +++ b/tests/value_stability.rs @@ -502,11 +502,11 @@ fn weibull_stability() { fn dirichlet_stability() { let mut rng = get_rng(223); assert_eq!( - rng.sample(Dirichlet::new([1.0, 2.0, 3.0]).unwrap()), + rng.sample(Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap()), [0.12941567177708177, 0.4702121891675036, 0.4003721390554146] ); assert_eq!( - rng.sample(Dirichlet::new([8.0; 5]).unwrap()), + rng.sample(Dirichlet::new(&[8.0; 5]).unwrap()), [ 0.17684200044809556, 0.29915953935953055, @@ -517,7 +517,7 @@ fn dirichlet_stability() { ); // Test stability for the case where all alphas are less than 0.1. assert_eq!( - rng.sample(Dirichlet::new([0.05, 0.025, 0.075, 0.05]).unwrap()), + rng.sample(Dirichlet::new(&[0.05, 0.025, 0.075, 0.05]).unwrap()), [ 0.00027580456855692104, 2.296135759821706e-20,