Skip to content

Latest commit

 

History

History
117 lines (76 loc) · 4.84 KB

README.md

File metadata and controls

117 lines (76 loc) · 4.84 KB

PerAda: Parameter-Efficient Federated Learning Personalization with Generalization Guarantees

Chulin Xie, De-An Huang, Wenda Chu, Daguang Xu, Chaowei Xiao, Bo Li, Anima Anandkumar

[CVPR 2024] [BibTeX]

Contents

Introduction

This repository contains the code for our CVPR'24 paper PerAda: Parameter-Efficient Federated Learning Personalization with Generalization Guarantees. Personalized Federated Learning (pFL) has emerged as a promising solution to tackle data heterogeneity across clients in FL. However, existing pFL methods either (1) introduce high computation and communication costs or (2) overfit to local data, which can be limited in scope and vulnerable to evolved test samples with natural distribution shifts.

In this paper, we propose PerAda, a parameter-efficient pFL framework that reduces communication and computational costs and exhibits superior generalization performance, especially under test-time distribution shifts. PerAda reduces the costs by leveraging the power of pretrained models and only updates and communicates a small number of additional parameters from adapters. PerAda achieves high generalization by regularizing each client's personalized adapter with a global adapter, while the global adapter uses knowledge distillation to aggregate generalized information from all clients.

Theoretically, we provide generalization bounds of PerAda, and we prove its convergence to stationary points under non-convex settings. Empirically, PerAda demonstrates higher personalized performance (+4.85% on CheXpert) and enables better out-of-distribution generalization (+5.23% on CIFAR-10-C) on different datasets across natural and medical domains compared with baselines, while only updating 12.6% of parameters per model on ResNet-18.

Install

  1. Create a conda environment:

    conda create -n perada python=3.8
    conda activate perada
    conda install pytorch==1.12.0 torchvision==0.13.0 cudatoolkit=11.3 -c pytorch

    Alternatively, use the docker image nvidia/pytorch:22.05-py3.

  2. Install additional packages:

    pip install -r requirements.txt

Dataset

See Datasets Preparation.

Important note: Each user is responsible for checking the content of datasets and the applicable licenses and determining if suitable for the intended use and applicable links before the script runs and the data is placed in the user machine.

Train

Take CIFAR-10 dataset as an example. Run PerAda without knowledge distillation:

export CUDA_VISIBLE_DEVICES=0
sh scripts/cifar10/perada-nokd.sh

Run PerAda with knowledge distillation:

sh scripts/cifar10/perada-kd.sh

Run FedAvg baseline:

sh scripts/cifar10/fedavg.sh

Run Standalone baseline:

sh scripts/cifar10/standalone.sh

For other experiments, replace cifar10 with oh or chexpert in the above commands.

Evaluation

We provide the evaluation pipeline by testing the personalized models on local test sets, global test sets, or out-of-distribution test sets.

Take CIFAR-10 dataset as example. Run:

sh scripts/cifar10/inf.sh

Update the model_paths in the above script for your trained models.

For other experiments, replace cifar10 with oh or chexpert in the above commands.

License

Copyright © 2024, NVIDIA Corporation. All rights reserved.

This work is made available under the Nvidia Source Code License-NC. Click here to view a copy of this license.

For business inquiries, please visit our website and submit the form: NVIDIA Research Licensing.

Citation

If you find this work useful for your research and applications, please cite using this BibTeX:

@inproceedings{xie2024perada,
  title={PerAda: Parameter-Efficient Federated Learning Personalization with Generalization Guarantees},
  author={Xie, Chulin and Huang, De-An and Chu, Wenda and Xu, Daguang and Xiao, Chaowei and Li, Bo and Anandkumar, Anima},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={23838--23848},
  year={2024}
}

Acknowledgement