diff --git a/CHANGELOG.md b/CHANGELOG.md
index 99faf4a9f..e748524d9 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -6,7 +6,17 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
*Note*: We try to adhere to these practices as of version [v1.1.1].
-## Version [1.4.0]
+## Version [1.4.1] - 2024-12-19
+
+### Changed
+
+- Updated dependencies. [#504]
+
+### Removed
+
+- Removed everything related to GrowingSpheres. [#504]
+
+## Version [1.4.0] - 2024-12-19
### Added
diff --git a/Project.toml b/Project.toml
index 59826e80f..7545116aa 100755
--- a/Project.toml
+++ b/Project.toml
@@ -1,7 +1,7 @@
name = "CounterfactualExplanations"
uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
authors = ["Patrick Altmeyer
and contributors"]
-version = "1.4.0"
+version = "1.4.1"
[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
@@ -44,13 +44,13 @@ NeuroTreeExt = "NeuroTreeModels"
[compat]
Aqua = "0.8"
CategoricalArrays = "0.10"
-CausalInference = "0.17.0"
+CausalInference = "0.17, 0.18"
ChainRulesCore = "1.15"
DataFrames = "1"
DecisionTree = "0.12.3, 0.12.4"
Distributions = "0.25.97"
EnergySamplers = "1.0"
-Flux = "0.12, 0.13, 0.14"
+Flux = "0.12, 0.13, 0.14, 0.15, 0.16"
Graphs = "1.11.1"
JointEnergyModels = "0.1.7"
LaplaceRedux = "0.1.4, 0.2, 1.0"
diff --git a/README.md b/README.md
index 40cdad007..999541aac 100644
--- a/README.md
+++ b/README.md
@@ -97,7 +97,9 @@ ce = generate_counterfactual(
plot(ce)
```
-![](README_files/figure-commonmark/cell-3-output-1.svg)
+ [ Info: No target label supplied, using first.
+
+![](README_files/figure-commonmark/cell-3-output-2.svg)
### Example: Give Me Some Credit
@@ -148,7 +150,7 @@ To this end, we specify a counterfactual generator of our choice:
generator = DiCEGenerator(λ=[0.1,0.3])
```
-Here, we have chosen to use the `GradientBasedGenerator` to move the individual from its factual label 1 to the target label 2.
+Here, we have chosen to use the `CounterfactualExplanations.Generators.GradientBasedGenerator` to move the individual from its factual label 1 to the target label 2.
With all of our ingredients specified, we finally generate counterfactuals using a simple API call:
@@ -162,7 +164,9 @@ ce = generate_counterfactual(
The plot below shows the resulting counterfactual path:
-![](README_files/figure-commonmark/cell-16-output-1.svg)
+ [ Info: No target label supplied, using first.
+
+![](README_files/figure-commonmark/cell-16-output-2.svg)
## ☑️ Implemented Counterfactual Generators
@@ -176,7 +180,6 @@ Currently, the following counterfactual generators are implemented:
- Generic
- GravitationalGenerator (Altmeyer et al. 2023)
- Greedy (Schut et al. 2021)
-- GrowingSpheres (Laugel et al. 2017)
- MINT (Karimi et al. 2020) (**causal CE**)
- PROBE (Pawelczyk et al. 2023)
- REVISE (Joshi et al. 2019)
@@ -251,8 +254,6 @@ Kaggle. 2011. “Give Me Some Credit, Improve on the State of the Art in Credit
Karimi, Amir-Hossein, Julius Von Kügelgen, Bernhard Schölkopf, and Isabel Valera. 2020. “Algorithmic Recourse Under Imperfect Causal Knowledge: A Probabilistic Approach.” .
-Laugel, Thibault, Marie-Jeanne Lesot, Christophe Marsala, Xavier Renard, and Marcin Detyniecki. 2017. “Inverse Classification for Comparison-Based Interpretability in Machine Learning.” .
-
Mothilal, Ramaravind K, Amit Sharma, and Chenhao Tan. 2020. “Explaining Machine Learning Classifiers Through Diverse Counterfactual Explanations.” In *Proceedings of the 2020 Conference on Fairness, Accountability, and Transparency*, 607–17. .
Pawelczyk, Martin, Teresa Datta, Johannes van-den-Heuvel, Gjergji Kasneci, and Himabindu Lakkaraju. 2023. “Probabilistically Robust Recourse: Navigating the Trade-Offs Between Costs and Robustness in Algorithmic Recourse.” .
diff --git a/README.qmd b/README.qmd
index 8f5347b32..e95783002 100644
--- a/README.qmd
+++ b/README.qmd
@@ -9,7 +9,9 @@ crossref:
tbl-prefix: Table
bibliography: https://raw.githubusercontent.com/pat-alt/bib/main/bib.bib
output: asis
-jupyter: julia-1.10
+engine: julia
+julia:
+ exeflags: ["--project=docs/"]
execute:
freeze: auto
eval: true
diff --git a/README_files/figure-commonmark/cell-11-output-1.svg b/README_files/figure-commonmark/cell-11-output-1.svg
index c8775e268..86855e8b1 100644
--- a/README_files/figure-commonmark/cell-11-output-1.svg
+++ b/README_files/figure-commonmark/cell-11-output-1.svg
@@ -1,146 +1,150 @@
diff --git a/README_files/figure-commonmark/cell-16-output-2.svg b/README_files/figure-commonmark/cell-16-output-2.svg
new file mode 100644
index 000000000..8d72b13a6
--- /dev/null
+++ b/README_files/figure-commonmark/cell-16-output-2.svg
@@ -0,0 +1,657 @@
+
+
diff --git a/README_files/figure-commonmark/cell-3-output-2.svg b/README_files/figure-commonmark/cell-3-output-2.svg
new file mode 100644
index 000000000..9b6c8c79b
--- /dev/null
+++ b/README_files/figure-commonmark/cell-3-output-2.svg
@@ -0,0 +1,361 @@
+
+
diff --git a/README_files/figure-commonmark/cell-6-output-1.svg b/README_files/figure-commonmark/cell-6-output-1.svg
index 15b8f698c..593154bf3 100644
--- a/README_files/figure-commonmark/cell-6-output-1.svg
+++ b/README_files/figure-commonmark/cell-6-output-1.svg
@@ -1,98 +1,96 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/docs/setup_docs.jl b/docs/setup_docs.jl
index 1b20ca907..648e5a67d 100644
--- a/docs/setup_docs.jl
+++ b/docs/setup_docs.jl
@@ -34,7 +34,7 @@ setup_docs = quote
# Setup:
theme(:wong)
- Random.seed!(2022)
+ Random.seed!(2025)
synthetic = TaijaData.load_synthetic_data()
ENV["DATADEPS_ALWAYS_ACCEPT"] = "true"
diff --git a/docs/src/_intro.qmd b/docs/src/_intro.qmd
index 279f01e16..7ee8f985e 100644
--- a/docs/src/_intro.qmd
+++ b/docs/src/_intro.qmd
@@ -115,6 +115,8 @@ plot(ce)
```{julia}
#| echo: false
+Random.seed!(2024)
+
# Data and Model:
data = TaijaData.load_gmsc(10000)
counterfactual_data = CounterfactualExplanations.DataPreprocessing.CounterfactualData(data...)
@@ -377,7 +379,6 @@ Currently, the following counterfactual generators are implemented:
- Generic
- GravitationalGenerator [@altmeyer2023endogenous]
- Greedy [@schut2021generating]
-- GrowingSpheres [@laugel2017inverse]
- MINT [@karimi2020algorithmic] (**causal CE**)
- PROBE [@pawelczyk2022probabilistically]
- REVISE [@joshi2019realistic]
diff --git a/docs/src/index.qmd b/docs/src/index.qmd
index 66772edd0..ef86e75b9 100644
--- a/docs/src/index.qmd
+++ b/docs/src/index.qmd
@@ -1,3 +1,9 @@
+---
+engine: julia
+julia:
+ exeflags: ["--project=docs/"]
+---
+
```@meta
CurrentModule = CounterfactualExplanations
```
diff --git a/docs/src/www/mnist_factual.png b/docs/src/www/mnist_factual.png
index 9709267ce..5e8c6b219 100644
Binary files a/docs/src/www/mnist_factual.png and b/docs/src/www/mnist_factual.png differ
diff --git a/src/counterfactuals/growing_spheres.jl b/src/counterfactuals/growing_spheres.jl
deleted file mode 100644
index 20a76bf28..000000000
--- a/src/counterfactuals/growing_spheres.jl
+++ /dev/null
@@ -1,41 +0,0 @@
-
-"""
- generate_counterfactual(
- x::Matrix,
- target::RawTargetType,
- data::DataPreprocessing.CounterfactualData,
- M::Models.AbstractModel,
- generator::Generators.GrowingSpheresGenerator;
- num_counterfactuals::Int=1,
- convergence::Union{AbstractConvergence,Symbol}=Convergence.DecisionThresholdConvergence(;
- decision_threshold=(1 / length(data.y_levels)), max_iter=1000
- ),
- kwrgs...,
- )
-
-Overloads the `generate_counterfactual` for the `GrowingSpheresGenerator` generator.
-"""
-function generate_counterfactual(
- x::Matrix,
- target::RawTargetType,
- data::DataPreprocessing.CounterfactualData,
- M::Models.AbstractModel,
- generator::Generators.GrowingSpheresGenerator;
- num_counterfactuals::Int=1,
- convergence::Union{AbstractConvergence,Symbol}=Convergence.DecisionThresholdConvergence(;
- decision_threshold=(1 / length(data.y_levels)), max_iter=1000
- ),
- kwrgs...,
-)
- ce = CounterfactualExplanation(
- x, target, data, M, generator; num_counterfactuals, convergence
- )
-
- Generators.growing_spheres_generation!(ce)
- Generators.feature_selection!(ce)
-
- # growing spheres does not support encodings, thus counterfactual is just counterfactual_state
- ce.counterfactual = ce.counterfactual_state
-
- return ce
-end
diff --git a/src/generative_models/encoders.jl b/src/generative_models/encoders.jl
index 474203908..ee83040e9 100644
--- a/src/generative_models/encoders.jl
+++ b/src/generative_models/encoders.jl
@@ -11,7 +11,7 @@ struct Encoder
μ::Any
logσ::Any
end
-Flux.@functor Encoder
+Flux.@layer Encoder
function Encoder(input_dim::Int, latent_dim::Int, hidden_dim::Int; activation=sigmoid)
return Encoder(
diff --git a/src/generative_models/vae.jl b/src/generative_models/vae.jl
index 8080c2b15..0bae8ab71 100644
--- a/src/generative_models/vae.jl
+++ b/src/generative_models/vae.jl
@@ -58,7 +58,7 @@ function VAE(input_dim; kws...)
return VAE(encoder, decoder, args, false)
end
-Flux.@functor VAE
+Flux.@layer VAE
function Flux.trainable(generative_model::VAE)
return (encoder=generative_model.encoder, decoder=generative_model.decoder)
diff --git a/src/generators/non_gradient_based/growing_spheres/growing_spheres.jl b/src/generators/non_gradient_based/growing_spheres/growing_spheres.jl
deleted file mode 100644
index 83568c248..000000000
--- a/src/generators/non_gradient_based/growing_spheres/growing_spheres.jl
+++ /dev/null
@@ -1,260 +0,0 @@
-using LinearAlgebra: LinearAlgebra
-using Random: Random
-
-"Growing Spheres counterfactual generator class."
-mutable struct GrowingSpheresGenerator <: AbstractNonGradientBasedGenerator
- n::Union{Nothing,Integer}
- η::Union{Nothing,AbstractFloat}
- latent_space::Bool
- dim_reduction::Bool
-end
-
-"""
- GrowingSpheresGenerator(; n::Int=100, η::Float64=0.1, kwargs...)
-
-Constructs a new Growing Spheres Generator object.
-"""
-function GrowingSpheresGenerator(;
- n::Union{Nothing,Integer}=100, η::Union{Nothing,AbstractFloat}=0.1
-)
- return GrowingSpheresGenerator(n, η, false, false)
-end
-
-"""
- growing_spheres_generation(ce::AbstractCounterfactualExplanation)
-
-Generate counterfactual candidates using the growing spheres generation algorithm.
-
-# Arguments
-- `ce::AbstractCounterfactualExplanation`: An instance of the `AbstractCounterfactualExplanation` type representing the counterfactual explanation.
-
-# Returns
-- `nothing`
-
-This function applies the growing spheres generation algorithm to generate counterfactual candidates. It starts by generating random points uniformly on a sphere, gradually reducing the search space until no counterfactuals are found. Then it expands the search space until at least one counterfactual is found or the maximum number of iterations is reached.
-
-The algorithm iteratively generates counterfactual candidates and predicts their labels using the model stored in `ce.M`. It checks if any of the predicted labels are different from the factual class. The process of reducing the search space involves halving the search radius, while the process of expanding the search space involves increasing the search radius.
-"""
-function growing_spheres_generation!(ce::AbstractCounterfactualExplanation)
- generator = ce.generator
- model = ce.M
- factual = ce.factual
- counterfactual_data = ce.data
- target = [ce.target]
- max_iter = 1000
-
- # Copy hyperparameters
- n = generator.n
- η = convert(eltype(factual), generator.η)
-
- # Generate random points uniformly on a sphere
- counterfactual_candidates = hyper_sphere_coordinates(n, factual, 0.0, η)
-
- if (factual == target)
- ce.counterfactual_state = factual
- return nothing
- end
-
- # Predict labels for each candidate counterfactual
- counterfactual = find_counterfactual(
- model, target, counterfactual_data, counterfactual_candidates
- )
-
- # Repeat until there's no counterfactual points (process of removing all counterfactuals by reducing the search space)
- while (!isnothing(counterfactual))
- η /= 2
- a₀ = convert(eltype(factual), 0.0)
-
- counterfactual_candidates = hyper_sphere_coordinates(n, factual, a₀, η)
- counterfactual = find_counterfactual(
- model, target, counterfactual_data, counterfactual_candidates
- )
-
- max_iter -= 1
- if max_iter == 0
- break
- end
- end
-
- # Update path
- ce.search[:iteration_count] += n
- for i in eachindex(counterfactual_candidates[1, :])
- push!(ce.search[:path], reshape(counterfactual_candidates[:, i], :, 1))
- end
-
- # Initialize boundaries of the sphere's radius
- a₀, a₁ = η, 2η
-
- # Repeat until there's at least one counterfactual (process of expanding the search space)
- while (isnothing(counterfactual))
- a₀ = a₁
- a₁ += η
-
- counterfactual_candidates = hyper_sphere_coordinates(n, factual, a₀, a₁)
- counterfactual = find_counterfactual(
- model, target, counterfactual_data, counterfactual_candidates
- )
-
- max_iter -= 1
- if max_iter == 0
- break
- end
- end
-
- # Update path
- ce.search[:iteration_count] += n
- for i in eachindex(counterfactual_candidates[1, :])
- push!(ce.search[:path], reshape(counterfactual_candidates[:, i], :, 1))
- end
-
- ce.counterfactual_state = counterfactual_candidates[:, counterfactual]
- return nothing
-end
-
-"""
- feature_selection!(ce::AbstractCounterfactualExplanation)
-
-Perform feature selection to find the dimension with the closest (but not equal) values between the `ce.factual` (factual) and `ce.counterfactual_state` (counterfactual) arrays.
-
-# Arguments
-- `ce::AbstractCounterfactualExplanation`: An instance of the `AbstractCounterfactualExplanation` type representing the counterfactual explanation.
-
-# Returns
-- `nothing`
-
-The function iteratively modifies the `ce.counterfactual_state` counterfactual array by updating its elements to match the corresponding elements in the `ce.factual` factual array, one dimension at a time, until the predicted label of the modified `ce.counterfactual_state` matches the predicted label of the `ce.factual` array.
-"""
-function feature_selection!(ce::AbstractCounterfactualExplanation)
- model = ce.M
- counterfactual_data = ce.data
- factual = ce.factual
- target = [ce.target]
-
- # Assign the initial counterfactual to both counterfactual′ and counterfactual″
- counterfactual′ = ce.counterfactual_state
- counterfactual″ = ce.counterfactual_state
-
- factual_class = CounterfactualExplanations.Models.predict_label(
- model, counterfactual_data, factual
- )[1]
-
- while (
- factual_class != CounterfactualExplanations.Models.predict_label(
- model, counterfactual_data, counterfactual′
- ) &&
- target == CounterfactualExplanations.Models.predict_label(
- model, counterfactual_data, counterfactual′
- )
- )
- counterfactual″ = counterfactual′
- i = find_closest_dimension(factual, counterfactual′)
- counterfactual′[i] = factual[i]
-
- ce.search[:iteration_count] += 1
- push!(ce.search[:path], reshape(counterfactual″, :, 1))
- end
-
- ce.counterfactual_state = counterfactual″
- return nothing
-end
-
-"""
- hyper_sphere_coordinates(n_search_samples::Int, instance::Vector{Float64}, low::Int, high::Int; p_norm::Int=2)
-
-Generates candidate counterfactuals using the growing spheres method based on hyper-sphere coordinates.
-
-The implementation follows the Random Point Picking over a sphere algorithm described in the paper:
-"Learning Counterfactual Explanations for Tabular Data" by Pawelczyk, Broelemann & Kascneci (2020),
-presented at The Web Conference 2020 (WWW). It ensures that points are sampled uniformly at random
-using insights from: http://mathworld.wolfram.com/HyperspherePointPicking.html
-
-The growing spheres method is originally proposed in the paper:
-"Comparison-based Inverse Classification for Interpretability in Machine Learning" by Thibaut Laugel et al (2018),
-presented at the International Conference on Information Processing and Management of Uncertainty in Knowledge-Based Systems (2018).
-
-# Arguments
-- `n_search_samples::Int`: The number of search samples (int > 0).
-- `instance::AbstractArray`: The input point array.
-- `low::AbstractFloat`: The lower bound (float >= 0, l < h).
-- `high::AbstractFloat`: The upper bound (float >= 0, h > l).
-- `p_norm::Integer`: The norm parameter (int >= 1).
-
-# Returns
-- `candidate_counterfactuals::Array`: An array of candidate counterfactuals.
-"""
-function hyper_sphere_coordinates(
- n_search_samples::Integer,
- instance::AbstractArray,
- low::AbstractFloat,
- high::AbstractFloat;
- p_norm::Integer=2,
-)
- delta_instance = Random.randn(n_search_samples, length(instance))
- delta_instance = convert.(eltype(instance), delta_instance)
-
- # length range [l, h)
- dist = Random.rand(n_search_samples) .* (high - low) .+ low
- norm_p = LinearAlgebra.norm(delta_instance, p_norm)
- # rescale/normalize factor
- d_norm = dist ./ norm_p
- delta_instance .= delta_instance .* d_norm
- instance_matrix = repeat(reshape(instance, 1, length(instance)), n_search_samples)
- candidate_counterfactuals = instance_matrix + delta_instance
-
- return transpose(candidate_counterfactuals)
-end
-
-"""
- find_counterfactual(model, factual_class, counterfactual_data, counterfactual_candidates)
-
-Find the first counterfactual index by predicting labels.
-
-# Arguments
-- `model`: The fitted model used for prediction.
-- `target_class`: Expected target class.
-- `counterfactual_data`: Data required for counterfactual generation.
-- `counterfactual_candidates`: The array of counterfactual candidates.
-
-# Returns
-- `counterfactual`: The index of the first counterfactual found.
-"""
-function find_counterfactual(
- model, target_class, counterfactual_data, counterfactual_candidates
-)
- predicted_labels = map(
- e -> CounterfactualExplanations.Models.predict_label(model, counterfactual_data, e),
- eachcol(counterfactual_candidates),
- )
- counterfactual = findfirst(predicted_labels .== target_class)
-
- return counterfactual
-end
-
-"""
- find_closest_dimension(factual, counterfactual)
-
-Find the dimension with the closest (but not equal) values between the `factual` and `counterfactual` arrays.
-
-# Arguments
-- `factual`: The factual array.
-- `counterfactual`: The counterfactual array.
-
-# Returns
-- `closest_dimension`: The index of the dimension with the closest values.
-
-The function iterates over the indices of the `factual` array and calculates the absolute difference between the corresponding elements in the `factual` and `counterfactual` arrays. It returns the index of the dimension with the smallest difference, excluding dimensions where the values in `factual` and `counterfactual` are equal.
-"""
-function find_closest_dimension(factual, counterfactual)
- min_diff = typemax(eltype(factual))
- closest_dimension = -1
-
- for i in eachindex(factual)
- diff = abs(factual[i] - counterfactual[i])
- if diff < min_diff && factual[i] != counterfactual[i]
- min_diff = diff
- closest_dimension = i
- end
- end
-
- return closest_dimension
-end
diff --git a/test/generators/full_runs.jl b/test/generators/full_runs.jl
index c2da7bf5d..4d6ac2918 100644
--- a/test/generators/full_runs.jl
+++ b/test/generators/full_runs.jl
@@ -8,7 +8,7 @@
for (key, generator_) in generators
name = uppercasefirst(string(key))
- # Feature Tweak and Growing Spheres will be tested separately
+ # Feature Tweak will be tested separately
if generator_() isa Generators.FeatureTweakGenerator
continue
end
diff --git a/test/generators/growing_spheres.jl b/test/generators/growing_spheres.jl
deleted file mode 100644
index df95675e1..000000000
--- a/test/generators/growing_spheres.jl
+++ /dev/null
@@ -1,74 +0,0 @@
-using CounterfactualExplanations
-using CounterfactualExplanations.Evaluation
-using CounterfactualExplanations.Generators
-using CounterfactualExplanations.Models
-using DataFrames
-using Flux
-using LinearAlgebra
-using MLUtils
-using Random
-
-@testset "Growing Spheres" begin
- generator = CounterfactualExplanations.Generators.GrowingSpheresGenerator()
- models = CounterfactualExplanations.Models.standard_models_catalogue
- @testset "Models for synthetic data" begin
- for (key, value) in synthetic
- name = string(key)
- @testset "$name" begin
- counterfactual_data = value[:data]
- X = counterfactual_data.X
- # Loop over values of the dict
-
- for (model_name, model) in models
- name = string(model_name)
- @testset "$name" begin
- M = CounterfactualExplanations.Models.fit_model(
- counterfactual_data, model_name
- )
- # Randomly selected factual:
- Random.seed!(123)
- x = select_factual(counterfactual_data, rand(1:size(X, 2)))
- # Choose target:
- y = predict_label(M, counterfactual_data, x)
- target = get_target(counterfactual_data, y[1])
-
- @testset "Convergence" begin
- @testset "Non-trivial case" begin
- counterfactual_data.input_encoder = nothing
- # Threshold reached if converged:
- counterfactual = generate_counterfactual(
- x, target, counterfactual_data, M, generator;
- )
- @test CounterfactualExplanations.Models.predict_label(
- M,
- counterfactual_data,
- counterfactual.counterfactual_state,
- )[1] == target
-
- @test CounterfactualExplanations.terminated(counterfactual)
- end
-
- @testset "Trivial case (already in target class)" begin
- counterfactual_data.input_encoder = nothing
- # Already in target class:
- y = CounterfactualExplanations.Models.predict_label(
- M, counterfactual_data, x
- )
- target = y[1]
- γ = minimum([1 / length(counterfactual_data.y_levels), 0.5])
- counterfactual = CounterfactualExplanations.generate_counterfactual(
- x, target, counterfactual_data, M, generator;
- )
- cf = counterfactual.counterfactual
- if counterfactual.generator.latent_space == false
- @test isapprox(counterfactual.factual, cf; atol=1e-6)
- end
- @test CounterfactualExplanations.terminated(counterfactual)
- end
- end
- end
- end
- end
- end
- end
-end