Skip to content

Commit

Permalink
Merge pull request #511 from JuliaTrustworthyAI/add-divergence-to-ben…
Browse files Browse the repository at this point in the history
…chmark

Add divergence to benchmark
  • Loading branch information
pat-alt authored Jan 13, 2025
2 parents a137d70 + e60ce2b commit b4b0527
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 29 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
### Added

- Added preliminary support for divergence metrics that can be used to evaluate counterfactuals with respect to target distributions.
-

## Version [1.4.3] - 2025-01-02

### Changed
Expand Down
1 change: 1 addition & 0 deletions src/CounterfactualExplanations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ export AbstractCounterfactualExplanation
export AbstractModel
export AbstractGenerator
export AbstractConvergence
export AbstractMeasure
export AbstractPenalty, PenaltyOrFun

# Global constants:
Expand Down
7 changes: 6 additions & 1 deletion src/base_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@ Base.broadcastable(gen::AbstractGenerator) = Ref(gen)
"An abstract type that serves as the base type for convergence objects."
abstract type AbstractConvergence end

"An abstract type that serves as the base type for measures. Objects of type `AbstractMeasure` need to be callable."
abstract type AbstractMeasure <: Function end

measure_name(m::Function) = Symbol(m)

"An abstract type for penalty functions."
abstract type AbstractPenalty end
abstract type AbstractPenalty <: AbstractMeasure end

"Treat `AbstractPenalty` as scalar when broadcasting."
Base.broadcastable(pen::AbstractPenalty) = Ref(pen)
Expand Down
10 changes: 8 additions & 2 deletions src/evaluation/Evaluation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@ using ..Models
using LinearAlgebra: LinearAlgebra
using Statistics

abstract type AbstractDivergenceMetric <: AbstractMeasure end

include("serialization.jl")
include("divergence/divergence.jl")

export MMD

include("measures.jl")
include("benchmark.jl")
include("evaluate.jl")
include("measures.jl")
include("divergence/divergence.jl")

export global_serializer, Serializer, NullSerializer, _serialization_state
export global_output_identifier, DefaultOutputIdentifier, _output_id, get_global_output_id
Expand All @@ -26,6 +31,7 @@ export plausibility_energy_differential,
export faithfulness
export plausibility_measures, default_measures, distance_measures, all_measures
export concatenate_benchmarks
export compute_divergence

"Available plausibility measures."
const plausibility_measures = [
Expand Down
74 changes: 64 additions & 10 deletions src/evaluation/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ end
benchmark(
counterfactual_explanations::Vector{CounterfactualExplanation};
meta_data::Union{Nothing,<:Vector{<:Dict}}=nothing,
measure::Union{Function,Vector{Function}}=default_measures,
measure::Union{Function,Vector{<:Function}}=default_measures,
store_ce::Bool=false,
)
Expand All @@ -93,7 +93,7 @@ Generates a `Benchmark` for a vector of counterfactual explanations. Optionally
function benchmark(
counterfactual_explanations::Vector{CounterfactualExplanation};
meta_data::Union{Nothing,<:Vector{<:Dict}}=nothing,
measure::Union{Function,Vector{Function}}=default_measures,
measure::Union{Function,Vector{<:Function}}=default_measures,
store_ce::Bool=false,
parallelizer::Union{Nothing,AbstractParallelizer}=nothing,
)
Expand All @@ -105,7 +105,7 @@ function benchmark(
measure=measure,
report_each=true,
report_meta=true,
store_ce=store_ce,
store_ce=needs_ce(store_ce, measure),
output_format=:DataFrame,
)
bmk = Benchmark(reduce(vcat, evaluations))
Expand All @@ -119,7 +119,7 @@ end
data::CounterfactualData;
models::Dict{<:Any,<:AbstractModel},
generators::Dict{<:Any,<:AbstractGenerator},
measure::Union{Function,Vector{Function}}=default_measures,
measure::Union{Function,Vector{<:Function}}=default_measures,
xids::Union{Nothing,AbstractArray}=nothing,
dataname::Union{Nothing,Symbol,String}=nothing,
verbose::Bool=true,
Expand All @@ -136,7 +136,7 @@ function benchmark(
data::CounterfactualData;
models::Dict{<:Any,<:AbstractModel},
generators::Dict{<:Any,<:AbstractGenerator},
measure::Union{Function,Vector{Function}}=default_measures,
measure::Union{Function,Vector{<:Function}}=default_measures,
dataname::Union{Nothing,Symbol,String}=nothing,
verbose::Bool=true,
store_ce::Bool=false,
Expand Down Expand Up @@ -201,7 +201,7 @@ function benchmark(
measure=measure,
report_each=true,
report_meta=true,
store_ce=store_ce,
store_ce=needs_ce(store_ce, measure),
output_format=:DataFrame,
verbose=verbose,
)
Expand Down Expand Up @@ -233,7 +233,7 @@ end
test_data::Union{Nothing,CounterfactualData}=nothing,
models::Dict{<:Any,<:Any}=standard_models_catalogue,
generators::Union{Nothing,Dict{<:Any,<:AbstractGenerator}}=nothing,
measure::Union{Function,Vector{Function}}=default_measures,
measure::Union{Function,Vector{<:Function}}=default_measures,
n_individuals::Int=5,
n_runs::Int=1,
suppress_training::Bool=false,
Expand All @@ -256,7 +256,7 @@ Benchmark a set of counterfactuals for a given data set and additional inputs.
- `test_data::Union{Nothing,CounterfactualData}`: Optional test data for evaluation. Defaults to `nothing`, in which case `data` is used for evaluation.
- `models::Dict{<:Any,<:Any}`: A dictionary of model objects keyed by their names. Defaults to `standard_models_catalogue`.
- `generators::Union{Nothing,Dict{<:Any,<:AbstractGenerator}}`: Optional dictionary of generator functions keyed by their names. Defaults to `nothing`, in which case the whole [`generator_catalogue`](@ref) is used.
- `measure::Union{Function,Vector{Function}}`: The measure(s) to evaluate the counterfactuals against. Defaults to `default_measures`.
- `measure::Union{Function,Vector{<:Function}}`: The measure(s) to evaluate the counterfactuals against. Defaults to `default_measures`.
- `n_individuals::Int=5`: Number of individuals to generate for each model and generator.
- `n_runs::Int=1`: Number of runs for each model and generator.
- `suppress_training::Bool=false`: Whether to suppress training of models during benchmarking. This is useful if models have already been trained.
Expand Down Expand Up @@ -291,7 +291,7 @@ function benchmark(
test_data::Union{Nothing,CounterfactualData}=nothing,
models::Dict{<:Any,<:Any}=standard_models_catalogue,
generators::Union{Nothing,Dict{<:Any,<:AbstractGenerator}}=nothing,
measure::Union{Function,Vector{Function}}=default_measures,
measure::Union{Function,Vector{<:Function}}=default_measures,
n_individuals::Int=5,
n_runs::Int=1,
suppress_training::Bool=false,
Expand Down Expand Up @@ -471,7 +471,7 @@ function benchmark(
measure=measure,
report_each=true,
report_meta=true,
store_ce=store_ce,
store_ce=needs_ce(store_ce, measure),
output_format=:DataFrame,
)

Expand Down Expand Up @@ -544,3 +544,57 @@ function get_benchmark_files(storage_path::String)

return bmk_files
end

"""
includes_divergence_metric(measure::Union{Function,Vector{<:Function}})
Checks if the provided `measure` includes a divergence metric.
"""
function includes_divergence_metric(measure::Union{Function,Vector{<:Function}})
if isa(measure, Function)
return isa(measure, AbstractDivergenceMetric)
else
return any(isa(m, AbstractDivergenceMetric) for m in measure)
end
end

"""
needs_ce(store_ce::Bool,measure::Union{Function,Vector{<:Function}})
A helper function to determine if counterfactual explanations should be stored based on the given `store_ce` flag and the presence of a divergence metric in the `measure`.
"""
function needs_ce(store_ce::Bool, measure::Union{Function,Vector{<:Function}})
if !store_ce && includes_divergence_metric(measure)
@warn "Divergence metric detected. Will temporarily store counterfactual explanations, which can lead to increased memory usage."
end
return store_ce || includes_divergence_metric(measure)
end

function compute_divergence(
bmk::Benchmark, measure::Union{Function,Vector{<:Function}}, data::CounterfactualData
)
@assert !isnothing(bmk.counterfactuals) "Cannot compute divergence without counterfactuals. Set `store_ce=true` when running the benchmark."
if !includes_divergence_metric(measure)
@info "No divergence metric detected. Skipping computation."
return bmk
end
df = innerjoin(bmk.evaluation, bmk.counterfactuals; on=:sample)
div_metrics = String.(measure_name.(measure)[isa.(measure, AbstractDivergenceMetric)])
gdf = groupby(df, [:variable, :generator, :model, :target, :factual])
final_df = DataFrame()
for _df in gdf
if !(unique(_df.variable)[1] in div_metrics)
_df.pval .= NaN
else
metric = measure[String.(measure_name.(measure)) .== unique(_df.variable)[1]][1]
val, pval = metric(collect(_df.ce), data)
_df.value .= val
_df.pval .= pval
end
first_cols = [:sample, :num_counterfactual, :variable, :value, :pval]
select!(df, first_cols, Not(first_cols))
final_df = vcat(final_df, _df)
end

return Benchmark(final_df, bmk.counterfactuals)
end
2 changes: 0 additions & 2 deletions src/evaluation/divergence/divergence.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
using CounterfactualExplanations
using CounterfactualExplanations: counterfactual

abstract type AbstractDivergenceMetric end

include("mmd.jl")

function get_samples_for_metric(
Expand Down
19 changes: 12 additions & 7 deletions src/evaluation/divergence/mmd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,30 @@ Concrete type for the Maximum Mean Discrepancy (MMD) metric.
"""
struct MMD{K<:KernelFunctions.Kernel} <: AbstractDivergenceMetric
kernel::K
compute_p::Union{Nothing,Int}
end

MMD() = MMD(default_kernel)
function MMD(; kernel=default_kernel, compute_p=1000)
return MMD(kernel, compute_p)
end

CounterfactualExplanations.measure_name(m::MMD) = :mmd

"""
(m::MMD)(x::AbstractArray, y::AbstractArray)
Computes the maximum mean discrepancy (MMD) between two datasets `x` and `y`. The MMD is a measure of the difference between two probability distributions. It is defined as the maximum value of the kernelized dot product between the two datasets. It is computed as the sum of average kernel values between columns (samples) of `x` and `y`, minus twice the average kernel value between columns (samples) of `x` and `y`. A larger MMD value indicates that the distributions are more different, while a value closer to zero suggests they are more similar. See also [`kernelsum`](@ref).
"""
function (m::MMD)(x::AbstractArray, y::AbstractArray; compute_p::Union{Nothing,Int}=1000)
function (m::MMD)(x::AbstractArray, y::AbstractArray)
xx = kernelsum(m.kernel, x)
yy = kernelsum(m.kernel, y)
xy = kernelsum(m.kernel, x, y)
mmd = xx + yy - 2xy
if !isnothing(compute_p)
mmd_null = mmd_null_dist(x, y, m.kernel; l=compute_p)
if !isnothing(m.compute_p)
mmd_null = mmd_null_dist(x, y, m.kernel; l=m.compute_p)
p_val = mmd_significance(mmd, mmd_null)
else
p_val = nothing
p_val = NaN
end
return mmd, p_val
end
Expand All @@ -44,7 +49,7 @@ end
kwrgs...
)
Computes the MMD between two datasets `x` and `y`, along with a p-value based on a null distribution of MMD values (unless `compute_p=nothing`) for a random subset of the data (of sample size `n`). The p-value is computed using a permutation test.
Computes the MMD between two datasets `x` and `y`, along with a p-value based on a null distribution of MMD values (unless `m.compute_p=nothing`) for a random subset of the data (of sample size `n`). The p-value is computed using a permutation test.
"""
function (m::MMD)(x::AbstractArray, y::AbstractArray, n::Int; kwrgs...)
n = minimum([size(x, 2), n])
Expand Down Expand Up @@ -74,7 +79,7 @@ function mmd_null_dist(
Zs = [Z[:, shuffle(1:end)] for i in 1:l]

bootstrap = function (z)
return MMD(k)(z[:, 1:n], z[:, (n + 1):end]; compute_p=nothing)[1]
return MMD(k, nothing)(z[:, 1:n], z[:, (n + 1):end])[1]
end

mmd_null = map(Zs) do z
Expand Down
16 changes: 13 additions & 3 deletions src/evaluation/evaluate.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using CounterfactualExplanations: measure_name
using DataFrames: nrow
using UUIDs: uuid1

Expand Down Expand Up @@ -48,6 +49,15 @@ function compute_measure(ce::CounterfactualExplanation, measure::Function, agg::
return ndims(val) > 1 ? vec(val) : [val]
end

"""
compute_measure(ce::CounterfactualExplanation, measure::AbstractDivergenceMetric, agg::Function)
For abstract divergence metrics, returns a vector of NaN values.
"""
compute_measure(
ce::CounterfactualExplanation, measure::AbstractDivergenceMetric, agg::Function
) = [NaN]

"""
evaluate_dict(ce::CounterfactualExplanation, measure::Vector{Function}, agg::Function)
Evaluates a counterfactual explanation and returns a dictionary of evaluation measures.
Expand Down Expand Up @@ -83,7 +93,7 @@ function to_dataframe(
evaluation = DataFrames.DataFrame(
Dict(
m => report_each ? val[1] : val for
(m, val) in zip(Symbol.(measure), computed_measures)
(m, val) in zip(measure_name.(measure), computed_measures)
),
)
evaluation.num_counterfactual = 1:nrow(evaluation)
Expand Down Expand Up @@ -128,7 +138,7 @@ end
evaluate(
ce::CounterfactualExplanation,
meta_data::Union{Nothing,Dict}=nothing;
measure::Union{Function,Vector{Function}}=default_measures,
measure::Union{Function,Vector{<:Function}}=default_measures,
agg::Function=mean,
report_each::Bool=false,
output_format::Symbol=:Vector,
Expand All @@ -155,7 +165,7 @@ Just computes evaluation `measures` for the counterfactual explanation. By defau
function evaluate(
ce::CounterfactualExplanation,
meta_data::Union{Nothing,Dict}=nothing;
measure::Union{Function,Vector{Function}}=default_measures,
measure::Union{Function,Vector{<:Function}}=default_measures,
agg::Function=mean,
report_each::Bool=false,
output_format::Symbol=:Vector,
Expand Down
23 changes: 21 additions & 2 deletions test/other/evaluation.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
using CounterfactualExplanations.Convergence
using CounterfactualExplanations.Evaluation:
Benchmark, evaluate, validity, distance_measures, concatenate_benchmarks
Benchmark,
evaluate,
validity,
distance_measures,
concatenate_benchmarks,
compute_divergence
using CounterfactualExplanations.Objectives: distance
using Serialization: serialize
using TaijaData: load_moons, load_circles
Expand Down Expand Up @@ -67,6 +72,10 @@ generators = Dict(
plaus = Evaluation.plausibility_cosine(ce)
plaus = Evaluation.plausibility_energy_differential(ce)
@test true

@testset "Divergence Metrics" begin
@test isnan(evaluate(ce; measure=MMD())[1][1])
end
end

@testset "Benchmarking" begin
Expand Down Expand Up @@ -126,6 +135,10 @@ end
bmk = vcat(bmks[1], bmks[2]; ids=collect(keys(datasets)))
@test typeof(bmk) <: Benchmark
end

@testset "Divergence" begin
@test all(isnan.(benchmark(ces; measure=MMD())().value))
end
end

@testset "Serialization" begin
Expand Down Expand Up @@ -175,6 +188,12 @@ end

mmd_generic = mmd(ces, counterfactual_data, n_individuals)

@test true
bmk =
benchmark(ces; measure=[validity, MMD()]) |>
bmk -> compute_divergence(
bmk, [validity, MMD(; compute_p=nothing)], counterfactual_data
)

@test all(.!isnan.(bmk.evaluation.value))
end
end
2 changes: 1 addition & 1 deletion test/other/performance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ generator = GenericGenerator()

if VERSION >= v"1" && !Sys.iswindows()
t = @benchmark generate_counterfactual(x, target, data, M, generator) samples = 1000
expected_allocs = 7000
expected_allocs = 10000
@test t.allocs <= expected_allocs
end

2 comments on commit b4b0527

@pat-alt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/122885

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.4.4 -m "<description of version>" b4b05279f59bec33934299cbb3d1df68991bcf22
git push origin v1.4.4

Please sign in to comment.