Skip to content

Bootstrap #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions src/Bootstrap.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
using AxisKeys
using InteractiveUtils
using LoopVectorization
using Random
using Statistics
using Tullio

export bayes_cv

"""
function bayes_cv(
log_likelihood::Array{Float} [, args...];
source::String="mcmc" [, chain_index::Vector{Int}, kwargs...]
) -> PsisBB

Use the Bayesian bootstrap (Bayes cross-validation) and PSIS to calculate an approximate
posterior distribution for the out-of-sample score.


# Arguments

- `log_likelihood::Array`: An array or matrix of log-likelihood values indexed as
`[data, step, chain]`. The chain argument can be left off if `chain_index` is provided
or if all posterior samples were drawn from a single chain.
- `args...`: Positional arguments to be passed to [`psis`](@ref).
- `chain_index::Vector`: An (optional) vector of integers specifying which chain each
step belongs to. For instance, `chain_index[3]` should return `2` if
`log_likelihood[:, 3]` belongs to the second chain.
- `kwargs...`: Keyword arguments to be passed to [`psis`](@ref).


# Extended help
The Bayesian bootstrap works similarly to other cross-validation methods: First, we remove
some piece of information from the model. Then, we test how well the model can reproduce
that information. With leave-k-out cross validation, the information we leave out is the
value for one or more data points. With the Bayesian bootstrap, the information being left
out is the true probability of each observation.


See also: [`BayesCV`](@ref), [`psis`](@ref), [`loo`](@ref), [`PsisLoo`](@ref).
"""
function bayes_cv(
log_likelihood::T,
args...;
resamples::Integer=2^10,
rng=MersenneTwister(1865),
kwargs...
) where {F<:AbstractFloat, T<:AbstractArray{F, 3}}

dims = size(log_likelihood)
data_size = dims[1]
mcmc_count = dims[2] * dims[3] # total number of samples from posterior
log_count = log(mcmc_count)


# TODO: Add a way of using score functions other than ELPD
bb_weights = data_size * rand(rng, Dirichlet(ones(data_size)), resamples)
@tullio bb_samples[re, step, chain] :=
bb_weights[datum, re] * log_likelihood[datum, step, chain]
@tullio log_is_ratios[re, step, chain] :=
(bb_weights[datum, re] - 1) * log_likelihood[datum, step, chain]
psis_object = psis(log_is_ratios, args...; kwargs...)
psis_weights = psis_object.weights

@tullio re_naive[re] := log <| # calculate the naive estimate in many resamples
psis_weights[re, step, chain] * exp(bb_samples[re, step, chain])
@tullio sample_est[i] := exp(log_likelihood[i, j, k] - log_count) |> log
@tullio naive_est := sample_est[i]

bb_ests = (2 * naive_est) .- re_naive
@tullio mcse[re] := sqrt <|
(psis_weights[re, step, chain] * (bb_samples[re, step, chain] - re_naive[re]))^2
bootstrap_se = std(re_naive) / sqrt(resamples)

# Posterior for the *average score*, not the mean of the posterior distribution:
resample_calcs = KeyedArray(
hcat(
bb_ests,
re_naive,
re_naive - bb_ests,
mcse,
psis_object.pareto_k
);
data=Base.OneTo(resamples),
statistic=[
:loo_est,
:naive_est,
:overfit,
:mcse,
:pareto_k
],
)

estimates = _generate_bayes_table(log_likelihood, resample_calcs, resamples, data_size)

return BayesCV(
estimates,
resample_calcs,
psis_object,
data_size
)

end


function bayes_cv(
log_likelihood::T,
args...;
chain_index::AbstractVector=ones(size(log_likelihood, 1)),
kwargs...,
) where {F<:AbstractFloat, T<:AbstractMatrix{F}}
new_log_ratios = _convert_to_array(log_likelihood, chain_index)
return psis_loo(new_log_ratios, args...; kwargs...)
end


function _generate_bayes_table(
log_likelihood::AbstractArray,
pointwise::AbstractArray,
resamples::Integer,
data_size::Integer
)

# create table with the right labels
table = KeyedArray(
similar(log_likelihood, 3, 4);
criterion=[:loo_est, :naive_est, :overfit],
statistic=[:total, :se_total, :mean, :se_mean],
)

# calculate the sample expectation for the total score
to_sum = pointwise([:loo_est, :naive_est, :overfit])
@tullio totals[crit] := to_sum[re, crit] / resamples
totals = reshape(totals, 3)
table(:, :total) .= totals

# calculate the sample expectation for the average score
table(:, :mean) .= table(:, :mean) / data_size

# calculate the sample expectation for the standard error in the totals
se_total = std(to_sum; dims=1) * sqrt(data_size)
se_total = reshape(se_total, 3)
table(:, :se_total) .= se_total

# calculate the sample expectation for the standard error in averages
table(:, :se_mean) .= se_total / data_size

return table
end
37 changes: 37 additions & 0 deletions src/InternalHelpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,41 @@ function _convert_to_array(matrix::AbstractMatrix, chain_index::AbstractVector)
new_ratios[:, :, i] .= matrix[:, chain_index .== i]
end
return new_ratios
end

"""
_generate_cv_table
Generate a table containing the results of cross-validation.
"""
function _generate_cv_table(
log_likelihood::AbstractArray,
pointwise::AbstractArray,
data_size::Integer
)

# create table with the right labels
table = KeyedArray(
similar(log_likelihood, 3, 4);
criterion=[:loo_est, :naive_est, :overfit],
statistic=[:total, :se_total, :mean, :se_mean],
)

# calculate the sample expectation for the total score
to_sum = pointwise([:loo_est, :naive_est, :overfit])
@tullio averages[crit] := to_sum[data, crit] / data_size
averages = reshape(averages, 3)
table(:, :mean) .= averages

# calculate the sample expectation for the average score
table(:, :total) .= table(:, :mean) .* data_size

# calculate the sample expectation for the standard error in the totals
se_mean = std(to_sum; mean=averages', dims=1) / sqrt(data_size)
se_mean = reshape(se_mean, 3)
table(:, :se_mean) .= se_mean

# calculate the sample expectation for the standard error in averages
table(:, :se_total) .= se_mean * data_size

return table
end
6 changes: 3 additions & 3 deletions src/LeaveOneOut.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ function psis_loo(
statistic=[:cv_est, :naive_est, :overfit, :mcse, :pareto_k],
)

table = _generate_loo_table(pointwise)
table = _generate_cv_table(log_likelihood, pointwise, data_size)

return PsisLoo(table, pointwise, psis_object)

Expand All @@ -165,7 +165,7 @@ function psis_loo(
end


function _generate_loo_table(pointwise::AbstractArray)
function _generate_cv_table(pointwise::AbstractArray)

data_size = size(pointwise, :data)
# create table with the right labels
Expand Down Expand Up @@ -212,4 +212,4 @@ end

function _mom_var_log_n(mean, variance)
return sqrt(log1p(variance / mean^2)) # MOM estimate for σ
end
end
Loading