Skip to content

Commit 8d7f22f

Browse files
authored
Fix discard_initial, and add support for discard_initial and thinning to iterator and transducer (#102)
* Fix `discard_initial`, and add support for `discard_initial` and `thinning` to iterator and transducer * Fix test errors on Julia < 1.6 * Only enable progress logging on Julia < 1.6 * Use different seed * Update api.md * Update api.md * Update sample.jl * Use `==` instead of `===`
1 parent 650d9e1 commit 8d7f22f

File tree

8 files changed

+243
-42
lines changed

8 files changed

+243
-42
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
keywords = ["markov chain monte carlo", "probablistic programming"]
44
license = "MIT"
55
desc = "A lightweight interface for common MCMC methods."
6-
version = "4.1.0"
6+
version = "4.1.1"
77

88
[deps]
99
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"

docs/src/api.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ AbstractMCMC.MCMCSerial
4343

4444
## Common keyword arguments
4545

46-
Common keyword arguments for regular and parallel sampling (not supported by the iterator and transducer)
47-
are:
46+
Common keyword arguments for regular and parallel sampling are:
4847
- `progress` (default: `AbstractMCMC.PROGRESS[]` which is `true` initially): toggles progress logging
4948
- `chain_type` (default: `Any`): determines the type of the returned chain
5049
- `callback` (default: `nothing`): if `callback !== nothing`, then
@@ -53,6 +52,9 @@ are:
5352
- `discard_initial` (default: `0`): number of initial samples that are discarded
5453
- `thinning` (default: `1`): factor by which to thin samples.
5554

55+
!!! info
56+
The common keyword arguments `progress`, `chain_type`, and `callback` are not supported by the iterator [`AbstractMCMC.steps`](@ref) and the transducer [`AbstractMCMC.Sample`](@ref).
57+
5658
There is no "official" way for providing initial parameter values yet.
5759
However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain.
5860
To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94):

src/sample.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ function mcmcsample(
120120
sample, state = step(rng, model, sampler; kwargs...)
121121

122122
# Discard initial samples.
123-
for i in 1:(discard_initial - 1)
123+
for i in 1:discard_initial
124124
# Update the progress bar.
125125
if progress && i >= next_update
126126
ProgressLogging.@logprogress i / Ntotal
@@ -218,7 +218,7 @@ function mcmcsample(
218218
sample, state = step(rng, model, sampler; kwargs...)
219219

220220
# Discard initial samples.
221-
for _ in 2:discard_initial
221+
for _ in 1:discard_initial
222222
# Obtain the next sample and state.
223223
sample, state = step(rng, model, sampler, state; kwargs...)
224224
end

src/stepper.jl

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,37 @@ struct Stepper{A<:Random.AbstractRNG,M<:AbstractModel,S<:AbstractSampler,K}
55
kwargs::K
66
end
77

8-
Base.iterate(stp::Stepper) = step(stp.rng, stp.model, stp.sampler; stp.kwargs...)
8+
# Initial sample.
9+
function Base.iterate(stp::Stepper)
10+
# Unpack iterator.
11+
rng = stp.rng
12+
model = stp.model
13+
sampler = stp.sampler
14+
kwargs = stp.kwargs
15+
discard_initial = get(kwargs, :discard_initial, 0)::Int
16+
17+
# Start sampling algorithm and discard initial samples if desired.
18+
sample, state = step(rng, model, sampler; kwargs...)
19+
for _ in 1:discard_initial
20+
sample, state = step(rng, model, sampler, state; kwargs...)
21+
end
22+
return sample, state
23+
end
24+
25+
# Subsequent samples.
926
function Base.iterate(stp::Stepper, state)
10-
return step(stp.rng, stp.model, stp.sampler, state; stp.kwargs...)
27+
# Unpack iterator.
28+
rng = stp.rng
29+
model = stp.model
30+
sampler = stp.sampler
31+
kwargs = stp.kwargs
32+
thinning = get(kwargs, :thinning, 1)::Int
33+
34+
# Return next sample, possibly after thinning the chain if desired.
35+
for _ in 1:(thinning - 1)
36+
_, state = step(rng, model, sampler, state; kwargs...)
37+
end
38+
return step(rng, model, sampler, state; kwargs...)
1139
end
1240

1341
Base.IteratorSize(::Type{<:Stepper}) = Base.IsInfinite()

src/transducer.jl

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,24 +40,58 @@ function Sample(
4040
return Sample(rng, model, sampler, kwargs)
4141
end
4242

43+
# Initial sample.
4344
function Transducers.start(rf::Transducers.R_{<:Sample}, result)
44-
sampler = Transducers.xform(rf)
45+
# Unpack transducer.
46+
td = Transducers.xform(rf)
47+
rng = td.rng
48+
model = td.model
49+
sampler = td.sampler
50+
kwargs = td.kwargs
51+
discard_initial = get(kwargs, :discard_initial, 0)::Int
52+
53+
# Start sampling algorithm and discard initial samples if desired.
54+
sample, state = step(rng, model, sampler; kwargs...)
55+
for _ in 1:discard_initial
56+
sample, state = step(rng, model, sampler, state; kwargs...)
57+
end
58+
4559
return Transducers.wrap(
46-
rf,
47-
step(sampler.rng, sampler.model, sampler.sampler; sampler.kwargs...),
48-
Transducers.start(Transducers.inner(rf), result),
60+
rf, (sample, state), Transducers.start(Transducers.inner(rf), result)
4961
)
5062
end
5163

64+
# Subsequent samples.
5265
function Transducers.next(rf::Transducers.R_{<:Sample}, result, input)
53-
t = Transducers.xform(rf)
54-
Transducers.wrapping(rf, result) do (sample, state), iresult
55-
iresult2 = Transducers.next(Transducers.inner(rf), iresult, sample)
56-
return step(t.rng, t.model, t.sampler, state; t.kwargs...), iresult2
66+
# Unpack transducer.
67+
td = Transducers.xform(rf)
68+
rng = td.rng
69+
model = td.model
70+
sampler = td.sampler
71+
kwargs = td.kwargs
72+
thinning = get(kwargs, :thinning, 1)::Int
73+
74+
let rng = rng,
75+
model = model,
76+
sampler = sampler,
77+
kwargs = kwargs,
78+
thinning = thinning,
79+
inner_rf = Transducers.inner(rf)
80+
81+
Transducers.wrapping(rf, result) do (sample, state), iresult
82+
iresult2 = Transducers.next(inner_rf, iresult, sample)
83+
84+
# Perform thinning if desired.
85+
for _ in 1:(thinning - 1)
86+
_, state = step(rng, model, sampler, state; kwargs...)
87+
end
88+
89+
return step(rng, model, sampler, state; kwargs...), iresult2
90+
end
5791
end
5892
end
5993

6094
function Transducers.complete(rf::Transducers.R_{Sample}, result)
61-
_private_state, inner_result = Transducers.unwrap(rf, result)
95+
_, inner_result = Transducers.unwrap(rf, result)
6296
return Transducers.complete(Transducers.inner(rf), inner_result)
6397
end

test/sample.jl

Lines changed: 75 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@
137137
@test chains isa Vector{<:MyChain}
138138
@test length(chains) == 1000
139139
@test all(x -> length(x.as) == length(x.bs) == N, chains)
140+
@test all(ismissing(x.as[1]) for x in chains)
140141

141142
# test some statistical properties
142143
@test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains)
@@ -147,9 +148,9 @@
147148
# test reproducibility
148149
Random.seed!(1234)
149150
chains2 = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; chain_type=MyChain)
150-
151-
@test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
152-
@test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
151+
@test all(ismissing(x.as[1]) for x in chains2)
152+
@test all(c1.as[i] == c2.as[i] for (c1, c2) in zip(chains, chains2), i in 2:N)
153+
@test all(c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
153154

154155
# Unexpected order of arguments.
155156
str = "Number of chains (10) is greater than number of samples per chain (5)"
@@ -245,7 +246,7 @@
245246

246247
# Test output type and size.
247248
@test chains isa Vector{<:MyChain}
248-
@test all(c.as[1] === missing for c in chains)
249+
@test all(ismissing(c.as[1]) for c in chains)
249250
@test length(chains) == 1000
250251
@test all(x -> length(x.as) == length(x.bs) == N, chains)
251252

@@ -260,9 +261,9 @@
260261
chains2 = sample(
261262
MyModel(), MySampler(), MCMCDistributed(), N, 1000; chain_type=MyChain
262263
)
263-
264-
@test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
265-
@test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
264+
@test all(ismissing(c.as[1]) for c in chains2)
265+
@test all(c1.as[i] == c2.as[i] for (c1, c2) in zip(chains, chains2), i in 2:N)
266+
@test all(c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
266267

267268
# Unexpected order of arguments.
268269
str = "Number of chains (10) is greater than number of samples per chain (5)"
@@ -330,7 +331,7 @@
330331

331332
# Test output type and size.
332333
@test chains isa Vector{<:MyChain}
333-
@test all(c.as[1] === missing for c in chains)
334+
@test all(ismissing(c.as[1]) for c in chains)
334335
@test length(chains) == 1000
335336
@test all(x -> length(x.as) == length(x.bs) == N, chains)
336337

@@ -343,9 +344,9 @@
343344
# Test reproducibility.
344345
Random.seed!(1234)
345346
chains2 = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain)
346-
347-
@test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
348-
@test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
347+
@test all(ismissing(c.as[1]) for c in chains2)
348+
@test all(c1.as[i] == c2.as[i] for (c1, c2) in zip(chains, chains2), i in 2:N)
349+
@test all(c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
349350

350351
# Unexpected order of arguments.
351352
str = "Number of chains (10) is greater than number of samples per chain (5)"
@@ -415,6 +416,7 @@
415416
progress=false,
416417
chain_type=MyChain,
417418
)
419+
@test all(ismissing(c.as[1]) for c in chains_serial)
418420

419421
# Multi-threaded sampling
420422
Random.seed!(1234)
@@ -427,12 +429,13 @@
427429
progress=false,
428430
chain_type=MyChain,
429431
)
432+
@test all(ismissing(c.as[1]) for c in chains_threads)
430433
@test all(
431-
c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_threads),
432-
i in 1:N
434+
c1.as[i] == c2.as[i] for (c1, c2) in zip(chains_serial, chains_threads),
435+
i in 2:N
433436
)
434437
@test all(
435-
c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_threads),
438+
c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains_serial, chains_threads),
436439
i in 1:N
437440
)
438441

@@ -447,12 +450,13 @@
447450
progress=false,
448451
chain_type=MyChain,
449452
)
453+
@test all(ismissing(c.as[1]) for c in chains_distributed)
450454
@test all(
451-
c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_distributed),
452-
i in 1:N
455+
c1.as[i] == c2.as[i] for (c1, c2) in zip(chains_serial, chains_distributed),
456+
i in 2:N
453457
)
454458
@test all(
455-
c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_distributed),
459+
c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains_serial, chains_distributed),
456460
i in 1:N
457461
)
458462
end
@@ -473,24 +477,41 @@
473477
end
474478

475479
@testset "Discard initial samples" begin
476-
chain = sample(MyModel(), MySampler(), 100; sleepy=true, discard_initial=50)
477-
@test length(chain) == 100
480+
# Create a chain and discard initial samples.
481+
Random.seed!(1234)
482+
N = 100
483+
discard_initial = 50
484+
chain = sample(MyModel(), MySampler(), N; discard_initial=discard_initial)
485+
@test length(chain) == N
478486
@test !ismissing(chain[1].a)
487+
488+
# Repeat sampling without discarding initial samples.
489+
# On Julia < 1.6 progress logging changes the global RNG and hence is enabled here.
490+
# https://github.com/TuringLang/AbstractMCMC.jl/pull/102#issuecomment-1142253258
491+
Random.seed!(1234)
492+
ref_chain = sample(
493+
MyModel(), MySampler(), N + discard_initial; progress=VERSION < v"1.6"
494+
)
495+
@test all(chain[i].a == ref_chain[i + discard_initial].a for i in 1:N)
496+
@test all(chain[i].b == ref_chain[i + discard_initial].b for i in 1:N)
479497
end
480498

481499
@testset "Thin chain by a factor of `thinning`" begin
482500
# Run a thinned chain with `N` samples thinned by factor of `thinning`.
483-
Random.seed!(1234)
501+
Random.seed!(100)
484502
N = 100
485503
thinning = 3
486-
chain = sample(MyModel(), MySampler(), N; sleepy=true, thinning=thinning)
504+
chain = sample(MyModel(), MySampler(), N; thinning=thinning)
487505
@test length(chain) == N
488506
@test ismissing(chain[1].a)
489507

490508
# Repeat sampling without thinning.
491-
Random.seed!(1234)
492-
ref_chain = sample(MyModel(), MySampler(), N * thinning; sleepy=true)
493-
@test all(chain[i].a === ref_chain[(i - 1) * thinning + 1].a for i in 1:N)
509+
# On Julia < 1.6 progress logging changes the global RNG and hence is enabled here.
510+
# https://github.com/TuringLang/AbstractMCMC.jl/pull/102#issuecomment-1142253258
511+
Random.seed!(100)
512+
ref_chain = sample(MyModel(), MySampler(), N * thinning; progress=VERSION < v"1.6")
513+
@test all(chain[i].a == ref_chain[(i - 1) * thinning + 1].a for i in 2:N)
514+
@test all(chain[i].b == ref_chain[(i - 1) * thinning + 1].b for i in 1:N)
494515
end
495516

496517
@testset "Sample without predetermined N" begin
@@ -501,16 +522,44 @@
501522
@test abs(bmean) <= 0.001 || length(chain) == 10_000
502523

503524
# Discard initial samples.
504-
chain = sample(MyModel(), MySampler(); discard_initial=50)
525+
Random.seed!(1234)
526+
discard_initial = 50
527+
chain = sample(MyModel(), MySampler(); discard_initial=discard_initial)
505528
bmean = mean(x.b for x in chain)
506529
@test !ismissing(chain[1].a)
507530
@test abs(bmean) <= 0.001 || length(chain) == 10_000
508531

532+
# On Julia < 1.6 progress logging changes the global RNG and hence is enabled here.
533+
# https://github.com/TuringLang/AbstractMCMC.jl/pull/102#issuecomment-1142253258
534+
Random.seed!(1234)
535+
N = length(chain)
536+
ref_chain = sample(
537+
MyModel(),
538+
MySampler(),
539+
N;
540+
discard_initial=discard_initial,
541+
progress=VERSION < v"1.6",
542+
)
543+
@test all(chain[i].a == ref_chain[i].a for i in 1:N)
544+
@test all(chain[i].b == ref_chain[i].b for i in 1:N)
545+
509546
# Thin chain by a factor of `thinning`.
510-
chain = sample(MyModel(), MySampler(); thinning=3)
547+
Random.seed!(1234)
548+
thinning = 3
549+
chain = sample(MyModel(), MySampler(); thinning=thinning)
511550
bmean = mean(x.b for x in chain)
512551
@test ismissing(chain[1].a)
513552
@test abs(bmean) <= 0.001 || length(chain) == 10_000
553+
554+
# On Julia < 1.6 progress logging changes the global RNG and hence is enabled here.
555+
# https://github.com/TuringLang/AbstractMCMC.jl/pull/102#issuecomment-1142253258
556+
Random.seed!(1234)
557+
N = length(chain)
558+
ref_chain = sample(
559+
MyModel(), MySampler(), N; thinning=thinning, progress=VERSION < v"1.6"
560+
)
561+
@test all(chain[i].a == ref_chain[i].a for i in 2:N)
562+
@test all(chain[i].b == ref_chain[i].b for i in 1:N)
514563
end
515564

516565
@testset "Sample vector of `NamedTuple`s" begin

0 commit comments

Comments
 (0)