Skip to content

Commit 3d14727

Browse files
Merge branch 'pm/violin_plot' of https://github.com/PaulinaMartin96/MCMCChains.jl into pm/violin_plot
2 parents 6669d0a + cf89757 commit 3d14727

12 files changed

+102
-48
lines changed

docs/Project.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@ CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
33
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
44
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
55
Gadfly = "c91e804a-d5a3-530f-b6f0-dfbca275c004"
6-
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
6+
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
77
MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91"
88
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
99

1010
[compat]
1111
CategoricalArrays = "0.8, 0.9, 0.10"
1212
DataFrames = "0.22, 1"
13-
Documenter = "0.26"
13+
Documenter = "0.26, 0.27"
1414
Gadfly = "1.3"
15-
MLJModels = "0.14"
15+
MLJBase = "0.18"
1616
MLJXGBoostInterface = "0.1"
1717
StatsPlots = "0.14"
18+
julia = "1.3"

src/chains.jl

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,22 @@ function Chains(
3131
name_map = (parameters = parameter_names,);
3232
start::Int = 1,
3333
thin::Int = 1,
34+
iterations::AbstractVector{Int} = range(start; step=thin, length=size(val, 1)),
3435
evidence = missing,
3536
info::NamedTuple = NamedTuple()
3637
)
38+
# Check that iteration numbers are reasonable
39+
if length(iterations) != size(val, 1)
40+
error("length of `iterations` (", length(iterations),
41+
") is not equal to the number of iterations (", size(val, 1), ")")
42+
end
43+
if !isempty(iterations) && first(iterations) < 1
44+
error("iteration numbers must be positive integers")
45+
end
46+
if !isstrictlyincreasing(iterations)
47+
error("iteration numbers must be strictly increasing")
48+
end
49+
3750
# Make sure that we have a `:parameters` index and # Copying can avoid state mutation.
3851
_name_map = initnamemap(name_map)
3952

@@ -58,7 +71,7 @@ function Chains(
5871

5972
# Construct the AxisArray.
6073
arr = AxisArray(val;
61-
iter = range(start, step=thin, length=size(val, 1)),
74+
iter = iterations,
6275
var = parameter_names,
6376
chain = 1:size(val, 3))
6477

@@ -444,17 +457,21 @@ Return the range of iteration indices of the `chains`.
444457
Base.range(chains::Chains) = chains.value[Axis{:iter}].val
445458

446459
"""
447-
setrange(chains::Chains, range)
460+
setrange(chains::Chains, range::AbstractVector{Int})
448461
449462
Generate a new chain from `chains` with iterations indexed by `range`.
450463
451464
The new chain and `chains` share the same data in memory.
452465
"""
453-
function setrange(chains::Chains, range::AbstractRange{<:Integer})
466+
function setrange(chains::Chains, range::AbstractVector{Int})
454467
if length(chains) != length(range)
455468
error("length of `range` (", length(range),
456469
") is not equal to the number of iterations (", length(chains), ")")
457470
end
471+
if !isempty(range) && first(range) < 1
472+
error("iteration numbers must be positive integers")
473+
end
474+
isstrictlyincreasing(range) || error("iteration numbers must be strictly increasing")
458475

459476
value = AxisArray(chains.value.data;
460477
iter = range, var = names(chains), chain = MCMCChains.chains(chains))
@@ -574,8 +591,7 @@ function header(c::Chains; section=missing)
574591
# Return header.
575592
return string(
576593
ismissing(c.logevidence) ? "" : "Log evidence = $(c.logevidence)\n",
577-
"Iterations = $(first(c)):$(last(c))\n",
578-
"Thinning interval = $(step(c))\n",
594+
"Iterations = $(range(c))\n",
579595
"Number of chains = $(size(c, 3))\n",
580596
"Samples per chain = $(length(range(c)))\n",
581597
ismissing(wall) ? "" : "Wall duration = $(round(wall, digits=2)) seconds\n",
@@ -725,8 +741,11 @@ _cat(dim::Int, cs::Chains...) = _cat(Val(dim), cs...)
725741

726742
function _cat(::Val{1}, c1::Chains, args::Chains...)
727743
# check inputs
728-
thin = step(c1)
729-
all(c -> step(c) == thin, args) || throw(ArgumentError("chain thinning differs"))
744+
lastiter = last(c1)
745+
for c in args
746+
first(c) > lastiter || throw(ArgumentError("iterations have to be sorted"))
747+
lastiter = last(c)
748+
end
730749
nms = names(c1)
731750
all(c -> names(c) == nms, args) || throw(ArgumentError("chain names differ"))
732751
chns = chains(c1)
@@ -735,7 +754,7 @@ function _cat(::Val{1}, c1::Chains, args::Chains...)
735754
# concatenate all chains
736755
data = mapreduce(c -> c.value.data, vcat, args; init = c1.value.data)
737756
value = AxisArray(data;
738-
iter = range(first(c1); length = size(data, 1), step = thin),
757+
iter = mapreduce(range, vcat, args; init=range(c1)),
739758
var = nms,
740759
chain = chns)
741760

src/constructors.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ function Base.Array(
6363
end
6464

6565
function to_matrix(chain::Chains)
66-
return Matrix(reshape(permutedims(chain.value.data, (1, 3, 2)), :, size(chain, 2)))
66+
x = permutedims(chain.value.data, (1, 3, 2))
67+
return Matrix(reshape(x, size(x, 1) * size(x, 2), size(x, 3)))
6768
end
6869

6970
function to_vector(chain::Chains)
@@ -79,4 +80,3 @@ function to_vector_of_matrices(chain::Chains)
7980
data = chain.value.data
8081
return [Matrix(data[:, :, i]) for i in axes(data, 3)]
8182
end
82-

src/fileio.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,5 @@ function readcoda(output::AbstractString, index::AbstractString)
2222
value[:, i] = out[inds, 2]
2323
end
2424

25-
Chains(value, start=first(window), thin=step(window), names=names)
25+
Chains(value; iterations=window, names=names)
2626
end

src/rstar.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,11 @@ verbosity level.
1818
1919
# Example
2020
```jldoctest rstar; output = false, filter = r".*"s
21-
using MLJModels
21+
using MLJBase, MLJXGBoostInterface
2222
23-
XGBoost = @load XGBoostClassifier verbosity=0
2423
chn = Chains(fill(4, 100, 2, 3))
2524
26-
Rs = rstar(XGBoost(), chn; iterations=20)
25+
Rs = rstar(XGBoostClassifier(), chn; iterations=20)
2726
R = round(mean(Rs); digits=0)
2827
2928
# output

src/utils.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ function merge_union(a::NamedTuple{an}, b::NamedTuple{bn}) where {an,bn}
9696
:(getfield(b, $(QuoteNode(n))))
9797
end
9898
end
99-
99+
100100
return :(NamedTuple{$names,$types}(($(values...),)))
101101
else
102102
names = Base.merge_names(an, bn)
@@ -113,7 +113,7 @@ function merge_union(a::NamedTuple{an}, b::NamedTuple{bn}) where {an,bn}
113113
getfield(b, n)
114114
end
115115
end
116-
116+
117117
return NamedTuple{names,types}(values)
118118
end
119119
end
@@ -179,8 +179,8 @@ function concretize(x::AbstractArray)
179179
return x
180180
else
181181
xnew = map(concretize, x)
182-
T = mapreduce(typeof, promote_type, xnew)
183-
if T <: eltype(xnew)
182+
T = mapreduce(typeof, promote_type, xnew; init=Union{})
183+
if T <: eltype(xnew) && T !== Union{}
184184
return convert(AbstractArray{T}, xnew)
185185
else
186186
return xnew
@@ -196,3 +196,17 @@ function concretize(x::Chains)
196196
return Chains(concretize(value), x.logevidence, x.name_map, x.info)
197197
end
198198
end
199+
200+
function isstrictlyincreasing(x::AbstractVector{Int})
201+
return isempty(x) || _isstrictlyincreasing_nonempty(x)
202+
end
203+
204+
_isstrictlyincreasing_nonempty(x::AbstractRange{Int}) = step(x) > 0
205+
function _isstrictlyincreasing_nonempty(x::AbstractVector{Int})
206+
i = first(x)
207+
for j in Iterators.drop(x, 1)
208+
j > i || return false
209+
i = j
210+
end
211+
return true
212+
end

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
2424
AbstractMCMC = "2.2.1, 3.0"
2525
DataFrames = "0.22.4, 1.0"
2626
Distributions = "0.24.12, 0.25"
27-
Documenter = "0.26"
27+
Documenter = "0.26, 0.27"
2828
FFTW = "1.1"
2929
IteratorInterfaceExtensions = "1"
3030
KernelDensity = "0.6.2"

test/arrayconstructor_tests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ using MCMCChains, Test
6767
Array(chns[:a])
6868
Array(chns, [:parameters])
6969
Array(chns, [:parameters, :internals])
70+
71+
# empty chain: #317
72+
empty_chain = chns[Symbol[]]
73+
@test isempty(MCMCChains.to_matrix(empty_chain))
74+
@test isempty(Array(empty_chain))
7075
end
7176
@testset "Accuracy" begin
7277
nchains = 5

test/concatenation_tests.jl

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,58 +46,60 @@ end
4646
chn = Chains(rand(10, 5, 2), ["a", "b", "c", "d", "e"], Dict(:internal => ["d", "e"]))
4747
chn1 = Chains(rand(5, 5, 2), ["a", "b", "c", "d", "e"], Dict(:internal => ["a", "b"]))
4848

49-
# incorrect thinning
50-
@test_throws ArgumentError vcat(chn, Chains(rand(2, 5, 2); thin = 2))
49+
# incorrect iterations
50+
@test_throws ArgumentError vcat(chn, Chains(rand(2, 5, 2)))
5151

5252
# incorrect names
53-
@test_throws ArgumentError vcat(chn, Chains(rand(10, 5, 2), ["a", "b", "c", "d", "f"]))
53+
@test_throws ArgumentError vcat(chn, Chains(rand(10, 5, 2), ["a", "b", "c", "d", "f"]; start=11))
5454

5555
# incorrect number of chains
56-
@test_throws ArgumentError vcat(chn, Chains(rand(10, 5, 3), ["a", "b", "c", "d", "e"]))
56+
@test_throws ArgumentError vcat(chn, Chains(rand(10, 5, 3), ["a", "b", "c", "d", "e"]; start=11))
5757

5858
# concate the same chain
59-
chn2 = vcat(chn, chn)
59+
chn_shifted = setrange(chn, 11:20)
60+
chn2 = vcat(chn, chn_shifted)
6061
@test chn2.value.data == vcat(chn.value.data, chn.value.data)
6162
@test size(chn2) == (20, 5, 2)
6263
@test names(chn2) == names(chn)
6364
@test range(chn2) == 1:20
6465
@test chn2.name_map == (parameters = [:a, :b, :c], internal = [:d, :e])
65-
66-
chn2a = cat(chn, chn)
66+
67+
chn2a = cat(chn, chn_shifted)
6768
@test chn2a.value == chn2.value
6869
@test chn2a.name_map == chn2.name_map
6970
@test chn2a.info == chn2.info
7071

71-
chn2b = cat(chn, chn; dims = Val(1))
72+
chn2b = cat(chn, chn_shifted; dims = Val(1))
7273
@test chn2b.value == chn2.value
7374
@test chn2b.name_map == chn2.name_map
7475
@test chn2b.info == chn2.info
7576

76-
chn2c = cat(chn, chn; dims = 1)
77+
chn2c = cat(chn, chn_shifted; dims = 1)
7778
@test chn2c.value == chn2.value
7879
@test chn2c.name_map == chn2.name_map
7980
@test chn2c.info == chn2.info
8081

8182
# concatenate a different chain
82-
chn3 = vcat(chn, chn1)
83+
chn1_shifted = setrange(chn1, 11:15)
84+
chn3 = vcat(chn, chn1_shifted)
8385
@test chn3.value.data == vcat(chn.value.data, chn1.value.data)
8486
@test size(chn3) == (15, 5, 2)
8587
@test names(chn3) == names(chn)
8688
@test range(chn3) == 1:15
8789
# just take the name map of first argument
8890
@test chn3.name_map == (parameters = [:a, :b, :c], internal = [:d, :e])
89-
90-
chn3a = cat(chn, chn1)
91+
92+
chn3a = cat(chn, chn1_shifted)
9193
@test chn3a.value == chn3.value
9294
@test chn3a.name_map == chn3.name_map
9395
@test chn3a.info == chn3.info
9496

95-
chn3b = cat(chn, chn1; dims = Val(1))
97+
chn3b = cat(chn, chn1_shifted; dims = Val(1))
9698
@test chn3b.value == chn3.value
9799
@test chn3b.name_map == chn3.name_map
98100
@test chn3b.info == chn3.info
99101

100-
chn3c = cat(chn, chn1; dims = 1)
102+
chn3c = cat(chn, chn1_shifted; dims = 1)
101103
@test chn3c.value == chn3.value
102104
@test chn3c.name_map == chn3.name_map
103105
@test chn3c.info == chn3.info

test/diagnostic_tests.jl

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ val = hcat(val, rand(1:2, niter, 1, nchains))
1515

1616
# construct a Chains object
1717
chn = Chains(val, start = 1, thin = 2)
18+
@test_throws ErrorException Chains(val; start=0, thin=2)
19+
@test_throws ErrorException Chains(val; start=niter, thin=-1)
20+
@test_throws ErrorException Chains(val; iterations=1:(niter - 1))
21+
@test_throws ErrorException Chains(val; iterations=range(0; step=2, length=niter))
22+
@test_throws ErrorException Chains(val; iterations=niter:-1:1)
23+
@test_throws ErrorException Chains(val; iterations=ones(Int, niter))
1824

1925
# Chains object for discretediag
2026
val_disc = rand(Int16, 200, nparams, nchains)
@@ -29,18 +35,26 @@ chn_disc = Chains(val_disc, start = 1, thin = 2)
2935
@test keys(chn) == names(chn) == [:param_1, :param_2, :param_3, :param_4]
3036

3137
@test range(chn) == range(1; step = 2, length = niter)
38+
@test range(chn) == range(Chains(val; iterations=range(chn)))
39+
@test range(chn) == range(Chains(val; iterations=collect(range(chn))))
3240

3341
@test_throws ErrorException setrange(chn, 1:10)
42+
@test_throws ErrorException setrange(chn, 0:(niter - 1))
43+
@test_throws ErrorException setrange(chn, niter:-1:1)
44+
@test_throws ErrorException setrange(chn, ones(Int, niter))
3445
@test_throws MethodError setrange(chn, float.(range(chn)))
3546

36-
chn2 = setrange(chn, range(1; step = 10, length = niter))
37-
@test range(chn2) == range(1; step = 10, length = niter)
38-
@test names(chn2) === names(chn)
39-
@test chains(chn2) === chains(chn)
40-
@test chn2.value.data === chn.value.data
41-
@test chn2.logevidence === chn.logevidence
42-
@test chn2.name_map === chn.name_map
43-
@test chn2.info == chn.info
47+
chn2a = setrange(chn, range(1; step = 10, length = niter))
48+
chn2b = setrange(chn, collect(range(1; step = 10, length = niter)))
49+
for chn2 in (chn2a, chn2b)
50+
@test range(chn2) == range(1; step = 10, length = niter)
51+
@test names(chn2) === names(chn)
52+
@test chains(chn2) === chains(chn)
53+
@test chn2.value.data === chn.value.data
54+
@test chn2.logevidence === chn.logevidence
55+
@test chn2.name_map === chn.name_map
56+
@test chn2.info == chn.info
57+
end
4458

4559
chn3 = resetrange(chn)
4660
@test range(chn3) == 1:niter

test/rstar_tests.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using MCMCChains
2-
using MLJModels
2+
using MLJBase
3+
using MLJXGBoostInterface
34
using Test
45

56
N = 1000
@@ -8,8 +9,7 @@ colnames = ["a", "b", "c", "d", "e", "f", "g", "h"]
89
internal_colnames = ["c", "d", "e", "f", "g", "h"]
910
chn = Chains(val, colnames, Dict(:internals => internal_colnames))
1011

11-
XGBoost = @load XGBoostClassifier
12-
classif = XGBoost()
12+
classif = XGBoostClassifier()
1313

1414
@testset "R star test" begin
1515
# Compute R* statistic for a mixed chain.

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Random.seed!(0)
2222
if VERSION >= v"1.3" && Sys.WORD_SIZE == 64
2323
# run tests related to rstar statistic
2424
println("Rstar")
25-
Pkg.add("MLJModels")
25+
Pkg.add("MLJBase")
2626
Pkg.add("MLJXGBoostInterface")
2727
@time include("rstar_tests.jl")
2828

0 commit comments

Comments
 (0)