-
Notifications
You must be signed in to change notification settings - Fork 30
New Plot: Violin Plot Implementation #316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
I think you probably ended up making a new branch from your other PR's branch in #307, rather than from the
|
d1cadce
to
4674ea5
Compare
0bdcdfd
to
e4c9765
Compare
FYI @PaulinaMartin96 the tests are failing on this one. You can see why the test failures are happening by clicking "Details" in the testing box towards the bottom of the thread on GH: |
3d14727
to
1bb0f93
Compare
I'm adding Paulina's ViolinPlot implementation here as a backup, since it will be lost entirely after resolving the merge conflicts with my fresh reimplementation—it won't be retained in the commit history. Paulina's Violin Plot Implementation (src/plot.jl)
@shorthands meanplot
@shorthands autocorplot
@shorthands mixeddensity
@shorthands pooleddensity
@shorthands traceplot
@shorthands corner
@shorthands violinplot
struct _TracePlot; c; val; end
struct _MeanPlot; c; val; end
struct _DensityPlot; c; val; end
struct _HistogramPlot; c; val; end
struct _AutocorPlot; lags; val; end
struct _ViolinPlot; par; val; end
# define alias functions for old syntax
const translationdict = Dict(
:traceplot => _TracePlot,
:meanplot => _MeanPlot,
:density => _DensityPlot,
:histogram => _HistogramPlot,
:autocorplot => _AutocorPlot,
:pooleddensity => _DensityPlot,
:violinplot => _ViolinPlot
)
const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :corner)
@recipe f(c::Chains, s::Symbol) = c, [s]
@recipe function f(
chains::Chains, i::Int;
colordim = :chain,
barbounds = (-Inf, Inf),
maxlag = nothing,
append_chains = false,
par_sections = chains.name_map[:parameters],
combined = true
)
st = get(plotattributes, :seriestype, :traceplot)
c = append_chains || st == :pooleddensity ? pool_chain(chains) : chains
if colordim == :parameter
title --> "Chain $(MCMCChains.chains(c)[i])"
label --> string.(names(c))
val = c.value[:, :, i]
elseif colordim == :chain
title --> string(names(c)[i])
label --> map(x -> "Chain $x", MCMCChains.chains(c))
val = c.value[:, i, :]
else
throw(ArgumentError("`colordim` must be one of `:chain` or `:parameter`"))
end
if st == :mixeddensity || st == :pooleddensity
discrete = indiscretesupport(c, barbounds)
st = if colordim == :chain
discrete[i] ? :histogram : :density
else
# NOTE: It might make sense to overlay histograms and density plots here.
:density
end
seriestype := st
end
if st == :autocorplot
lags = 0:(maxlag === nothing ? round(Int, 10 * log10(length(range(c)))) : maxlag)
ac = autocor(c; sections = nothing, lags = lags)
ac_mat = convert(Array, ac)
val = colordim == :parameter ? ac_mat[:, :, i]' : ac_mat[i, :, :]
_AutocorPlot(lags, val)
elseif st == :violinplot
n_iter, n_par, n_chains = size(chains)
if combined
colordim := :chain
par = string.(reshape(repeat(par_sections, inner = n_iter), n_iter, n_par))[:,i]
val = Array(chains)[:,i]
_ViolinPlot(par, val)
elseif combined == false
if colordim == :chain
par_names = ["$(par_sections[i]).Chain $j" for i in 1:n_par, j in 1:n_chains]
pars = string.(reshape(repeat(vec(par_names), inner = n_iter), (n_iter, n_par, n_chains)))
val = chains.value[:,i,:]
par = pars[:,i,:]
elseif colordim == :parameter
par_vec = repeat(par_sections, inner = n_iter)
pars = string.(reshape(repeat(par_vec, n_chains, 1), (n_iter, n_par, n_chains)))
val = chains.value[:,:,i]
par = pars[:,:,i]
label --> string.(names(c))
else
throw(ArgumentError("`colordim` must be one of `:chain` or `:parameter`"))
end
_ViolinPlot(par, val)
else
throw(ArgumentError("In `ViolinPlots` `Chains` can be combined or separated "))
end
elseif st ∈ supportedplots
translationdict[st](c, val)
else
range(c), val
end
end
@recipe function f(p::_DensityPlot)
xaxis --> "Sample value"
yaxis --> "Density"
trim --> true
[collect(skipmissing(p.val[:,k])) for k in 1:size(p.val, 2)]
end
@recipe function f(p::_HistogramPlot)
xaxis --> "Sample value"
yaxis --> "Frequency"
fillalpha --> 0.7
bins --> 25
trim --> true
[collect(skipmissing(p.val[:,k])) for k in 1:size(p.val, 2)]
end
@recipe function f(p::_MeanPlot)
seriestype := :path
xaxis --> "Iteration"
yaxis --> "Mean"
range(p.c), cummean(p.val)
end
@recipe function f(p::_AutocorPlot)
seriestype := :path
xaxis --> "Lag"
yaxis --> "Autocorrelation"
p.lags, p.val
end
@recipe function f(p::_TracePlot)
seriestype := :path
xaxis --> "Iteration"
yaxis --> "Sample value"
range(p.c), p.val
end
@recipe function f(
chains::Chains,
parameters::AbstractVector{Symbol};
colordim = :chain
)
colordim != :chain &&
error("Symbol names are interpreted as parameter names, only compatible with ",
"`colordim = :chain`")
ret = indexin(parameters, names(chains))
any(y === nothing for y in ret) && error("Parameter not found")
return chains, Int.(ret)
end
@recipe function f(
chains::Chains,
parameters::AbstractVector{<:Integer} = Int[];
sections = _default_sections(chains),
width = 500,
height = 250,
colordim = :chain,
append_chains = false
)
_chains = isempty(parameters) ? Chains(chains, _clean_sections(chains, sections)) : chains
c = append_chains ? pool_chain(_chains) : _chains
ptypes = get(plotattributes, :seriestype, (:traceplot, :mixeddensity))
ptypes = ptypes isa Symbol ? (ptypes,) : ptypes
@assert all(ptype -> ptype ∈ supportedplots, ptypes)
ntypes = length(ptypes)
nrows, nvars, nchains = size(c)
isempty(parameters) && (parameters = colordim == :chain ? (1:nvars) : (1:nchains))
N = length(parameters)
if :corner ∉ ptypes
size --> (ntypes*width, N*height)
legend --> false
multiple_plots = N * ntypes > 1
if multiple_plots
layout := (N, ntypes)
end
i = 0
for par in parameters
for ptype in ptypes
i += 1
@series begin
if multiple_plots
subplot := i
end
colordim := colordim
seriestype := ptype
c, par
end
end
end
else
ntypes > 1 && error(":corner is not compatible with multiple seriestypes")
Corner(c, names(c)[parameters])
end
end
struct Corner
c
parameters
end
@recipe function f(corner::Corner)
label --> permutedims(corner.parameters)
compact --> true
size --> (600, 600)
ar = collect(Array(corner.c.value[:, corner.parameters,i]) for i in chains(corner.c))
RecipesBase.recipetype(:cornerplot, vcat(ar...))
end
@recipe function f(p::_ViolinPlot)
@series begin
seriestype := :violin
p.par, p.val
end
@series begin
seriestype := :boxplot
bar_width --> 0.1
linewidth --> 2
fillalpha --> 0.8
p.par, p.val
end
end |
1bb0f93
to
3a2005f
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #316 +/- ##
==========================================
+ Coverage 85.55% 85.81% +0.26%
==========================================
Files 20 20
Lines 1073 1107 +34
==========================================
+ Hits 918 950 +32
- Misses 155 157 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
I have implemented Violin plot and here are its results: Chain object construction
using MCMCChains
using StatsPlots
# Define the experiment
n_iter = 100
n_name = 3
n_chain = 2
# experiment results
val = randn(n_iter, n_name, n_chain) .+ [1, 2, 3]'
val = hcat(val, rand(1:2, n_iter, 1, n_chain))
# construct a Chains object
chn = Chains(val, [:A, :B, :C, :D]) violinplot(chn) # Plotting parameter 1 across all chains
|
Can someone please review this PR, it's ready! cc @yebai |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't run the code but based on the plots in the last comment, the legends on the violin plots aren't correct.
It seems to me that legends aren't really useful for these plots -- maybe they should just be removed?
I have removed legends from this! Thanks for the suggestion! |
A recipe for violin plots was added. Following ArviZ's convetion for violin plots,
kwarg
combined
was added. Iftrue
, chains are appended and only one plot per parameter is returned. In this case,colordim := :chain
. Otherwise, a violin plot is returned for every chain and every parameter.colordim
can be set to:chain
or:parameter
as shown below. Discrete parameters are plotted as defined inStatsPlots.jl
.