Recent studies have shown that Machine Learning (ML) models can exhibit bias in real-world scenarios, posing significant challenges in ethically sensitive domains such as healthcare. Such bias can negatively affect model fairness, model generalization abilities and further risks amplifying social discrimination. There is a need to remove biases from trained models. Existing debiasing approaches often necessitate access to original training data and need extensive model retraining; they also typically exhibit trade-offs between model fairness and discriminative performance. To address these challenges, we propose Soft-Mask Weight Fine-Tuning (SWiFT), a debiasing framework that efficiently improves fairness while preserving discriminative performance with much less debiasing costs. Notably, SWiFT requires only a small external dataset and only a few epochs of model fine-tuning. The idea behind SWiFT is to first find the relative, and yet distinct, contributions of model parameters to both bias and predictive performance. Then, a two-step fine-tuning process updates each parameter with different gradient flows defined by its contribution. Extensive experiments with three bias sensitive attributes (gender, skin tone, and age) across four dermatological and two chest X-ray datasets demonstrate that SWiFT can consistently reduce model bias while achieving competitive or even superior diagnostic accuracy under common fairness and accuracy metrics, compared to the state-of-the-art. Specifically, we demonstrate improved model generalization ability as evidenced by superior performance on several out-of-distribution (OOD) datasets.
# clone project
git clone https://github.com/vios-s/SWiFT.git
cd SWiFT
# [OPTIONAL] create conda environment
conda create -n myenv python=3.9
conda activate myenv
# install pytorch according to instructions
# https://pytorch.org/get-started/
# install requirements
pip install -r requirements.txt# clone project
git clone https://github.com/vios-s/SWiFT.git
cd SWiFT
# create conda environment and install dependencies
conda env create -f environment.yaml -n myenv
# activate conda environment
conda activate myenvThe datasets used in this project are available from the following sources (registration may be required):
| Dataset | Access |
|---|---|
| ISIC 2020 | Link |
| ISIC 2019/2018/2017 | Link |
| Interactive Atlas of Dermoscopy | Link |
| PAD | Link |
| Fitzpatrick17k | Link |
| MIMIC-CXR | Link |
| CheXpert | Link |
| NIH ChestX-ray14 | Link |
python algorithm/TwoStageFinetune.pyYou can overide any parameter in argument.py or from command line like this
python algorithm/TwoStageFinetune.py --arch resnet50 --task 'xray' --attr 'age_attribute' --lr-base 0.000001 --lr-forget 0.000001 --beta 0.01 --model-dir './logs/model/resnet50_mimic_val0_gender.ckpt' --csv-dir './data/chestXray/csv/mimic_val_gender_0.csv' --batch-size 128 --num-attr 'binary'Pre-train the baseline model with default configuration
# train on CPU
python src/train.py trainer=cpu
# train on GPU
python src/train.py trainer=gpuTrain model with chosen experiment configuration from configs/experiment/
python src/train.py experiment=experiment_name.yamlYou can override any parameter from command line like this
python src/train.py trainer.max_epochs=20 data.batch_size=64