33//! Simplistic MCMC ensemble sampler based on [emcee](https://emcee.readthedocs.io/), the MCMC hammer
44//!
55//! ```
6- //! use hammer_and_sample::{sample, MinChainLen, Model, Serial};
6+ //! use hammer_and_sample::{sample, MinChainLen, Model, Serial, Stretch };
77//! use rand::{Rng, SeedableRng};
88//! use rand_pcg::Pcg64;
99//!
3939//! ([p], rng)
4040//! });
4141//!
42- //! let (chain, _accepted) = sample(&model, walkers, MinChainLen(10 * 1000), Serial);
42+ //! let (chain, _accepted) = sample(&model, &Stretch::default(), walkers, MinChainLen(10 * 1000), Serial);
4343//!
4444//! // 100 iterations of 10 walkers as burn-in
4545//! let chain = &chain[10 * 100..];
4848//! }
4949//! ```
5050use std:: ops:: ControlFlow ;
51+ use std:: ptr;
5152
5253use rand:: {
5354 distr:: { Distribution , StandardUniform , Uniform } ,
5455 Rng ,
5556} ;
57+ use rand_distr:: {
58+ weighted:: { AliasableWeight , WeightedAliasIndex } ,
59+ Normal ,
60+ } ;
5661#[ cfg( feature = "rayon" ) ]
5762use rayon:: iter:: { IntoParallelRefMutIterator , ParallelExtend , ParallelIterator } ;
5863
@@ -117,6 +122,209 @@ impl Params for Box<[f64]> {
117122 }
118123}
119124
125+ /// TODO
126+ pub trait Move < M >
127+ where
128+ M : Model ,
129+ {
130+ /// TODO
131+ fn propose < ' a , O , R > ( & self , self_ : & ' a M :: Params , other : O , rng : & mut R ) -> ( M :: Params , f64 )
132+ where
133+ O : FnMut ( & mut R ) -> & ' a M :: Params ,
134+ R : Rng ;
135+ }
136+
137+ /// TODO
138+ pub struct Stretch {
139+ scale : f64 ,
140+ }
141+
142+ impl Stretch {
143+ /// TODO
144+ pub fn new ( scale : f64 ) -> Self {
145+ Self { scale }
146+ }
147+ }
148+
149+ impl Default for Stretch {
150+ fn default ( ) -> Self {
151+ Self :: new ( 2. )
152+ }
153+ }
154+
155+ impl < M > Move < M > for Stretch
156+ where
157+ M : Model ,
158+ {
159+ fn propose < ' a , O , R > ( & self , self_ : & ' a M :: Params , mut other : O , rng : & mut R ) -> ( M :: Params , f64 )
160+ where
161+ O : FnMut ( & mut R ) -> & ' a M :: Params ,
162+ R : Rng ,
163+ {
164+ let other = other ( rng) ;
165+
166+ let z = ( ( self . scale - 1. ) * gen_unit ( rng) + 1. ) . powi ( 2 ) / self . scale ;
167+
168+ let new_state = M :: Params :: collect (
169+ self_
170+ . values ( )
171+ . zip ( other. values ( ) )
172+ . map ( |( self_, other) | other - z * ( other - self_) ) ,
173+ ) ;
174+
175+ let factor = ( new_state. dimension ( ) - 1 ) as f64 * z. ln ( ) ;
176+
177+ ( new_state, factor)
178+ }
179+ }
180+
181+ /// TODO
182+ pub struct DifferentialEvolution {
183+ gamma : Normal < f64 > ,
184+ }
185+
186+ impl DifferentialEvolution {
187+ /// TODO
188+ pub fn new ( gamma_mean : f64 , gamma_std_dev : f64 ) -> Self {
189+ Self {
190+ gamma : Normal :: new ( gamma_mean, gamma_std_dev) . unwrap ( ) ,
191+ }
192+ }
193+ }
194+
195+ impl < M > Move < M > for DifferentialEvolution
196+ where
197+ M : Model ,
198+ {
199+ fn propose < ' a , O , R > ( & self , self_ : & ' a M :: Params , mut other : O , rng : & mut R ) -> ( M :: Params , f64 )
200+ where
201+ O : FnMut ( & mut R ) -> & ' a M :: Params ,
202+ R : Rng ,
203+ {
204+ let first_other = other ( rng) ;
205+ let mut second_other = other ( rng) ;
206+
207+ while ptr:: eq ( first_other, second_other) {
208+ second_other = other ( rng) ;
209+ }
210+
211+ let gamma = self . gamma . sample ( rng) ;
212+
213+ let new_state = M :: Params :: collect (
214+ self_
215+ . values ( )
216+ . zip ( first_other. values ( ) )
217+ . zip ( second_other. values ( ) )
218+ . map ( |( ( self_, first_other) , second_other) | {
219+ self_ + gamma * ( first_other - second_other)
220+ } ) ,
221+ ) ;
222+
223+ ( new_state, 0. )
224+ }
225+ }
226+
227+ /// TODO
228+ pub struct RandomGaussian {
229+ dist : Normal < f64 > ,
230+ }
231+
232+ impl RandomGaussian {
233+ /// TODO
234+ pub fn new ( scale : f64 ) -> Self {
235+ Self {
236+ dist : Normal :: new ( 0. , scale) . unwrap ( ) ,
237+ }
238+ }
239+ }
240+
241+ impl < M > Move < M > for RandomGaussian
242+ where
243+ M : Model ,
244+ {
245+ fn propose < ' a , O , R > ( & self , self_ : & ' a M :: Params , _other : O , rng : & mut R ) -> ( M :: Params , f64 )
246+ where
247+ O : FnMut ( & mut R ) -> & ' a M :: Params ,
248+ R : Rng ,
249+ {
250+ let dir = rng. random_range ( 0 ..self_. dimension ( ) ) ;
251+
252+ let new_state = M :: Params :: collect ( self_. values ( ) . enumerate ( ) . map ( |( idx, value) | {
253+ if idx == dir {
254+ value + self . dist . sample ( rng)
255+ } else {
256+ * value
257+ }
258+ } ) ) ;
259+
260+ ( new_state, 0. )
261+ }
262+ }
263+
264+ /// TODO
265+ pub struct Mixture < W , M > ( WeightedAliasIndex < W > , M )
266+ where
267+ W : AliasableWeight ;
268+
269+ macro_rules! impl_mixture {
270+ ( $( $types: ident @ $weights: ident) ,+ ) => {
271+ impl <W , $( $types ) ,+> From <( $( ( $types, W ) ) ,+ ) > for Mixture <W , ( $( $types ) ,+ ) >
272+ where
273+ W : AliasableWeight
274+ {
275+ #[ allow( non_snake_case) ]
276+ fn from( ( $( ( $types, $weights ) ) ,+ ) : ( $( ( $types, W ) ) ,+ ) ) -> Self {
277+ let index = WeightedAliasIndex :: new( vec![ $( $weights ) ,+] ) . unwrap( ) ;
278+
279+ Self ( index, ( $( $types ) ,+ ) )
280+ }
281+ }
282+
283+ impl <W , $( $types ) ,+, M > Move <M > for Mixture <W , ( $( $types ) ,+ ) >
284+ where
285+ W : AliasableWeight ,
286+ M : Model ,
287+ $( $types: Move <M > ) ,+
288+ {
289+ #[ allow( non_snake_case) ]
290+ fn propose<' a, O , R >( & self , self_: & ' a M :: Params , other: O , rng: & mut R ) -> ( M :: Params , f64 )
291+ where
292+ O : FnMut ( & mut R ) -> & ' a M :: Params ,
293+ R : Rng ,
294+ {
295+ let Self ( index, ( $( $types ) ,+ ) ) = self ;
296+
297+ let chosen_index = index. sample( rng) ;
298+
299+ let mut index = 0 ;
300+
301+ $(
302+
303+ #[ allow( unused_assignments) ]
304+ if chosen_index == index {
305+ return $types. propose( self_, other, rng)
306+ } else {
307+ index += 1 ;
308+ }
309+
310+ ) +
311+
312+ unreachable!( )
313+ }
314+ }
315+ } ;
316+ }
317+
318+ impl_mixture ! ( A @ a, B @ b) ;
319+ impl_mixture ! ( A @ a, B @ b, C @ c) ;
320+ impl_mixture ! ( A @ a, B @ b, C @ c, D @ d) ;
321+ impl_mixture ! ( A @ a, B @ b, C @ c, D @ d, E @ e) ;
322+ impl_mixture ! ( A @ a, B @ b, C @ c, D @ d, E @ e, F @ f) ;
323+ impl_mixture ! ( A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g) ;
324+ impl_mixture ! ( A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g, H @ h) ;
325+ impl_mixture ! ( A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g, H @ h, I @ i) ;
326+ impl_mixture ! ( A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g, H @ h, I @ i, J @ j) ;
327+
120328/// Models are defined by the type of their parameters and their probability functions
121329pub trait Model : Send + Sync {
122330 /// Type used to store the model parameters, e.g. `[f64; N]` or `Vec<f64>`
@@ -126,9 +334,6 @@ pub trait Model: Send + Sync {
126334 ///
127335 /// The sampler will only ever consider differences of these values, i.e. any addititive constant that does _not_ depend on `state` can be omitted when computing them.
128336 fn log_prob ( & self , state : & Self :: Params ) -> f64 ;
129-
130- /// Scale parameter for stretch moves
131- const SCALE : f64 = 2. ;
132337}
133338
134339/// Runs the sampler on the given [`model`][Model] using the chosen [`schedule`][Schedule] and [`execution`][Execution] strategy
@@ -138,17 +343,19 @@ pub trait Model: Send + Sync {
138343/// The number of walkers must be non-zero, even and at least twice the number of parameters.
139344///
140345/// A vector of samples and the number of accepted moves are returned.
141- pub fn sample < M , W , R , S , E > (
142- model : & M ,
346+ pub fn sample < MD , MV , W , R , S , E > (
347+ model : & MD ,
348+ move_ : & MV ,
143349 walkers : W ,
144350 mut schedule : S ,
145351 execution : E ,
146- ) -> ( Vec < M :: Params > , usize )
352+ ) -> ( Vec < MD :: Params > , usize )
147353where
148- M : Model ,
149- W : Iterator < Item = ( M :: Params , R ) > ,
354+ MD : Model ,
355+ MV : Move < MD > + Send + Sync ,
356+ W : Iterator < Item = ( MD :: Params , R ) > ,
150357 R : Rng + Send + Sync ,
151- S : Schedule < M :: Params > ,
358+ S : Schedule < MD :: Params > ,
152359 E : Execution ,
153360{
154361 let mut walkers = walkers
@@ -166,10 +373,8 @@ where
166373
167374 let random_index = Uniform :: new ( 0 , half) . unwrap ( ) ;
168375
169- let update_walker = move |walker : & mut Walker < M , R > , other_walkers : & [ Walker < M , R > ] | {
170- let other = & other_walkers[ random_index. sample ( & mut walker. rng ) ] ;
171-
172- walker. move_ ( model, other)
376+ let update_walker = move |walker : & mut Walker < MD , R > , other_walkers : & [ Walker < MD , R > ] | {
377+ walker. move_ ( model, move_, |rng| & other_walkers[ random_index. sample ( rng) ] )
173378 } ;
174379
175380 while schedule. next_step ( & chain) . is_continue ( ) {
@@ -187,22 +392,22 @@ where
187392 ( chain, accepted)
188393}
189394
190- struct Walker < M , R >
395+ struct Walker < MD , R >
191396where
192- M : Model ,
397+ MD : Model ,
193398{
194- state : M :: Params ,
399+ state : MD :: Params ,
195400 log_prob : f64 ,
196401 rng : R ,
197402 accepted : usize ,
198403}
199404
200- impl < M , R > Walker < M , R >
405+ impl < MD , R > Walker < MD , R >
201406where
202- M : Model ,
407+ MD : Model ,
203408 R : Rng ,
204409{
205- fn new ( model : & M , state : M :: Params , rng : R ) -> Self {
410+ fn new ( model : & MD , state : MD :: Params , rng : R ) -> Self {
206411 let log_prob = model. log_prob ( & state) ;
207412
208413 Self {
@@ -213,20 +418,17 @@ where
213418 }
214419 }
215420
216- fn move_ ( & mut self , model : & M , other : & Self ) -> M :: Params {
217- let z = ( ( M :: SCALE - 1. ) * gen_unit ( & mut self . rng ) + 1. ) . powi ( 2 ) / M :: SCALE ;
218-
219- let mut new_state = M :: Params :: collect (
220- self . state
221- . values ( )
222- . zip ( other. state . values ( ) )
223- . map ( |( self_, other) | other - z * ( other - self_) ) ,
224- ) ;
421+ fn move_ < ' a , MV , O > ( & ' a mut self , model : & MD , move_ : & MV , mut other : O ) -> MD :: Params
422+ where
423+ MV : Move < MD > ,
424+ O : FnMut ( & mut R ) -> & ' a Self ,
425+ {
426+ let ( mut new_state, factor) =
427+ move_. propose ( & self . state , |rng| & other ( rng) . state , & mut self . rng ) ;
225428
226429 let new_log_prob = model. log_prob ( & new_state) ;
227430
228- let log_prob_diff =
229- ( new_state. dimension ( ) - 1 ) as f64 * z. ln ( ) + new_log_prob - self . log_prob ;
431+ let log_prob_diff = factor + new_log_prob - self . log_prob ;
230432
231433 if log_prob_diff > gen_unit ( & mut self . rng ) . ln ( ) {
232434 self . state . clone_from ( & new_state) ;
@@ -380,7 +582,7 @@ where
380582/// Runs the inner `schedule` after calling the given `callback`
381583///
382584/// ```
383- /// # use hammer_and_sample::{sample, MinChainLen, Model, Schedule, Serial, WithProgress};
585+ /// # use hammer_and_sample::{sample, MinChainLen, Model, Schedule, Serial, Stretch, WithProgress};
384586/// # use rand::SeedableRng;
385587/// # use rand_pcg::Pcg64Mcg;
386588/// #
@@ -407,7 +609,7 @@ where
407609/// callback: |chain: &[_]| eprintln!("{} %", 100 * chain.len() / 100_000),
408610/// };
409611///
410- /// let (chain, accepted) = sample(&model, walkers, schedule, Serial);
612+ /// let (chain, accepted) = sample(&model, &Stretch::default(), walkers, schedule, Serial);
411613/// ```
412614pub struct WithProgress < S , C > {
413615 /// The inner schedule which determines the number of iterations
0 commit comments