Skip to content

Commit 9f8c6fa

Browse files
authored
Merge pull request #13 from TuringLang/csp/iterator
Add iterator interface
2 parents 3bb7767 + 4ad538d commit 9f8c6fa

File tree

3 files changed

+80
-2
lines changed

3 files changed

+80
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
keywords = ["markov chain monte carlo", "probablistic programming"]
44
license = "MIT"
55
desc = "A lightweight interface for common MCMC methods."
6-
version = "0.3.0"
6+
version = "0.4.0"
77

88
[deps]
99
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"

src/AbstractMCMC.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,4 +406,56 @@ function psample(
406406
return reduce(chainscat, chains)
407407
end
408408

409+
410+
##################
411+
# Iterator tools #
412+
##################
413+
struct Stepper{A<:AbstractRNG, ModelType<:AbstractModel, SamplerType<:AbstractSampler, K}
414+
rng::A
415+
model::ModelType
416+
s::SamplerType
417+
kwargs::K
418+
end
419+
420+
function Base.iterate(stp::Stepper, state=nothing)
421+
t = step!(stp.rng, stp.model, stp.s, 1, state; stp.kwargs...)
422+
return t, t
423+
end
424+
425+
Base.IteratorSize(::Type{<:Stepper}) = Base.IsInfinite()
426+
Base.IteratorEltype(::Type{<:Stepper}) = Base.EltypeUnknown()
427+
428+
"""
429+
steps!([rng::AbstractRNG, ]model::AbstractModel, s::AbstractSampler, kwargs...)
430+
431+
`steps!` returns an iterator that returns samples continuously, after calling `sample_init!`.
432+
433+
Usage:
434+
435+
```julia
436+
for transition in steps!(MyModel(), MySampler())
437+
println(transition)
438+
439+
# Do other stuff with transition below.
440+
end
441+
```
442+
"""
443+
function steps!(
444+
model::AbstractModel,
445+
s::AbstractSampler,
446+
kwargs...
447+
)
448+
return steps!(GLOBAL_RNG, model, s; kwargs...)
449+
end
450+
451+
function steps!(
452+
rng::AbstractRNG,
453+
model::AbstractModel,
454+
s::AbstractSampler,
455+
kwargs...
456+
)
457+
sample_init!(rng, model, s, 0)
458+
return Stepper(rng, model, s, kwargs)
459+
end
460+
409461
end # module AbstractMCMC

test/runtests.jl

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using AbstractMCMC
2-
using AbstractMCMC: sample, psample
2+
using AbstractMCMC: sample, psample, steps!
33

44
using Random
55
using Statistics
@@ -57,4 +57,30 @@ include("interface.jl")
5757
@test chain1 isa Vector{MyTransition}
5858
@test chain2 isa MyChain
5959
end
60+
61+
@testset "Iterator sampling" begin
62+
Random.seed!(1234)
63+
as = []
64+
bs = []
65+
66+
iter = steps!(MyModel(), MySampler())
67+
68+
for (count, t) in enumerate(iter)
69+
if count >= 1000
70+
break
71+
end
72+
73+
push!(as, t.a)
74+
push!(bs, t.b)
75+
end
76+
77+
@test mean(as) 0.5 atol=1e-2
78+
@test var(as) 1 / 12 atol=5e-3
79+
@test mean(bs) 0.0 atol=5e-2
80+
@test var(bs) 1 atol=5e-2
81+
82+
println(eltype(iter))
83+
@test Base.IteratorSize(iter) == Base.IsInfinite()
84+
@test Base.IteratorEltype(iter) == Base.EltypeUnknown()
85+
end
6086
end

0 commit comments

Comments
 (0)