Skip to content

Commit 187b7d1

Browse files
authored
Merge pull request #485 from rohitjoshi/master
Support for Dirichlet distribution
2 parents ec3d7ef + fde9567 commit 187b7d1

File tree

2 files changed

+145
-2
lines changed

2 files changed

+145
-2
lines changed

src/distributions/dirichlet.rs

+138
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
// Copyright 2013 The Rust Project Developers. See the COPYRIGHT
2+
// file at the top-level directory of this distribution and at
3+
// https://rust-lang.org/COPYRIGHT.
4+
//
5+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6+
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7+
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8+
// option. This file may not be copied, modified, or distributed
9+
// except according to those terms.
10+
11+
//! The dirichlet distribution.
12+
13+
use Rng;
14+
use distributions::Distribution;
15+
use distributions::gamma::Gamma;
16+
17+
/// The dirichelet distribution `Dirichlet(alpha)`.
18+
///
19+
/// The Dirichlet distribution is a family of continuous multivariate probability distributions parameterized by
20+
/// a vector alpha of positive reals. https://en.wikipedia.org/wiki/Dirichlet_distribution
21+
/// It is a multivariate generalization of the beta distribution.
22+
///
23+
/// # Example
24+
///
25+
/// ```
26+
/// use rand::prelude::*;
27+
/// use rand::distributions::Dirichlet;
28+
///
29+
/// let dirichlet = Dirichlet::new(vec![1.0, 2.0, 3.0]);
30+
/// let samples = dirichlet.sample(&mut rand::thread_rng());
31+
/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples);
32+
/// ```
33+
34+
#[derive(Clone, Debug)]
35+
pub struct Dirichlet {
36+
/// Concentration parameters (alpha)
37+
alpha: Vec<f64>,
38+
}
39+
40+
impl Dirichlet {
41+
/// Construct a new `Dirichlet` with the given alpha parameter `alpha`.
42+
///
43+
/// # Panics
44+
/// - if `alpha.len() < 2`
45+
///
46+
#[inline]
47+
pub fn new<V: Into<Vec<f64>>>(alpha: V) -> Dirichlet {
48+
let a = alpha.into();
49+
assert!(a.len() > 1);
50+
for i in 0..a.len() {
51+
assert!(a[i] > 0.0);
52+
}
53+
54+
Dirichlet { alpha: a }
55+
}
56+
57+
/// Construct a new `Dirichlet` with the given shape parameter `alpha` and `size`.
58+
///
59+
/// # Panics
60+
/// - if `alpha <= 0.0`
61+
/// - if `size < 2`
62+
///
63+
#[inline]
64+
pub fn new_with_param(alpha: f64, size: usize) -> Dirichlet {
65+
assert!(alpha > 0.0);
66+
assert!(size > 1);
67+
Dirichlet {
68+
alpha: vec![alpha; size],
69+
}
70+
}
71+
}
72+
73+
impl Distribution<Vec<f64>> for Dirichlet {
74+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<f64> {
75+
let n = self.alpha.len();
76+
let mut samples = vec![0.0f64; n];
77+
let mut sum = 0.0f64;
78+
79+
for i in 0..n {
80+
let g = Gamma::new(self.alpha[i], 1.0);
81+
samples[i] = g.sample(rng);
82+
sum += samples[i];
83+
}
84+
let invacc = 1.0 / sum;
85+
for i in 0..n {
86+
samples[i] *= invacc;
87+
}
88+
samples
89+
}
90+
}
91+
92+
#[cfg(test)]
93+
mod test {
94+
use super::Dirichlet;
95+
use distributions::Distribution;
96+
97+
#[test]
98+
fn test_dirichlet() {
99+
let d = Dirichlet::new(vec![1.0, 2.0, 3.0]);
100+
let mut rng = ::test::rng(221);
101+
let samples = d.sample(&mut rng);
102+
let _: Vec<f64> = samples
103+
.into_iter()
104+
.map(|x| {
105+
assert!(x > 0.0);
106+
x
107+
})
108+
.collect();
109+
}
110+
111+
#[test]
112+
fn test_dirichlet_with_param() {
113+
let alpha = 0.5f64;
114+
let size = 2;
115+
let d = Dirichlet::new_with_param(alpha, size);
116+
let mut rng = ::test::rng(221);
117+
let samples = d.sample(&mut rng);
118+
let _: Vec<f64> = samples
119+
.into_iter()
120+
.map(|x| {
121+
assert!(x > 0.0);
122+
x
123+
})
124+
.collect();
125+
}
126+
127+
#[test]
128+
#[should_panic]
129+
fn test_dirichlet_invalid_length() {
130+
Dirichlet::new_with_param(0.5f64, 1);
131+
}
132+
133+
#[test]
134+
#[should_panic]
135+
fn test_dirichlet_invalid_alpha() {
136+
Dirichlet::new_with_param(0.0f64, 2);
137+
}
138+
}

src/distributions/mod.rs

+7-2
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@
8181
//! - Related to real-valued quantities that grow linearly
8282
//! (e.g. errors, offsets):
8383
//! - [`Normal`] distribution, and [`StandardNormal`] as a primitive
84-
//! - [`Cauchy`] distribution
8584
//! - Related to Bernoulli trials (yes/no events, with a given probability):
8685
//! - [`Binomial`] distribution
8786
//! - [`Bernoulli`] distribution, similar to [`Rng::gen_bool`].
@@ -96,7 +95,8 @@
9695
//! - [`ChiSquared`] distribution
9796
//! - [`StudentT`] distribution
9897
//! - [`FisherF`] distribution
99-
//!
98+
//! - Related to continuous multivariate probability distributions
99+
//! - [`Dirichlet`] distribution
100100
//!
101101
//! # Examples
102102
//!
@@ -150,6 +150,7 @@
150150
//! [`Binomial`]: struct.Binomial.html
151151
//! [`Cauchy`]: struct.Cauchy.html
152152
//! [`ChiSquared`]: struct.ChiSquared.html
153+
//! [`Dirichlet`]: struct.Dirichlet.html
153154
//! [`Exp`]: struct.Exp.html
154155
//! [`Exp1`]: struct.Exp1.html
155156
//! [`FisherF`]: struct.FisherF.html
@@ -185,6 +186,8 @@ use Rng;
185186
#[doc(inline)] pub use self::bernoulli::Bernoulli;
186187
#[cfg(feature = "std")]
187188
#[doc(inline)] pub use self::cauchy::Cauchy;
189+
#[cfg(feature = "std")]
190+
#[doc(inline)] pub use self::dirichlet::Dirichlet;
188191

189192
pub mod uniform;
190193
#[cfg(feature="std")]
@@ -202,6 +205,8 @@ pub mod uniform;
202205
#[doc(hidden)] pub mod bernoulli;
203206
#[cfg(feature = "std")]
204207
#[doc(hidden)] pub mod cauchy;
208+
#[cfg(feature = "std")]
209+
#[doc(hidden)] pub mod dirichlet;
205210

206211
mod float;
207212
mod integer;

0 commit comments

Comments
 (0)