@@ -11,7 +11,7 @@ struct _MeanPlot; c; val; end
1111struct _DensityPlot; c; val; end
1212struct _HistogramPlot; c; val; end
1313struct _AutocorPlot; lags; val; end
14- struct _ViolinPlot; parameters ; val; total_chains ; end
14+ struct _ViolinPlot; par ; val; end
1515
1616# define alias functions for old syntax
1717const translationdict = Dict (
@@ -33,7 +33,9 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor
3333 colordim = :chain ,
3434 barbounds = (- Inf , Inf ),
3535 maxlag = nothing ,
36- append_chains = false
36+ append_chains = false ,
37+ sections = chains. name_map[:parameters ],
38+ combined = true
3739)
3840 st = get (plotattributes, :seriestype , :traceplot )
3941 c = append_chains || st == :pooleddensity ? pool_chain (chains) : chains
@@ -72,6 +74,39 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor
7274 else
7375 range (c), val
7476 end
77+
78+ total_chains = i
79+ if st == :violinplot
80+ n_iter, n_par, n_chains = size (chains)
81+ if combined
82+ colordim := :chain
83+ par = string .(reshape (repeat (sections, inner = n_iter), n_iter, n_par))[:,i]
84+ val = Array (chains)[:,i]
85+ _ViolinPlot (par, val)
86+ elseif combined == false
87+ if colordim == :chain
88+ par_names = [" $(sections[i]) .Chain $j " for i in 1 : n_par, j in 1 : n_chains]
89+ pars = string .(reshape (repeat (vec (par_names), inner = n_iter), (n_iter, n_par, n_chains)))
90+ val = chains. value[:,i,:]
91+ par = pars[:,i,:]
92+ elseif colordim == :parameter
93+ par_vec = repeat (sections, inner = n_iter)
94+ pars = string .(reshape (repeat (par_vec, n_chains, 1 ), (n_iter, n_par, n_chains)))
95+ val = chains. value[:,:,i]
96+ par = pars[:,:,i]
97+ label --> string .(names (c))
98+ else
99+ throw (ArgumentError (" `colordim` must be one of `:chain` or `:parameter`" ))
100+ end
101+ _ViolinPlot (par, val)
102+ else
103+ throw (ArgumentError (" In `ViolinPlots` `Chains` can be combined or separated " ))
104+ end
105+ elseif st ∈ supportedplots
106+ translationdict[st](c, val)
107+ else
108+ range (c), val
109+ end
75110end
76111
77112@recipe function f (p:: _DensityPlot )
@@ -188,59 +223,17 @@ end
188223 RecipesBase. recipetype (:cornerplot , vcat (ar... ))
189224end
190225
191- @recipe function f (
192- chains:: Chains ;
193- sections:: Vector{Symbol} = chains. name_map[:parameters ],
194- combined = true
195- )
196-
197- st = get (plotattributes, :seriestype , :traceplot )
198- total_chains = 0
199- if st == :violinplot
200- if combined
201- n_iter, n_parameters = size (Array (chains))
202- parameters = string .(repeat (sections, inner = n_iter))
203- val = vec (Array (chains))
204- total_chains = Integer (size (chains. value. data)[3 ])
205- _ViolinPlot (parameters, val, total_chains)
206- elseif combined == false
207- n_parameters = length (sections)
208- chain_arr = Array (chains, append_chains = false )
209- val_vec = [chain_arr[j][:,i]
210- for i in 1 : n_parameters
211- for j in 1 : length (chain_arr)]
212- n_iter = length (val_vec[1 ])
213- total_chains = length (val_vec)
214- val = zeros (Float64, n_iter, total_chains)
215- for i in 1 : total_chains
216- val[:,i] = val_vec[:][i]
217- end
218- val = vec (val)
219- parameters_names = [" param $(sections[i]) .Chain $j "
220- for i in 1 : n_parameters
221- for j in 1 : length (chain_arr)]
222- parameters = string .(repeat (parameters_names, inner = n_iter))
223- _ViolinPlot (parameters, val, total_chains)
224- else
225- error (" Symbol names are interpreted as parameter names, only compatible with " ,
226- " `colordim = :chain`" )
227- end
228- end
229- end
230-
231226@recipe function f (p:: _ViolinPlot )
232227 @series begin
233228 seriestype := :violin
234- xaxis --> " Parameter"
235- size --> (200 * p. total_chains, 500 )
236- p. parameters, p. val
229+ p. par, p. val
237230 end
238231
239232 @series begin
240233 seriestype := :boxplot
241234 bar_width --> 0.1
242235 linewidth --> 2
243236 fillalpha --> 0.8
244- p. parameters , p. val
237+ p. par , p. val
245238 end
246239end
0 commit comments