Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Changelog

## development

### Introducing CTEImputer [#225](https://github.com/mmschlk/shapiq/pull/225)
Adds the [`CTEImputer`](https://github.com/mmschlk/shapiq/blob/main/src/shapiq/imputer/cte_imputer.py) following the [*compress then explain* (CTE)](https://openreview.net/forum?id=LiUfN9h0Lx) methodology.
It replaces missing features of the explanation point by values sampled from the background data, which is first subsampled using a distribution compression algorithm, specifically [Compress++](https://openreview.net/forum?id=lzupY5zjaU9) with [Kernel Thinning](https://www.jmlr.org/papers/v25/21-1334.html).
CTE has shown to provide accurate and stable estimates of explanations while being computationally efficient.
It is a new default imputer in `TabularExplainer`, removing the necessity to set `sample_size`.


## v1.4.1 (2025-11-10)

### Bugfix
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies = [
"requests",
"sparse-transform", # for sparse approximators (spex, proxyspex)
"galois", # for sparse approximators (spex, proxyspex)
"goodpoints", # for the compress then explain (cte) imputer
# plotting
"matplotlib",
"networkx",
Expand Down Expand Up @@ -55,6 +56,7 @@ keywords = [
"machine learning",
"interpretable machine learning",
"shap",
"shapiq",
"xai",
"explainable ai",
"interaction",
Expand Down
18 changes: 13 additions & 5 deletions src/shapiq/explainer/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


TabularExplainerApproximators = Literal["spex", "montecarlo", "svarm", "permutation", "regression"]
TabularExplainerImputers = Literal["marginal", "baseline", "conditional"]
TabularExplainerImputers = Literal["cte", "marginal", "baseline", "conditional"]
TabularExplainerIndices = ExplainerIndices


Expand All @@ -47,7 +47,7 @@ def __init__(
data: np.ndarray,
*,
class_index: int | None = None,
imputer: Imputer | TabularExplainerImputers = "marginal",
imputer: Imputer | TabularExplainerImputers = "cte",
approximator: (
Literal["auto"] | TabularExplainerApproximators | Approximator[TabularExplainerIndices]
) = "auto",
Expand All @@ -71,9 +71,9 @@ def __init__(

imputer: Either an :class:`~shapiq.games.imputer.Imputer` as implemented in the
:mod:`~shapiq.games.imputer` module, or a literal string from
``["marginal", "baseline", "conditional"]``. Defaults to ``"marginal"``, which
``["cte", "marginal", "baseline", "conditional"]``. Defaults to ``"cte"``, which
initializes the default
:class:`~shapiq.games.imputer.marginal_imputer.MarginalImputer` with its default
:class:`~shapiq.games.imputer.marginal_imputer.CTEImputer` with its default
parameters or as provided in ``kwargs``.

approximator: An :class:`~shapiq.approximator.Approximator` object to use for the
Expand Down Expand Up @@ -110,6 +110,7 @@ def __init__(
"""
from shapiq.imputer import (
BaselineImputer,
CTEImputer,
GenerativeConditionalImputer,
MarginalImputer,
TabPFNImputer,
Expand All @@ -128,7 +129,14 @@ def __init__(
stacklevel=2,
)

if imputer == "marginal":
if imputer == "cte":
self._imputer = CTEImputer(
self.predict,
self._data,
random_state=random_state,
**kwargs,
)
elif imputer == "marginal":
self._imputer = MarginalImputer(
self.predict,
self._data,
Expand Down
2 changes: 2 additions & 0 deletions src/shapiq/imputer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Imputer objects for the shapiq package."""

from .baseline_imputer import BaselineImputer
from .cte_imputer import CTEImputer
from .gaussian_copula_imputer import GaussianCopulaImputer
from .gaussian_imputer import GaussianImputer
from .generative_conditional_imputer import GenerativeConditionalImputer
Expand All @@ -11,6 +12,7 @@
"MarginalImputer",
"GenerativeConditionalImputer",
"BaselineImputer",
"CTEImputer",
"TabPFNImputer",
"GaussianImputer",
"GaussianCopulaImputer",
Expand Down
164 changes: 164 additions & 0 deletions src/shapiq/imputer/cte_imputer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
"""Implementation of the marginal imputer."""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
from goodpoints import compress

from .base import Imputer

if TYPE_CHECKING:
from shapiq.typing import CoalitionMatrix, GameValues, Model


class CTEImputer(Imputer):
"""The compress then explain (CTE) imputer for the shapiq package.

The CTE imputer replaces missing features of the explanation point ``x`` by values
sampled from the background data. Background data is first subsampled using a distribution
compression algorithm, and then rows are sampled jointly from the compressed background data.
This has shown to provide accurate and stable estimates of explanations while being computationally
efficient. For details, see the paper introducing CTE by Baniecki et al. (2025) [Ban25]_.

This corresponds to *interventional* imputation (often called *marginal fANOVA* in the
literature), as opposed to *observational* imputers that condition on observed features.

Examples:
>>> model = lambda x: np.sum(x, axis=1) # some dummy model
>>> data = np.random.rand(1000, 4) # some background data
>>> x_to_impute = np.array([[1, 1, 1, 1]]) # some data point to impute
>>> imputer = CTEImputer(model=model, data=data, x=x_to_impute, random_state=42)
>>> # get the model prediction with missing values
>>> imputer(np.array([[True, False, True, False]]))
np.array([2.01]) # some model prediction (might be different)
>>> # exchange the background data
>>> new_data = np.random.rand(1000, 4)
>>> imputer.init_background(data=new_data)

See Also:
- :class:`shapiq.imputer.MarginalImputer` for the marginal imputer.
- :class:`shapiq.imputer.BaselineImputer` for the baseline imputer.
- :class:`shapiq.imputer.base.Imputer` for the base imputer class.

References:
.. [Ban25] Baniecki, H., Casalicchio, G., Bischl, B., Biecek, P., (2025). Efficient and Accurate Explanation Estimation with Distribution Compression. In International Conference on Learning Representations. url: https://openreview.net/forum?id=LiUfN9h0Lx

"""


def __init__(
self,
model: Model,
data: np.ndarray,
*,
x: np.ndarray | None = None,
normalize: bool = True,
random_state: int | None = None,
) -> None:
"""Initializes the marginal imputer.

Args:
model: The model to explain as a callable function expecting a data points as input and
returning the model's predictions.

data: The background data to use for the explainer as a two-dimensional array
with shape ``(n_samples, n_features)``.

x: The explanation point to use the imputer on either as a 2-dimensional array with
shape ``(1, n_features)`` or as a vector with shape ``(n_features,)``. If ``None``,
the imputer must be fitted before it can be used.

normalize: A flag to normalize the game values. If ``True``, then the game values are
normalized and centered to be zero for the empty set of features.

random_state: The random state to use for sampling. If ``None``, the random state is not
fixed.
"""
super().__init__(
model=model,
data=data,
x=x,
random_state=random_state,
)

# setup attributes
self._replacement_data: np.ndarray = np.zeros((1, self.n_features))
self.init_background(self.data)

if normalize: # update normalization value
self.normalization_value = self.empty_prediction

def value_function(self, coalitions: CoalitionMatrix) -> GameValues:
"""Imputes the missing values of a data point and calls the model.

Args:
coalitions: A boolean array indicating which features are present (``True``) and which
are missing (``False``). The shape of the array must be ``(n_subsets, n_features)``.

Returns:
The model's predictions on the imputed data points. The shape of the array is
``(n_subsets, n_outputs)``.

"""
n_coalitions = coalitions.shape[0]
sample_size = self._replacement_data.shape[0]
outputs = np.zeros((sample_size, n_coalitions))
imputed_data = np.tile(self.x, (n_coalitions, 1))
for i in range(sample_size):
replacements = np.tile(self._replacement_data[i], (n_coalitions, 1))
imputed_data[~coalitions] = replacements[~coalitions]
predictions = self.predict(imputed_data)
outputs[i] = predictions
outputs = np.mean(outputs, axis=0) # average over the samples
# insert the better approximate empty prediction for the empty coalitions
outputs[~np.any(coalitions, axis=1)] = self.empty_prediction
return outputs

def init_background(self, data: np.ndarray) -> CTEImputer:
"""Initializes the imputer to a background data set.

The background data is used to sample replacement values for the missing features. To change
the background data, use this method.

Args:
data: The background data to use for the imputer. The shape of the array must
be ``(n_samples, n_features)``.

Returns:
The initialized imputer.

Examples:
>>> model = lambda x: np.sum(x, axis=1)
>>> data = np.random.rand(10, 3)
>>> imputer = MarginalImputer(model=model, data=data, x=data[0])
>>> new_data = np.random.rand(10, 3)
>>> imputer.init_background(data=new_data)

Raises:
UserWarning: If the sample size is larger than the number of data points in the
background data. In this case, the sample size is reduced to the number of data
points in the background data.

"""
d = data.shape[1]
sigma = np.sqrt(2 * d)
id_compressed = compress.compresspp_kt(data, kernel_type=b"gaussian", k_params=np.array([sigma**2]), g=4, seed=self.random_state)
self._replacement_data = data[id_compressed]
self.calc_empty_prediction() # reset the empty prediction to the new background data
return self

def calc_empty_prediction(self) -> float:
"""Runs the model on empty data points (all features missing) to get the empty prediction.

Returns:
The empty prediction of the model provided only missing features.

"""
empty_predictions = self.predict(self._replacement_data)
empty_prediction = float(np.mean(empty_predictions))
self.empty_prediction = empty_prediction
if self.normalize: # reset the normalization value
self.normalization_value = empty_prediction
return empty_prediction
Loading