Skip to content

Adding FastCAV #1622

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion captum/concept/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from captum.concept._core.cav import CAV
from captum.concept._core.concept import Concept, ConceptInterpreter
from captum.concept._core.tcav import TCAV
from captum.concept._utils.classifier import Classifier, DefaultClassifier
from captum.concept._utils.classifier import (
Classifier,
DefaultClassifier,
FastCAVClassifier,
)

__all__ = [
"CAV",
Expand All @@ -13,4 +17,5 @@
"TCAV",
"Classifier",
"DefaultClassifier",
"FastCAVClassifier",
]
337 changes: 334 additions & 3 deletions captum/concept/_utils/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
# pyre-strict

import random
import time
import warnings
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from captum._utils.models.linear_model import model
from captum._utils.models.linear_model.model import LinearModel, SkLearnSGDClassifier
from captum._utils.models.linear_model.train import NormLayer
from torch import Tensor
from torch.utils.data import DataLoader, TensorDataset

Expand Down Expand Up @@ -141,7 +143,7 @@ def __init__(self) -> None:
" `Classifer` abstract class",
stacklevel=2,
)
self.lm = model.SkLearnSGDClassifier(alpha=0.01, max_iter=1000, tol=1e-3)
self.lm: LinearModel = SkLearnSGDClassifier(alpha=0.01, max_iter=1000, tol=1e-3)

def train_and_eval(
self,
Expand Down Expand Up @@ -252,3 +254,332 @@ def _train_test_split(
torch.stack(y_train),
torch.stack(y_test),
)


class FastCAVClassifier(DefaultClassifier):
r"""Fast implementation of concept activation vectors calculation
using mean differences. This implementation requires balanced classes.
This implements the classifier proposed in the paper `FastCAV: Efficient
Computation of Concept Activation Vectors for Explaining Deep Neural Networks
<https://arxiv.org/abs/2505.17883>`_.

This classifier provides an efficient alternative to other CAV classifiers.

It is equivalent to an SVM when the following assumptions hold:

- **Gaussian Distribution**: The activation vectors for both the random samples
and the concept samples are assumed to follow independent multivariate
Gaussian distributions.
- **Equal Mixture**: The set of concept examples and the set of random examples
are of equal size :math:`(∣D_c​∣=∣D_r​∣)`, resulting in a uniform mixture of
the two Gaussian distributions.
- **Isotropic Covariance**: The within-class covariance matrices are assumed to
be isotropic, meaning they are proportional to the unit matrix. This is a
critical assumption that makes the FastCAV solution equivalent to the solution
of a Fisher discriminant analysis.
- **High-Dimensionality**: The method is analyzed in the context of
high-dimensional activation spaces, where the number of dimensions :math:`d` is
significantly larger than the number of samples :math:`n` (:math:`d >> n`).
In such spaces, the set of support vectors used by an SVM is likely to contain
most of the training samples, making `the SVM solution approximate the Fisher
discriminant solution
<https://link.springer.com/article/10.1023/A:1018677409366>`_,
and by extension, the FastCAV solution.

Note that default implementation slices input dataset into train and test
splits and keeps them in memory.
In case concept datasets are large, this can lead to out of memory and we
recommend to provide a custom Classier that extends `Classifier` abstract
class and handles large concept datasets accordingly.

Example:

>>> import torchvision
>>> from captum.concept import TCAV
>>> from captum.concept._utils.classifier import FastCAVClassifier
>>> from captum.attr import LayerIntegratedGradients
>>>
>>> model = torchvision.models.googlenet(pretrained=True)
>>> model = model.eval()
>>> clf = FastCAVClassifier()
>>> layers=['inception4c', 'inception4d', 'inception4e']
>>> mytcav = TCAV(model=model,
>>> layers=layers,
>>> classifier=clf,
>>> layer_attr_method = LayerIntegratedGradients(
>>> model, None, multiply_by_inputs=False))
>>> # ...
>>> # For a full workflow, follow `tutorials/TCAV_Image.ipynb` and
>>> # replace the classifier.
"""

def __init__(self) -> None:
self.lm = FastCAVLinearModel()


class FastCAVLinearModel(LinearModel):
"""
FastCAVLinearModel is a wrapper to convert `FastCAV` into a `LinearModel`.

Args:
**kwargs (Any): Additional keyword arguments passed to the
`LinearModel` base class.

Returns:
None
"""

def __init__(self, **kwargs) -> None:
super().__init__(train_fn=fastcav_train_linear_model, **kwargs)


def fastcav_train_linear_model(
model: LinearModel,
dataloader: DataLoader,
construct_kwargs: Dict[str, Any],
norm_input: bool = False,
**fit_kwargs: Any,
) -> Dict[str, float]:
r"""
Trains a `captum.concept._utils.models.linear_model.LinearModel` using
the `FastCAV` classifier. It follows closely the implementation of the
`linea_model.train.sklearn_train_linear_model` function.

This method consumes the entire dataloader to construct the dataset in
memory for training.

Please note that this assumes:

1. The dataset can fit into memory.

Args:
model (LinearModel): The model to train.
dataloader (DataLoader): The data to use. This will be exhausted and converted
to single tensors. Do not use an infinite dataloader.
construct_kwargs (dict): Arguments to pass to the FastCAV constructor.
FastCAV currently does not support any additional parameters.
norm_input (bool, optional): Whether or not to normalize the input.
Default: False
fit_kwargs (dict, optional): Other arguments to send to FastCAV's fit method.
FastCAV does not support sample weights or other fit arguments.
Default: None

Returns:
dict: A dictionary containing the train_time.
"""
# Extract data from dataloader
fast_classifier = FastCAV(**construct_kwargs)
num_batches = 0
xs: List[Tensor] = []
ys: List[Tensor] = []
ws: List[Tensor] = []
for data in dataloader:
if len(data) == 3:
x, y, w = data
else:
assert len(data) == 2
x, y = data
w = None

xs.append(x)
ys.append(y)
if w is not None:
ws.append(w)
num_batches += 1

x = torch.cat(xs, dim=0)
y = torch.cat(ys, dim=0)
if len(ws) > 0:
w = torch.cat(ws, dim=0)
else:
w = None

if norm_input:
mean, std = x.mean(0), x.std(0)
x -= mean
x /= std

t1 = time.time()

if len(w) > 0:
warnings.warn(
"Sample weight is not supported for FastCAV!"
" Trained model without weighting inputs",
stacklevel=1,
)

fast_classifier.fit(x, y, **fit_kwargs)

t2 = time.time()

# Convert weights to pytorch
classes = torch.IntTensor(fast_classifier.classes_)

# extract model device
device = getattr(model, "device", "cpu")

num_outputs = (
fast_classifier.coef_.shape[0] # type: ignore
if fast_classifier.coef_.ndim > 1 # type: ignore
else 1
) # type: ignore
weight_values = torch.FloatTensor(fast_classifier.coef_).to(device) # type: ignore
bias_values = torch.FloatTensor([fast_classifier.intercept_]).to( # type: ignore
device # type: ignore
) # type: ignore
model._construct_model_params(
norm_type=None,
weight_values=weight_values.view(num_outputs, -1),
bias_value=bias_values.squeeze().unsqueeze(0),
classes=classes,
)

if norm_input:
# pyre-fixme[61]: `mean` is undefined, or not always defined.
# pyre-fixme[61]: `std` is undefined, or not always defined.
model.norm = NormLayer(mean, std)

return {"train_time": t2 - t1}


class FastCAV:
r"""Fast implementation of concept activation vectors calculation
using mean differences. This implementation requires balanced classes.

This classifier provides an efficient alternative to other CAV classifiers.

It is equal to an SVM when the following assumptions hold:

- **Gaussian Distribution**: The activation vectors for both the random samples
and the concept samples are assumed to follow independent multivariate
Gaussian distributions.
- **Equal Mixture**: The set of concept examples and the set of random examples
are of equal size :math:`(∣D_c​∣=∣D_r​∣)`, resulting in a uniform mixture of
the two Gaussian distributions.
- **Isotropic Covariance**: The within-class covariance matrices are assumed to
be isotropic, meaning they are proportional to the unit matrix. This is a
critical assumption that makes the FastCAV solution equivalent to the solution
of a Fisher discriminant analysis.
- **High-Dimensionality**: The method is analyzed in the context of
high-dimensional activation spaces, where the number of dimensions :math:`d` is
significantly larger than the number of samples :math:`n` (:math:`d >> n`).
In such spaces, the set of support vectors used by an SVM is likely to contain
most of the training samples, making `the SVM solution approximate the Fisher
discriminant solution
<https://link.springer.com/article/10.1023/A:1018677409366>`_,
and by extension, the FastCAV solution.

For more details, see the paper:
`FastCAV: Efficient Computation of Concept Activation Vectors for Explaining
Deep Neural Networks <https://arxiv.org/abs/2505.17883>`_.

Example::

>>> from captum.concept._utils import classifier
>>> fast_cav = classifier.FastCAV()
>>> x = torch.randn(100, 20) # 100 samples, 20 features
>>> y = torch.randint(0, 2, (100,)) # Binary
>>> fast_cav.fit(x, y)
>>> predictions = fast_cav.predict(x)

"""

def __init__(self, **kwargs) -> None:
self.intercept_: Optional[torch.Tensor] = None
self.coef_: Optional[torch.Tensor] = None
self.mean: Optional[torch.Tensor] = None
self.classes_: Optional[torch.Tensor] = None
if kwargs:
warnings.warn(
"FastCAV does not support any additional parameters. "
f"Ignoring provided parameters: {kwargs.keys()}.",
stacklevel=2,
)

def fit(self, x: Tensor, y: Tensor) -> None:
"""
Fits a binary linear classifier to obtain a Concept Activation Vector (CAV)
using the mean difference between two classes.

Args:
x (Tensor): Input data of shape (n_samples, n_features).
Training data for binary classification.
y (Tensor): Binary target labels of shape (n_samples,).
Labels should be 0 or 1. Classes should be balanced.

Returns:
None

Note:
Computes the linear concept boundary using the mean difference vector
between the two classes. Converts inputs to PyTorch tensors if needed.

Why balanced classes:
Imbalanced classes will skew the computed CAV toward the majority class,
leading to inaccurate results. FastCAV works best with balanced classes.
"""
x = torch.as_tensor(x)
y = torch.as_tensor(y)

assert x.ndim == 2, "Input tensor must be 2D (batch_size, num_features)"
assert y.ndim == 1, "Labels tensor must be 1D (batch_size,)"
assert x.shape[0] == y.shape[0], "Input and labels must have same batch size"

self.classes_ = torch.unique(y).int()
assert len(self.classes_) == 2, "Only binary classification is supported"

class_counts = torch.bincount(y)
if torch.abs(class_counts[0] - class_counts[1]).float() / len(y) > 0.2:
warnings.warn(
"Classes are imbalanced (>20% difference). "
"FastCAV works best with balanced classes."
)

with torch.no_grad():
self.mean = x.mean(dim=0)
self.coef_ = (
(x[y == self.classes_[-1]] - self.mean).mean(dim=0).unsqueeze(0)
)
self.intercept_ = (-self.coef_ @ self.mean).unsqueeze(1)

def predict(self, x: Tensor) -> Tensor:
"""
Predicts the class labels for the given input tensor using the trained model.

Args:
x (Tensor): Input tensor of shape (n_samples, n_features) or (n_features,).
If a 1D tensor is provided, it is treated as a single sample.

Returns:
Tensor: Predicted class labels as a tensor of shape (n_samples,).

Raises:
ValueError: If the model has not been trained (i.e., `coef_`, `intercept_`,
or `classes_` is None).
"""
if self.coef_ is None or self.intercept_ is None or self.classes_ is None:
raise ValueError("Model not trained. Call fit() first.")

x = torch.as_tensor(x)
if x.ndim == 1:
x = x.unsqueeze(0)
with torch.no_grad():
return torch.take(
self.classes_,
((self.coef_ @ torch.as_tensor(x).T + self.intercept_) > 0).long(),
).T

def classes(self) -> Tensor:
"""
Returns the classes learned by the classifier.

Returns:
Tensor: A tensor containing the unique class labels identified during
model training.

Raises:
ValueError: If the model has not been trained and `fit` has not been called.
"""
if self.classes_ is None:
raise ValueError("Please call `fit` to train the model first.")
return self.classes_
12 changes: 12 additions & 0 deletions sphinx/source/concept.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,15 @@ Classifier

.. autoclass:: captum.concept.Classifier
:members:

DefaultClassifier
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: captum.concept.DefaultClassifier
:members:

FastCAVClassifier
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: captum.concept.FastCAVClassifier
:members:
Empty file.
Loading