Skip to content

Commit

Permalink
Merge pull request #492 from JuliaTrustworthyAI/490-use-gradient-magn…
Browse files Browse the repository at this point in the history
…itude-matching-or-root-scaling

Balancing gradients
  • Loading branch information
pat-alt authored Nov 7, 2024
2 parents 9013ce1 + 719ca82 commit 3a8cdc3
Showing 14 changed files with 142 additions and 90 deletions.
13 changes: 12 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -6,7 +6,18 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),

*Note*: We try to adhere to these practices as of version [v1.1.1].

## Version [1.3.5]
## Version [1.3.6]

### Changed

- Slight changes to the implementation of `ProbeGenerator` (no longer calling a redundant `hinge_loss` function for all other generators).

### Added

- Added a warning message to the `ProbeGenerator` pointing to the issues with with current implementation.
- Added links to papers to all docstrings for generators.

## Version [1.3.5] - 2024-10-28

### Changed

2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "CounterfactualExplanations"
uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
authors = ["Patrick Altmeyer <p.altmeyer@tudelft.nl> and contributors"]
version = "1.3.5"
version = "1.3.6"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
10 changes: 6 additions & 4 deletions src/CounterfactualExplanations.jl
Original file line number Diff line number Diff line change
@@ -55,6 +55,11 @@ export flux_training_params
export probs, logits
export standard_models_catalogue, all_models_catalogue, model_evaluation, predict_label

# Convergence
include("convergence/Convergence.jl")
using .Convergence
export conditions_satisfied

### Objectives
# ℓ( ℳ[𝒟](xᵢ) , target ) + λ cost(xᵢ)
###
@@ -79,12 +84,9 @@ export REVISEGenerator
export DiCEGenerator
export WachterGenerator
export generator_catalogue
export generate_perturbations, conditions_satisfied
export generate_perturbations
export @objective

include("convergence/Convergence.jl")
using .Convergence

### CounterfactualExplanation
# argmin
###
4 changes: 2 additions & 2 deletions src/convergence/Convergence.jl
Original file line number Diff line number Diff line change
@@ -4,7 +4,6 @@ using Distributions
using Flux
using LinearAlgebra
using ..CounterfactualExplanations
using ..Generators
using ..Models

include("decision_threshold.jl")
@@ -69,11 +68,12 @@ end
export convergence_catalogue
export converged
export get_convergence_type
export hinge_loss, invalidation_rate
export invalidation_rate
export threshold_reached
export DecisionThresholdConvergence
export GeneratorConditionsConvergence
export InvalidationRateConvergence
export MaxIterConvergence
export conditions_satisfied

end
11 changes: 10 additions & 1 deletion src/convergence/generator_conditions.jl
Original file line number Diff line number Diff line change
@@ -52,5 +52,14 @@ function converged(
ce::AbstractCounterfactualExplanation,
x::Union{AbstractArray,Nothing}=nothing,
)
return threshold_reached(ce, x) && Generators.conditions_satisfied(ce.generator, ce)
return threshold_reached(ce, x) && conditions_satisfied(ce.generator, ce)
end

"""
conditions_satisfied(gen::AbstractGenerator, ce::AbstractCounterfactualExplanation)
This function is overloaded in the `Generators` module to check whether the counterfactual search has converged with respect to generator conditions.
"""
function conditions_satisfied(gen::AbstractGenerator, ce::AbstractCounterfactualExplanation)
return true
end
51 changes: 13 additions & 38 deletions src/convergence/invalidation_rate.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using ChainRulesCore: ignore_derivatives
using Distributions: Distributions
using Flux: Flux
using LinearAlgebra: LinearAlgebra
@@ -40,44 +41,18 @@ Calculates the invalidation rate of a counterfactual explanation.
The invalidation rate of the counterfactual explanation.
"""
function invalidation_rate(ce::AbstractCounterfactualExplanation)
index_target = findfirst(map(x -> x == ce.target, ce.data.y_levels))
f_loss = logits(ce.M, CounterfactualExplanations.decode_state(ce))[index_target]
grad = []
for i in 1:length(ce.counterfactual_state)
push!(
grad,
Flux.gradient(
() -> logits(ce.M, CounterfactualExplanations.decode_state(ce))[i],
Flux.params(ce.counterfactual_state),
)[ce.counterfactual_state],
)
z = []
ignore_derivatives() do
index_target = get_target_index(ce.data.y_levels, ce.target)
f_loss = logits(ce.M, CounterfactualExplanations.decode_state(ce))[index_target]
grad = Flux.gradient(
() -> logits(ce.M, CounterfactualExplanations.decode_state(ce))[index_target],
Flux.params(ce.counterfactual_state),
)[ce.counterfactual_state]
denominator = sqrt(ce.convergence.variance) * norm(grad)
normalized_gradient = f_loss / denominator
push!(z, normalized_gradient)
end
gradᵀ = LinearAlgebra.transpose(grad)

identity_matrix = LinearAlgebra.Matrix{Float32}(
LinearAlgebra.I, length(grad), length(grad)
)
denominator = sqrt(gradᵀ * ce.convergence.variance * identity_matrix * grad)[1]

normalized_gradient = f_loss / denominator
ϕ = Distributions.cdf(Distributions.Normal(0, 1), normalized_gradient)
ϕ = Distributions.cdf(Distributions.Normal(0, 1), z[1])
return 1 - ϕ
end

"""
hinge_loss(convergence::InvalidationRateConvergence, ce::AbstractCounterfactualExplanation)
Calculates the hinge loss of a counterfactual explanation.
# Arguments
- `convergence::InvalidationRateConvergence`: The convergence criterion to use.
- `ce::AbstractCounterfactualExplanation`: The counterfactual explanation to calculate the hinge loss for.
# Returns
The hinge loss of the counterfactual explanation.
"""
function hinge_loss(
convergence::InvalidationRateConvergence, ce::AbstractCounterfactualExplanation
)
return max(0, invalidation_rate(ce) - convergence.invalidation_rate)
end
4 changes: 2 additions & 2 deletions src/generators/Generators.jl
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@ using ..GenerativeModels
using Flux
using LinearAlgebra
using ..Models
using ..Convergence
using ..Objectives
using Statistics: Statistics
using DataFrames: DataFrames
@@ -30,11 +31,10 @@ export REVISEGenerator
export WachterGenerator
export FeatureTweakGenerator
export generator_catalogue
export generate_perturbations, conditions_satisfied
export generate_perturbations
export GradientBasedGenerator
export @objective, @with_optimiser, @search_feature_space, @search_latent_space
export JSMADescent
export hinge_loss
export predictive_entropy
export ProbeGenerator

48 changes: 36 additions & 12 deletions src/generators/gradient_based/generators.jl
Original file line number Diff line number Diff line change
@@ -5,49 +5,63 @@ function GenericGenerator(; λ::AbstractFloat=0.1, kwargs...)
return GradientBasedGenerator(; penalty=default_distance, λ=λ, kwargs...)
end

"Constructor for `ECCoGenerator`. This corresponds to the generator proposed in https://arxiv.org/abs/2312.10648, without the conformal set size penalty."
const DOC_ECCCo = "For details, see Altmeyer et al. ([2024](https://ojs.aaai.org/index.php/AAAI/article/view/28956))."

"Constructor for `ECCoGenerator`. This corresponds to the generator proposed in https://arxiv.org/abs/2312.10648, without the conformal set size penalty. $DOC_ECCCo"
function ECCoGenerator(; λ::Vector{<:AbstractFloat}=[0.1, 1.0], kwargs...)
_penalties = [default_distance, Objectives.energy_constraint]
return GradientBasedGenerator(; penalty=_penalties, λ=λ, kwargs...)
end

"Constructor for `WachterGenerator`."
const DOC_Wachter = "For details, see Wachter et al. ([2018](https://arxiv.org/abs/1711.00399))."

"Constructor for `WachterGenerator`. $DOC_Wachter"
function WachterGenerator(; λ::AbstractFloat=0.1, kwargs...)
return GradientBasedGenerator(; penalty=Objectives.distance_mad, λ=λ, kwargs...)
end

"Constructor for `DiCEGenerator`."
const DOC_DiCE = "For details, see Mothilal et al. ([2020](https://arxiv.org/abs/1905.07697))."

"Constructor for `DiCEGenerator`. $DOC_DiCE"
function DiCEGenerator(; λ::Vector{<:AbstractFloat}=[0.1, 0.1], kwargs...)
_penalties = [default_distance, Objectives.ddp_diversity]
return GradientBasedGenerator(; penalty=_penalties, λ=λ, kwargs...)
end

"Constructor for `ClaPGenerator`."
const DOC_SaTML = "For details, see Altmeyer et al. ([2023](https://ieeexplore.ieee.org/abstract/document/10136130))."

"Constructor for `ClaPGenerator`. $DOC_SaTML"
function ClaPROARGenerator(; λ::Vector{<:AbstractFloat}=[0.1, 0.5], kwargs...)
_penalties = [default_distance, Objectives.model_loss_penalty]
return GradientBasedGenerator(; penalty=_penalties, λ=λ, kwargs...)
end

"Constructor for `GravitationalGenerator`."
"Constructor for `GravitationalGenerator`. $DOC_SaTML"
function GravitationalGenerator(; λ::Vector{<:AbstractFloat}=[0.1, 0.5], kwargs...)
_penalties = [default_distance, Objectives.distance_from_target]
return GradientBasedGenerator(; penalty=_penalties, λ=λ, kwargs...)
end

"Constructor for `REVISEGenerator`."
const DOC_REVISE = "For details, see Joshi et al. ([2019](https://arxiv.org/abs/1907.09615))."

"Constructor for `REVISEGenerator`. $DOC_REVISE"
function REVISEGenerator(; λ::AbstractFloat=0.1, latent_space=true, kwargs...)
return GradientBasedGenerator(;
penalty=default_distance, λ=λ, latent_space=latent_space, kwargs...
)
end

"Constructor for `GreedyGenerator`."
const DOC_Greedy = "For details, see Schut et al. ([2021](https://proceedings.mlr.press/v130/schut21a/schut21a.pdf))."

"Constructor for `GreedyGenerator`. $DOC_Greedy"
function GreedyGenerator(; η=0.1, n=nothing, kwargs...)
opt = CounterfactualExplanations.Generators.JSMADescent(; η=η, n=n)
return GradientBasedGenerator(; penalty=default_distance, λ=0.0, opt=opt, kwargs...)
end

"Constructor for `CLUEGenerator`."
const DOC_CLUE = "For details, see Antoran et al. ([2021](https://arxiv.org/abs/2006.06848))."

"Constructor for `CLUEGenerator`. $DOC_CLUE"
function CLUEGenerator(; λ::AbstractFloat=0.1, latent_space=true, kwargs...)
return GradientBasedGenerator(;
loss=predictive_entropy,
@@ -58,14 +72,24 @@ function CLUEGenerator(; λ::AbstractFloat=0.1, latent_space=true, kwargs...)
)
end

"Constructor for `ProbeGenerator`."
const DOC_Probe = "For details, see Pawelczyk et al. ([2022](https://proceedings.mlr.press/v151/pawelczyk22a/pawelczyk22a.pdf))."

const DOC_Probe_warn = "The `ProbeGenerator` is currenlty not working adequately. In particular, gradients are not computed with respect to the Hinge loss term proposed in the paper. It is still possible, however, to use this generator to achieve a desired invalidation rate. See issue [#376](https://github.com/JuliaTrustworthyAI/CounterfactualExplanations.jl/issues/376) for details."

"""
Constructor for `ProbeGenerator`. $DOC_Probe
## Warning
$DOC_Probe_warn
"""
function ProbeGenerator(;
λ::AbstractFloat=0.1,
λ::Vector{<:AbstractFloat}=[0.1, 1.0],
loss::Symbol=:logitbinarycrossentropy,
penalty=Objectives.distance_l1,
penalty=[Objectives.distance_l1, Objectives.hinge_loss],
kwargs...,
)
@assert haskey(losses_catalogue, loss) "Loss function not found in catalogue."
@warn DOC_Probe_warn
user_loss = Objectives.losses_catalogue[loss]
return GradientBasedGenerator(; loss=user_loss, penalty=penalty, λ=λ, kwargs...)
end
18 changes: 7 additions & 11 deletions src/generators/gradient_based/loss.jl
Original file line number Diff line number Diff line change
@@ -45,15 +45,11 @@ It simply computes the weighted sum over partial derivates. It assumes that `Zyg
If the counterfactual is being generated using Probe, the hinge loss is added to the gradient.
"""
function (generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation)
return ∂ℓ(generator, ce) .+ ∂h(generator, ce) .+ hinge_loss(ce.convergence, ce)
end

"""
hinge_loss(convergence::AbstractConvergence, ce::AbstractCounterfactualExplanation)
The default hinge loss for any convergence criterion.
Can be overridden inside the `Convergence` module as part of the definition of specific convergence criteria.
"""
function hinge_loss(convergence::AbstractConvergence, ce::AbstractCounterfactualExplanation)
return 0
grad_loss = ∂ℓ(generator, ce)
# println("Loss:")
# display(grad_loss)
grad_pen = ∂h(generator, ce)
# println("Penality:")
# display(grad_pen)
return grad_loss .+ grad_pen
end
10 changes: 8 additions & 2 deletions src/generators/gradient_based/utils.jl
Original file line number Diff line number Diff line change
@@ -8,18 +8,24 @@ function _replace_nans(Δcounterfactual_state::AbstractArray, old_new::Pair=(NaN
end

"""
conditions_satisfied(generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation)
Convergence.conditions_satisfied(generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation)
The default method to check if the all conditions for convergence of the counterfactual search have been satisified for gradient-based generators.
By default, gradient-based search is considered to have converged as soon as the proposed feature changes for all features are smaller than one percent of its standard deviation.
"""
function conditions_satisfied(
function Convergence.conditions_satisfied(
generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation
)
Δcounterfactual_state = (generator, ce)
Δcounterfactual_state = CounterfactualExplanations.apply_mutability(
ce, Δcounterfactual_state
)
if !hasfield(typeof(ce.convergence), :gradient_tol)
# Temporary fix due to the fact that `ProbeGenerator` relies on `InvalidationRateConvergence`.
@warn "Checking for generator conditions convergence is not implemented for this generator type. Return `false`." maxlog =
1
return false
end
τ = ce.convergence.gradient_tol
satisfied = map(
x -> all(abs.(x) .< τ),
2 changes: 2 additions & 0 deletions src/objectives/Objectives.jl
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@ export losses_catalogue
export distance, distance_mad, distance_l0, distance_l1, distance_l2, distance_linf
export ddp_diversity
export EnergyDifferential
export hinge_loss
export penalties_catalogue

const losses_catalogue = Dict(
@@ -35,6 +36,7 @@ const penalties_catalogue = Dict(
:ddp_diversity => ddp_diversity,
:energy_constraint => energy_constraint,
:energy_differential => EnergyDifferential(),
:hinge_loss => hinge_loss,
)

end
22 changes: 22 additions & 0 deletions src/objectives/penalties.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using CounterfactualExplanations: polynomial_decay
using CounterfactualExplanations.Models
using CounterfactualExplanations.Convergence
using EnergySamplers: EnergySamplers
using LinearAlgebra: LinearAlgebra, det, norm
using Random: Random
@@ -274,3 +275,24 @@ function EnergySamplers.energy_differential(M::AbstractModel, xgen, xsampled, y:
f = M.fitresult.fitresult
return EnergySamplers.energy_differential(f, xgen, xsampled, y)
end

"""
hinge_loss(ce::AbstractCounterfactualExplanation)
Calculates the hinge loss of a counterfactual explanation with `InvalidationRateConvergence`.
# Arguments
- `ce::AbstractCounterfactualExplanation`: The counterfactual explanation to calculate the hinge loss for.
# Returns
The hinge loss of the counterfactual explanation.
"""
function hinge_loss(ce::AbstractCounterfactualExplanation)
typeof(ce.M.type) <: Models.AbstractFluxNN || throw(NotImplementedModel(ce.M))
if !(ce.convergence isa InvalidationRateConvergence)
@warn "The hinge loss is only defined for `InvalidationRateConvergence`s. Setting convergence to default `InvalidationRateConvergence`." maxlog =
1
ce.convergence = InvalidationRateConvergence()
end
return max(0, invalidation_rate(ce) - ce.convergence.invalidation_rate)
end
Loading

0 comments on commit 3a8cdc3

Please sign in to comment.