From 272fd01be1cfe920a6749a0257b16ad4f8cf098e Mon Sep 17 00:00:00 2001 From: Laines Schmalwasser Date: Fri, 27 Jun 2025 14:37:51 +0200 Subject: [PATCH] Initial FastCAV --- captum/concept/__init__.py | 7 +- captum/concept/_utils/classifier.py | 337 +++++++++++++++++++++++- sphinx/source/concept.rst | 12 + tests/concept/_utils/__init__.py | 0 tests/concept/_utils/test_classifier.py | 194 ++++++++++++++ 5 files changed, 546 insertions(+), 4 deletions(-) create mode 100644 tests/concept/_utils/__init__.py create mode 100644 tests/concept/_utils/test_classifier.py diff --git a/captum/concept/__init__.py b/captum/concept/__init__.py index d821a664da..eb76856be5 100644 --- a/captum/concept/__init__.py +++ b/captum/concept/__init__.py @@ -4,7 +4,11 @@ from captum.concept._core.cav import CAV from captum.concept._core.concept import Concept, ConceptInterpreter from captum.concept._core.tcav import TCAV -from captum.concept._utils.classifier import Classifier, DefaultClassifier +from captum.concept._utils.classifier import ( + Classifier, + DefaultClassifier, + FastCAVClassifier, +) __all__ = [ "CAV", @@ -13,4 +17,5 @@ "TCAV", "Classifier", "DefaultClassifier", + "FastCAVClassifier", ] diff --git a/captum/concept/_utils/classifier.py b/captum/concept/_utils/classifier.py index 477fa0c255..7a3ce5662f 100644 --- a/captum/concept/_utils/classifier.py +++ b/captum/concept/_utils/classifier.py @@ -3,12 +3,14 @@ # pyre-strict import random +import time import warnings from abc import ABC, abstractmethod -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch -from captum._utils.models.linear_model import model +from captum._utils.models.linear_model.model import LinearModel, SkLearnSGDClassifier +from captum._utils.models.linear_model.train import NormLayer from torch import Tensor from torch.utils.data import DataLoader, TensorDataset @@ -141,7 +143,7 @@ def __init__(self) -> None: " `Classifer` abstract class", stacklevel=2, ) - self.lm = model.SkLearnSGDClassifier(alpha=0.01, max_iter=1000, tol=1e-3) + self.lm: LinearModel = SkLearnSGDClassifier(alpha=0.01, max_iter=1000, tol=1e-3) def train_and_eval( self, @@ -252,3 +254,332 @@ def _train_test_split( torch.stack(y_train), torch.stack(y_test), ) + + +class FastCAVClassifier(DefaultClassifier): + r"""Fast implementation of concept activation vectors calculation + using mean differences. This implementation requires balanced classes. + This implements the classifier proposed in the paper `FastCAV: Efficient + Computation of Concept Activation Vectors for Explaining Deep Neural Networks + `_. + + This classifier provides an efficient alternative to other CAV classifiers. + + It is equivalent to an SVM when the following assumptions hold: + + - **Gaussian Distribution**: The activation vectors for both the random samples + and the concept samples are assumed to follow independent multivariate + Gaussian distributions. + - **Equal Mixture**: The set of concept examples and the set of random examples + are of equal size :math:`(∣D_c​∣=∣D_r​∣)`, resulting in a uniform mixture of + the two Gaussian distributions. + - **Isotropic Covariance**: The within-class covariance matrices are assumed to + be isotropic, meaning they are proportional to the unit matrix. This is a + critical assumption that makes the FastCAV solution equivalent to the solution + of a Fisher discriminant analysis. + - **High-Dimensionality**: The method is analyzed in the context of + high-dimensional activation spaces, where the number of dimensions :math:`d` is + significantly larger than the number of samples :math:`n` (:math:`d >> n`). + In such spaces, the set of support vectors used by an SVM is likely to contain + most of the training samples, making `the SVM solution approximate the Fisher + discriminant solution + `_, + and by extension, the FastCAV solution. + + Note that default implementation slices input dataset into train and test + splits and keeps them in memory. + In case concept datasets are large, this can lead to out of memory and we + recommend to provide a custom Classier that extends `Classifier` abstract + class and handles large concept datasets accordingly. + + Example: + + >>> import torchvision + >>> from captum.concept import TCAV + >>> from captum.concept._utils.classifier import FastCAVClassifier + >>> from captum.attr import LayerIntegratedGradients + >>> + >>> model = torchvision.models.googlenet(pretrained=True) + >>> model = model.eval() + >>> clf = FastCAVClassifier() + >>> layers=['inception4c', 'inception4d', 'inception4e'] + >>> mytcav = TCAV(model=model, + >>> layers=layers, + >>> classifier=clf, + >>> layer_attr_method = LayerIntegratedGradients( + >>> model, None, multiply_by_inputs=False)) + >>> # ... + >>> # For a full workflow, follow `tutorials/TCAV_Image.ipynb` and + >>> # replace the classifier. + """ + + def __init__(self) -> None: + self.lm = FastCAVLinearModel() + + +class FastCAVLinearModel(LinearModel): + """ + FastCAVLinearModel is a wrapper to convert `FastCAV` into a `LinearModel`. + + Args: + **kwargs (Any): Additional keyword arguments passed to the + `LinearModel` base class. + + Returns: + None + """ + + def __init__(self, **kwargs) -> None: + super().__init__(train_fn=fastcav_train_linear_model, **kwargs) + + +def fastcav_train_linear_model( + model: LinearModel, + dataloader: DataLoader, + construct_kwargs: Dict[str, Any], + norm_input: bool = False, + **fit_kwargs: Any, +) -> Dict[str, float]: + r""" + Trains a `captum.concept._utils.models.linear_model.LinearModel` using + the `FastCAV` classifier. It follows closely the implementation of the + `linea_model.train.sklearn_train_linear_model` function. + + This method consumes the entire dataloader to construct the dataset in + memory for training. + + Please note that this assumes: + + 1. The dataset can fit into memory. + + Args: + model (LinearModel): The model to train. + dataloader (DataLoader): The data to use. This will be exhausted and converted + to single tensors. Do not use an infinite dataloader. + construct_kwargs (dict): Arguments to pass to the FastCAV constructor. + FastCAV currently does not support any additional parameters. + norm_input (bool, optional): Whether or not to normalize the input. + Default: False + fit_kwargs (dict, optional): Other arguments to send to FastCAV's fit method. + FastCAV does not support sample weights or other fit arguments. + Default: None + + Returns: + dict: A dictionary containing the train_time. + """ + # Extract data from dataloader + fast_classifier = FastCAV(**construct_kwargs) + num_batches = 0 + xs: List[Tensor] = [] + ys: List[Tensor] = [] + ws: List[Tensor] = [] + for data in dataloader: + if len(data) == 3: + x, y, w = data + else: + assert len(data) == 2 + x, y = data + w = None + + xs.append(x) + ys.append(y) + if w is not None: + ws.append(w) + num_batches += 1 + + x = torch.cat(xs, dim=0) + y = torch.cat(ys, dim=0) + if len(ws) > 0: + w = torch.cat(ws, dim=0) + else: + w = None + + if norm_input: + mean, std = x.mean(0), x.std(0) + x -= mean + x /= std + + t1 = time.time() + + if len(w) > 0: + warnings.warn( + "Sample weight is not supported for FastCAV!" + " Trained model without weighting inputs", + stacklevel=1, + ) + + fast_classifier.fit(x, y, **fit_kwargs) + + t2 = time.time() + + # Convert weights to pytorch + classes = torch.IntTensor(fast_classifier.classes_) + + # extract model device + device = getattr(model, "device", "cpu") + + num_outputs = ( + fast_classifier.coef_.shape[0] # type: ignore + if fast_classifier.coef_.ndim > 1 # type: ignore + else 1 + ) # type: ignore + weight_values = torch.FloatTensor(fast_classifier.coef_).to(device) # type: ignore + bias_values = torch.FloatTensor([fast_classifier.intercept_]).to( # type: ignore + device # type: ignore + ) # type: ignore + model._construct_model_params( + norm_type=None, + weight_values=weight_values.view(num_outputs, -1), + bias_value=bias_values.squeeze().unsqueeze(0), + classes=classes, + ) + + if norm_input: + # pyre-fixme[61]: `mean` is undefined, or not always defined. + # pyre-fixme[61]: `std` is undefined, or not always defined. + model.norm = NormLayer(mean, std) + + return {"train_time": t2 - t1} + + +class FastCAV: + r"""Fast implementation of concept activation vectors calculation + using mean differences. This implementation requires balanced classes. + + This classifier provides an efficient alternative to other CAV classifiers. + + It is equal to an SVM when the following assumptions hold: + + - **Gaussian Distribution**: The activation vectors for both the random samples + and the concept samples are assumed to follow independent multivariate + Gaussian distributions. + - **Equal Mixture**: The set of concept examples and the set of random examples + are of equal size :math:`(∣D_c​∣=∣D_r​∣)`, resulting in a uniform mixture of + the two Gaussian distributions. + - **Isotropic Covariance**: The within-class covariance matrices are assumed to + be isotropic, meaning they are proportional to the unit matrix. This is a + critical assumption that makes the FastCAV solution equivalent to the solution + of a Fisher discriminant analysis. + - **High-Dimensionality**: The method is analyzed in the context of + high-dimensional activation spaces, where the number of dimensions :math:`d` is + significantly larger than the number of samples :math:`n` (:math:`d >> n`). + In such spaces, the set of support vectors used by an SVM is likely to contain + most of the training samples, making `the SVM solution approximate the Fisher + discriminant solution + `_, + and by extension, the FastCAV solution. + + For more details, see the paper: + `FastCAV: Efficient Computation of Concept Activation Vectors for Explaining + Deep Neural Networks `_. + + Example:: + + >>> from captum.concept._utils import classifier + >>> fast_cav = classifier.FastCAV() + >>> x = torch.randn(100, 20) # 100 samples, 20 features + >>> y = torch.randint(0, 2, (100,)) # Binary + >>> fast_cav.fit(x, y) + >>> predictions = fast_cav.predict(x) + + """ + + def __init__(self, **kwargs) -> None: + self.intercept_: Optional[torch.Tensor] = None + self.coef_: Optional[torch.Tensor] = None + self.mean: Optional[torch.Tensor] = None + self.classes_: Optional[torch.Tensor] = None + if kwargs: + warnings.warn( + "FastCAV does not support any additional parameters. " + f"Ignoring provided parameters: {kwargs.keys()}.", + stacklevel=2, + ) + + def fit(self, x: Tensor, y: Tensor) -> None: + """ + Fits a binary linear classifier to obtain a Concept Activation Vector (CAV) + using the mean difference between two classes. + + Args: + x (Tensor): Input data of shape (n_samples, n_features). + Training data for binary classification. + y (Tensor): Binary target labels of shape (n_samples,). + Labels should be 0 or 1. Classes should be balanced. + + Returns: + None + + Note: + Computes the linear concept boundary using the mean difference vector + between the two classes. Converts inputs to PyTorch tensors if needed. + + Why balanced classes: + Imbalanced classes will skew the computed CAV toward the majority class, + leading to inaccurate results. FastCAV works best with balanced classes. + """ + x = torch.as_tensor(x) + y = torch.as_tensor(y) + + assert x.ndim == 2, "Input tensor must be 2D (batch_size, num_features)" + assert y.ndim == 1, "Labels tensor must be 1D (batch_size,)" + assert x.shape[0] == y.shape[0], "Input and labels must have same batch size" + + self.classes_ = torch.unique(y).int() + assert len(self.classes_) == 2, "Only binary classification is supported" + + class_counts = torch.bincount(y) + if torch.abs(class_counts[0] - class_counts[1]).float() / len(y) > 0.2: + warnings.warn( + "Classes are imbalanced (>20% difference). " + "FastCAV works best with balanced classes." + ) + + with torch.no_grad(): + self.mean = x.mean(dim=0) + self.coef_ = ( + (x[y == self.classes_[-1]] - self.mean).mean(dim=0).unsqueeze(0) + ) + self.intercept_ = (-self.coef_ @ self.mean).unsqueeze(1) + + def predict(self, x: Tensor) -> Tensor: + """ + Predicts the class labels for the given input tensor using the trained model. + + Args: + x (Tensor): Input tensor of shape (n_samples, n_features) or (n_features,). + If a 1D tensor is provided, it is treated as a single sample. + + Returns: + Tensor: Predicted class labels as a tensor of shape (n_samples,). + + Raises: + ValueError: If the model has not been trained (i.e., `coef_`, `intercept_`, + or `classes_` is None). + """ + if self.coef_ is None or self.intercept_ is None or self.classes_ is None: + raise ValueError("Model not trained. Call fit() first.") + + x = torch.as_tensor(x) + if x.ndim == 1: + x = x.unsqueeze(0) + with torch.no_grad(): + return torch.take( + self.classes_, + ((self.coef_ @ torch.as_tensor(x).T + self.intercept_) > 0).long(), + ).T + + def classes(self) -> Tensor: + """ + Returns the classes learned by the classifier. + + Returns: + Tensor: A tensor containing the unique class labels identified during + model training. + + Raises: + ValueError: If the model has not been trained and `fit` has not been called. + """ + if self.classes_ is None: + raise ValueError("Please call `fit` to train the model first.") + return self.classes_ diff --git a/sphinx/source/concept.rst b/sphinx/source/concept.rst index 19157398b7..fcd7995709 100644 --- a/sphinx/source/concept.rst +++ b/sphinx/source/concept.rst @@ -27,3 +27,15 @@ Classifier .. autoclass:: captum.concept.Classifier :members: + +DefaultClassifier +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: captum.concept.DefaultClassifier + :members: + +FastCAVClassifier +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: captum.concept.FastCAVClassifier + :members: diff --git a/tests/concept/_utils/__init__.py b/tests/concept/_utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/concept/_utils/test_classifier.py b/tests/concept/_utils/test_classifier.py new file mode 100644 index 0000000000..7509b4441b --- /dev/null +++ b/tests/concept/_utils/test_classifier.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 + +# pyre-unsafe + +import unittest +import warnings + +import torch +from captum.concept._utils.classifier import FastCAV + + +class TestFastCAV(unittest.TestCase): + def setUp(self): + """Set up simple, deterministic data for tests.""" + # Balanced, well-separated, non-deterministic data + self.x_train_balanced_randn = torch.cat( + [ + torch.randn(10, 5), # Class 0 + torch.randn(10, 5), # Class 1 + ] + ) + self.y_train_balanced_randn = torch.cat( + [ + torch.zeros(10), + torch.ones(10), + ] + ).int() + + # Imbalanced data (triggers warning) + self.x_train_imbalanced = torch.cat( + [ + torch.randn(15, 5), # Class 0 + torch.randn(5, 5), # Class 1 + ] + ) + self.y_train_imbalanced = torch.cat( + [ + torch.zeros(15), + torch.ones(5), + ] + ).int() + + # Simple, deterministic data for predictable results + self.x_train_simple = torch.tensor( + [[-1.0, -1.0], [-2.0, -2.0], [1.0, 1.0], [2.0, 2.0]] + ) + self.y_train_simple = torch.tensor([0, 0, 1, 1]).int() + + # Test data for simple model + self.x_test_simple = torch.tensor( + [ + [-10.0, -10.0], # Should be class 0 + [10.0, 10.0], # Should be class 1 + [0.0, 0.0], # Should be class 0 (on boundary) + ] + ) + self.expected_pred_simple = torch.tensor([[0], [1], [0]]) + + def test_init(self): + """Test FastCAV initialization.""" + cav = FastCAV() + self.assertIsNone(cav.coef_) + self.assertIsNone(cav.intercept_) + self.assertIsNone(cav.mean) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + FastCAV(foo="bar", baz=123) + self.assertTrue( + any( + "FastCAV does not support any additional parameters" + in str(warn.message) + for warn in w + ) + ) + + def test_fit_balanced(self): + """Test fitting with balanced, deterministic data.""" + cav = FastCAV() + cav.fit(self.x_train_simple, self.y_train_simple) + + self.assertIsNotNone(cav.coef_) + self.assertIsNotNone(cav.intercept_) + self.assertIsNotNone(cav.mean) + self.assertIsNotNone(cav.classes_) + + self.assertEqual(cav.coef_.shape, (1, 2)) + self.assertEqual(cav.intercept_.shape, (1, 1)) + self.assertEqual(cav.mean.shape, (2,)) + self.assertEqual(cav.classes_.shape, (2,)) + + expected_mean = torch.tensor([0.0, 0.0]) + torch.testing.assert_close(cav.mean, expected_mean) + + expected_coef = torch.tensor([[1.5, 1.5]]) + torch.testing.assert_close(cav.coef_, expected_coef) + + expected_intercept = torch.tensor([[0.0]]) + torch.testing.assert_close(cav.intercept_, expected_intercept) + + self.assertTrue(torch.equal(cav.classes_, torch.tensor([0, 1]))) + + def test_fit_imbalanced_warns(self): + """Test that fitting with imbalanced data raises a warning.""" + cav = FastCAV() + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + cav.fit(self.x_train_imbalanced, self.y_train_imbalanced) + self.assertTrue( + any("Classes are imbalanced" in str(warn.message) for warn in w) + ) + + self.assertIsNotNone(cav.coef_) + self.assertIsNotNone(cav.intercept_) + + def test_fit_assertions(self): + """Test assertions for invalid input shapes and labels.""" + cav = FastCAV() + with self.assertRaises(AssertionError) as cm: + cav.fit(torch.randn(20), self.y_train_balanced_randn) + self.assertIn("Input tensor must be 2D", str(cm.exception)) + + with self.assertRaises(AssertionError) as cm: + cav.fit(self.x_train_balanced_randn, torch.randn(20, 1)) + self.assertIn("Labels tensor must be 1D", str(cm.exception)) + + with self.assertRaises(AssertionError) as cm: + cav.fit(self.x_train_balanced_randn, torch.zeros(5)) + self.assertIn("Input and labels must have same batch size", str(cm.exception)) + + y_multi = self.y_train_balanced_randn.clone() + y_multi[0] = 2 + with self.assertRaises(AssertionError) as cm: + cav.fit(self.x_train_balanced_randn, y_multi) + self.assertIn("Only binary classification is supported", str(cm.exception)) + + def test_predict_before_fit(self): + """Test that predict raises an error if called before fit.""" + cav = FastCAV() + with self.assertRaises(ValueError) as cm: + cav.predict(self.x_test_simple) + self.assertIn("Model not trained. Call fit() first.", str(cm.exception)) + + def test_predict_after_fit(self): + """Test prediction on single and batch inputs after fitting.""" + cav = FastCAV() + cav.fit(self.x_train_simple, self.y_train_simple) + + predictions = cav.predict(self.x_test_simple) + self.assertTrue(torch.equal(predictions, self.expected_pred_simple)) + + prediction_single_0 = cav.predict(self.x_test_simple[0]) + self.assertEqual( + prediction_single_0.item(), self.expected_pred_simple[0].item() + ) + + prediction_single_1 = cav.predict(self.x_test_simple[1]) + self.assertEqual( + prediction_single_1.item(), self.expected_pred_simple[1].item() + ) + + def test_fit_zero_mean_difference(self): + """Test fitting when class means are identical.""" + cav = FastCAV() + x_train = torch.tensor([[-1.0, 1.0], [1.0, -1.0], [-1.0, 1.0], [1.0, -1.0]]) + y_train = torch.tensor([0, 0, 1, 1]).int() + + cav.fit(x_train, y_train) + + torch.testing.assert_close(cav.coef_, torch.zeros_like(cav.coef_)) + torch.testing.assert_close(cav.intercept_, torch.zeros_like(cav.intercept_)) + + predictions = cav.predict(torch.randn(5, 2)) + self.assertTrue(torch.all(predictions == cav.classes_[0])) + + def test_classes_before_fit(self): + """Test that classes raises an error if called before fit.""" + cav = FastCAV() + with self.assertRaises(ValueError) as cm: + cav.classes() + self.assertIn("Please call `fit` to train the model first.", str(cm.exception)) + + def test_classes_after_fit(self): + """Test the classes method after fitting.""" + cav = FastCAV() + cav.fit(self.x_train_simple, self.y_train_simple) + classes = cav.classes() + self.assertTrue(torch.equal(classes, torch.tensor([0, 1]))) + self.assertTrue(torch.equal(classes, cav.classes_)) + + y_train_custom_labels = torch.tensor([2, 2, 5, 5]).int() + cav.fit(self.x_train_simple, y_train_custom_labels) + classes = cav.classes() + self.assertTrue(torch.equal(classes, torch.tensor([2, 5])))