Skip to content

New Plot: Violin Plot Implementation #316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

PaulinaMartin96
Copy link
Contributor

@PaulinaMartin96 PaulinaMartin96 commented Jul 6, 2021

A recipe for violin plots was added. Following ArviZ's convetion for violin plots, kwarg combined was added. If true, chains are appended and only one plot per parameter is returned. In this case, colordim := :chain. Otherwise, a violin plot is returned for every chain and every parameter. colordim can be set to :chain or :parameter as shown below. Discrete parameters are plotted as defined in StatsPlots.jl.

using MCMCChains
using StatsPlots

n_iter = 100
n_name = 3
n_chain = 2

val = randn(n_iter, n_name, n_chain) .+ [1, 2, 3]'
val = hcat(val, rand(1:2, n_iter, 1, n_chain))

chn = Chains(val, [:A, :B, :C, :D])
violinplot(chn, size = (200, 1000)) #violinplot(chn, combined = true, colordim = :chain)

image

violinplot(chn, combined =  false, colordim = :chain)

image

violinplot(chn, combined = false, colordim = :parameter, size=(800, 800))

image

@cpfiffer
Copy link
Member

cpfiffer commented Jul 6, 2021

I think you probably ended up making a new branch from your other PR's branch in #307, rather than from the master branch. Each PR should be independent from one another unless they are explicitly planned to be merged in sequence or something, which we very rarely do. It might be worth investigating some time on figuring out how to make it so your two commits 4fc9d62 and d1cadce are added to master rather than on top of the commits in #307.

  • Make sure to increment the version number for this one -- I think if we end up merging Make Chains objects display only information and not statistical eval #307 first, this PR should be associated with 4.15.0.
  • Please add a demo plot and a minimum working example so the reviewer knows what to expect. It's also generally considered a good idea to do a brief writeup in general, rather than just use the commit title you currently have.

@PaulinaMartin96 PaulinaMartin96 changed the title Pm/violin plot Violin plot Jul 13, 2021
@PaulinaMartin96 PaulinaMartin96 marked this pull request as ready for review July 13, 2021 03:21
@cpfiffer
Copy link
Member

FYI @PaulinaMartin96 the tests are failing on this one. You can see why the test failures are happening by clicking "Details" in the testing box towards the bottom of the thread on GH:
image

@shravanngoswamii
Copy link
Member

I'm adding Paulina's ViolinPlot implementation here as a backup, since it will be lost entirely after resolving the merge conflicts with my fresh reimplementation—it won't be retained in the commit history.

Paulina's Violin Plot Implementation (src/plot.jl)

@shorthands meanplot
@shorthands autocorplot
@shorthands mixeddensity
@shorthands pooleddensity
@shorthands traceplot
@shorthands corner
@shorthands violinplot

struct _TracePlot; c; val; end
struct _MeanPlot; c; val;  end
struct _DensityPlot; c; val;  end
struct _HistogramPlot; c; val;  end
struct _AutocorPlot; lags; val;  end
struct _ViolinPlot; par; val; end

# define alias functions for old syntax
const translationdict = Dict(
                        :traceplot => _TracePlot,
                        :meanplot => _MeanPlot,
                        :density => _DensityPlot,
                        :histogram => _HistogramPlot,
                        :autocorplot => _AutocorPlot,
                        :pooleddensity => _DensityPlot,
                        :violinplot => _ViolinPlot
                      )

const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :corner)

@recipe f(c::Chains, s::Symbol) = c, [s]

@recipe function f(
    chains::Chains, i::Int;
    colordim = :chain,
    barbounds = (-Inf, Inf),
    maxlag = nothing,
    append_chains = false,
    par_sections = chains.name_map[:parameters],
    combined = true
)
    st = get(plotattributes, :seriestype, :traceplot)
    c = append_chains || st == :pooleddensity ? pool_chain(chains) : chains

    if colordim == :parameter
        title --> "Chain $(MCMCChains.chains(c)[i])"
        label --> string.(names(c))
        val = c.value[:, :, i]
    elseif colordim == :chain
        title --> string(names(c)[i])
        label --> map(x -> "Chain $x", MCMCChains.chains(c))
        val = c.value[:, i, :]
    else
        throw(ArgumentError("`colordim` must be one of `:chain` or `:parameter`"))
    end

    if st == :mixeddensity || st == :pooleddensity
        discrete = indiscretesupport(c, barbounds)
        st = if colordim == :chain
            discrete[i] ? :histogram : :density
        else
            # NOTE: It might make sense to overlay histograms and density plots here.
            :density
        end
        seriestype := st
    end

    if st == :autocorplot
        lags = 0:(maxlag === nothing ? round(Int, 10 * log10(length(range(c)))) : maxlag)
        ac = autocor(c; sections = nothing, lags = lags)
        ac_mat = convert(Array, ac)
        val = colordim == :parameter ? ac_mat[:, :, i]' : ac_mat[i, :, :]
        _AutocorPlot(lags, val)

    elseif st == :violinplot
        n_iter, n_par, n_chains = size(chains)
        if combined
            colordim := :chain
            par = string.(reshape(repeat(par_sections, inner = n_iter), n_iter, n_par))[:,i]
            val = Array(chains)[:,i]
            _ViolinPlot(par, val)
        elseif combined == false
            if colordim == :chain
                par_names = ["$(par_sections[i]).Chain $j" for i in 1:n_par, j in 1:n_chains]
                pars = string.(reshape(repeat(vec(par_names), inner = n_iter), (n_iter, n_par, n_chains)))
                val = chains.value[:,i,:]
                par = pars[:,i,:]
            elseif colordim == :parameter
                par_vec = repeat(par_sections, inner = n_iter)
                pars = string.(reshape(repeat(par_vec, n_chains, 1), (n_iter, n_par, n_chains)))
                val = chains.value[:,:,i]
                par = pars[:,:,i]
                label --> string.(names(c))
            else
                throw(ArgumentError("`colordim` must be one of `:chain` or `:parameter`"))
            end
            _ViolinPlot(par, val)
        else
            throw(ArgumentError("In `ViolinPlots` `Chains` can be combined or separated "))
        end
    elseif st  supportedplots
        translationdict[st](c, val)
    else
        range(c), val
    end
end

@recipe function f(p::_DensityPlot)
    xaxis --> "Sample value"
    yaxis --> "Density"
    trim --> true
    [collect(skipmissing(p.val[:,k])) for k in 1:size(p.val, 2)]
end

@recipe function f(p::_HistogramPlot)
    xaxis --> "Sample value"
    yaxis --> "Frequency"
    fillalpha --> 0.7
    bins --> 25
    trim --> true
    [collect(skipmissing(p.val[:,k])) for k in 1:size(p.val, 2)]
end

@recipe function f(p::_MeanPlot)
    seriestype := :path
    xaxis --> "Iteration"
    yaxis --> "Mean"
    range(p.c), cummean(p.val)
end

@recipe function f(p::_AutocorPlot)
    seriestype := :path
    xaxis --> "Lag"
    yaxis --> "Autocorrelation"
    p.lags, p.val
end

@recipe function f(p::_TracePlot)
    seriestype := :path
    xaxis --> "Iteration"
    yaxis --> "Sample value"
    range(p.c), p.val
end

@recipe function f(
    chains::Chains,
    parameters::AbstractVector{Symbol};
    colordim = :chain
)
    colordim != :chain &&
        error("Symbol names are interpreted as parameter names, only compatible with ",
              "`colordim = :chain`")

    ret = indexin(parameters, names(chains))
    any(y === nothing for y in ret) && error("Parameter not found")

    return chains, Int.(ret)
end

@recipe function f(
    chains::Chains,
    parameters::AbstractVector{<:Integer} = Int[];
    sections = _default_sections(chains),
    width = 500,
    height = 250,
    colordim = :chain,
    append_chains = false
)
    _chains = isempty(parameters) ? Chains(chains, _clean_sections(chains, sections)) : chains
    c = append_chains ? pool_chain(_chains) : _chains
    ptypes = get(plotattributes, :seriestype, (:traceplot, :mixeddensity))
    ptypes = ptypes isa Symbol ? (ptypes,) : ptypes
    @assert all(ptype -> ptype  supportedplots, ptypes)
    ntypes = length(ptypes)
    nrows, nvars, nchains = size(c)
    isempty(parameters) && (parameters = colordim == :chain ? (1:nvars) : (1:nchains))
    N = length(parameters)

    if :corner  ptypes
        size --> (ntypes*width, N*height)
        legend --> false

        multiple_plots = N * ntypes > 1
        if multiple_plots
            layout := (N, ntypes)
        end

        i = 0
        for par in parameters
            for ptype in ptypes
                i += 1

                @series begin
                    if multiple_plots
                        subplot := i
                    end
                    colordim := colordim
                    seriestype := ptype
                    c, par
                end
            end
        end
    else
        ntypes > 1 && error(":corner is not compatible with multiple seriestypes")
        Corner(c, names(c)[parameters])
    end
end

struct Corner
    c
    parameters
end

@recipe function f(corner::Corner)
    label --> permutedims(corner.parameters)
    compact --> true
    size --> (600, 600)
    ar = collect(Array(corner.c.value[:, corner.parameters,i]) for i in chains(corner.c))
    RecipesBase.recipetype(:cornerplot, vcat(ar...))
end

@recipe function f(p::_ViolinPlot)
    @series begin
        seriestype := :violin
        p.par, p.val
    end

    @series begin
        seriestype := :boxplot
        bar_width --> 0.1
        linewidth --> 2
        fillalpha --> 0.8
        p.par, p.val
    end
end

Copy link

codecov bot commented May 24, 2025

Codecov Report

Attention: Patch coverage is 94.02985% with 4 lines in your changes missing coverage. Please review.

Project coverage is 85.81%. Comparing base (268886e) to head (8613e98).

Files with missing lines Patch % Lines
src/plot.jl 94.02% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #316      +/-   ##
==========================================
+ Coverage   85.55%   85.81%   +0.26%     
==========================================
  Files          20       20              
  Lines        1073     1107      +34     
==========================================
+ Hits          918      950      +32     
- Misses        155      157       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@shravanngoswamii
Copy link
Member

shravanngoswamii commented May 24, 2025

I have implemented Violin plot and here are its results:

Chain object construction

using MCMCChains
using StatsPlots

# Define the experiment
n_iter = 100
n_name = 3
n_chain = 2

# experiment results
val = randn(n_iter, n_name, n_chain) .+ [1, 2, 3]'
val = hcat(val, rand(1:2, n_iter, 1, n_chain))

# construct a Chains object
chn = Chains(val, [:A, :B, :C, :D])

violinplot(chn) # Plotting parameter 1 across all chains
Plot

plot_1

violinplot(chn, 1) # Plotting parameter 1 across all chains
Plot

plot_2

violinplot(chn, :A) # Plotting a specific parameter across all chains
Plot

plot_3

violinplot(chn, [:C, :B, :A]) # Plotting multiple specific parameters across all chains
Plot

plot_4

violinplot(chn, 1, colordim = :parameter) # Plotting chain 1 across all parameters
Plot

plot_5

violinplot(chn, show_boxplot = false) # Plotting all parameters without the inner boxplot
Plot

plot_6

violinplot(chn, :A, append_chains = true) # Single parameter, all chains appended
Plot

plot_7

violinplot(chn, append_chains = true) # All parameters, all chains appended
Plot

plot_8

plot(chn, seriestype = :violin) # Using plot function with violin series
Plot

plot_9

@shravanngoswamii
Copy link
Member

Can someone please review this PR, it's ready!

cc @yebai

@shravanngoswamii shravanngoswamii changed the title Violin plot New Plot: Violin Plot Implementation May 24, 2025
Copy link
Member

@penelopeysm penelopeysm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't run the code but based on the plots in the last comment, the legends on the violin plots aren't correct.

It seems to me that legends aren't really useful for these plots -- maybe they should just be removed?

@shravanngoswamii
Copy link
Member

It seems to me that legends aren't really useful for these plots -- maybe they should just be removed?

I have removed legends from this! Thanks for the suggestion!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants