Skip to content

Lazy-compute-ESS #52

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Sep 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ jobs:
version:
- '1' # latest stable 1.x release of Julia
- '1.6' # oldest supported version
- 'nightly'
os:
- ubuntu-latest
arch:
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ version = "0.6.6"
[deps]
AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Lazy = "50d2b5c4-7a5e-59d5-8109-a42b560f39c0"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
Memoize = "c03570c3-d221-55d1-a50c-7939bbd78826"
NamedDims = "356022a1-0364-5f58-8944-0da4b18d706f"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand Down
13 changes: 6 additions & 7 deletions src/AbstractCV.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
using AxisKeys
using PrettyTables

export AbstractCVMethod, AbstractCV
# export AbstractCVMethod, AbstractCV

const POINTWISE_LABELS = (:cv_elpd, :naive_lpd, :p_eff, :ess, :pareto_k)
const CV_DESC = """
# Fields

Expand Down Expand Up @@ -73,12 +72,12 @@ An abstract type used in cross-validation.
abstract type AbstractCV end


"""
AbstractCVMethod
# """
# AbstractCVMethod

An abstract type used to dispatch the correct method for cross validation.
"""
abstract type AbstractCVMethod end
# An abstract type used to dispatch the correct method for cross validation.
# """
# abstract type AbstractCVMethod end


##########################
Expand Down
10 changes: 5 additions & 5 deletions src/ESS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ export relative_eff, psis_ess, sup_ess

"""
relative_eff(
sample::AbstractArray{Real, 3};
sample::AbstractArray{<:Real, 3};
method=MCMCDiagnosticTools.FFTESSMethod()
)

Expand All @@ -16,7 +16,7 @@ by the nominal sample size.

- `sample::AbstractArray{<:Real, 3}`: An array of log-likelihood values.
"""
function relative_eff(sample::AbstractArray{<:Real, 3}; maxlag=size(sample, 2), kwargs...)
function relative_eff(sample::AbstractArray{<:Real,3}; maxlag=size(sample, 2), kwargs...)
dims = size(sample)
post_sample_size = dims[2] * dims[3]
ess_sample = permutedims(sample, [2, 1, 3])
Expand Down Expand Up @@ -60,13 +60,13 @@ end

"""
function sup_ess(
weights::AbstractVector{T},
weights::AbstractMatrix{T},
r_eff::AbstractVector{T}
) -> AbstractVector

Calculate the supremum-based effective sample size of a PSIS sample, i.e. the inverse of the
maximum weight. This measure is more trustworthy than the `ess` from `psis_ess`. It uses the
L-∞ norm.
maximum weight. This measure is more sensitive than the `ess` from `psis_ess`, but also
much more variable. It uses the L-∞ norm.

# Arguments
- `weights`: A set of importance sampling weights derived from PSIS.
Expand Down
14 changes: 8 additions & 6 deletions src/GPD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ using Tullio


"""
gpdfit(
sample::AbstractVector{T<:Real};
gpd_fit(
sample::AbstractVector{T<:Real},
r_eff::T = 1;
wip::Bool=true,
min_grid_pts::Integer=30,
sort_sample::Bool=false
Expand All @@ -29,12 +30,13 @@ generalized Pareto distribution (GPD), assuming the location parameter is 0.
Estimation method taken from Zhang, J. and Stephens, M.A. (2009). The parameter ξ is the
negative of k.
"""
function gpdfit(
sample::AbstractVector{T};
function gpd_fit(
sample::AbstractVector{T},
r_eff::T=1;
wip::Bool=true,
min_grid_pts::Integer=30,
sort_sample::Bool=false,
) where {T <: Real}
) where T<:Real

len = length(sample)
# sample must be sorted, but we can skip if sample is already sorted
Expand Down Expand Up @@ -70,7 +72,7 @@ function gpdfit(

# Drag towards .5 to reduce variance for small len
if wip
@fastmath ξ = (ξ * len + 0.5 * n_0) / (len + n_0)
@fastmath ξ = (r_eff * ξ * len + 0.5 * n_0) / (r_eff * len + n_0)
end

return ξ, σ
Expand Down
99 changes: 61 additions & 38 deletions src/ImportanceSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ double check it is correct.
const MIN_TAIL_LEN = 5 # Minimum size of a tail for PSIS to give sensible answers
const SAMPLE_SOURCES = ["mcmc", "vi", "other"]

export psis, psis!, PsisLoo, PsisLooMethod, Psis
export psis, psis!, Psis


###########################
Expand All @@ -24,9 +24,7 @@ A struct containing the results of Pareto-smoothed importance sampling.

# Fields

- `log_weights`: A vector of smoothed and truncated but *unnormalized* importance sampling
weights.
- `weights`: A lazy
- `weights`: A vector of smoothed, truncated, and normalized importance sampling weights.
- `pareto_k`: Estimates of the shape parameter `k` of the generalized Pareto distribution.
- `ess`: Estimated effective sample size for each LOO evaluation, based on the variance of
the weights.
Expand All @@ -39,21 +37,38 @@ A struct containing the results of Pareto-smoothed importance sampling.
- `data_size`: How many data points were used for PSIS.
"""
struct Psis{
RealType <: Real,
AT <: AbstractArray{RealType, 3},
VT <: AbstractVector{RealType},
R <: Real,
AT <: AbstractArray{R, 3},
VT <: AbstractVector{R}
}
weights::AT
pareto_k::VT
ess::VT
sup_ess::VT
r_eff::VT
tail_len::Vector{Int}
tail_len::AbstractVector{Int}
posterior_sample_size::Int
data_size::Int
end


function Base.getproperty(psis_obj::Psis, k::Symbol)
if k === :log_weights
return log.(getfield(psis_obj, :weights))
else
return getfield(psis_obj, k)
end
end


function Base.propertynames(psis_object::Psis)
return (
fieldnames(typeof(psis_object))...,
:log_weights,
)
end


function Base.show(io::IO, ::MIME"text/plain", psis_object::Psis)
table = hcat(psis_object.pareto_k, psis_object.ess, psis_object.sup_ess)
post_samples = psis_object.posterior_sample_size
Expand All @@ -79,14 +94,16 @@ end
"""
psis(
log_ratios::AbstractArray{T<:Real},
r_eff::AbstractVector;
r_eff::AbstractVector{T};
source::String="mcmc"
) -> Psis

Implements Pareto-smoothed importance sampling (PSIS).

# Arguments

## Positional Arguments

- `log_ratios::AbstractArray`: A 2d or 3d array of (unnormalized) importance ratios on the
log scale. Indices must be ordered as `[data, step, chain]`. The chain index can be left
off if there is only one chain, or if keyword argument `chain_index` is provided.
Expand All @@ -98,15 +115,17 @@ Implements Pareto-smoothed importance sampling (PSIS).
- `source::String="mcmc"`: A string or symbol describing the source of the sample being
used. If `"mcmc"`, adjusts ESS for autocorrelation. Otherwise, samples are assumed to be
independent. Currently permitted values are $SAMPLE_SOURCES.
- `calc_ess::Bool=true`: If `false`, do not calculate ESS diagnostics. Attempting to
access ESS diagnostics will return an empty list.

See also: [`relative_eff`]@ref, [`psis_loo`]@ref, [`psis_ess`]@ref.
"""
function psis(
log_ratios::AbstractArray{<:Real, 3};
r_eff::AbstractVector{<:Real}=similar(log_ratios, 0),
log_ratios::AbstractArray{T, 3};
r_eff::AbstractVector{T}=similar(log_ratios, 0),
source::Union{AbstractString, Symbol}="mcmc",
log_weights::Bool=true
)
calc_ess::Bool = true
) where T <: Real

source = lowercase(String(source))
dims = size(log_ratios)
Expand All @@ -115,27 +134,35 @@ function psis(
post_sample_size = dims[2] * dims[3]

# Reshape to matrix (easier to deal with)
log_ratios = reshape(log_ratios, data_size, post_sample_size)
r_eff = _generate_r_eff(log_ratios, dims, r_eff, source)
_check_input_validity_psis(reshape(log_ratios, dims), r_eff)
weights = @. exp(log_ratios - $maximum(log_ratios; dims=2))
log_ratios_mat = reshape(log_ratios, data_size, post_sample_size)
r_eff = _generate_r_eff(log_ratios_mat, dims, r_eff, source)
_check_input_validity_psis(log_ratios, r_eff)
weights = similar(log_ratios)
weights_mat = reshape(weights, data_size, post_sample_size)
@. weights = exp(log_ratios - $maximum(log_ratios; dims=(2,3)))


tail_length = Vector{Int}(undef, data_size)
tail_length = similar(r_eff, Int)
ξ = similar(r_eff)
@inbounds Threads.@threads for i in eachindex(tail_length)
tail_length[i] = _def_tail_length(post_sample_size, r_eff[i])
ξ[i] = @views psis!(weights[i, :], tail_length[i])
ξ[i] = @views psis!(weights_mat[i, :], r_eff[i]; tail_length=tail_length[i])
end

@tullio norm_const[i] := weights[i, j]
@tullio norm_const[i] := weights[i, j, k]
@. weights = weights / norm_const
ess = psis_ess(weights, r_eff)
inf_ess = sup_ess(weights, r_eff)

weights = reshape(weights, dims)

if calc_ess
ess = psis_ess(weights_mat, r_eff)
inf_ess = sup_ess(weights_mat, r_eff)
else
ess = similar(weights_mat, 0)
inf_ess = similar(weights_mat, 0)
end

return Psis(
weights,
weights,
ξ,
ess,
inf_ess,
Expand Down Expand Up @@ -193,10 +220,11 @@ log-weights.
Unlike the methods for arrays, `psis!` performs no checks to make sure the input values are
valid.
"""
function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer;
function psis!(is_ratios::AbstractVector{T}, r_eff::T=one(T);
tail_length::Integer = _def_tail_length(length(is_ratios), r_eff),
log_weights::Bool=false
)

) where T<:Real
len = length(is_ratios)
tail_start = len - tail_length + 1 # index of smallest tail value

Expand All @@ -213,7 +241,7 @@ function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer;

# Get value just before the tail starts:
cutoff = is_ratios[tail_start - 1]
ξ = _psis_smooth_tail!(tail, cutoff)
ξ = _psis_smooth_tail!(tail, cutoff, r_eff)

# truncate at max of raw weights (1 after scaling)
clamp!(is_ratios, 0, 1)
Expand All @@ -228,38 +256,33 @@ function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer;
end


function psis!(is_ratios::AbstractVector{<:Real}, r_eff::Real=1)
tail_length = _def_tail_length(length(is_ratios), r_eff)
return psis!(is_ratios, tail_length)
end


"""
_def_tail_length(log_ratios::AbstractVector, r_eff::Real) -> Integer

Define the tail length as in Vehtari et al. (2019), with the small addition that the tail
must a multiple of `32*bit_length` (which improves performance).
"""
function _def_tail_length(length::Integer, r_eff::Real=1)
function _def_tail_length(length::Integer, r_eff::Real=one(T))
return min(cld(length, 5), ceil(3 * sqrt(length / r_eff))) |> Int
end


"""
_psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T) where {T<:Real} -> ξ::T
_psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T, r_eff::T=1) where {T<:Real}
-> ξ::T

Takes an *already sorted* vector of observations from the tail and smooths it *in place*
with PSIS before returning shape parameter `ξ`.
"""
function _psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T) where {T <: Real}
function _psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T, r_eff::T=one(T)) where {T <: Real}
len = length(tail)
if any(isinf.(tail))
return ξ = Inf
else
@. tail = tail - cutoff

# save time not sorting since tail is already sorted
ξ, σ = gpdfit(tail)
ξ, σ = gpd_fit(tail, r_eff)
@. tail = gpd_quantile(($(1:len) - 0.5) / len, ξ, σ) + cutoff
end
return ξ
Expand Down
24 changes: 12 additions & 12 deletions src/LeaveOneOut.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,27 @@ using Statistics
using Printf
using Tullio

export loo, psis_loo, loo_from_psis
export loo, psis_loo, loo_from_psis, PsisLoo


#####################
###### STRUCTS ######
#####################


"""
PsisLooMethod
# """
# PsisLooMethod

Use Pareto-smoothed importance sampling together with leave-one-out cross validation to
estimate the out-of-sample predictive accuracy.
"""
struct PsisLooMethod <: AbstractCVMethod end
# Use Pareto-smoothed importance sampling together with leave-one-out cross validation to
# estimate the out-of-sample predictive accuracy.
# """
# struct PsisLooMethod <: AbstractCVMethod end


"""
PsisLoo <: AbstractCV

A struct containing the results of leave-one-out cross validation using Pareto
A struct containing the results of leave-one-out cross validation computed with Pareto
smoothed importance sampling.

$CV_DESC
Expand Down Expand Up @@ -71,17 +71,17 @@ end


"""
function loo(args...; method=PsisLooMethod(), kwargs...) -> PsisLoo
function loo(args...; kwargs...) -> PsisLoo

Compute the approximate leave-one-out cross-validation score using the specified method.
Compute an approximate leave-one-out cross-validation score.

Currently, this function only serves to call `psis_loo`, but this could change in the
future. The default methods or return type may change without warning; thus, we recommend
future. The default methods or return type may change without warning, so we recommend
using `psis_loo` instead if reproducibility is required.

See also: [`psis_loo`](@ref), [`PsisLoo`](@ref).
"""
function loo(args...; method=PsisLooMethod(), kwargs...)
function loo(args...; kwargs...)
return psis_loo(args...; kwargs...)
end

Expand Down
Loading