diff --git a/benchmark/Single.hs b/benchmark/Single.hs index d7473b4c..3c1f749b 100644 --- a/benchmark/Single.hs +++ b/benchmark/Single.hs @@ -1,92 +1,113 @@ {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE ImportQualifiedPost #-} -import Control.Monad.Bayes.Class +import Control.Monad.Bayes.Class (MonadInfer) import Control.Monad.Bayes.Inference.MCMC (MCMCConfig (..), Proposal (SingleSiteMH)) -import Control.Monad.Bayes.Inference.RMSMC +import Control.Monad.Bayes.Inference.RMSMC (rmsmcBasic) import Control.Monad.Bayes.Inference.SMC + ( SMCConfig (SMCConfig, numParticles, numSteps, resampler), + smc, + ) import Control.Monad.Bayes.Population -import Control.Monad.Bayes.Population (population) + ( population, + resampleSystematic, + ) import Control.Monad.Bayes.Sampler -import Control.Monad.Bayes.Traced -import Control.Monad.Bayes.Weighted + ( Sampler, + sampleSTfixed, + sampleWith, + ) +import Control.Monad.Bayes.Traced (mh) +import Control.Monad.Bayes.Weighted (unweighted) import Control.Monad.ST (runST) -import Data.Time +import Data.Time (diffUTCTime, getCurrentTime) import HMM qualified import LDA qualified import LogReg qualified import Options.Applicative + ( Applicative (liftA2), + ParserInfo, + auto, + execParser, + fullDesc, + help, + info, + long, + maybeReader, + option, + short, + ) import System.Random.MWC (GenIO, createSystemRandom) -data Model = LR Int | HMM Int | LDA (Int, Int) - deriving stock (Show, Read) +-- data Model = LR Int | HMM Int | LDA (Int, Int) +-- deriving stock (Show, Read) -parseModel :: String -> Maybe Model -parseModel s = - case s of - 'L' : 'R' : n -> Just $ LR (read n) - 'H' : 'M' : 'M' : n -> Just $ HMM (read n) - 'L' : 'D' : 'A' : n -> Just $ LDA (5, read n) - _ -> Nothing +-- parseModel :: String -> Maybe Model +-- parseModel s = +-- case s of +-- 'L' : 'R' : n -> Just $ LR (read n) +-- 'H' : 'M' : 'M' : n -> Just $ HMM (read n) +-- 'L' : 'D' : 'A' : n -> Just $ LDA (5, read n) +-- _ -> Nothing -getModel :: MonadInfer m => Model -> (Int, m String) -getModel model = (size model, program model) - where - size (LR n) = n - size (HMM n) = n - size (LDA (d, w)) = d * w - program (LR n) = show <$> (LogReg.logisticRegression (runST $ sampleSTfixed (LogReg.syntheticData n))) - program (HMM n) = show <$> (HMM.hmm (runST $ sampleSTfixed (HMM.syntheticData n))) - program (LDA (d, w)) = show <$> (LDA.lda (runST $ sampleSTfixed (LDA.syntheticData d w))) +-- getModel :: MonadInfer m => Model -> (Int, m String) +-- getModel model = (size model, program model) +-- where +-- size (LR n) = n +-- size (HMM n) = n +-- size (LDA (d, w)) = d * w +-- program (LR n) = show <$> (LogReg.logisticRegression (runST $ sampleSTfixed (LogReg.syntheticData n))) +-- program (HMM n) = show <$> (HMM.hmm (runST $ sampleSTfixed (HMM.syntheticData n))) +-- program (LDA (d, w)) = show <$> (LDA.lda (runST $ sampleSTfixed (LDA.syntheticData d w))) -data Alg = SMC | MH | RMSMC - deriving stock (Read, Show) +-- data Alg = SMC | MH | RMSMC +-- deriving stock (Read, Show) -runAlg :: Model -> Alg -> Sampler GenIO IO String -runAlg model alg = - case alg of - SMC -> - let n = 100 - (k, m) = getModel model - in show <$> population (smc SMCConfig {numSteps = k, numParticles = n, resampler = resampleSystematic} m) - MH -> - let t = 100 - (_, m) = getModel model - in show <$> unweighted (mh t m) - RMSMC -> - let n = 10 - t = 1 - (k, m) = getModel model - in show <$> population (rmsmcBasic MCMCConfig {numMCMCSteps = t, numBurnIn = 0, proposal = SingleSiteMH} (SMCConfig {numSteps = k, numParticles = n, resampler = resampleSystematic}) m) +-- runAlg :: Model -> Alg -> Sampler GenIO IO String +-- runAlg model alg = +-- case alg of +-- SMC -> +-- let n = 100 +-- (k, m) = getModel model +-- in show <$> population (smc SMCConfig {numSteps = k, numParticles = n, resampler = resampleSystematic} m) +-- MH -> +-- let t = 100 +-- (_, m) = getModel model +-- in show <$> unweighted (mh t m) +-- RMSMC -> +-- let n = 10 +-- t = 1 +-- (k, m) = getModel model +-- in show <$> population (rmsmcBasic MCMCConfig {numMCMCSteps = t, numBurnIn = 0, proposal = SingleSiteMH} (SMCConfig {numSteps = k, numParticles = n, resampler = resampleSystematic}) m) -infer :: Model -> Alg -> IO () -infer model alg = do - g <- createSystemRandom - x <- sampleWith (runAlg model alg) g - print x +-- infer :: Model -> Alg -> IO () +-- infer model alg = do +-- g <- createSystemRandom +-- x <- sampleWith (runAlg model alg) g +-- print x -opts :: ParserInfo (Model, Alg) -opts = flip info fullDesc $ liftA2 (,) model alg - where - model = - option - (maybeReader parseModel) - ( long "model" - <> short 'm' - <> help "Model" - ) - alg = - option - auto - ( long "alg" - <> short 'a' - <> help "Inference algorithm" - ) +-- opts :: ParserInfo (Model, Alg) +-- opts = flip info fullDesc $ liftA2 (,) model alg +-- where +-- model = +-- option +-- (maybeReader parseModel) +-- ( long "model" +-- <> short 'm' +-- <> help "Model" +-- ) +-- alg = +-- option +-- auto +-- ( long "alg" +-- <> short 'a' +-- <> help "Inference algorithm" +-- ) -main :: IO () -main = do - (model, alg) <- execParser opts - startTime <- getCurrentTime - infer model alg - endTime <- getCurrentTime - print (diffUTCTime endTime startTime) +-- main :: IO () +-- main = do +-- (model, alg) <- execParser opts +-- startTime <- getCurrentTime +-- infer model alg +-- endTime <- getCurrentTime +-- print (diffUTCTime endTime startTime) diff --git a/docs/source/usage.md b/docs/source/usage.md index 8b4ac610..b7272ccd 100644 --- a/docs/source/usage.md +++ b/docs/source/usage.md @@ -391,6 +391,7 @@ Summary of key info on `Sequential`: - `instance MonadSample m => instance MonadSample (Sequential m)` - `instance MonadCond m => instance MonadCond (Sequential m)` + ```haskell newtype Sequential m a = Sequential {runSequential :: Coroutine (Await ()) m a} @@ -474,34 +475,36 @@ hoistFirst :: (forall x. m x -> m x) -> Sequential m a -> Sequential m a hoistFirst f = Sequential . Coroutine . f . resume . runSequential ``` - - When `m` is `Population n` for some other `n`, then `resampleGeneric` gives us one example of the natural transformation we want. In other words, operating in `Sequential (Population n)` works, and not only works but does something statistically interesting: particle filtering (aka SMC). -### FreeSampler +### Density -Summary of key info on `FreeSampler`: +Summary of key info on `Density`: -- `FreeSampler :: (Type -> Type) -> (Type -> Type)` -- `instance MonadSample (FreeSampler m)` +- `Density :: (Type -> Type) -> (Type -> Type)` +- `instance MonadSample (Density m)` - **No** instance for `MonadCond` -`FreeSampler m` is not often going to be used on its own, but instead as part of the `Traced` type, defined below. A `FreeSampler m a` represents a reified execution of the program. +A *trace* of a program of type `MonadSample m => m a` is an execution of the program, so a choice for each of the random values. Recall that `random` underlies all of the random values in a program, so a trace for a program is fully specified by a list of `Double`s, giving the value of each call to `random`. + +With this in mind, a `Density m a` is an interpretation of a probabilistic program as a function from a trace to the *density* of that execution of the program. + +Monad-bayes offers two implementations, in `Control.Monad.Bayes.Density.State` and `Control.Monad.Bayes.Density.Free`. The first is slow but easy to understand, the second is more sophisticated, but faster. -`FreeSampler m` is best understood if you're familiar with the standard use of a free monad to construct a domain specific language. For probability in particular, see this [blog post](https://jtobin.io/simple-probabilistic-programming). Here's the definition: +The former is relatively straightforward: the `MonadSample` instance implements `random` as `get`ting the trace (using `get` from `MonadState`), using (and removing) the first element (`put` from `MonadState`), and writing that element to the output (using `tell` from `MonadWriter`). If the trace is empty, the `random` from the underlying monad is used, but the result is still written with `tell`. + +The latter is best understood if you're familiar with the standard use of a free monad to construct a domain specific language. For probability in particular, see this [blog post](https://jtobin.io/simple-probabilistic-programming). Here's the definition: ```haskell newtype SamF a = Random (Double -> a) -newtype FreeSampler m a = - FreeSampler {runFreeSampler :: FT SamF m a} +newtype Density m a = + Density {density :: FT SamF m a} -instance Monad m => MonadSample (FreeSampler m) where - random = FreeSampler $ liftF (Random id) +instance Monad m => MonadSample (Density m) where + random = Density $ liftF (Random id) ``` The monad-bayes implementation uses a more efficient implementation of `FreeT`, namely `FT` from the `free` package, known as the *Church transformed Free monad*. This is a technique explained in https://begriffs.com/posts/2016-02-04-difference-lists-and-codennsity.html. But that only changes the operational semantics - performance aside, it works just the same as the standard `FreeT` datatype. @@ -509,21 +512,16 @@ The monad-bayes implementation uses a more efficient implementation of `FreeT`, If you unpack the definition, you get: ```haskell -FreeSampler m a ~ m (Either a (Double -> (FreeSampler m a))) +Density m a ~ m (Either a (Double -> (Density m a))) ``` -As you can see, this is rather like `Coroutine`, except to "resume", you must provide a new `Double`, corresponding to the value of some particular random choice. - Since `FreeT` is a transformer, we can use `lift` to get a `MonadSample` instance. - -A *trace* of a program of type `MonadSample m => m a` is an execution of the program, so a choice for each of the random values. Recall that `random` underlies all of the random values in a program, so a trace for a program is fully specified by a list of `Double`s, giving the value of each call to `random`. - -Given a probabilistic program interpreted in `FreeSampler m`, we can "run" it to produce a program in the underlying monad `m`. For simplicity, consider the case of a program `bernoulli 0.5 :: FreeSampler SamplerIO Bool`. We can then use the following function: +`density` is then defined using the canonical property of the free monad (transformer), embodied by `iterFT`, which interprets `SamF` in the appropriate way: ```haskell -withPartialRandomness :: MonadSample m => [Double] -> FreeSampler m a -> m (a, [Double]) -withPartialRandomness randomness (FreeSampler m) = +density :: MonadSample m => [Double] -> Density m a -> m (a, [Double]) +density randomness (Density m) = runWriterT $ evalStateT (iterTM f $ hoistFT lift m) randomness where f (Random k) = do @@ -538,7 +536,7 @@ withPartialRandomness randomness (FreeSampler m) = k x ``` -This takes a list of `Double`s (a representation of a trace), and a probabilistic program like `example`, and gives back a `SamplerIO (Bool, [Double])`. At each call to `random` in `example`, the next double in the list is used. If the list of doubles runs out, calls are made to `random` using the underlying monad, which in our example is `SamplerIO`. Hence "with*Partial*Randomness". +This takes a list of `Double`s (a representation of a trace), and a probabilistic program like `example`, and gives back a `SamplerIO (Bool, [Double])`. At each call to `random` in `example`, the next double in the list is used. If the list of doubles runs out, calls are made to `random` using the underlying monad. @@ -554,7 +552,7 @@ Summary of key info on `Traced`: - `instance MonadSample m => MonadSample (Traced m)` - `instance MonadCond m => MonadCond (Traced m)` -`Traced m` is actually several related interpretations, each built on top of `FreeSampler`. These range in complexity. +`Traced m` is actually several related interpretations, each built on top of `Density`. These range in complexity. @@ -576,12 +574,12 @@ data Trace a = Trace } ``` -We also need a specification of the probabilistic program in question, free of any particular interpretation. That is precisely what `FreeSampler` is for. +We also need a specification of the probabilistic program in question, free of any particular interpretation. That is precisely what `Density` is for. The simplest version of `Traced` is in `Control.Monad.Bayes.Traced.Basic` ```haskell -Traced m a ~ (FreeSampler Identity a, Log Double), m (Trace a)) +Traced m a ~ (Density Identity a, Log Double), m (Trace a)) ``` A `Traced` interpretation of a model is a particular run of the model with its corresponding probability, alongside a distribution over `Trace` info, which records: the value of each call to `random`, the value of the final output, and the density of this program trace. @@ -707,7 +705,7 @@ A single step in this chain (in Metropolis Hasting MCMC) looks like this: ```haskell mhTrans :: MonadSample m => - Weighted (FreeSampler m) a -> Trace a -> m (Trace a) + Weighted (Density m) a -> Trace a -> m (Trace a) mhTrans m t@Trace {variables = us, density = p} = do let n = length us us' <- do @@ -717,15 +715,14 @@ mhTrans m t@Trace {variables = us, density = p} = do (xs, _ : ys) -> return $ xs ++ (u' : ys) _ -> error "impossible" ((b, q), vs) <- - runWriterT $ weighted - $ Weighted.hoist (WriterT . withPartialRandomness us') m + runWriterT $ weighted $ Weighted.hoist (WriterT . density us') m let ratio = (exp . ln) $ min 1 (q * fromIntegral n / (p * fromIntegral (length vs))) accept <- bernoulli ratio return $ if accept then Trace vs b q else t ``` -Our probabilistic program is interpreted in the type `Weighted (FreeSampler m) a`, which is an instance of `MonadInfer`. We use this to define our kernel on traces. We begin by perturbing the list of doubles contained in the trace by selecting a random position in the list and resampling there. We could do this *proposal* in a variety of ways, but here, we do so by choosing a double from the list at random and resampling it (hence, *single site* trace MCMC). We then run the program on this new list of doubles; `((b,q), vs)` is the outcome, probability, and result of all calls to `random`, respectively (recalling that the list of doubles may be shorter than the number of calls to `random`). The value of these is probabilistic in the underlying monad `m`. We then use the MH criterion to decide whether to accept the new list of doubles as our trace. +Our probabilistic program is interpreted in the type `Weighted (Density m) a`, which is an instance of `MonadInfer`. We use this to define our kernel on traces. We begin by perturbing the list of doubles contained in the trace by selecting a random position in the list and resampling there. We could do this *proposal* in a variety of ways, but here, we do so by choosing a double from the list at random and resampling it (hence, *single site* trace MCMC). We then run the program on this new list of doubles; `((b,q), vs)` is the outcome, probability, and result of all calls to `random`, respectively (recalling that the list of doubles may be shorter than the number of calls to `random`). The value of these is probabilistic in the underlying monad `m`. We then use the MH criterion to decide whether to accept the new list of doubles as our trace. MH is then easily defined as taking steps with this kernel, in the usual fashion. Note that it works for any probabilistic program whatsoever. @@ -736,7 +733,7 @@ MH is then easily defined as taking steps with this kernel, in the usual fashion This is provided by ```haskell -sis :: +sequentially :: Monad m => -- | transformation (forall x. m x -> m x) -> @@ -744,10 +741,10 @@ sis :: Int -> Sequential m a -> m a -sis f k = finish . composeCopies k (advance . hoistFirst f) +sequentially f k = finish . composeCopies k (advance . hoistFirst f) ``` -in Control.Monad.Bayes.Sequential. You provide a natural transformation in the underlying monad `m`, and `sis` applies that natural transformation at each point of conditioning in your program. The main use case is in defining `smc`, below, but here is a nice alternative use case: +in `Control.Monad.Bayes.Sequential.Coroutine`. You provide a natural transformation in the underlying monad `m`, and `sequentially` applies that natural transformation at each point of conditioning in your program. The main use case is in defining `smc`, below, but here is a nice didactic use case: Consider the program: diff --git a/flake.lock b/flake.lock index 0d7b434a..0967fb94 100644 --- a/flake.lock +++ b/flake.lock @@ -697,4 +697,4 @@ }, "root": "root", "version": 7 -} +} \ No newline at end of file diff --git a/models/HMM.hs b/models/HMM.hs index 2cd44c63..e6b1ee4d 100644 --- a/models/HMM.hs +++ b/models/HMM.hs @@ -17,88 +17,88 @@ import Pipes (MFunctor (hoist), MonadTrans (lift), each, yield, (>->)) import Pipes.Core (Producer) import qualified Pipes.Prelude as Pipes --- | Observed values -values :: [Double] -values = - [ 0.9, - 0.8, - 0.7, - 0, - -0.025, - -5, - -2, - -0.1, - 0, - 0.13, - 0.45, - 6, - 0.2, - 0.3, - -1, - -1 - ] +-- -- | Observed values +-- values :: [Double] +-- values = +-- [ 0.9, +-- 0.8, +-- 0.7, +-- 0, +-- -0.025, +-- -5, +-- -2, +-- -0.1, +-- 0, +-- 0.13, +-- 0.45, +-- 6, +-- 0.2, +-- 0.3, +-- -1, +-- -1 +-- ] --- | The transition model. -trans :: MonadSample m => Int -> m Int -trans 0 = categorical $ fromList [0.1, 0.4, 0.5] -trans 1 = categorical $ fromList [0.2, 0.6, 0.2] -trans 2 = categorical $ fromList [0.15, 0.7, 0.15] -trans _ = error "unreachable" +-- -- | The transition model. +-- trans :: MonadSample m => Int -> m Int +-- trans 0 = categorical $ fromList [0.1, 0.4, 0.5] +-- trans 1 = categorical $ fromList [0.2, 0.6, 0.2] +-- trans 2 = categorical $ fromList [0.15, 0.7, 0.15] +-- trans _ = error "unreachable" --- | The emission model. -emissionMean :: Int -> Double -emissionMean 0 = -1 -emissionMean 1 = 1 -emissionMean 2 = 0 -emissionMean _ = error "unreachable" +-- -- | The emission model. +-- emissionMean :: Int -> Double +-- emissionMean 0 = -1 +-- emissionMean 1 = 1 +-- emissionMean 2 = 0 +-- emissionMean _ = error "unreachable" --- | Initial state distribution -start :: MonadSample m => m Int -start = uniformD [0, 1, 2] +-- -- | Initial state distribution +-- start :: MonadSample m => m Int +-- start = uniformD [0, 1, 2] --- | Example HMM from http://dl.acm.org/citation.cfm?id=2804317 -hmm :: (MonadInfer m) => [Double] -> m [Int] -hmm dataset = f dataset (const . return) - where - expand x y = do - x' <- trans x - factor $ normalPdf (emissionMean x') 1 y - return x' - f [] k = start >>= k [] - f (y : ys) k = f ys (\xs x -> expand x y >>= k (x : xs)) +-- -- | Example HMM from http://dl.acm.org/citation.cfm?id=2804317 +-- hmm :: (MonadInfer m) => [Double] -> m [Int] +-- hmm dataset = f dataset (const . return) +-- where +-- expand x y = do +-- x' <- trans x +-- factor $ normalPdf (emissionMean x') 1 y +-- return x' +-- f [] k = start >>= k [] +-- f (y : ys) k = f ys (\xs x -> expand x y >>= k (x : xs)) -syntheticData :: MonadSample m => Int -> m [Double] -syntheticData n = replicateM n syntheticPoint - where - syntheticPoint = uniformD [0, 1, 2] +-- syntheticData :: MonadSample m => Int -> m [Double] +-- syntheticData n = replicateM n syntheticPoint +-- where +-- syntheticPoint = uniformD [0, 1, 2] --- | Equivalent model, but using pipes for simplicity +-- -- | Equivalent model, but using pipes for simplicity --- | Prior expressed as a stream -hmmPrior :: MonadSample m => Producer Int m b -hmmPrior = do - x <- lift start - yield x - Pipes.unfoldr (fmap (Right . (\k -> (k, k))) . trans) x +-- -- | Prior expressed as a stream +-- hmmPrior :: MonadSample m => Producer Int m b +-- hmmPrior = do +-- x <- lift start +-- yield x +-- Pipes.unfoldr (fmap (Right . (\k -> (k, k))) . trans) x --- | Observations expressed as a stream -hmmObservations :: Functor m => [a] -> Producer (Maybe a) m () -hmmObservations dataset = each (Nothing : (Just <$> reverse dataset)) +-- -- | Observations expressed as a stream +-- hmmObservations :: Functor m => [a] -> Producer (Maybe a) m () +-- hmmObservations dataset = each (Nothing : (Just <$> reverse dataset)) --- | Posterior expressed as a stream -hmmPosterior :: (MonadInfer m) => [Double] -> Producer Int m () -hmmPosterior dataset = - zipWithM - hmmLikelihood - hmmPrior - (hmmObservations dataset) - where - hmmLikelihood :: MonadCond f => (Int, Maybe Double) -> f () - hmmLikelihood (l, o) = when (isJust o) (factor $ normalPdf (emissionMean l) 1 (fromJust o)) +-- -- | Posterior expressed as a stream +-- hmmPosterior :: (MonadInfer m) => [Double] -> Producer Int m () +-- hmmPosterior dataset = +-- zipWithM +-- hmmLikelihood +-- hmmPrior +-- (hmmObservations dataset) +-- where +-- hmmLikelihood :: MonadCond f => (Int, Maybe Double) -> f () +-- hmmLikelihood (l, o) = when (isJust o) (factor $ normalPdf (emissionMean l) 1 (fromJust o)) - zipWithM f p1 p2 = Pipes.zip p1 p2 >-> Pipes.chain f >-> Pipes.map fst +-- zipWithM f p1 p2 = Pipes.zip p1 p2 >-> Pipes.chain f >-> Pipes.map fst -hmmPosteriorPredictive :: MonadSample m => [Double] -> Producer Double m () -hmmPosteriorPredictive dataset = - Pipes.hoist enumerateToDistribution (hmmPosterior dataset) - >-> Pipes.mapM (\x -> normal (emissionMean x) 1) +-- hmmPosteriorPredictive :: MonadSample m => [(Real m)] -> Producer (Real m) m () +-- hmmPosteriorPredictive dataset = +-- Pipes.hoist enumerateToDistribution (hmmPosterior dataset) +-- >-> Pipes.mapM (\x -> normal (emissionMean x) 1) diff --git a/models/LDA.hs b/models/LDA.hs index 97aa69b1..86830054 100644 --- a/models/LDA.hs +++ b/models/LDA.hs @@ -26,59 +26,59 @@ import Numeric.Log (Log (Exp)) import Text.Pretty.Simple (pPrint) import Prelude hiding (words) -vocabulary :: [Text] -vocabulary = ["bear", "wolf", "python", "prolog"] +-- vocabulary :: [Text] +-- vocabulary = ["bear", "wolf", "python", "prolog"] -topics :: [Text] -topics = ["topic1", "topic2"] +-- topics :: [Text] +-- topics = ["topic1", "topic2"] -type Documents = [[Text]] +-- type Documents = [[Text]] -documents :: Documents -documents = - [ words "bear wolf bear wolf bear wolf python wolf bear wolf", - words "python prolog python prolog python prolog python prolog python prolog", - words "bear wolf bear wolf bear wolf bear wolf bear wolf", - words "python prolog python prolog python prolog python prolog python prolog", - words "bear wolf bear python bear wolf bear wolf bear wolf" - ] +-- documents :: Documents +-- documents = +-- [ words "bear wolf bear wolf bear wolf python wolf bear wolf", +-- words "python prolog python prolog python prolog python prolog python prolog", +-- words "bear wolf bear wolf bear wolf bear wolf bear wolf", +-- words "python prolog python prolog python prolog python prolog python prolog", +-- words "bear wolf bear python bear wolf bear wolf bear wolf" +-- ] -wordDistPrior :: MonadSample m => m (V.Vector Double) -wordDistPrior = dirichlet $ V.replicate (length vocabulary) 1 +-- wordDistPrior :: MonadSample m => m (V.Vector Double) +-- wordDistPrior = dirichlet $ V.replicate (length vocabulary) 1 -topicDistPrior :: MonadSample m => m (V.Vector Double) -topicDistPrior = dirichlet $ V.replicate (length topics) 1 +-- topicDistPrior :: MonadSample m => m (V.Vector Double) +-- topicDistPrior = dirichlet $ V.replicate (length topics) 1 -wordIndex :: Map.Map Text Int -wordIndex = Map.fromList $ zip vocabulary [0 ..] +-- wordIndex :: Map.Map Text Int +-- wordIndex = Map.fromList $ zip vocabulary [0 ..] -lda :: - MonadInfer m => - Documents -> - m (Map.Map Text (V.Vector (Text, Double)), [(Text, V.Vector (Text, Double))]) -lda docs = do - word_dist_for_topic <- do - ts <- List.replicateM (length topics) wordDistPrior - return $ Map.fromList $ zip topics ts - let obs doc = do - topic_dist <- topicDistPrior - let f word = do - topic <- (fmap (topics !!) . categorical) topic_dist - factor $ (Exp . log) $ (word_dist_for_topic Map.! topic) V.! (wordIndex Map.! word) - mapM_ f doc - return topic_dist - td <- mapM obs docs - return - ( fmap (V.zip (V.fromList vocabulary)) word_dist_for_topic, - zip (fmap (foldr1 (\x y -> x <> " " <> y)) docs) (fmap (V.zip $ V.fromList ["topic1", "topic2"]) td) - ) +-- lda :: +-- MonadInfer m => +-- Documents -> +-- m (Map.Map Text (V.Vector (Text, Double)), [(Text, V.Vector (Text, Double))]) +-- lda docs = do +-- word_dist_for_topic <- do +-- ts <- List.replicateM (length topics) wordDistPrior +-- return $ Map.fromList $ zip topics ts +-- let obs doc = do +-- topic_dist <- topicDistPrior +-- let f word = do +-- topic <- (fmap (topics !!) . categorical) topic_dist +-- factor $ (Exp . log) $ (word_dist_for_topic Map.! topic) V.! (wordIndex Map.! word) +-- mapM_ f doc +-- return topic_dist +-- td <- mapM obs docs +-- return +-- ( fmap (V.zip (V.fromList vocabulary)) word_dist_for_topic, +-- zip (fmap (foldr1 (\x y -> x <> " " <> y)) docs) (fmap (V.zip $ V.fromList ["topic1", "topic2"]) td) +-- ) -syntheticData :: MonadSample m => Int -> Int -> m [[Text]] -syntheticData d w = List.replicateM d (List.replicateM w syntheticWord) - where - syntheticWord = uniformD vocabulary +-- syntheticData :: MonadSample m => Int -> Int -> m [[Text]] +-- syntheticData d w = List.replicateM d (List.replicateM w syntheticWord) +-- where +-- syntheticWord = uniformD vocabulary -runLDA :: IO () -runLDA = do - s <- sampleIOfixed $ unweighted $ mh 1000 $ lda documents - pPrint (head s) +-- runLDA :: IO () +-- runLDA = do +-- s <- sampleIOfixed $ unweighted $ mh 1000 $ lda documents +-- pPrint (head s) diff --git a/models/LogReg.hs b/models/LogReg.hs index 65bdfda9..0cf335da 100644 --- a/models/LogReg.hs +++ b/models/LogReg.hs @@ -3,7 +3,7 @@ -- Logistic regression model from Anglican -- (https://bitbucket.org/probprog/anglican-white-paper) -module LogReg (logisticRegression, syntheticData, xs, labels) where +module LogReg () where import Control.Monad (replicateM) import Control.Monad.Bayes.Class @@ -13,29 +13,29 @@ import Control.Monad.Bayes.Class ) import Numeric.Log (Log (Exp)) -logisticRegression :: MonadInfer m => [(Double, Bool)] -> m Double -logisticRegression dat = do - m <- normal 0 1 - b <- normal 0 1 - sigma <- gamma 1 1 - let y x = normal (m * x + b) sigma - sigmoid x = y x >>= \t -> return $ 1 / (1 + exp (-t)) - obs x label = do - p <- sigmoid x - factor $ (Exp . log) $ if label then p else 1 - p - mapM_ (uncurry obs) dat - sigmoid 8 +-- logisticRegression :: MonadInfer m => [(Double, Bool)] -> m Double +-- logisticRegression dat = do +-- m <- normal 0 1 +-- b <- normal 0 1 +-- sigma <- gamma 1 1 +-- let y x = normal (m * x + b) sigma +-- sigmoid x = y x >>= \t -> return $ 1 / (1 + exp (-t)) +-- obs x label = do +-- p <- sigmoid x +-- factor $ (Exp . log) $ if label then p else 1 - p +-- mapM_ (uncurry obs) dat +-- sigmoid 8 --- make a synthetic dataset by randomly choosing input-label pairs -syntheticData :: MonadSample m => Int -> m [(Double, Bool)] -syntheticData n = replicateM n do - x <- uniform (-1) 1 - label <- bernoulli 0.5 - return (x, label) +-- -- make a synthetic dataset by randomly choosing input-label pairs +-- syntheticData :: MonadSample m => Int -> m [(Double, Bool)] +-- syntheticData n = replicateM n do +-- x <- uniform (-1) 1 +-- label <- bernoulli 0.5 +-- return (x, label) --- a tiny test dataset, for sanity-checking -xs :: [Double] -xs = [-10, -5, 2, 6, 10] +-- -- a tiny test dataset, for sanity-checking +-- xs :: [Double] +-- xs = [-10, -5, 2, 6, 10] -labels :: [Bool] -labels = [False, False, True, True, True] +-- labels :: [Bool] +-- labels = [False, False, True, True, True] diff --git a/models/NestedInference.hs b/models/NestedInference.hs index 79aaa5c7..fa8abc88 100644 --- a/models/NestedInference.hs +++ b/models/NestedInference.hs @@ -6,30 +6,30 @@ import Control.Monad.Bayes.Class (MonadInfer, MonadSample (uniformD), factor) import Control.Monad.Bayes.Enumerator (mass) import Numeric.Log (Log (Exp)) -data Utterance = ASquare | AShape deriving (Eq, Show, Ord) - -data State = Square | Circle deriving (Eq, Show, Ord) - -data Action = Speak Utterance | DoNothing deriving (Eq, Show, Ord) - --- | uniformly likely to say any true utterance to convey the given state -truthfulAgent :: MonadSample m => State -> m Action -truthfulAgent state = uniformD case state of - Square -> [Speak ASquare, Speak AShape, DoNothing] - Circle -> [Speak AShape, DoNothing] - --- | a listener which applies Bayes rule to infer the state --- given an observed action of the other agent -listener :: MonadInfer m => Action -> m State -listener observedAction = do - state <- uniformD [Square, Circle] - factor $ log $ Exp $ mass (truthfulAgent state) observedAction - return state - --- | an agent which produces an action by reasoning about --- how the listener would interpret it -informativeAgent :: MonadInfer m => State -> m Action -informativeAgent state = do - utterance <- uniformD [Speak ASquare, Speak AShape, DoNothing] - factor $ log $ Exp $ mass (listener utterance) state - return utterance +-- data Utterance = ASquare | AShape deriving (Eq, Show, Ord) + +-- data State = Square | Circle deriving (Eq, Show, Ord) + +-- data Action = Speak Utterance | DoNothing deriving (Eq, Show, Ord) + +-- -- | uniformly likely to say any true utterance to convey the given state +-- truthfulAgent :: MonadSample m => State -> m Action +-- truthfulAgent state = uniformD case state of +-- Square -> [Speak ASquare, Speak AShape, DoNothing] +-- Circle -> [Speak AShape, DoNothing] + +-- -- | a listener which applies Bayes rule to infer the state +-- -- given an observed action of the other agent +-- listener :: MonadInfer m => Action -> m State +-- listener observedAction = do +-- state <- uniformD [Square, Circle] +-- factor $ log $ Exp $ mass (truthfulAgent state) observedAction +-- return state + +-- -- | an agent which produces an action by reasoning about +-- -- how the listener would interpret it +-- informativeAgent :: MonadInfer m => State -> m Action +-- informativeAgent state = do +-- utterance <- uniformD [Speak ASquare, Speak AShape, DoNothing] +-- factor $ log $ Exp $ mass (listener utterance) state +-- return utterance diff --git a/monad-bayes-site/AdvancedSampling.html b/monad-bayes-site/AdvancedSampling.html index 61baff06..b9cdb665 100644 --- a/monad-bayes-site/AdvancedSampling.html +++ b/monad-bayes-site/AdvancedSampling.html @@ -14592,7 +14592,7 @@ import Control.Monad.Bayes.Enumerator import Control.Monad.Bayes.Weighted import Control.Monad.Bayes.Sampler -import Control.Monad.Bayes.Free +import Control.Monad.Bayes.Density.Free import Control.Monad.Bayes.Population import Control.Monad.Bayes.Sequential import Control.Monad.Bayes.Inference.SMC diff --git a/monad-bayes-site/Functional_PPLs.html b/monad-bayes-site/Functional_PPLs.html index 4b3fff56..a444e2a0 100644 --- a/monad-bayes-site/Functional_PPLs.html +++ b/monad-bayes-site/Functional_PPLs.html @@ -16843,7 +16843,7 @@

Probability in a functional langua import Control.Monad.Bayes.Sampler import Control.Monad.Bayes.Integrator import Control.Monad.Bayes.Population -import Control.Monad.Bayes.Free +import Control.Monad.Bayes.Density.Free import Control.Monad.Bayes.Traced.Static import Control.Monad.Bayes.Inference.SMC diff --git a/monad-bayes-site/SMC.html b/monad-bayes-site/SMC.html index bcdbeeb9..43555376 100644 --- a/monad-bayes-site/SMC.html +++ b/monad-bayes-site/SMC.html @@ -14607,7 +14607,7 @@

Sequential Inferenceimport Control.Monad.Bayes.Enumerator import Control.Monad.Bayes.Weighted import Control.Monad.Bayes.Sampler -import Control.Monad.Bayes.Free +import Control.Monad.Bayes.Density.Free import Control.Monad.Bayes.Population import Control.Monad.Bayes.Sequential import Control.Monad.Bayes.Inference.SMC diff --git a/monad-bayes.cabal b/monad-bayes.cabal index b4e99401..723c5cfb 100644 --- a/monad-bayes.cabal +++ b/monad-bayes.cabal @@ -32,8 +32,9 @@ flag dev library exposed-modules: Control.Monad.Bayes.Class + Control.Monad.Bayes.Density.Free + Control.Monad.Bayes.Density.State Control.Monad.Bayes.Enumerator - Control.Monad.Bayes.Free Control.Monad.Bayes.Inference.MCMC Control.Monad.Bayes.Inference.PMMH Control.Monad.Bayes.Inference.RMSMC @@ -47,6 +48,7 @@ library Control.Monad.Bayes.Traced.Basic Control.Monad.Bayes.Traced.Dynamic Control.Monad.Bayes.Traced.Static + Control.Monad.Bayes.Traced.Grad Control.Monad.Bayes.Weighted Math.Integrators.StormerVerlet @@ -54,29 +56,33 @@ library other-modules: Control.Monad.Bayes.Traced.Common default-language: Haskell2010 build-depends: - base >=4.11 && <4.17 + ad + , base >=4.11 && <4.17 , containers >=0.5.10 && <0.7 , foldl - , free >=5.0.2 && <5.2 - , ieee754 ^>=0.8.0 + , free >=5.0.2 && <5.2 + , ieee754 ^>=0.8.0 , integration , lens , linear - , log-domain >=0.12 && <0.14 - , math-functions >=0.2.1 && <0.4 + , log-domain >=0.12 && <0.14 + , math-functions >=0.2.1 && <0.4 , matrix - , monad-coroutine ^>=0.9.0 - , mtl ^>=2.2.2 - , mwc-random >=0.13.6 && <0.16 + , monad-coroutine ^>=0.9.0 + , mtl ^>=2.2.2 + , mwc-random >=0.13.6 && <0.16 , pipes , primitive , random - , safe ^>=0.3.17 + , recursion-schemes + , safe ^>=0.3.17 , scientific - , statistics >=0.14.0 && <0.17 + , statistics >=0.14.0 && <0.17 , text - , transformers ^>=0.5.2 - , vector ^>=0.12.0 + , transformers ^>=0.5.2 + , vector ^>=0.12.0 + , reflection + , erf default-extensions: BlockArguments @@ -122,7 +128,7 @@ executable example if flag(dev) ghc-options: - -Wall -Wcompat -Wincomplete-record-updates + -threaded -O2 -Wall -Wcompat -Wincomplete-record-updates -Wincomplete-uni-patterns -Wnoncanonical-monad-instances else @@ -174,7 +180,6 @@ test-suite monad-bayes-test , mwc-random , pipes , pretty-simple - , profunctors , QuickCheck , random , statistics diff --git a/notebooks/.ipynb_checkpoints/AdvancedSampling-checkpoint.ipynb b/notebooks/.ipynb_checkpoints/AdvancedSampling-checkpoint.ipynb index df135650..9abc6dac 100644 --- a/notebooks/.ipynb_checkpoints/AdvancedSampling-checkpoint.ipynb +++ b/notebooks/.ipynb_checkpoints/AdvancedSampling-checkpoint.ipynb @@ -18,9 +18,9 @@ "import Control.Monad.Bayes.Enumerator\n", "import Control.Monad.Bayes.Weighted\n", "import Control.Monad.Bayes.Sampler\n", - "import Control.Monad.Bayes.Free\n", + "import Control.Monad.Bayes.Density.Free\n", "import Control.Monad.Bayes.Population\n", - "import Control.Monad.Bayes.Sequential\n", + "import Control.Monad.Bayes.Sequential.Coroutine\n", "import Control.Monad.Bayes.Inference.SMC\n", "import Control.Monad.Bayes.Inference.RMSMC\n", "import Control.Monad.Bayes.Inference.PMMH\n", diff --git a/notebooks/.ipynb_checkpoints/Functional_PPLs-checkpoint.ipynb b/notebooks/.ipynb_checkpoints/Functional_PPLs-checkpoint.ipynb index 4fa46e51..37bb79bf 100644 --- a/notebooks/.ipynb_checkpoints/Functional_PPLs-checkpoint.ipynb +++ b/notebooks/.ipynb_checkpoints/Functional_PPLs-checkpoint.ipynb @@ -321,7 +321,7 @@ "import Control.Monad.Bayes.Sampler\n", "import Control.Monad.Bayes.Integrator\n", "import Control.Monad.Bayes.Population\n", - "import Control.Monad.Bayes.Free\n", + "import Control.Monad.Bayes.Density.Free\n", "\n", ":l Plotting.hs\n", "\n", diff --git a/notebooks/.ipynb_checkpoints/SMC-checkpoint.ipynb b/notebooks/.ipynb_checkpoints/SMC-checkpoint.ipynb index fb5691ef..1c15dcbd 100644 --- a/notebooks/.ipynb_checkpoints/SMC-checkpoint.ipynb +++ b/notebooks/.ipynb_checkpoints/SMC-checkpoint.ipynb @@ -32,9 +32,9 @@ "import Control.Monad.Bayes.Enumerator\n", "import Control.Monad.Bayes.Weighted\n", "import Control.Monad.Bayes.Sampler\n", - "import Control.Monad.Bayes.Free\n", + "import Control.Monad.Bayes.Density.Free\n", "import Control.Monad.Bayes.Population\n", - "import Control.Monad.Bayes.Sequential\n", + "import Control.Monad.Bayes.Sequential.Coroutine\n", "import Control.Monad.Bayes.Inference.SMC\n", "\n", "import qualified Graphics.Vega.VegaLite as VL\n", @@ -2338,7 +2338,7 @@ "outputs": [], "source": [ "-- import Control.Monad.IO.Class\n", - "-- import Control.Monad.Bayes.Sequential\n", + "-- import Control.Monad.Bayes.Sequential.Coroutine\n", "-- import Numeric.Log\n", "-- import Control.Monad.Identity\n", "\n", diff --git a/notebooks/AdvancedSampling.ipynb b/notebooks/AdvancedSampling.ipynb index 34807cbd..12e2595a 100644 --- a/notebooks/AdvancedSampling.ipynb +++ b/notebooks/AdvancedSampling.ipynb @@ -18,9 +18,9 @@ "import Control.Monad.Bayes.Enumerator\n", "import Control.Monad.Bayes.Weighted\n", "import Control.Monad.Bayes.Sampler\n", - "import Control.Monad.Bayes.Free\n", + "import Control.Monad.Bayes.Density.Free\n", "import Control.Monad.Bayes.Population\n", - "import Control.Monad.Bayes.Sequential\n", + "import Control.Monad.Bayes.Sequential.Coroutine\n", "import Control.Monad.Bayes.Inference.SMC\n", "import Control.Monad.Bayes.Inference.RMSMC\n", "import Control.Monad.Bayes.Inference.PMMH\n", diff --git a/notebooks/Functional_PPLs.ipynb b/notebooks/Functional_PPLs.ipynb index e88b65a2..8d4569df 100644 --- a/notebooks/Functional_PPLs.ipynb +++ b/notebooks/Functional_PPLs.ipynb @@ -320,7 +320,7 @@ "import Control.Monad.Bayes.Sampler\n", "import Control.Monad.Bayes.Integrator\n", "import Control.Monad.Bayes.Population\n", - "import Control.Monad.Bayes.Free\n", + "import Control.Monad.Bayes.Density.Free\n", "\n", ":l Plotting.hs\n", "\n", diff --git a/notebooks/SMC.ipynb b/notebooks/SMC.ipynb index fb5691ef..1c15dcbd 100644 --- a/notebooks/SMC.ipynb +++ b/notebooks/SMC.ipynb @@ -32,9 +32,9 @@ "import Control.Monad.Bayes.Enumerator\n", "import Control.Monad.Bayes.Weighted\n", "import Control.Monad.Bayes.Sampler\n", - "import Control.Monad.Bayes.Free\n", + "import Control.Monad.Bayes.Density.Free\n", "import Control.Monad.Bayes.Population\n", - "import Control.Monad.Bayes.Sequential\n", + "import Control.Monad.Bayes.Sequential.Coroutine\n", "import Control.Monad.Bayes.Inference.SMC\n", "\n", "import qualified Graphics.Vega.VegaLite as VL\n", @@ -2338,7 +2338,7 @@ "outputs": [], "source": [ "-- import Control.Monad.IO.Class\n", - "-- import Control.Monad.Bayes.Sequential\n", + "-- import Control.Monad.Bayes.Sequential.Coroutine\n", "-- import Numeric.Log\n", "-- import Control.Monad.Identity\n", "\n", diff --git a/src/Control/Monad/Bayes/Class.hs b/src/Control/Monad/Bayes/Class.hs index a3c8453b..96bce91d 100644 --- a/src/Control/Monad/Bayes/Class.hs +++ b/src/Control/Monad/Bayes/Class.hs @@ -1,5 +1,6 @@ {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE TypeFamilies #-} {-# OPTIONS_GHC -Wno-deprecations #-} -- | @@ -36,7 +37,7 @@ -- return rain -- @ module Control.Monad.Bayes.Class - ( MonadSample, + ( MonadSample (..), random, uniform, normal, @@ -50,6 +51,7 @@ module Control.Monad.Bayes.Class poisson, dirichlet, MonadCond, + CustomReal, score, factor, condition, @@ -74,6 +76,7 @@ import Control.Monad.Trans.List (ListT) import Control.Monad.Trans.Reader (ReaderT) import Control.Monad.Trans.State (StateT) import Control.Monad.Trans.Writer (WriterT) +import Data.Kind (Type) import Data.Matrix hiding ((!)) import Data.Vector qualified as V import Data.Vector.Generic as VG (Vector, map, mapM, null, sum, (!)) @@ -88,76 +91,89 @@ import Statistics.Distribution.Geometric (geometric0) import Statistics.Distribution.Normal (normalDistr) import Statistics.Distribution.Poisson qualified as Poisson import Statistics.Distribution.Uniform (uniformDistr) +import Prelude hiding (Real) +import Data.Number.Erf (InvErf (inverf)) --- | Monads that can draw random variables. -class Monad m => MonadSample m where - -- | Draw from a uniform distribution. +type CustomReal a = (RealFloat a, InvErf a) + +class (CustomReal (Real m), Monad m) => MonadSample m where + type Real m :: Type + + -- random :: m (Real m) + + -- Monads that can draw random variables. + -- class Monad m => MonadSample m where + -- | Draw from a uniform distribution. random :: -- | \(\sim \mathcal{U}(0, 1)\) - m Double + m (Real m) -- | Draw from a uniform distribution. uniform :: -- | lower bound a - Double -> + (Real m) -> -- | upper bound b - Double -> + (Real m) -> -- | \(\sim \mathcal{U}(a, b)\). - m Double - uniform a b = draw (uniformDistr a b) + m (Real m) + + -- uniform a b = draw (uniformDistr a b) -- | Draw from a normal distribution. normal :: -- | mean μ - Double -> + (Real m) -> -- | standard deviation σ - Double -> + (Real m) -> -- | \(\sim \mathcal{N}(\mu, \sigma^2)\) - m Double - normal m s = draw (normalDistr m s) + m (Real m) + normal m s = inverf <$> random -- draw (normalDistr m s) -- | Draw from a gamma distribution. gamma :: -- | shape k - Double -> + (Real m) -> -- | scale θ - Double -> + (Real m) -> -- | \(\sim \Gamma(k, \theta)\) - m Double - gamma shape scale = draw (gammaDistr shape scale) + m (Real m) + + -- gamma shape scale = draw (gammaDistr shape scale) -- | Draw from a beta distribution. beta :: -- | shape α - Double -> + (Real m) -> -- | shape β - Double -> + (Real m) -> -- | \(\sim \mathrm{Beta}(\alpha, \beta)\) - m Double - beta a b = draw (betaDistr a b) + m (Real m) + + -- beta a b = draw (betaDistr a b) -- | Draw from a Bernoulli distribution. bernoulli :: -- | probability p - Double -> + Real m -> -- | \(\sim \mathrm{B}(1, p)\) m Bool bernoulli p = fmap (< p) random -- | Draw from a categorical distribution. categorical :: - Vector v Double => + Vector v (Real m) => -- | event probabilities - v Double -> + v (Real m) -> -- | outcome category m Int - categorical ps = if VG.null ps then error "empty input list" else fromPMF (ps !) + + -- categorical ps = if VG.null ps then error "empty input list" else fromPMF (ps !) -- | Draw from a categorical distribution in the log domain. logCategorical :: - (Vector v (Log Double), Vector v Double) => + (Vector v (Log (Real m)), Vector v (Real m)) => -- | event probabilities - v (Log Double) -> + v (Log (Real m)) -> -- | outcome category m Int logCategorical = categorical . VG.map (exp . ln) @@ -179,7 +195,8 @@ class Monad m => MonadSample m where Double -> -- | \(\sim\) number of failed Bernoulli trials with success probability p before first success m Int - geometric = discrete . geometric0 + + -- geometric = discrete . geometric0 -- | Draw from a Poisson distribution. poisson :: @@ -187,15 +204,16 @@ class Monad m => MonadSample m where Double -> -- | \(\sim \mathrm{Pois}(\lambda)\) m Int - poisson = discrete . Poisson.poisson + + -- poisson = discrete . Poisson.poisson -- | Draw from a Dirichlet distribution. dirichlet :: - Vector v Double => + Vector v (Real m) => -- | concentration parameters @as@ - v Double -> + v (Real m) -> -- | \(\sim \mathrm{Dir}(\mathrm{as})\) - m (v Double) + m (v (Real m)) dirichlet as = do xs <- VG.mapM (`gamma` 1) as let s = VG.sum xs @@ -204,12 +222,12 @@ class Monad m => MonadSample m where -- | Draw from a continuous distribution using the inverse cumulative density -- function. -draw :: (ContDistr d, MonadSample m) => d -> m Double +draw :: (ContDistr d, MonadSample m, Real m ~ Double) => d -> m Double draw d = fmap (quantile d) random -- | Draw from a discrete distribution using a sequence of draws from -- Bernoulli. -fromPMF :: MonadSample m => (Int -> Double) -> m Int +fromPMF :: (MonadSample m, Real m ~ Double) => (Int -> Double) -> m Int fromPMF p = f 0 1 where f i r = do @@ -220,7 +238,7 @@ fromPMF p = f 0 1 if b then pure i else f (i + 1) (r - q) -- | Draw from a discrete distributions using the probability mass function. -discrete :: (DiscreteDistr d, MonadSample m) => d -> m Int +discrete :: (DiscreteDistr d, MonadSample m, Real m ~ Double) => d -> m Int discrete = fromPMF . probability -- | Monads that can score different execution paths. @@ -228,19 +246,19 @@ class Monad m => MonadCond m where -- | Record a likelihood. score :: -- | likelihood of the execution path - Log Double -> + Log (Real m) -> m () -- | Synonym for 'score'. factor :: MonadCond m => -- | likelihood of the execution path - Log Double -> + Log (Real m) -> m () factor = score -- | Hard conditioning. -condition :: MonadCond m => Bool -> m () +condition :: (MonadCond m, RealFloat (Real m)) => Bool -> m () condition b = score $ if b then 1 else 0 independent :: Applicative m => Int -> m a -> m [a] @@ -264,7 +282,7 @@ normalPdf mu sigma x = Exp $ logDensity (normalDistr mu sigma) x -------------------- -- | multivariate normal -mvNormal :: MonadSample m => V.Vector Double -> Matrix Double -> m (V.Vector Double) +mvNormal :: MonadSample m => V.Vector (Real m) -> Matrix (Real m) -> m (V.Vector (Real m)) mvNormal mu bigSigma = do let n = length mu ss <- replicateM n (normal 0 1) @@ -276,7 +294,7 @@ mvNormal mu bigSigma = do data Bayesian m z o = Bayesian { latent :: m z, -- prior over latent variable Z generative :: z -> m o, -- distribution over observations given Z=z - likelihood :: z -> o -> Log Double -- p(o|z) + likelihood :: z -> o -> Log (Real m) -- p(o|z) } posterior :: (MonadInfer m, Foldable f, Functor f) => Bayesian m z o -> f o -> m z @@ -299,6 +317,7 @@ posteriorPredictive bm os = posterior bm os >>= generative bm -- Instances that lift probabilistic effects to standard tranformers. instance MonadSample m => MonadSample (IdentityT m) where + type Real (IdentityT m) = Real m random = lift random bernoulli = lift . bernoulli @@ -308,6 +327,7 @@ instance MonadCond m => MonadCond (IdentityT m) where instance MonadInfer m => MonadInfer (IdentityT m) instance MonadSample m => MonadSample (ExceptT e m) where + type Real (ExceptT e m) = Real m random = lift random uniformD = lift . uniformD @@ -317,6 +337,7 @@ instance MonadCond m => MonadCond (ExceptT e m) where instance MonadInfer m => MonadInfer (ExceptT e m) instance MonadSample m => MonadSample (ReaderT r m) where + type Real (ReaderT r m) = Real m random = lift random bernoulli = lift . bernoulli @@ -326,6 +347,7 @@ instance MonadCond m => MonadCond (ReaderT r m) where instance MonadInfer m => MonadInfer (ReaderT r m) instance (Monoid w, MonadSample m) => MonadSample (WriterT w m) where + type Real (WriterT w m) = Real m random = lift random bernoulli = lift . bernoulli categorical = lift . categorical @@ -336,6 +358,7 @@ instance (Monoid w, MonadCond m) => MonadCond (WriterT w m) where instance (Monoid w, MonadInfer m) => MonadInfer (WriterT w m) instance MonadSample m => MonadSample (StateT s m) where + type Real (StateT s m) = Real m random = lift random bernoulli = lift . bernoulli categorical = lift . categorical @@ -347,6 +370,7 @@ instance MonadCond m => MonadCond (StateT s m) where instance MonadInfer m => MonadInfer (StateT s m) instance MonadSample m => MonadSample (ListT m) where + type Real (ListT m) = Real m random = lift random bernoulli = lift . bernoulli categorical = lift . categorical @@ -357,6 +381,7 @@ instance MonadCond m => MonadCond (ListT m) where instance MonadInfer m => MonadInfer (ListT m) instance MonadSample m => MonadSample (ContT r m) where + type Real (ContT r m) = Real m random = lift random instance MonadCond m => MonadCond (ContT r m) where diff --git a/src/Control/Monad/Bayes/Free.hs b/src/Control/Monad/Bayes/Density/Free.hs similarity index 54% rename from src/Control/Monad/Bayes/Free.hs rename to src/Control/Monad/Bayes/Density/Free.hs index 7c59010c..22617bda 100644 --- a/src/Control/Monad/Bayes/Free.hs +++ b/src/Control/Monad/Bayes/Density/Free.hs @@ -2,9 +2,11 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} -- | --- Module : Control.Monad.Bayes.Free +-- Module : Control.Monad.Bayes.Density.Free -- Description : Free monad transformer over random sampling -- Copyright : (c) Adam Scibior, 2015-2020 -- License : MIT @@ -12,67 +14,71 @@ -- Stability : experimental -- Portability : GHC -- --- 'FreeSampler' is a free monad transformer over random sampling. -module Control.Monad.Bayes.Free - ( FreeSampler, +-- 'Density' is a free monad transformer over random sampling. +module Control.Monad.Bayes.Density.Free + ( Density, hoist, interpret, withRandomness, - withPartialRandomness, + density, traced, ) where -import Control.Monad.Bayes.Class (MonadSample (random)) +import Control.Monad.Bayes.Class import Control.Monad.State (evalStateT, get, put) import Control.Monad.Trans (MonadTrans (..)) import Control.Monad.Trans.Free.Church (FT, MonadFree (..), hoistFT, iterT, iterTM, liftF) import Control.Monad.Writer (WriterT (..), tell) import Data.Functor.Identity (Identity, runIdentity) +import Prelude hiding (Real) -- | Random sampling functor. -newtype SamF a = Random (Double -> a) +newtype SamF n a = Random (n -> a) -instance Functor SamF where +instance Functor (SamF n) where fmap f (Random k) = Random (f . k) -- | Free monad transformer over random sampling. -- -- Uses the Church-encoded version of the free monad for efficiency. -newtype FreeSampler m a = FreeSampler {runFreeSampler :: FT SamF m a} - deriving newtype (Functor, Applicative, Monad, MonadTrans) +newtype Density m a = Density {runDensity :: FT (SamF (Real m)) m a} + deriving newtype (Functor, Applicative, Monad) -instance MonadFree SamF (FreeSampler m) where - wrap = FreeSampler . wrap . fmap runFreeSampler +instance MonadTrans (Density) -instance Monad m => MonadSample (FreeSampler m) where - random = FreeSampler $ liftF (Random id) +instance (Real m ~ n, CustomReal (n)) => MonadFree (SamF n) (Density m) where + wrap = Density . wrap . fmap runDensity --- | Hoist 'FreeSampler' through a monad transform. -hoist :: (Monad m, Monad n) => (forall x. m x -> n x) -> FreeSampler m a -> FreeSampler n a -hoist f (FreeSampler m) = FreeSampler (hoistFT f m) +instance (Monad m, CustomReal (Real m)) => MonadSample (Density m) where + type Real (Density m) = Real m + random = Density $ liftF (Random id) + +-- | Hoist 'Density' through a monad transform. +hoist :: (Real m ~ Real n, Monad m, Monad n) => (forall x. m x -> n x) -> Density m a -> Density n a +hoist f (Density m) = Density (hoistFT f m) -- | Execute random sampling in the transformed monad. -interpret :: MonadSample m => FreeSampler m a -> m a -interpret (FreeSampler m) = iterT f m +interpret :: MonadSample m => Density m a -> m a +interpret (Density m) = iterT f m where f (Random k) = random >>= k -- | Execute computation with supplied values for random choices. -withRandomness :: Monad m => [Double] -> FreeSampler m a -> m a -withRandomness randomness (FreeSampler m) = evalStateT (iterTM f m) randomness +withRandomness :: Monad m => [Real m] -> Density m a -> m a +withRandomness randomness (Density m) = evalStateT (iterTM f m) randomness where f (Random k) = do xs <- get case xs of - [] -> error "FreeSampler: the list of randomness was too short" + [] -> error "Density: the list of randomness was too short" y : ys -> put ys >> k y -- | Execute computation with supplied values for a subset of random choices. -- Return the output value and a record of all random choices used, whether -- taken as input or drawn using the transformed monad. -withPartialRandomness :: MonadSample m => [Double] -> FreeSampler m a -> m (a, [Double]) -withPartialRandomness randomness (FreeSampler m) = +density :: MonadSample m => [Real m] -> Density m a -> m (a, [Real m]) +density randomness (Density m) = runWriterT $ evalStateT (iterTM f $ hoistFT lift m) randomness where f (Random k) = do @@ -87,5 +93,5 @@ withPartialRandomness randomness (FreeSampler m) = k x -- | Like 'withPartialRandomness', but use an arbitrary sampling monad. -traced :: MonadSample m => [Double] -> FreeSampler Identity a -> m (a, [Double]) -traced randomness m = withPartialRandomness randomness $ hoist (return . runIdentity) m +traced :: MonadSample m => [Real m] -> Density Identity a -> m (a, [Real m]) +traced randomness m = undefined -- withPartialRandomness randomness $ hoist (return . runIdentity) m diff --git a/src/Control/Monad/Bayes/Density/State.hs b/src/Control/Monad/Bayes/Density/State.hs new file mode 100644 index 00000000..e66d91b9 --- /dev/null +++ b/src/Control/Monad/Bayes/Density/State.hs @@ -0,0 +1,47 @@ +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} + +-- | +-- slower than Control.Monad.Bayes.Density.Free, +-- but much more elementary to understand. Just uses standard +-- monad transformer techniques. +-- @ +module Control.Monad.Bayes.Density.State where + +import Control.Monad.Bayes.Class +import Control.Monad.State (MonadState (get, put), MonadTrans, StateT, evalStateT) +import Control.Monad.Trans.Class (MonadTrans (lift)) +import Control.Monad.Trans.Writer.Strict (WriterT, runWriterT) +import Control.Monad.Writer (MonadWriter, tell) +import Prelude hiding (Real) +import Numeric.AD + +newtype Density m a = Density (WriterT [Real m] (StateT [Real m] m) a) deriving newtype (Functor, Applicative, Monad) + +instance MonadTrans Density where + lift = Density . lift . lift + +instance (Real m ~ n, Monad m) => MonadState [n] (Density m) where + get = Density $ lift $ get + put = Density . lift . put + +instance (Real m ~ n, Monad m) => MonadWriter [n] (Density m) where + tell = Density . tell + +instance (MonadSample m) => MonadSample (Density m) where + type (Real (Density m)) = Real m + random = do + trace <- get + x <- case trace of + [] -> random + r : xs -> put xs >> pure r + tell [x] + pure x + +density :: Monad m => Density m b -> [Real m] -> m (b, [Real m]) +density (Density m) = evalStateT (runWriterT m) + diff --git a/src/Control/Monad/Bayes/Enumerator.hs b/src/Control/Monad/Bayes/Enumerator.hs index 6c7b3c94..08b1b13c 100644 --- a/src/Control/Monad/Bayes/Enumerator.hs +++ b/src/Control/Monad/Bayes/Enumerator.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE TypeFamilies #-} -- | -- Module : Control.Monad.Bayes.Enumerator @@ -34,10 +35,6 @@ import Control.Applicative (Alternative) import Control.Arrow (second) import Control.Monad (MonadPlus) import Control.Monad.Bayes.Class - ( MonadCond (..), - MonadInfer, - MonadSample (bernoulli, categorical, logCategorical, random), - ) import Control.Monad.Trans.Writer (WriterT (..)) import Data.AEq (AEq, (===), (~==)) import Data.List (sortOn) @@ -48,6 +45,7 @@ import Data.Ord (Down (Down)) import Data.Vector qualified as VV import Data.Vector.Generic qualified as V import Numeric.Log as Log (Log (..), sum) +import Prelude hiding (Real) -- | An exact inference transformer that integrates -- discrete random variables by enumerating all execution paths. @@ -55,6 +53,7 @@ newtype Enumerator a = Enumerator (WriterT (Product (Log Double)) [] a) deriving newtype (Functor, Applicative, Monad, Alternative, MonadPlus) instance MonadSample Enumerator where + type Real Enumerator = Double random = error "Infinitely supported random variables not supported in Enumerator" bernoulli p = fromList [(True, (Exp . log) p), (False, (Exp . log) (1 - p))] categorical v = fromList $ zip [0 ..] $ map (Exp . log) (V.toList v) @@ -131,7 +130,7 @@ toEmpirical ls = normalizeWeights $ compact (zip ls (repeat 1)) toEmpiricalWeighted :: (Fractional b, Ord a, Ord b) => [(a, b)] -> [(a, b)] toEmpiricalWeighted = normalizeWeights . compact -enumerateToDistribution :: (MonadSample n) => Enumerator a -> n a +enumerateToDistribution :: (MonadSample n, Real n ~ Double) => Enumerator a -> n a enumerateToDistribution model = do let samples = logExplicit model let (support, logprobs) = unzip samples diff --git a/src/Control/Monad/Bayes/Inference/PMMH.hs b/src/Control/Monad/Bayes/Inference/PMMH.hs index 5f745c76..b3694090 100644 --- a/src/Control/Monad/Bayes/Inference/PMMH.hs +++ b/src/Control/Monad/Bayes/Inference/PMMH.hs @@ -18,7 +18,7 @@ module Control.Monad.Bayes.Inference.PMMH ) where -import Control.Monad.Bayes.Class (Bayesian (generative), MonadInfer, MonadSample, latent) +import Control.Monad.Bayes.Class import Control.Monad.Bayes.Inference.MCMC (MCMCConfig, mcmc) import Control.Monad.Bayes.Inference.SMC (SMCConfig (SMCConfig, numParticles, numSteps, resampler), smc) import Control.Monad.Bayes.Population as Pop @@ -33,6 +33,7 @@ import Control.Monad.Bayes.Traced.Static (Traced) import Control.Monad.Bayes.Weighted import Control.Monad.Trans (lift) import Numeric.Log (Log) +import Prelude hiding (Real) -- | Particle Marginal Metropolis-Hastings sampling. pmmh :: @@ -41,7 +42,7 @@ pmmh :: SMCConfig (Weighted m) -> Traced (Weighted m) a1 -> (a1 -> Sequential (Population (Weighted m)) a2) -> - m [[(a2, Log Double)]] + m [[(a2, Log (Real m))]] pmmh mcmcConf smcConf param model = mcmc mcmcConf @@ -59,5 +60,5 @@ pmmhBayesianModel :: MCMCConfig -> SMCConfig (Weighted m) -> (forall m. MonadInfer m => Bayesian m a1 a2) -> - m [[(a2, Log Double)]] + m [[(a2, Log (Real m))]] pmmhBayesianModel mcmcConf smcConf bm = pmmh mcmcConf smcConf (latent bm) (generative bm) diff --git a/src/Control/Monad/Bayes/Inference/RMSMC.hs b/src/Control/Monad/Bayes/Inference/RMSMC.hs index b0105551..9cb6de0b 100644 --- a/src/Control/Monad/Bayes/Inference/RMSMC.hs +++ b/src/Control/Monad/Bayes/Inference/RMSMC.hs @@ -22,7 +22,7 @@ where import Control.Monad.Bayes.Class (MonadSample) import Control.Monad.Bayes.Inference.MCMC (MCMCConfig (..)) -import Control.Monad.Bayes.Inference.SMC (SMCConfig (..)) +import Control.Monad.Bayes.Inference.SMC import Control.Monad.Bayes.Population ( Population, spawn, diff --git a/src/Control/Monad/Bayes/Inference/SMC2.hs b/src/Control/Monad/Bayes/Inference/SMC2.hs index 5dd80297..a3cf0c24 100644 --- a/src/Control/Monad/Bayes/Inference/SMC2.hs +++ b/src/Control/Monad/Bayes/Inference/SMC2.hs @@ -1,5 +1,7 @@ {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} -- | -- Module : Control.Monad.Bayes.Inference.SMC2 @@ -19,10 +21,6 @@ module Control.Monad.Bayes.Inference.SMC2 where import Control.Monad.Bayes.Class - ( MonadCond (..), - MonadInfer, - MonadSample (random), - ) import Control.Monad.Bayes.Inference.MCMC import Control.Monad.Bayes.Inference.RMSMC (rmsmc) import Control.Monad.Bayes.Inference.SMC (SMCConfig (SMCConfig, numParticles, numSteps, resampler), smcPush) @@ -31,28 +29,34 @@ import Control.Monad.Bayes.Sequential (Sequential) import Control.Monad.Bayes.Traced import Control.Monad.Trans (MonadTrans (..)) import Numeric.Log (Log) +import Prelude hiding (Real) -- | Helper monad transformer for preprocessing the model for 'smc2'. newtype SMC2 m a = SMC2 (Sequential (Traced (Population m)) a) - deriving newtype (Functor, Applicative, Monad) + deriving newtype (Functor) setup :: SMC2 m a -> Sequential (Traced (Population m)) a setup (SMC2 m) = m instance MonadTrans SMC2 where - lift = SMC2 . lift . lift . lift + lift = undefined -- SMC2 . lift . lift . lift + +instance Monad m => (Applicative (SMC2 m)) + +instance Monad m => (Monad (SMC2 m)) instance MonadSample m => MonadSample (SMC2 m) where + type Real (SMC2 m) = Real m random = lift random -instance Monad m => MonadCond (SMC2 m) where +instance (Monad m, RealFloat (Real m), MonadCond m) => MonadCond (SMC2 m) where score = SMC2 . score -instance MonadSample m => MonadInfer (SMC2 m) +instance MonadInfer m => MonadInfer (SMC2 m) -- | Sequential Monte Carlo squared. smc2 :: - MonadSample m => + (MonadSample m, MonadInfer m) => -- | number of time steps Int -> -- | number of inner particles @@ -65,7 +69,7 @@ smc2 :: Sequential (Traced (Population m)) b -> -- | model (b -> Sequential (Population (SMC2 m)) a) -> - Population m [(a, Log Double)] + Population m [(a, Log (Real m))] smc2 k n p t param model = rmsmc MCMCConfig {numMCMCSteps = t, proposal = SingleSiteMH, numBurnIn = 0} diff --git a/src/Control/Monad/Bayes/Integrator.hs b/src/Control/Monad/Bayes/Integrator.hs index 8b8bb830..5cf32485 100644 --- a/src/Control/Monad/Bayes/Integrator.hs +++ b/src/Control/Monad/Bayes/Integrator.hs @@ -3,6 +3,7 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE TypeFamilies #-} {-# OPTIONS_GHC -Wno-type-defaults #-} {-# OPTIONS_GHC -Wno-unused-top-binds #-} @@ -34,7 +35,7 @@ where import Control.Applicative (Applicative (..)) import Control.Foldl (Fold) import Control.Foldl qualified as Foldl -import Control.Monad.Bayes.Class (MonadSample (bernoulli, random, uniformD)) +import Control.Monad.Bayes.Class import Control.Monad.Bayes.Weighted (Weighted, weighted) import Control.Monad.Trans.Cont ( Cont, @@ -47,7 +48,7 @@ import Data.Set (Set, elems) import Data.Text qualified as T import Numeric.Integration.TanhSinh (Result (result), trap) import Numeric.Log (Log (ln)) -import Statistics.Distribution (density) +import Statistics.Distribution qualified as Statistics import Statistics.Distribution.Uniform qualified as Statistics newtype Integrator a = Integrator {getCont :: Cont Double a} @@ -58,7 +59,8 @@ integrator f (Integrator a) = runCont a f runIntegrator = integrator instance MonadSample Integrator where - random = fromDensityFunction $ density $ Statistics.uniformDistr 0 1 + type Real Integrator = Double + random = fromDensityFunction $ Statistics.density $ Statistics.uniformDistr 0 1 bernoulli p = Integrator $ cont (\f -> p * f True + (1 - p) * f False) uniformD ls = fromMassFunction (const (1 / fromIntegral (length ls))) ls diff --git a/src/Control/Monad/Bayes/Population.hs b/src/Control/Monad/Bayes/Population.hs index 756d7b48..b825debb 100644 --- a/src/Control/Monad/Bayes/Population.hs +++ b/src/Control/Monad/Bayes/Population.hs @@ -2,6 +2,7 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeFamilies #-} {-# OPTIONS_GHC -Wno-deprecations #-} -- | @@ -41,11 +42,6 @@ where import Control.Arrow (second) import Control.Monad (replicateM) import Control.Monad.Bayes.Class - ( MonadCond, - MonadInfer, - MonadSample (categorical, logCategorical, random, uniform), - factor, - ) import Control.Monad.Bayes.Weighted ( Weighted, applyWeight, @@ -62,45 +58,53 @@ import Data.Vector ((!)) import Data.Vector qualified as V import Numeric.Log (Log, ln, sum) import Numeric.Log qualified as Log -import Prelude hiding (all, sum) +import Prelude hiding (Real, all, sum) -- | A collection of weighted samples, or particles. newtype Population m a = Population (Weighted (ListT m) a) - deriving newtype (Functor, Applicative, Monad, MonadIO, MonadSample, MonadCond, MonadInfer) + deriving newtype (Functor, Applicative, Monad, MonadIO) + +instance MonadSample m => MonadSample (Population m) where + type Real (Population m) = Real m + +instance MonadCond m => MonadCond (Population m) where + score = lift . score + +instance MonadInfer m => MonadInfer (Population m) instance MonadTrans Population where lift = Population . lift . lift -- | Explicit representation of the weighted sample with weights in the log -- domain. -population, runPopulation :: Population m a -> m [(a, Log Double)] +population, runPopulation :: RealFloat (Real m) => Population m a -> m [(a, Log (Real m))] population (Population m) = runListT $ weighted m -- | deprecated synonym runPopulation = population -- | Explicit representation of the weighted sample. -explicitPopulation :: Functor m => Population m a -> m [(a, Double)] +explicitPopulation :: (Floating (Real m), Functor m, RealFloat (Real m)) => Population m a -> m [(a, Real m)] explicitPopulation = fmap (map (second (exp . ln))) . population -- | Initialize 'Population' with a concrete weighted sample. -fromWeightedList :: Monad m => m [(a, Log Double)] -> Population m a +fromWeightedList :: (Monad m, RealFloat (Real m)) => m [(a, Log (Real m))] -> Population m a fromWeightedList = Population . withWeight . ListT -- | Increase the sample size by a given factor. -- The weights are adjusted such that their sum is preserved. -- It is therefore safe to use 'spawn' in arbitrary places in the program -- without introducing bias. -spawn :: Monad m => Int -> Population m () +spawn :: (Monad m, RealFloat (Real m)) => Int -> Population m () spawn n = fromWeightedList $ pure $ replicate n ((), 1 / fromIntegral n) -withParticles :: Monad m => Int -> Population m a -> Population m a +withParticles :: (Monad m, RealFloat (Real m)) => Int -> Population m a -> Population m a withParticles n = (spawn n >>) resampleGeneric :: MonadSample m => -- | resampler - (V.Vector Double -> m [Int]) -> + (V.Vector (Real m) -> m [Int]) -> Population m a -> Population m a resampleGeneric resampler m = fromWeightedList $ do @@ -135,7 +139,7 @@ resampleGeneric resampler m = fromWeightedList $ do -- Q^{(m)}=\sum_{k=1}^{m} w^{(k)} -- \] -- and \(w^{(k)}\) are the weights. See also [Comparison of Resampling Schemes for Particle Filtering](https://arxiv.org/abs/cs/0507025). -systematic :: Double -> V.Vector Double -> [Int] +systematic :: RealFloat n => n -> V.Vector n -> [Int] systematic u ps = f 0 (u / fromIntegral n) 0 0 [] where prob i = ps V.! i @@ -172,7 +176,7 @@ resampleSystematic = resampleGeneric (\ps -> (`systematic` ps) <$> random) -- and \(w^{(k)}\) are the weights. -- -- The conditional variance of stratified sampling is always smaller than that of multinomial sampling and it is also unbiased - see [Comparison of Resampling Schemes for Particle Filtering](https://arxiv.org/abs/cs/0507025). -stratified :: MonadSample m => V.Vector Double -> m [Int] +stratified :: MonadSample m => V.Vector (Real m) -> m [Int] stratified weights = do let bigN = V.length weights dithers <- V.replicateM bigN (uniform 0.0 1.0) @@ -200,7 +204,7 @@ resampleStratified = resampleGeneric stratified -- | Multinomial sampler. Sample from \(0, \ldots, n - 1\) \(n\) -- times drawn at random according to the weights where \(n\) is the -- length of vector of weights. -multinomial :: MonadSample m => V.Vector Double -> m [Int] +multinomial :: MonadSample m => V.Vector (Real m) -> m [Int] multinomial ps = replicateM (V.length ps) (categorical ps) -- | Resample the population using the underlying monad and a multinomial resampling scheme. @@ -214,7 +218,7 @@ resampleMultinomial = resampleGeneric multinomial -- | Separate the sum of weights into the 'Weighted' transformer. -- Weights are normalized after this operation. extractEvidence :: - Monad m => + (Monad m, RealFloat (Real m)) => Population m a -> Population (Weighted m) a extractEvidence m = fromWeightedList $ do @@ -228,7 +232,7 @@ extractEvidence m = fromWeightedList $ do -- | Push the evidence estimator as a score to the transformed monad. -- Weights are normalized after this operation. pushEvidence :: - MonadCond m => + (MonadCond m, RealFloat (Real m)) => Population m a -> Population m a pushEvidence = hoist applyWeight . extractEvidence @@ -247,7 +251,7 @@ proper m = do return x -- | Model evidence estimator, also known as pseudo-marginal likelihood. -evidence :: (Monad m) => Population m a -> m (Log Double) +evidence :: (Monad m, RealFloat (Real m)) => Population m a -> m (Log (Real m)) evidence = extractWeight . population . extractEvidence -- | Picks one point from the population and uses model evidence as a 'score' @@ -261,7 +265,7 @@ collapse :: collapse = applyWeight . proper -- | Population average of a function, computed using unnormalized weights. -popAvg :: (Monad m) => (a -> Double) -> Population m a -> m Double +popAvg :: (Monad m, Floating (Real m), RealFloat (Real m)) => (a -> (Real m)) -> Population m a -> m (Real m) popAvg f p = do xs <- explicitPopulation p let ys = map (\(x, w) -> f x * w) xs @@ -270,7 +274,7 @@ popAvg f p = do -- | Applies a transformation to the inner monad. hoist :: - Monad n => + (Real m ~ Real n, Monad n, RealFloat (Real n)) => (forall x. m x -> n x) -> Population m a -> Population n a diff --git a/src/Control/Monad/Bayes/Sampler.hs b/src/Control/Monad/Bayes/Sampler.hs index 7542e789..a8e4b400 100644 --- a/src/Control/Monad/Bayes/Sampler.hs +++ b/src/Control/Monad/Bayes/Sampler.hs @@ -3,6 +3,7 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE TypeFamilies #-} -- | -- Module : Control.Monad.Bayes.Sampler @@ -30,17 +31,6 @@ where import Control.Foldl qualified as F hiding (random) import Control.Monad.Bayes.Class - ( MonadSample - ( bernoulli, - beta, - categorical, - gamma, - geometric, - normal, - random, - uniform - ), - ) import Control.Monad.IO.Class import Control.Monad.ST (ST) import Control.Monad.Trans.Reader (ReaderT (..), runReaderT) @@ -63,6 +53,7 @@ type SamplerIO = Sampler (IOGenM StdGen) IO type SamplerST s = Sampler (STGenM StdGen s) (ST s) instance StatefulGen g m => MonadSample (Sampler g m) where + type Real (Sampler g m) = Double random = Sampler (ReaderT uniformDouble01M) uniform a b = Sampler (ReaderT $ uniformRM (a, b)) diff --git a/src/Control/Monad/Bayes/Sequential.hs b/src/Control/Monad/Bayes/Sequential.hs index 0ae92ad9..36d3eea3 100644 --- a/src/Control/Monad/Bayes/Sequential.hs +++ b/src/Control/Monad/Bayes/Sequential.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeFamilies #-} -- | -- Module : Control.Monad.Bayes.Sequential @@ -26,10 +27,6 @@ module Control.Monad.Bayes.Sequential where import Control.Monad.Bayes.Class - ( MonadCond (..), - MonadInfer, - MonadSample (bernoulli, categorical, random), - ) import Control.Monad.Coroutine ( Coroutine (..), bounce, @@ -42,6 +39,7 @@ import Control.Monad.Coroutine.SuspensionFunctors ) import Control.Monad.Trans (MonadIO, MonadTrans (..)) import Data.Either (isRight) +import Prelude hiding (Real) -- | Represents a computation that can be suspended at certain points. -- The intermediate monadic effects can be extracted, which is particularly @@ -55,6 +53,7 @@ extract :: Await () a -> a extract (Await f) = f () instance MonadSample m => MonadSample (Sequential m) where + type Real (Sequential m) = Real m random = lift random bernoulli = lift . bernoulli categorical = lift . categorical @@ -115,5 +114,5 @@ sequentially, m a sequentially f k = finish . composeCopies k (advance . hoistFirst f) --- | deprecated synonym +-- | synonym sis = sequentially diff --git a/src/Control/Monad/Bayes/Traced/Basic.hs b/src/Control/Monad/Bayes/Traced/Basic.hs index e8c6a22f..3188136e 100644 --- a/src/Control/Monad/Bayes/Traced/Basic.hs +++ b/src/Control/Monad/Bayes/Traced/Basic.hs @@ -1,4 +1,6 @@ {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} -- | -- Module : Control.Monad.Bayes.Traced.Basic @@ -9,21 +11,15 @@ -- Stability : experimental -- Portability : GHC module Control.Monad.Bayes.Traced.Basic - ( Traced, - hoist, - marginal, - mhStep, - mh, - ) where import Control.Applicative (liftA2) import Control.Monad.Bayes.Class ( MonadCond (..), MonadInfer, - MonadSample (random), + MonadSample (random, Real), ) -import Control.Monad.Bayes.Free (FreeSampler) +import Control.Monad.Bayes.Density.Free (Density) import Control.Monad.Bayes.Traced.Common ( Trace (..), bind, @@ -34,33 +30,36 @@ import Control.Monad.Bayes.Traced.Common import Control.Monad.Bayes.Weighted (Weighted) import Data.Functor.Identity (Identity) import Data.List.NonEmpty as NE (NonEmpty ((:|)), toList) +import Prelude hiding (Real) -- | Tracing monad that records random choices made in the program. data Traced m a = Traced { -- | Run the program with a modified trace. - model :: Weighted (FreeSampler Identity) a, + model :: Weighted (Density Identity) a, -- | Record trace and output. - traceDist :: m (Trace a) + traceDist :: m (Trace (Real m) a) } instance Monad m => Functor (Traced m) where fmap f (Traced m d) = Traced (fmap f m) (fmap (fmap f) d) -instance Monad m => Applicative (Traced m) where +instance (Monad m, RealFloat (Real m)) => Applicative (Traced m) where pure x = Traced (pure x) (pure (pure x)) (Traced mf df) <*> (Traced mx dx) = Traced (mf <*> mx) (liftA2 (<*>) df dx) -instance Monad m => Monad (Traced m) where +instance (Monad m, RealFloat (Real m)) => Monad (Traced m) where (Traced mx dx) >>= f = Traced my dy where my = mx >>= model . f dy = dx `bind` (traceDist . f) instance MonadSample m => MonadSample (Traced m) where - random = Traced random (fmap singleton random) + type Real (Traced m) = Real m + +-- random = Traced random (fmap singleton random) -instance MonadCond m => MonadCond (Traced m) where - score w = Traced (score w) (score w >> pure (scored w)) +instance (MonadCond m, RealFloat (Real m)) => MonadCond (Traced m) where + score w = undefined -- Traced (score w) (score w >> pure (scored w)) instance MonadInfer m => MonadInfer (Traced m) diff --git a/src/Control/Monad/Bayes/Traced/Common.hs b/src/Control/Monad/Bayes/Traced/Common.hs index c424361f..d1e9e5f9 100644 --- a/src/Control/Monad/Bayes/Traced/Common.hs +++ b/src/Control/Monad/Bayes/Traced/Common.hs @@ -7,7 +7,7 @@ -- Stability : experimental -- Portability : GHC module Control.Monad.Bayes.Traced.Common - ( Trace, + ( Trace(..), singleton, output, scored, @@ -15,18 +15,12 @@ module Control.Monad.Bayes.Traced.Common mhTrans, mhTrans', burnIn, + _variables ) where import Control.Monad.Bayes.Class - ( MonadSample (bernoulli, random), - discrete, - ) -import Control.Monad.Bayes.Free as FreeSampler - ( FreeSampler, - hoist, - withPartialRandomness, - ) +import Control.Monad.Bayes.Density.Free import Control.Monad.Bayes.Weighted as Weighted ( Weighted, hoist, @@ -36,64 +30,72 @@ import Control.Monad.Trans.Writer (WriterT (WriterT, runWriterT)) import Data.Functor.Identity (Identity (runIdentity)) import Numeric.Log (Log, ln) import Statistics.Distribution.DiscreteUniform (discreteUniformAB) +import Prelude hiding (Real) -- | Collection of random variables sampler during the program's execution. -data Trace a = Trace +data Trace n a = Trace { -- | Sequence of random variables sampler during the program's execution. - variables :: [Double], + variables :: [n], -- output :: a, -- | The probability of observing this particular sequence. - density :: Log Double + probDensity :: Log n } -instance Functor Trace where +_variables :: Functor m => ([Real m] -> m [Real m]) + -> Trace (Real m) a -> m (Trace (Real m) a) +_variables f tr@Trace {variables = v, output = o, probDensity = p} = fmap (\v -> tr {variables = v}) (f v) + +instance Functor (Trace n) where fmap f t = t {output = f (output t)} -instance Applicative Trace where - pure x = Trace {variables = [], output = x, density = 1} +instance RealFloat n => Applicative (Trace n) where + pure x = Trace {variables = [], output = x, probDensity = 1} tf <*> tx = Trace { variables = variables tf ++ variables tx, output = output tf (output tx), - density = density tf * density tx + probDensity = probDensity tf * probDensity tx } -instance Monad Trace where +instance RealFloat n => Monad (Trace n) where t >>= f = let t' = f (output t) - in t' {variables = variables t ++ variables t', density = density t * density t'} + in t' {variables = variables t ++ variables t', probDensity = probDensity t * probDensity t'} -singleton :: Double -> Trace Double -singleton u = Trace {variables = [u], output = u, density = 1} +singleton :: RealFloat n => n -> Trace n n +singleton u = Trace {variables = [u], output = u, probDensity = 1} -scored :: Log Double -> Trace () -scored w = Trace {variables = [], output = (), density = w} +scored :: Log n -> Trace n () +scored w = Trace {variables = [], output = (), probDensity = w} -bind :: Monad m => m (Trace a) -> (a -> m (Trace b)) -> m (Trace b) +bind :: (Monad m, RealFloat n) => m (Trace n a) -> (a -> m (Trace n b)) -> m (Trace n b) bind dx f = do t1 <- dx t2 <- f (output t1) - return $ t2 {variables = variables t1 ++ variables t2, density = density t1 * density t2} + return $ t2 {variables = variables t1 ++ variables t2, probDensity = probDensity t1 * probDensity t2} + + + -- | A single Metropolis-corrected transition of single-site Trace MCMC. -mhTrans :: MonadSample m => Weighted (FreeSampler m) a -> Trace a -> m (Trace a) -mhTrans m t@Trace {variables = us, density = p} = do +mhTrans :: MonadSample m => Weighted (Density m) a -> Trace (Real m) a -> m (Trace (Real m) a) +mhTrans m t@Trace {variables = us, probDensity = p} = do let n = length us us' <- do - i <- discrete $ discreteUniformAB 0 (n - 1) + i <- undefined -- discrete $ discreteUniformAB 0 (n - 1) u' <- random case splitAt i us of (xs, _ : ys) -> return $ xs ++ (u' : ys) _ -> error "impossible" - ((b, q), vs) <- runWriterT $ weighted $ Weighted.hoist (WriterT . withPartialRandomness us') m + ((b, q), vs) <- runWriterT $ weighted $ Weighted.hoist (WriterT . density us') m let ratio = (exp . ln) $ min 1 (q * fromIntegral n / (p * fromIntegral (length vs))) accept <- bernoulli ratio return $ if accept then Trace vs b q else t -- | A variant of 'mhTrans' with an external sampling monad. -mhTrans' :: MonadSample m => Weighted (FreeSampler Identity) a -> Trace a -> m (Trace a) -mhTrans' m = mhTrans (Weighted.hoist (FreeSampler.hoist (return . runIdentity)) m) +mhTrans' :: MonadSample m => Weighted (Density Identity) a -> Trace (Real m) a -> m (Trace (Real m) a) +mhTrans' m = undefined -- mhTrans (Weighted.hoist (FreeSampler.hoist (return . runIdentity)) m) -- | burn in an MCMC chain for n steps (which amounts to dropping samples of the end of the list) burnIn :: Functor m => Int -> m [a] -> m [a] diff --git a/src/Control/Monad/Bayes/Traced/Dynamic.hs b/src/Control/Monad/Bayes/Traced/Dynamic.hs index 6e23bf33..6d934234 100644 --- a/src/Control/Monad/Bayes/Traced/Dynamic.hs +++ b/src/Control/Monad/Bayes/Traced/Dynamic.hs @@ -1,4 +1,6 @@ {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} -- | -- Module : Control.Monad.Bayes.Traced.Dynamic @@ -20,11 +22,7 @@ where import Control.Monad (join) import Control.Monad.Bayes.Class - ( MonadCond (..), - MonadInfer, - MonadSample (random), - ) -import Control.Monad.Bayes.Free (FreeSampler) +import Control.Monad.Bayes.Density.Free (Density) import Control.Monad.Bayes.Traced.Common ( Trace (..), bind, @@ -35,12 +33,13 @@ import Control.Monad.Bayes.Traced.Common import Control.Monad.Bayes.Weighted (Weighted) import Control.Monad.Trans (MonadTrans (..)) import Data.List.NonEmpty as NE (NonEmpty ((:|)), toList) +import Prelude hiding (Real) -- | A tracing monad where only a subset of random choices are traced and this -- subset can be adjusted dynamically. -newtype Traced m a = Traced {runTraced :: m (Weighted (FreeSampler m) a, Trace a)} +newtype Traced m a = Traced {runTraced :: m (Weighted (Density m) a, Trace (Real m) a)} -pushM :: Monad m => m (Weighted (FreeSampler m) a) -> Weighted (FreeSampler m) a +pushM :: Monad m => m (Weighted (Density m) a) -> Weighted (Density m) a pushM = join . lift . lift instance Monad m => Functor (Traced m) where @@ -50,14 +49,14 @@ instance Monad m => Functor (Traced m) where let t' = fmap f t return (m', t') -instance Monad m => Applicative (Traced m) where +instance (Monad m, RealFloat (Real m)) => Applicative (Traced m) where pure x = Traced $ pure (pure x, pure x) (Traced cf) <*> (Traced cx) = Traced $ do (mf, tf) <- cf (mx, tx) <- cx return (mf <*> mx, tf <*> tx) -instance Monad m => Monad (Traced m) where +instance (Monad m, RealFloat (Real m)) => Monad (Traced m) where (Traced cx) >>= f = Traced $ do (mx, tx) <- cx let m = mx >>= pushM . fmap fst . runTraced . f @@ -65,13 +64,14 @@ instance Monad m => Monad (Traced m) where return (m, t) instance MonadTrans Traced where - lift m = Traced $ fmap ((,) (lift $ lift m) . pure) m + lift m = undefined -- Traced $ fmap ((,) (lift $ lift m) . pure) m instance MonadSample m => MonadSample (Traced m) where + type Real (Traced m) = Real m random = Traced $ fmap ((,) random . singleton) random -instance MonadCond m => MonadCond (Traced m) where - score w = Traced $ fmap (score w,) (score w >> pure (scored w)) +instance (MonadCond m, RealFloat (Real m)) => MonadCond (Traced m) where + score w = undefined -- Traced $ fmap (score w,) (score w >> pure (scored w)) instance MonadInfer m => MonadInfer (Traced m) @@ -84,7 +84,7 @@ marginal (Traced c) = fmap (output . snd) c -- | Freeze all traced random choices to their current values and stop tracing -- them. -freeze :: Monad m => Traced m a -> Traced m a +freeze :: (Monad m, RealFloat (Real m)) => Traced m a -> Traced m a freeze (Traced c) = Traced $ do (_, t) <- c let x = output t diff --git a/src/Control/Monad/Bayes/Traced/Grad.hs b/src/Control/Monad/Bayes/Traced/Grad.hs new file mode 100644 index 00000000..820d9566 --- /dev/null +++ b/src/Control/Monad/Bayes/Traced/Grad.hs @@ -0,0 +1,229 @@ +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} + +-- | +-- Module : Control.Monad.Bayes.Traced.Static +-- Description : Distributions on execution traces of full programs +-- Copyright : (c) Adam Scibior, 2015-2020 +-- License : MIT +-- Maintainer : leonhard.markert@tweag.io +-- Stability : experimental +-- Portability : GHC +module Control.Monad.Bayes.Traced.Grad +-- ( Traced, +-- hoist, +-- marginal, +-- mhStep, +-- mh, +-- ) +where + +import Control.Applicative (liftA2) +import Control.Monad.Bayes.Class +import Control.Monad.Bayes.Traced.Common + +import Control.Monad.Bayes.Weighted (Weighted, weighted, unweighted) +import Control.Monad.Trans (MonadTrans (..)) +import Data.List.NonEmpty as NE (NonEmpty ((:|)), toList) +import Prelude hiding (Real) +import Control.Monad.State +import Control.Monad.Writer +import Numeric.Log (Log (ln, Exp)) +import Numeric.AD +import Numeric.AD.Internal.Reverse (Reverse (Lift), Tape) +import Data.Reflection (Reifies) +import Data.Number.Erf +import Linear (V2, _y) +import Linear.V2 (V2(..)) +import Control.Lens (view) +import Math.Integrators.StormerVerlet +import Unsafe.Coerce (unsafeCoerce) +import Control.Monad.Bayes.Sampler (sampler) +import qualified Control.Monad.Bayes.Weighted as Weighted + +newtype Density n a = Density (State [n] a) deriving newtype (Functor, Applicative, Monad) + +instance MonadState [n] (Density n) where + get = Density get + put = Density . put + + +instance CustomReal n => MonadSample (Density n) where + type (Real (Density n)) = n + random = do + trace <- get + x <- case trace of + [] -> error "ran out of randomness in the trace: this suggests that you are running HMC on a probabilistic program with stochastic control flow. Don't do that!" + r : xs -> put xs >> pure r + pure x + + +data Traced n m a = Traced + { model :: Weighted (Density n) a, + traceDist :: m (Trace (Real m) a) + } + +marginal :: Monad m => Traced n m a -> m a +marginal (Traced _ d) = fmap output d + + +instance Monad m => Functor (Traced n m) where + fmap f (Traced m d) = Traced (fmap f m) (fmap (fmap f) d) + +instance (Monad m, CustomReal (Real m)) => Applicative (Traced n m) where + pure x = Traced (pure x) (pure (pure x)) + (Traced mf df) <*> (Traced mx dx) = Traced (mf <*> mx) (liftA2 (<*>) df dx) + +instance (Monad m, CustomReal (Real m)) => Monad (Traced n m) where + (Traced mx dx) >>= f = Traced my dy + where + my = mx >>= model . f + dy = dx `bind` (traceDist . f) + +-- instance MonadTrans (Traced n) where +-- lift m = undefined -- Traced (lift $ lift m) (fmap pure m) + +instance (MonadSample m, CustomReal n, n ~ Real m) => MonadSample (Traced n m) where + type Real (Traced n m) = Real m + random = Traced random (fmap singleton random) + -- normal a b = lift $ normal a b + normal a b = inverf <$> random + +instance (MonadCond m, CustomReal (Real m), n ~ Real m) => MonadCond (Traced n m) where + score w = Traced (score w) (score w >> pure (scored w)) + +instance (CustomReal n, n ~ Real m, MonadInfer m) => MonadInfer (Traced n m) + +density :: Density n b -> [n] -> b +density (Density m) = evalState ( m) + +tangent :: CustomReal n => (forall m . MonadInfer m => m a) -> [n] -> [n] +tangent m = grad $ pdf m + +pdf :: CustomReal c => Weighted (Density c) a -> [c] -> c +pdf m = ln . exp . snd . density (weighted m) + +ex :: CustomReal n => [n] -> [n] +ex = tangent example + +example :: MonadInfer m => m Bool +example = do + x <- random + -- condition (x > 0.5) + -- factor (Exp $ log ((1/x) ** 3)) + return (x > 0.5) + + +nrml :: Floating a => a -> a -> a -> a +nrml mu sigma x = 1 / (sigma * sqrt (2 * pi)) * exp ((-0.5) * (((x - mu) / sigma) ^^ 2)) + +getPhasePoint :: MonadSample m => Real m -> m (Real m, Real m) +getPhasePoint q = (,q) <$> normal 0 1 -- TODO: fix to normal q 1 + +getPhasePoints :: MonadSample m => [Real m] -> m (V2 [Real m]) +getPhasePoints x = uncurry V2 . unzip <$> Prelude.mapM getPhasePoint x + +hamiltonian :: CustomReal n => ([n] -> n) -> [n] -> [n] -> n +hamiltonian potential p q = negate (log $ potential q) - Prelude.sum (log (nrml q 1 p)) + + + + +hmcKernel :: (MonadSample m) => +-- (Weighted (FreeSampler IdentityN) (Reverse s a) a) -> + (forall n. (CustomReal n) => [n] -> n) -> [Real m] -> m [Real m] +hmcKernel potential = + fmap + ( view _y . + Prelude.foldr (.) id (Prelude.replicate 100 stepForward) + ) + . getPhasePoints + where + h :: (CustomReal n) => [n] -> [n] -> n + h = hamiltonian potential + stepForward :: (CustomReal a, Num a) => V2 [a] -> V2 [a] + stepForward x@(V2 p q) = stormerVerlet2H 0.1 ((grad $ h (Lift <$> p))) (grad (flip (h) (Lift <$> q))) x + + + + +-- example :: (Show n, CustomReal n, MonadSample n m) => m n [n] +-- example = (\x -> ((hmcKernel . pdf) program) x) [invSigmoid 0.5] + +ex2 :: IO [Bool] +ex2 = sampler $ unweighted $ marginal $ mh 10 example + +-- | A single step of the Trace Metropolis-Hastings algorithm. +mhStep :: forall n m a . MonadSample m => Traced n m a -> Traced n m a +mhStep (Traced m d) = Traced m d' + where + d' = do + tr <- d + let vars = variables tr + newVars <- hmcKernel (pdf (unsafeCoerce m)) vars + return (tr {variables = newVars} ) + +-- | A single Metropolis-corrected transition of single-site Trace MCMC. +-- mhTrans :: MonadSample m => Weighted (Density n) a -> Trace (Real m) a -> m (Trace (Real m) a) +-- mhTrans m t@Trace {variables = us, probDensity = p} = do +-- let n = length us +-- us' <- do +-- i <- undefined -- discrete $ discreteUniformAB 0 (n - 1) +-- u' <- random +-- case splitAt i us of +-- (xs, _ : ys) -> return $ xs ++ (u' : ys) +-- _ -> error "impossible" +-- ((b, q), vs) <- runWriterT $ weighted $ Weighted.hoist (WriterT . density us') m +-- let ratio = (exp . ln) $ min 1 (q * fromIntegral n / (p * fromIntegral (length vs))) +-- accept <- bernoulli ratio +-- return $ if accept then Trace vs b q else t + +mh :: MonadSample m => Int -> Traced n m a -> m [a] +mh n (Traced m d) = fmap (map output . NE.toList) (f n) + where + f k + | k <= 0 = fmap (:| []) d + | otherwise = do + (x :| xs) <- f (k - 1) + y <- (_variables (hmcKernel (pdf (unsafeCoerce m)))) x + return (y :| x : xs) + + +instance (Num a) => Num ([a]) where + ([a]) + (b) = fmap ((+) a) b + (a) + ([b]) = fmap ((+) b) a + (a) + (b) = zipWith (+) a b + + -- (a) * (b) = zipWith (*) (traceIt "left" a) (traceIt "right" b) + ([a]) * (b) = fmap ((*) a) b + (a) * ([b]) = fmap ((*) b) a + (a) * (b) = zipWith (*) a b + abs (a) = abs <$> a + negate (a) = negate <$> a + signum (a) = signum <$> a + fromInteger a = [fromInteger a] + +instance (Fractional a) => Fractional ([a]) where + fromRational a = [fromRational a] + recip = fmap recip + +instance Floating a => Floating ([a]) where + pi = [pi] + log = fmap log + sin = fmap sin + cos = fmap cos + cosh = fmap cosh + sinh = fmap sinh + exp = fmap exp + atan = fmap atan + asin = fmap asin + acosh = fmap acosh + atanh = fmap atanh + asinh = fmap asinh + acos = fmap acos + tan = fmap tan \ No newline at end of file diff --git a/src/Control/Monad/Bayes/Traced/Static.hs b/src/Control/Monad/Bayes/Traced/Static.hs index 4be243e8..9c1fbb84 100644 --- a/src/Control/Monad/Bayes/Traced/Static.hs +++ b/src/Control/Monad/Bayes/Traced/Static.hs @@ -1,5 +1,7 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} -- | -- Module : Control.Monad.Bayes.Traced.Static @@ -10,7 +12,7 @@ -- Stability : experimental -- Portability : GHC module Control.Monad.Bayes.Traced.Static - ( Traced, + ( Traced(..), hoist, marginal, mhStep, @@ -20,11 +22,7 @@ where import Control.Applicative (liftA2) import Control.Monad.Bayes.Class - ( MonadCond (..), - MonadInfer, - MonadSample (random), - ) -import Control.Monad.Bayes.Free (FreeSampler) +import Control.Monad.Bayes.Density.Free (Density) import Control.Monad.Bayes.Traced.Common ( Trace (..), bind, @@ -35,37 +33,39 @@ import Control.Monad.Bayes.Traced.Common import Control.Monad.Bayes.Weighted (Weighted) import Control.Monad.Trans (MonadTrans (..)) import Data.List.NonEmpty as NE (NonEmpty ((:|)), toList) +import Prelude hiding (Real) -- | A tracing monad where only a subset of random choices are traced. -- -- The random choices that are not to be traced should be lifted from the -- transformed monad. data Traced m a = Traced - { model :: Weighted (FreeSampler m) a, - traceDist :: m (Trace a) + { model :: Weighted (Density m) a, + traceDist :: m (Trace (Real m) a) } instance Monad m => Functor (Traced m) where fmap f (Traced m d) = Traced (fmap f m) (fmap (fmap f) d) -instance Monad m => Applicative (Traced m) where +instance (Monad m, RealFloat (Real m)) => Applicative (Traced m) where pure x = Traced (pure x) (pure (pure x)) (Traced mf df) <*> (Traced mx dx) = Traced (mf <*> mx) (liftA2 (<*>) df dx) -instance Monad m => Monad (Traced m) where +instance (Monad m, RealFloat (Real m)) => Monad (Traced m) where (Traced mx dx) >>= f = Traced my dy where my = mx >>= model . f dy = dx `bind` (traceDist . f) instance MonadTrans Traced where - lift m = Traced (lift $ lift m) (fmap pure m) + lift m = undefined -- Traced (lift $ lift m) (fmap pure m) instance MonadSample m => MonadSample (Traced m) where + type Real (Traced m) = Real m random = Traced random (fmap singleton random) -instance MonadCond m => MonadCond (Traced m) where - score w = Traced (score w) (score w >> pure (scored w)) +instance (MonadCond m, RealFloat (Real m)) => MonadCond (Traced m) where + score w = undefined -- Traced (score w) (score w >> pure (scored w)) instance MonadInfer m => MonadInfer (Traced m) diff --git a/src/Control/Monad/Bayes/Weighted.hs b/src/Control/Monad/Bayes/Weighted.hs index 2dd29a03..a3e70836 100644 --- a/src/Control/Monad/Bayes/Weighted.hs +++ b/src/Control/Monad/Bayes/Weighted.hs @@ -1,6 +1,8 @@ {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} -- | -- Module : Control.Monad.Bayes.Weighted @@ -28,59 +30,63 @@ where import Control.Arrow (Arrow (first)) import Control.Monad.Bayes.Class - ( MonadCond (..), - MonadInfer, - MonadSample, - factor, - ) import Control.Monad.Trans (MonadIO, MonadTrans (..)) import Control.Monad.Trans.State (StateT (..), mapStateT, modify) import Data.Fixed (mod') import Numeric.Log (Log) +import Prelude hiding (Real) -- | Execute the program using the prior distribution, while accumulating likelihood. -newtype Weighted m a = Weighted (StateT (Log Double) m a) +newtype Weighted m a = Weighted (StateT (Log (Real m)) m a) -- StateT is more efficient than WriterT - deriving newtype (Functor, Applicative, Monad, MonadIO, MonadTrans, MonadSample) + deriving newtype (Functor, Applicative, Monad, MonadIO) + +instance (MonadTrans Weighted) where + lift = Weighted . lift -instance Monad m => MonadCond (Weighted m) where +instance (Monad m, RealFloat (Real m)) => MonadCond (Weighted m) where score w = Weighted (modify (* w)) +instance MonadSample m => MonadSample (Weighted m) where + type Real (Weighted m) = Real m + random = lift random + normal m v = lift $ normal m v + instance MonadSample m => MonadInfer (Weighted m) -- | Obtain an explicit value of the likelihood for a given value. -weighted, runWeighted :: Weighted m a -> m (a, Log Double) +weighted, runWeighted :: RealFloat (Real m) => Weighted m a -> m (a, Log (Real m)) weighted (Weighted m) = runStateT m 1 runWeighted = weighted -- | Compute the sample and discard the weight. -- -- This operation introduces bias. -unweighted :: Functor m => Weighted m a -> m a +unweighted :: (Functor m, RealFloat (Real m)) => Weighted m a -> m a unweighted = fmap fst . weighted -- | Compute the weight and discard the sample. -extractWeight :: Functor m => Weighted m a -> m (Log Double) +extractWeight :: (Functor m, RealFloat (Real m)) => Weighted m a -> m (Log (Real m)) extractWeight = fmap snd . weighted -- | Embed a random variable with explicitly given likelihood. -- -- > weighted . withWeight = id -withWeight :: (Monad m) => m (a, Log Double) -> Weighted m a +withWeight :: (Monad m, RealFloat (Real m)) => m (a, Log (Real m)) -> Weighted m a withWeight m = Weighted $ do (x, w) <- lift m modify (* w) return x -- | Use the weight as a factor in the transformed monad. -applyWeight :: MonadCond m => Weighted m a -> m a +applyWeight :: (MonadCond m, RealFloat (Real m)) => Weighted m a -> m a applyWeight m = do (x, w) <- weighted m factor w return x -- | Apply a transformation to the transformed monad. -hoist :: (forall x. m x -> n x) -> Weighted m a -> Weighted n a +hoist :: Real n ~ Real m => (forall x. m x -> n x) -> Weighted m a -> Weighted n a hoist t (Weighted m) = Weighted $ mapStateT t m toBinsWeighted :: Double -> [(Double, Log Double)] -> [(Double, Log Double)] diff --git a/src/Math/Integrators/StormerVerlet.hs b/src/Math/Integrators/StormerVerlet.hs index 0bcf8a8f..6b48229e 100644 --- a/src/Math/Integrators/StormerVerlet.hs +++ b/src/Math/Integrators/StormerVerlet.hs @@ -25,7 +25,7 @@ type Integrator a = -- | Störmer-Verlet integration scheme for systems of the form -- \(\mathbb{H}(p,q) = T(p) + V(q)\) stormerVerlet2H :: - (Applicative f, Num (f a), Show (f a), Fractional a) => + (Applicative f, Num (f a), Fractional a) => -- | Step size a -> -- | \(\frac{\partial H}{\partial q}\) diff --git a/stack.yaml b/stack.yaml index a5c40966..0bbdb223 100644 --- a/stack.yaml +++ b/stack.yaml @@ -1,7 +1,10 @@ -resolver: nightly-2022-07-26 +resolver: nightly-2022-03-19 packages: - "." flags: monad-bayes: dev: True +extra-deps: +- monad-coroutine-0.9.2@sha256:33b0851419996ddacf665101369bc9e5e80263758564841eba9c9f65a8a75553,1302 +- monad-parallel-0.8@sha256:44f64399061036580acfa627c1f07f5481553f058de7d54ab9953801f75d5e86,1155 \ No newline at end of file