diff --git a/src/Bootstrap.jl b/src/Bootstrap.jl new file mode 100644 index 0000000..26a58d2 --- /dev/null +++ b/src/Bootstrap.jl @@ -0,0 +1,149 @@ +using AxisKeys +using InteractiveUtils +using LoopVectorization +using Random +using Statistics +using Tullio + +export bayes_cv + +""" + function bayes_cv( + log_likelihood::Array{Float} [, args...]; + source::String="mcmc" [, chain_index::Vector{Int}, kwargs...] + ) -> PsisBB + +Use the Bayesian bootstrap (Bayes cross-validation) and PSIS to calculate an approximate +posterior distribution for the out-of-sample score. + + +# Arguments + + - `log_likelihood::Array`: An array or matrix of log-likelihood values indexed as + `[data, step, chain]`. The chain argument can be left off if `chain_index` is provided + or if all posterior samples were drawn from a single chain. + - `args...`: Positional arguments to be passed to [`psis`](@ref). + - `chain_index::Vector`: An (optional) vector of integers specifying which chain each + step belongs to. For instance, `chain_index[3]` should return `2` if + `log_likelihood[:, 3]` belongs to the second chain. + - `kwargs...`: Keyword arguments to be passed to [`psis`](@ref). + + +# Extended help +The Bayesian bootstrap works similarly to other cross-validation methods: First, we remove +some piece of information from the model. Then, we test how well the model can reproduce +that information. With leave-k-out cross validation, the information we leave out is the +value for one or more data points. With the Bayesian bootstrap, the information being left +out is the true probability of each observation. + + +See also: [`BayesCV`](@ref), [`psis`](@ref), [`loo`](@ref), [`PsisLoo`](@ref). +""" +function bayes_cv( + log_likelihood::T, + args...; + resamples::Integer=2^10, + rng=MersenneTwister(1865), + kwargs... +) where {F<:AbstractFloat, T<:AbstractArray{F, 3}} + + dims = size(log_likelihood) + data_size = dims[1] + mcmc_count = dims[2] * dims[3] # total number of samples from posterior + log_count = log(mcmc_count) + + + # TODO: Add a way of using score functions other than ELPD + bb_weights = data_size * rand(rng, Dirichlet(ones(data_size)), resamples) + @tullio bb_samples[re, step, chain] := + bb_weights[datum, re] * log_likelihood[datum, step, chain] + @tullio log_is_ratios[re, step, chain] := + (bb_weights[datum, re] - 1) * log_likelihood[datum, step, chain] + psis_object = psis(log_is_ratios, args...; kwargs...) + psis_weights = psis_object.weights + + @tullio re_naive[re] := log <| # calculate the naive estimate in many resamples + psis_weights[re, step, chain] * exp(bb_samples[re, step, chain]) + @tullio sample_est[i] := exp(log_likelihood[i, j, k] - log_count) |> log + @tullio naive_est := sample_est[i] + + bb_ests = (2 * naive_est) .- re_naive + @tullio mcse[re] := sqrt <| + (psis_weights[re, step, chain] * (bb_samples[re, step, chain] - re_naive[re]))^2 + bootstrap_se = std(re_naive) / sqrt(resamples) + + # Posterior for the *average score*, not the mean of the posterior distribution: + resample_calcs = KeyedArray( + hcat( + bb_ests, + re_naive, + re_naive - bb_ests, + mcse, + psis_object.pareto_k + ); + data=Base.OneTo(resamples), + statistic=[ + :loo_est, + :naive_est, + :overfit, + :mcse, + :pareto_k + ], + ) + + estimates = _generate_bayes_table(log_likelihood, resample_calcs, resamples, data_size) + + return BayesCV( + estimates, + resample_calcs, + psis_object, + data_size + ) + +end + + +function bayes_cv( + log_likelihood::T, + args...; + chain_index::AbstractVector=ones(size(log_likelihood, 1)), + kwargs..., +) where {F<:AbstractFloat, T<:AbstractMatrix{F}} + new_log_ratios = _convert_to_array(log_likelihood, chain_index) + return psis_loo(new_log_ratios, args...; kwargs...) +end + + +function _generate_bayes_table( + log_likelihood::AbstractArray, + pointwise::AbstractArray, + resamples::Integer, + data_size::Integer +) + + # create table with the right labels + table = KeyedArray( + similar(log_likelihood, 3, 4); + criterion=[:loo_est, :naive_est, :overfit], + statistic=[:total, :se_total, :mean, :se_mean], + ) + + # calculate the sample expectation for the total score + to_sum = pointwise([:loo_est, :naive_est, :overfit]) + @tullio totals[crit] := to_sum[re, crit] / resamples + totals = reshape(totals, 3) + table(:, :total) .= totals + + # calculate the sample expectation for the average score + table(:, :mean) .= table(:, :mean) / data_size + + # calculate the sample expectation for the standard error in the totals + se_total = std(to_sum; dims=1) * sqrt(data_size) + se_total = reshape(se_total, 3) + table(:, :se_total) .= se_total + + # calculate the sample expectation for the standard error in averages + table(:, :se_mean) .= se_total / data_size + + return table +end diff --git a/src/InternalHelpers.jl b/src/InternalHelpers.jl index fe2549d..cb3808d 100644 --- a/src/InternalHelpers.jl +++ b/src/InternalHelpers.jl @@ -115,4 +115,41 @@ function _convert_to_array(matrix::AbstractMatrix, chain_index::AbstractVector) new_ratios[:, :, i] .= matrix[:, chain_index .== i] end return new_ratios +end + +""" + _generate_cv_table +Generate a table containing the results of cross-validation. +""" +function _generate_cv_table( + log_likelihood::AbstractArray, + pointwise::AbstractArray, + data_size::Integer +) + + # create table with the right labels + table = KeyedArray( + similar(log_likelihood, 3, 4); + criterion=[:loo_est, :naive_est, :overfit], + statistic=[:total, :se_total, :mean, :se_mean], + ) + + # calculate the sample expectation for the total score + to_sum = pointwise([:loo_est, :naive_est, :overfit]) + @tullio averages[crit] := to_sum[data, crit] / data_size + averages = reshape(averages, 3) + table(:, :mean) .= averages + + # calculate the sample expectation for the average score + table(:, :total) .= table(:, :mean) .* data_size + + # calculate the sample expectation for the standard error in the totals + se_mean = std(to_sum; mean=averages', dims=1) / sqrt(data_size) + se_mean = reshape(se_mean, 3) + table(:, :se_mean) .= se_mean + + # calculate the sample expectation for the standard error in averages + table(:, :se_total) .= se_mean * data_size + + return table end \ No newline at end of file diff --git a/src/LeaveOneOut.jl b/src/LeaveOneOut.jl index 513fe18..822adee 100644 --- a/src/LeaveOneOut.jl +++ b/src/LeaveOneOut.jl @@ -147,7 +147,7 @@ function psis_loo( statistic=[:cv_est, :naive_est, :overfit, :mcse, :pareto_k], ) - table = _generate_loo_table(pointwise) + table = _generate_cv_table(log_likelihood, pointwise, data_size) return PsisLoo(table, pointwise, psis_object) @@ -165,7 +165,7 @@ function psis_loo( end -function _generate_loo_table(pointwise::AbstractArray) +function _generate_cv_table(pointwise::AbstractArray) data_size = size(pointwise, :data) # create table with the right labels @@ -212,4 +212,4 @@ end function _mom_var_log_n(mean, variance) return sqrt(log1p(variance / mean^2)) # MOM estimate for σ -end \ No newline at end of file +end diff --git a/src/LooStructs.jl b/src/LooStructs.jl new file mode 100644 index 0000000..fbb828b --- /dev/null +++ b/src/LooStructs.jl @@ -0,0 +1,254 @@ + +using AxisKeys +using PrettyTables +export PsisLoo, PsisLooMethod, Psis, BayesCV + +const POINTWISE_LABELS = (:cv_est, :naive_est, :overfit, :ess, :pareto_k) +const CV_DESC = """ +# Fields + + - `estimates::KeyedArray`: A KeyedArray with columns `:total, :se_total, :mean, :se_mean`, + and rows `:cv_est, :naive_est, :overfit`. See `# Extended help` for more. + - `:cv_est` contains estimates for the out-of-sample prediction error, as + predicted using the jackknife (LOO-CV). + - `:naive_est` contains estimates of the in-sample prediction error. + - `:overfit` is the difference between the previous two estimators, and estimates + the amount of overfitting. When using the log probability score, it is equal to + the effective number of parameters -- a model with an overfit of 2 is "about as + overfit" as a model with 2 independent parameters that have a flat prior. + - `pointwise::KeyedArray`: A `KeyedArray` of pointwise estimates with 5 columns -- + - `:cv_est` contains the estimated out-of-sample error for this point, as measured + using leave-one-out cross validation. + - `:naive_est` contains the in-sample estimate of error for this point. + - `:overfit` is the difference in the two previous estimates. + - `:ess` is the effective sample size, which measures the simulation error introduced + by the computer program. It is *not* related to the number of data points collected, + and it does *not* measure how accurate your predictions are. + - `:pareto_k` is the estimated value for the parameter `ξ` of the generalized Pareto + distribution. Values above .7 indicate that PSIS has failed to approximate the true + distribution. + - `psis_object::Psis`: A `Psis` object containing the results of Pareto-smoothed + importance sampling. + + +# Extended help + +The total score depends on the sample size, and summarizes the weight of evidence for or +against a model. Total scores are on an interval scale, meaning that only differences of +scores are meaningful. *It is not possible to interpret a total score by looking at it.* +The total score is not a relative goodness-of-fit statistic (for this, see the average +score). + + +The overfit is equal to the difference between the in-sample and out-of-sample predictive +accuracy. When using the log probability score, it is equal to the "effective number of +parameters" -- a model with an overfit of 2 is "about as overfit" as a model with 2 +free parameters and flat priors. + + +The average score is the total score, divided by the sample size. It estimates the expected +log score, i.e. the expectation of the log probability density of observing the next point. +The average score is a relative goodness-of-fit statistic which does not depend on sample +size. + + +Unlike for chi-square goodness of fit tests, models do not have to be nested for model +comparison using cross-validation methods. +""" + + +########################### +### IMPORTANCE SAMPLING ### +########################### + +""" + Psis{V<:AbstractVector{F},I<:Integer} where {F<:AbstractFloat} + +A struct containing the results of Pareto-smoothed importance sampling. + +# Fields + - `weights`: A vector of smoothed, truncated, and normalized importance sampling weights. + - `pareto_k`: Estimates of the shape parameter `k` of the generalized Pareto distribution. + - `ess`: Estimated effective sample size for each LOO evaluation. + - `tail_len`: Vector indicating how large the "tail" is for each observation. + - `dims`: Named tuple of length 2 containing `s` (posterior sample size) and `n` (number + of observations). +""" +struct Psis{ + F<:AbstractFloat, + AF<:AbstractArray{F,3}, + VF<:AbstractVector{F}, + I<:Integer, + VI<:AbstractVector{I}, +} + weights::AF + pareto_k::VF + ess::VF + r_eff::VF + tail_len::VI + posterior_sample_size::I + data_size::I +end + + +function _throw_pareto_k_warning(ξ) + if any(ξ .≥ .7) + @warn "Some Pareto k values are very high (>0.7), indicating that PSIS has " * + "failed to approximate the true distribution." + elseif any(ξ .≥ .5) + @info "Some Pareto k values are slightly high (>0.5); some pointwise estimates " * + "may be slow to converge or have high variance." + end +end + + +function Base.show(io::IO, ::MIME"text/plain", psis_object::Psis) + table = hcat(psis_object.pareto_k, psis_object.ess) + post_samples = psis_object.posterior_sample_size + data_size = psis_object.data_size + println("Results of PSIS with $post_samples Monte Carlo samples and " * + "$data_size data points.") + _throw_pareto_k_warning(psis_object.pareto_k) + return pretty_table( + table; + compact_printing=false, + header=[:pareto_k, :ess], + formatters=ft_printf("%5.2f"), + alignment=:r, + ) +end + + + +########################## +#### CROSS VALIDATION #### +########################## + +""" + AbstractCV +An abstract type used in cross-validation. +""" +abstract type AbstractCV end + +""" + AbstractCVMethod +An abstract type used to dispatch the correct method for cross validation. +""" +abstract type AbstractCVMethod end + + + +########################## +######## PSIS-LOO ######## +########################## + +""" + PsisLooMethod + +Use Pareto-smoothed importance sampling together with leave-one-out cross validation to +estimate the out-of-sample predictive accuracy. +""" +struct PsisLooMethod <: AbstractCVMethod end + + +""" + PsisLoo{ + F <: AbstractFloat, + AF <: AbstractArray{F}, + VF <: AbstractVector{F}, + I <: Integer, + VI <: AbstractVector{I}, + } <: AbstractCV + +A struct containing the results of jackknife (leave-one-out) cross validation using Pareto +smoothed importance sampling. + +$CV_DESC + +See also: [`loo`]@ref, [`bayes_cv`]@ref, [`psis_loo`]@ref, [`Psis`]@ref +""" +struct PsisLoo{ + F <: AbstractFloat, + AF <: AbstractArray{F}, + VF <: AbstractVector{F}, + I <: Integer, + VI <: AbstractVector{I}, +} <: AbstractCV + estimates::KeyedArray + pointwise::KeyedArray + psis_object::Psis{F, AF, VF, I, VI} +end + + + + +function Base.show(io::IO, ::MIME"text/plain", loo_object::PsisLoo) + table = loo_object.estimates + _throw_pareto_k_warning(loo_object.pointwise(:pareto_k)) + post_samples = loo_object.psis_object.posterior_sample_size + data_size = loo_object.psis_object.data_size + println("Results of PSIS-LOO-CV with $post_samples Monte Carlo samples and " * + "$data_size data points.") + return pretty_table( + table; + compact_printing=false, + header=table.statistic, + row_names=table.criterion, + formatters=ft_printf("%5.2f"), + alignment=:r, + ) +end + + + +########################## +### BAYESIAN BOOTSTRAP ### +########################## + +""" + BayesCV{ + F <: AbstractFloat, + AF <: AbstractArray{F}, + VF <: AbstractVector{F}, + I <: Integer, + VI <: AbstractVector{I}, + } <: AbstractCV + +A struct containing the results of cross-validation using the Bayesian bootstrap. + +$CV_DESC + +See also: [`bayes_cv`]@ref, [`psis_loo`]@ref, [`psis`]@ref, [`Psis`]@ref +""" +struct BayesCV{ + F <: AbstractFloat, + AF <: AbstractArray{F}, + VF <: AbstractVector{F}, + I <: Integer, + VI <: AbstractVector{I}, +} <: AbstractCV + estimates::KeyedArray + posteriors::KeyedArray + psis_object::Psis{F, AF, VF, I, VI} + data_size::I +end + + +function Base.show(io::IO, ::MIME"text/plain", cv_object::BayesCV) + table = cv_object.estimates + post_samples = cv_object.psis_object.posterior_sample_size + resamples = cv_object.psis_object.data_size + data_size = cv_object.data_size + + _throw_pareto_k_warning(cv_object.psis_object.pareto_k) + println("Results of Bayesian bootstrap CV with $post_samples Monte Carlo samples, " * + "$resamples bootstrap samples, and $data_size data points.") + return pretty_table( + table; + compact_printing=false, + header=table.statistic, + row_names=table.criterion, + formatters=ft_printf("%5.2f"), + alignment=:r, + ) +end