Skip to content

Commit bb7896d

Browse files
authored
Separate cache from model (#13)
1 parent cc3a792 commit bb7896d

11 files changed

+198
-180
lines changed

.github/workflows/CI.yml

+2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ jobs:
5050
${{ runner.os }}-
5151
- uses: julia-actions/julia-buildpkg@v1
5252
- uses: julia-actions/julia-runtest@v1
53+
env:
54+
JULIA_NUM_THREADS: 2
5355
- uses: julia-actions/julia-processcoverage@v1
5456
if: matrix.coverage
5557
- uses: codecov/codecov-action@v1

Project.toml

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "EllipticalSliceSampling"
22
uuid = "cad2338a-1db2-11e9-3401-43bc07c9ede2"
3-
authors = ["David Widmann <[email protected]>"]
4-
version = "0.3.1"
3+
authors = ["David Widmann <[email protected]>"]
4+
version = "0.4.0"
55

66
[deps]
77
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -18,11 +18,11 @@ julia = "1"
1818

1919
[extras]
2020
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
21+
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
2122
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
2223
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
23-
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
2424
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2525
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2626

2727
[targets]
28-
test = ["Distances", "FillArrays", "LinearAlgebra", "SafeTestsets", "Statistics", "Test"]
28+
test = ["Distances", "Distributed", "FillArrays", "LinearAlgebra", "Statistics", "Test"]

README.md

+13-6
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,16 @@ which returns a vector of `N` samples for approximating the posterior of
3030
a model with a Gaussian prior that allows sampling from the `prior` and
3131
evaluation of the log likelihood `loglikelihood`.
3232

33+
You can sample multiple chains in parallel with multiple threads or processes
34+
by running
35+
```julia
36+
sample([rng, ]ESSModel(prior, loglikelihood), ESS(), MCMCThreads(), N, nchains[; kwargs...])
37+
```
38+
or
39+
```julia
40+
sample([rng, ]ESSModel(prior, loglikelihood), ESS(), MCMCDistributed(), N, nchains[; kwargs...])
41+
```
42+
3343
If you want to have more control about the sampling procedure (e.g., if you
3444
only want to save a subset of samples or want to use another stopping
3545
criterion), the function
@@ -44,6 +54,9 @@ AbstractMCMC.steps(
4454
gives you access to an iterator from which you can generate an unlimited
4555
number of samples.
4656

57+
For more details regarding `sample` and `steps` please check the documentation of
58+
[AbstractMCMC.jl](https://github.com/TuringLang/AbstractMCMC.jl).
59+
4760
### Prior
4861

4962
You may specify Gaussian priors with arbitrary means. EllipticalSliceSampling.jl
@@ -74,12 +87,6 @@ Statistics.mean(dist::GaussianPrior) = ...
7487
# - otherwise only `rand!(rng, dist, sample)` is required
7588
Base.rand(rng::AbstractRNG, dist::GaussianPrior) = ...
7689
Random.rand!(rng::AbstractRNG, dist::GaussianPrior, sample) = ...
77-
78-
# specify the type of a sample from the distribution
79-
Base.eltype(::Type{<:GaussianPrior}) = ...
80-
81-
# in the case of mutable samples, specify the array size of the samples
82-
Base.size(dist::GaussianPrior) = ...
8390
```
8491

8592
### Log likelihood

src/EllipticalSliceSampling.jl

+5-2
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@ import Distributions
77
import Random
88
import Statistics
99

10-
export sample, ESSModel, ESS
10+
export ESSModel, ESS
11+
12+
# reexports
13+
using AbstractMCMC: sample, MCMCThreads, MCMCDistributed
14+
export sample, MCMCThreads, MCMCDistributed
1115

1216
include("abstractmcmc.jl")
1317
include("model.jl")
14-
include("distributions.jl")
1518
include("interface.jl")
1619

1720
end # module

src/abstractmcmc.jl

+23-6
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,19 @@
22
struct ESS <: AbstractMCMC.AbstractSampler end
33

44
# state of the elliptical slice sampler
5-
struct ESSState{S,L}
5+
struct ESSState{S,L,C}
66
"Sample of the elliptical slice sampler."
77
sample::S
88
"Log-likelihood of the sample."
99
loglikelihood::L
10+
"Cache used for in-place sampling."
11+
cache::C
12+
end
13+
14+
function ESSState(sample, loglikelihood)
15+
# create cache since it was not provided (initial sampling step)
16+
cache = ArrayInterface.ismutable(sample) ? similar(sample) : nothing
17+
return ESSState(sample, loglikelihood, cache)
1018
end
1119

1220
# first step of the elliptical slice sampler
@@ -33,8 +41,17 @@ function AbstractMCMC.step(
3341
state::ESSState;
3442
kwargs...
3543
)
44+
# obtain the prior
45+
prior = EllipticalSliceSampling.prior(model)
46+
3647
# sample from Gaussian prior
37-
ν = sample_prior(rng, model)
48+
cache = state.cache
49+
if cache === nothing
50+
ν = Random.rand(rng, prior)
51+
else
52+
Random.rand!(rng, prior, cache)
53+
ν = cache
54+
end
3855

3956
# sample log-likelihood threshold
4057
loglikelihood = state.loglikelihood
@@ -47,7 +64,7 @@ function AbstractMCMC.step(
4764

4865
# compute the proposal
4966
f = state.sample
50-
fnext = proposal(model, f, ν, θ)
67+
fnext = proposal(prior, f, ν, θ)
5168

5269
# compute the log-likelihood of the proposal
5370
loglikelihood = Distributions.loglikelihood(model, fnext)
@@ -66,14 +83,14 @@ function AbstractMCMC.step(
6683

6784
# recompute the proposal
6885
if ArrayInterface.ismutable(fnext)
69-
proposal!(fnext, model, f, ν, θ)
86+
proposal!(fnext, prior, f, ν, θ)
7087
else
71-
fnext = proposal(model, f, ν, θ)
88+
fnext = proposal(prior, f, ν, θ)
7289
end
7390

7491
# compute the log-likelihood of the proposal
7592
loglikelihood = Distributions.loglikelihood(model, fnext)
7693
end
7794

78-
return fnext, ESSState(fnext, loglikelihood)
95+
return fnext, ESSState(fnext, loglikelihood, cache)
7996
end

src/distributions.jl

-54
This file was deleted.

src/interface.jl

+46-22
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,73 @@
11
# private interface
22

33
"""
4-
initial_sample(rng, model)
4+
isgaussian(dist)
55
6-
Return the initial sample for the `model` using the random number generator `rng`.
6+
Check if distribution `dist` is a Gaussian distribution.
7+
"""
8+
isgaussian(dist) = false
9+
isgaussian(::Type{<:Distributions.Normal}) = true
10+
isgaussian(::Type{<:Distributions.NormalCanon}) = true
11+
isgaussian(::Type{<:Distributions.AbstractMvNormal}) = true
712

8-
By default, sample from the prior by calling [`sample_prior(rng, model)`](@ref).
913
"""
10-
function initial_sample(rng::Random.AbstractRNG, model::AbstractMCMC.AbstractModel)
11-
return sample_prior(rng, model)
12-
end
14+
prior(model)
15+
16+
Return the prior distribution of the `model`.
17+
"""
18+
function prior(::AbstractMCMC.AbstractModel) end
1319

1420
"""
15-
sample_prior(rng, model)
21+
initial_sample(rng, model)
1622
17-
Sample from the prior of the `model` using the random number generator `rng`.
23+
Return the initial sample for the `model` using the random number generator `rng`.
24+
25+
By default, sample from [`prior(model)`](@ref).
1826
"""
19-
function sample_prior(::Random.AbstractRNG, ::AbstractMCMC.AbstractModel) end
27+
function initial_sample(rng::Random.AbstractRNG, model::AbstractMCMC.AbstractModel)
28+
return Random.rand(rng, prior(model))
29+
end
2030

2131
"""
22-
proposal(model, f, ν, θ)
32+
proposal(prior, f, ν, θ)
2333
24-
Compute the proposal for the next sample in the elliptical slice sampling algorithm for the
25-
`model` from the previous sample `f`, the sample `ν` from the Gaussian prior, and the angle
26-
`θ`.
34+
Compute the proposal for the next sample in the elliptical slice sampling algorithm.
2735
2836
Mathematically, the proposal can be computed as
2937
```math
30-
\\cos θ f + ν \\sin θ ν + μ (1 - \\sin θ + \\cos θ),
38+
f \\cos θ + ν \\sin θ + μ (1 - (\\sin θ + \\cos θ)),
3139
```
32-
where ``μ`` is the mean of the Gaussian prior.
40+
where ``μ`` is the mean of the Gaussian `prior`, `f` is the previous sample, and `ν` is a
41+
sample from the Gaussian `prior`.
42+
43+
See also: [`proposal!`](@ref)
3344
"""
34-
function proposal(model::AbstractMCMC.AbstractModel, f, ν, θ) end
45+
function proposal(prior, f, ν, θ)
46+
sinθ, cosθ = sincos(θ)
47+
a = 1 - (sinθ + cosθ)
48+
μ = Statistics.mean(prior)
49+
return @. cosθ * f + sinθ * ν + a * μ
50+
end
3551

3652
"""
3753
proposal!(out, model, f, ν, θ)
3854
39-
Compute the proposal for the next sample in the elliptical slice sampling algorithm for the
40-
`model` from the previous sample `f`, the sample `ν` from the Gaussian prior, and the angle
41-
`θ`, and save it to `out`.
55+
Compute the proposal for the next sample in the elliptical slice sampling algorithm, and
56+
save it to `out`.
4257
4358
Mathematically, the proposal can be computed as
4459
```math
45-
\\cos θ f + ν \\sin θ ν + μ (1 - \\sin θ + \\cos θ),
60+
f \\cos θ + ν \\sin θ + μ (1 - (\\sin θ + \\cos θ)),
4661
```
47-
where ``μ`` is the mean of the Gaussian prior.
62+
where ``μ`` is the mean of the Gaussian `prior`, `f` is the previous sample, and `ν` is a
63+
sample from the Gaussian `prior`.
64+
65+
See also: [`proposal`](@ref)
4866
"""
49-
function proposal!(out, model::AbstractMCMC.AbstractModel, f, ν, θ) end
67+
function proposal!(out, prior, f, ν, θ)
68+
sinθ, cosθ = sincos(θ)
69+
a = 1 - (sinθ + cosθ)
70+
μ = Statistics.mean(prior)
71+
@. out = cosθ * f + sinθ * ν + a * μ
72+
return out
73+
end

src/model.jl

+5-52
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,23 @@
1-
# internal model structure consisting of prior, log-likelihood function, and a cache
1+
# internal model structure consisting of prior and log-likelihood function
22

3-
struct ESSModel{P,L,C} <: AbstractMCMC.AbstractModel
3+
struct ESSModel{P,L} <: AbstractMCMC.AbstractModel
44
"Gaussian prior."
55
prior::P
66
"Log likelihood function."
77
loglikelihood::L
8-
"Cache."
9-
cache::C
108

119
function ESSModel{P,L}(prior::P, loglikelihood::L) where {P,L}
1210
isgaussian(P) ||
1311
error("prior distribution has to be a Gaussian distribution")
14-
15-
# create cache
16-
c = cache(prior)
17-
18-
new{P,L,typeof(c)}(prior, loglikelihood, c)
12+
new{P,L}(prior, loglikelihood)
1913
end
2014
end
2115

2216
ESSModel(prior, loglikelihood) =
2317
ESSModel{typeof(prior),typeof(loglikelihood)}(prior, loglikelihood)
2418

25-
# cache for high-dimensional samplers
26-
function cache(dist)
27-
T = randtype(typeof(dist))
28-
29-
# only create a cache if the distribution produces mutable samples
30-
ArrayInterface.ismutable(T) || return nothing
31-
32-
similar(T, size(dist))
33-
end
34-
35-
# test if a distribution is Gaussian
36-
isgaussian(dist) = false
37-
38-
# unify element type of samplers
39-
randtype(dist) = eltype(dist)
19+
# obtain prior
20+
prior(model::ESSModel) = model.prior
4021

4122
# evaluate the loglikelihood of a sample
4223
Distributions.loglikelihood(model::ESSModel, f) = model.loglikelihood(f)
43-
44-
# sample from the prior
45-
initial_sample(rng::Random.AbstractRNG, model::ESSModel) = rand(rng, model.prior)
46-
function sample_prior(rng::Random.AbstractRNG, model::ESSModel)
47-
cache = model.cache
48-
49-
if cache === nothing
50-
return rand(rng, model.prior)
51-
else
52-
Random.rand!(rng, model.prior, model.cache)
53-
return model.cache
54-
end
55-
end
56-
57-
# compute the proposal
58-
proposal(model::ESSModel, f, ν, θ) = proposal(model.prior, f, ν, θ)
59-
proposal!(out, model::ESSModel, f, ν, θ) = proposal!(out, model.prior, f, ν, θ)
60-
61-
# default out-of-place implementation
62-
function proposal(prior, f, ν, θ)
63-
sinθ, cosθ = sincos(θ)
64-
a = 1 - (sinθ + cosθ)
65-
μ = Statistics.mean(prior)
66-
return @. cosθ * f + sinθ * ν + a * μ
67-
end
68-
69-
# default in-place implementation
70-
proposal!(out, prior, f, ν, θ) = copyto!(out, proposal(prior, f, ν, θ))

0 commit comments

Comments
 (0)