Diffusion Driven Balancing (DDB) This repository contains the official implementation of our ICCV 2025 main conference paper.
Deep neural networks trained with Empirical Risk Minimization (ERM) perform well when both training and test data come from the same domain, but they often fail to generalize to out-of-distribution samples. In image classification, these models may rely on spurious correlations that often exist between labels and irrelevant features of images, making predictions unreliable when those features do not exist. We propose a Diffusion Driven Balancing (DDB) technique to generate training samples with text-to-image diffusion models for addressing the spurious correlation problem. First, we compute the best describing token for the visual features pertaining to the causal components of samples by a textual inversion mechanism. Then, leveraging a language segmentation method and a diffusion model, we generate new samples by combining the causal component with the elements from other classes. We also meticulously prune the generated samples based on the prediction probabilities and attribution scores of the ERM model to ensure their correct composition for our objective. Finally, we retrain the ERM model on our augmented dataset. This process reduces the model’s reliance on spurious correlations by learning from carefully crafted samples for in which this correlation does not exist. Our experiments show that across different benchmarks, our technique achieves better worst-group accuracy than the existing state-of-the-art methods.
We follow the dataset implementation approach from the DaC repository. Our code supports three datasets: MetaShift, Waterbirds, and CelebA.
- Directory:
data/metashift - Please download the MetaShift dataset into the
data/metashiftdirectory. The dataset can be downloaded form here.
Our code expects the following structure:
data/metashift/traindata/metashift/test
- Directory:
data/waterbirds - Place the Waterbirds dataset in the
data/waterbirdsdirectory. The dataset can be downloaded from here.
Our code expects the following structure:
data/waterbirdsdata/waterbirds/metadata.csv
- Directory:
data/celeba - Place the Waterbirds dataset in the
data/celebadirectory. The dataset can be downloaded from here.
Our code expects the following structure:
data/celeba/img_align_celebadata/celeba/celeba_split.csvdata/celeba/list_attr_celeba.csv
The textual inversion component is adapted from Hugging Face Diffusers. This script allows you to learn a new token representing a specific visual concept (e.g., a dog) using your dataset.
To train a textual inversion embedding for a class (e.g., "bird" in the Waterbirds dataset), run:
accelerate launch textual_inversion.py \
--pretrained_model_name_or_path="stabilityai/stable-diffusion-2-base" \
--train_data_dir="Textual_inversion_meta/wb" \
--learnable_property="object" \
--placeholder_token="<bird>" \
--initializer_token="bird" \
--class_name="bird" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--max_train_steps=3000 \
--learning_rate=5.0e-04 \
--scale_lr \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--output_dir="textual_inversion_bird"This will produce a learned embedding for the placeholder token (e.g., <dog>), which can be used in subsequent diffusion-based sample generation steps.
Our project expects the languagesegmentanything folder containing the language segmentation model. You can obtain it by cloning the repository from This Github repo
The ERM models used in this project are available at the following link: ERM Models - Google Drive
- Download the ERM models from the link above.
- Place the downloaded files in the
modelsfolder of the project.
This script runs the DDB pipeline to generate counterfactual training samples using Stable Diffusion and LangSAM.
- Select low-loss samples of a target class from a trained ERM model.
- Segment the causal region (e.g., hair, bird) using LangSAM.
- Inpaint new images using a textual inversion token with Stable Diffusion.
- Prune generated samples using attribution scores and classifier confidence.
- Save filtered samples and metadata for retraining.
python generate.py \
--dataset waterbirds \
--model_path models/ERM-wb.model \
--textual_inversion Textual_inversions/textual_inversion_wb \
--prompt "a photo of a <waterbird> bird" \
--mask_prompt "bird" \
--token "<waterbird>" \
--label 0 \
--k 1000 \
--save_dir generated_data/waterbirds-wb/| Argument | Type | Description |
|---|---|---|
--dataset |
str |
Dataset name: celeba, waterbirds, or metashift |
--model_path |
str |
Path to the pretrained ERM model |
--textual_inversion |
str |
Path to learned textual inversion embedding (e.g., <waterbird>) |
--prompt |
str |
Prompt to guide the diffusion model (must include the learned token) |
--mask_prompt |
str |
Language prompt used by LangSAM to extract causal mask (e.g., "bird") |
--token |
str |
The learned placeholder token (e.g., <waterbird>) |
--label |
int |
Class label of initial samples (typically 0 or 1) |
--k |
int |
Number of low-loss samples to generate new images from |
--threshold |
float |
Attribution score threshold for pruning generated images |
--save_dir |
str |
Directory to save generated images and their metadata |
This script retrains the ERM model on the combined dataset of:
- Original training samples
- Generated counterfactuals (from
generate.py) for both classes
- Loads a pretrained ResNet50 ERM model
- Merges synthetic and real training samples
- Applies per-environment loss weighting (
gamma1,gamma2) - Evaluates performance per environment during training and testing
python train.py \
--new_data new_data \
--dataset celeba \
--model_path models/ERM-wb.model \
--gamma1 6 \
--gamma2 10 \
--lr 5e-6 \
--epochs 60 \
--batch_size 32 \
--output_dir checkpoints/| Argument | Type | Description |
|---|---|---|
--new_data |
str |
Folder path to the generated data (with metadata CSVs) |
--dataset |
str |
Dataset name: celeba, waterbirds, or metashift |
--model_path |
str |
Path to the pretrained ERM model |
--gamma1 |
float |
Loss multiplier for class 0 counterfactuals |
--gamma2 |
float |
Loss multiplier for class 1 counterfactuals |
--lr |
float |
Learning rate for optimizer |
--epochs |
int |
Number of epochs to retrain |
--batch_size |
int |
Batch size for retraining |
--output_dir |
str |
Directory to save best checkpoint (best_model.pth) |

