Skip to content

vios-s/SWiFT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

30 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SWiFT: Soft-Mask Weight Fine-tuning for Bias Mitigation

PyTorch Lightning Config: Hydra Template
Paper Conference

Description

Recent studies have shown that Machine Learning (ML) models can exhibit bias in real-world scenarios, posing significant challenges in ethically sensitive domains such as healthcare. Such bias can negatively affect model fairness, model generalization abilities and further risks amplifying social discrimination. There is a need to remove biases from trained models. Existing debiasing approaches often necessitate access to original training data and need extensive model retraining; they also typically exhibit trade-offs between model fairness and discriminative performance. To address these challenges, we propose Soft-Mask Weight Fine-Tuning (SWiFT), a debiasing framework that efficiently improves fairness while preserving discriminative performance with much less debiasing costs. Notably, SWiFT requires only a small external dataset and only a few epochs of model fine-tuning. The idea behind SWiFT is to first find the relative, and yet distinct, contributions of model parameters to both bias and predictive performance. Then, a two-step fine-tuning process updates each parameter with different gradient flows defined by its contribution. Extensive experiments with three bias sensitive attributes (gender, skin tone, and age) across four dermatological and two chest X-ray datasets demonstrate that SWiFT can consistently reduce model bias while achieving competitive or even superior diagnostic accuracy under common fairness and accuracy metrics, compared to the state-of-the-art. Specifically, we demonstrate improved model generalization ability as evidenced by superior performance on several out-of-distribution (OOD) datasets.

Installation

Pip

# clone project
git clone https://github.com/vios-s/SWiFT.git
cd SWiFT

# [OPTIONAL] create conda environment
conda create -n myenv python=3.9
conda activate myenv

# install pytorch according to instructions
# https://pytorch.org/get-started/

# install requirements
pip install -r requirements.txt

Conda

# clone project
git clone https://github.com/vios-s/SWiFT.git
cd SWiFT

# create conda environment and install dependencies
conda env create -f environment.yaml -n myenv

# activate conda environment
conda activate myenv

Dataset

The datasets used in this project are available from the following sources (registration may be required):

Dataset Access
ISIC 2020 Link
ISIC 2019/2018/2017 Link
Interactive Atlas of Dermoscopy Link
PAD Link
Fitzpatrick17k Link
MIMIC-CXR Link
CheXpert Link
NIH ChestX-ray14 Link

How to run

Run the debiasing algorithm

python algorithm/TwoStageFinetune.py

You can overide any parameter in argument.py or from command line like this

python algorithm/TwoStageFinetune.py --arch resnet50 --task 'xray' --attr 'age_attribute' --lr-base 0.000001 --lr-forget 0.000001 --beta 0.01 --model-dir './logs/model/resnet50_mimic_val0_gender.ckpt' --csv-dir './data/chestXray/csv/mimic_val_gender_0.csv' --batch-size 128 --num-attr 'binary'

Train the baseline model

Pre-train the baseline model with default configuration

# train on CPU
python src/train.py trainer=cpu

# train on GPU
python src/train.py trainer=gpu

Train model with chosen experiment configuration from configs/experiment/

python src/train.py experiment=experiment_name.yaml

You can override any parameter from command line like this

python src/train.py trainer.max_epochs=20 data.batch_size=64

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •  

Languages