diff --git a/.github/workflows/build_and_test_python.yml b/.github/workflows/build_and_test_python.yml index c56c2807..753a7456 100644 --- a/.github/workflows/build_and_test_python.yml +++ b/.github/workflows/build_and_test_python.yml @@ -24,9 +24,9 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install "pybind11[global]" + pip3 install "pybind11[global]" if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi - pip install -e . + pip3 install -e . - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names diff --git a/README.md b/README.md index 57b481ea..b9a7a73e 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,7 @@ AIJack allows you to assess the privacy and security risks of machine learning a - [Federated Learning and Model Inversion Attack](#federated-learning-and-model-inversion-attack) - [Split Learning and Label Leakage Attack](#split-learning-and-label-leakage-attack) - [DPSGD (SGD with Differential Privacy)](#dpsgd-sgd-with-differential-privacy) + - [Federated Learning with Homomorphic Encryption](#federated-learning-with-homomorphic-encryption) - [SecureBoost (XGBoost with Homomorphic Encryption)](#secureboost-xgboost-with-homomorphic-encryption) - [Evasion Attack](#evasion-attack) - [Poisoning Attack](#poisoning-attack) @@ -68,7 +69,7 @@ pip install git+https://github.com/Koukyosyumei/AIJack | FedAVG | [example](docs/aijack_fedavg.ipynb) | [paper](https://arxiv.org/abs/1602.05629) | | FedProx | WIP | [paper](https://arxiv.org/abs/1812.06127) | | FedKD | [example](test/collaborative/fedkd/test_fedkd.py) | [paper](https://arxiv.org/abs/2108.13323) | -| FedMD | WIP | [paper](https://arxiv.org/abs/1910.03581) | +| FedMD | [example](docs/aijack_fedmd.ipynb) | [paper](https://arxiv.org/abs/1910.03581) | | FedGEMS | WIP | [paper](https://arxiv.org/abs/2110.11027) | | DSFL | WIP | [paper](https://arxiv.org/abs/2008.06180) | | SplitNN | [example](docs/aijack_split_learning.ipynb) | [paper](https://arxiv.org/abs/1812.00564) | @@ -87,6 +88,7 @@ pip install git+https://github.com/Koukyosyumei/AIJack | GAN Attack | Model Inversion | [example](example/model_inversion/gan_attack.py) | [paper](https://arxiv.org/abs/1702.07464) | | Shadow Attack | Membership Inference | [example](docs/aijack_membership_inference.ipynb) | [paper](https://arxiv.org/abs/1610.05820) | | Norm attack | Label Leakage | [example](docs/aijack_split_learning.ipynb) | [paper](https://arxiv.org/abs/2102.08504) | +| Delta Weights | Free Rider Attack | WIP | [paper](https://arxiv.org/pdf/1911.12560.pdf) | | Gradient descent attacks | Evasion Attack | [example](docs/aijack_evasion_attack.ipynb) | [paper](https://arxiv.org/abs/1708.06131) | | SVM Poisoning | Poisoning Attack | [example](docs/aijack_poison_attack.ipynb) | [paper](https://arxiv.org/abs/1206.6389) | @@ -219,6 +221,24 @@ for data in lot_loader(trainset): optimizer.step() ``` +## Federated Learning with Homomorphic Encryption + +```Python + from aijack.collaborative import FedAvgClient, FedAvgServer + from aijack.defense import PaillierGradientClientManager, PaillierKeyGenerator + +keygenerator = PaillierKeyGenerator(64) +pk, sk = keygenerator.generate_keypair() + +manager = PaillierGradientClientManager(pk, sk) +PaillierGradFedAvgClient = manager.attach(FedAvgClient) + +clients = [ + PaillierGradFedAvgClient(Net(), user_id=i, lr=lr, server_side_update=False) + for i in range(client_num) +] +``` + ## SecureBoost (XGBoost with Homomorphic Encryption) SecureBoost is a vertically federated version of XGBoost, where each party encrypts sensitive information with Paillier Encryption. You need additional compile to use secureboost, which requires Boost 1.65 or later. diff --git a/docs/_toc.yml b/docs/_toc.yml index 3d8df29c..5a7aab58 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -8,6 +8,7 @@ parts: numbered: true chapters: - file: aijack_fedavg + - file: aijack_fedmd - file: aijack_secureboost - caption: Model Inversion numbered: true diff --git a/docs/aijack_fedavg.ipynb b/docs/aijack_fedavg.ipynb index a620701b..c926174f 100644 --- a/docs/aijack_fedavg.ipynb +++ b/docs/aijack_fedavg.ipynb @@ -1,42 +1,32 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, "cells": [ { "cell_type": "markdown", + "metadata": { + "id": "SROAZ9dO80s8" + }, "source": [ "# FedAVG\n", "\n", "In this tutorial, you will learn how to simulate FedAVG, representative scheme of Federated Learning, with AIJack. You can choose the single process or MPI as the backend." - ], - "metadata": { - "id": "SROAZ9dO80s8" - } + ] }, { "cell_type": "markdown", - "source": [ - "## Single Process" - ], "metadata": { "id": "FNMDQuH49CBO" - } + }, + "source": [ + "## Single Process" + ] }, { "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "J4s0w9rHwOd8" + }, + "outputs": [], "source": [ "import random\n", "\n", @@ -50,15 +40,15 @@ "\n", "from aijack.collaborative import FedAvgClient, FedAvgServer\n", "from aijack.collaborative.fedavg import FedAVGAPI" - ], - "metadata": { - "id": "J4s0w9rHwOd8" - }, - "execution_count": 1, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "gsh_CjamwqfV" + }, + "outputs": [], "source": [ "def fix_seed(seed):\n", " random.seed(seed)\n", @@ -98,15 +88,15 @@ " x = self.ln(x.reshape(-1, 28 * 28))\n", " output = F.log_softmax(x, dim=1)\n", " return output" - ], - "metadata": { - "id": "gsh_CjamwqfV" - }, - "execution_count": 2, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "UX-5FPlbxs4z" + }, + "outputs": [], "source": [ "training_batch_size = 64\n", "test_batch_size = 64\n", @@ -118,27 +108,197 @@ "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "fix_seed(seed)" - ], - "metadata": { - "id": "UX-5FPlbxs4z" - }, - "execution_count": 3, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 431, + "referenced_widgets": [ + "9593a116d90b4813950ab65267268ad2", + "77c48fa6fbbc4ce1b6fe5d2692478dae", + "f88dbb8f70d84434b296e7d828651b14", + "4160c1d4b4ea453aa3568722217c4690", + "550b9614ff5b47a3a09f6d61b390b1d6", + "2d65774840de45428117e57938bae02c", + "7b90829e62d345d8b0dfb04509f79dfd", + "fbf5e10097c347c78c48a742607738ff", + "53ee6fdd59b742f2a0f8fca851d3016e", + "ae648105b31545fea8e15c7683203dd9", + "345e7473970c4eb3a7b2ced1b82b9a57", + "ee62f1db501c4264a76b1a62f48a7be9", + "1091c9f58f8d4c64ae6f90ec42221ced", + "44aa60060ec648debc2334110dddf2ec", + "f39bd6bdb4e14c9789439e906bbf43a6", + "1034bf75a49b45738f37e48ba3d3b450", + "aadbecb8edc8489999632ed9863befdf", + "4d12af4771c444c8a0d26b9823947805", + "31f7c397dc4d4cea84cde7d0b352c7b2", + "39439c0b576e449dbc71aa143a1473b5", + "bc3d0630330841a9ab7186c77bc8b78b", + "3e2f43e944a5412db641fb078990fc3f", + "f9201d21f7064d75b4174d0d309c65bf", + "ae614b87627240fcba7c93a452322adb", + "5e567ac544d243a79db8ee6f8a700b05", + "a7466c8f4b56470dbb2bd30b7501b584", + "fd7acd8304054cbcac4327efb988575b", + "376fc19cdac6469a9c9c6828dd075f01", + "6a4e6266fe6548c6956f983d67386655", + "3e12f79b9c9f4df0b619208ff6b98776", + "c8970960cf904f089bdc2b263dde4385", + "e51ddd7b5ca2417c882ee3824bea929d", + "4096574643fb424fb244a2fbea3ac75b", + "39ab6acaea554a3f80b9ae135bf28e66", + "80599966b0974e538ea4a7ee15b8ef0c", + "4006824773f444cbbe840a8807ac3eab", + "d9230884d5664145aee2a39d6e429b62", + "f23749eb92104627ace585cf6920d6ec", + "c3dec35f56464d948064a5f9f7ca750c", + "a453fe4ebfd34b019bd9e8d74d3ae058", + "06c5e346163144298a141a20c4e778da", + "d40273866c914f578e4e5c1d8f430941", + "b6dbf9b70d534db18aa2853a33626152", + "8f4b1efa44ba415fb9af0f4404a2352c" + ] + }, + "id": "zz_YjoioAcLD", + "outputId": "903d8cbb-dca4-421d-a6af-750d199ffaf4" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", + "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to MNIST/raw/train-images-idx3-ubyte.gz\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9593a116d90b4813950ab65267268ad2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/9912422 [00:00= 1.11.0', 'torchvision', 'pybind11', diff --git a/src/aijack/attack/freerider/__init__.py b/src/aijack/attack/freerider/__init__.py new file mode 100644 index 00000000..728c37a5 --- /dev/null +++ b/src/aijack/attack/freerider/__init__.py @@ -0,0 +1 @@ +from .freerider import FreeRiderClientManager # noqa : F401 diff --git a/src/aijack/attack/freerider/freerider.py b/src/aijack/attack/freerider/freerider.py new file mode 100644 index 00000000..bbdb4fb3 --- /dev/null +++ b/src/aijack/attack/freerider/freerider.py @@ -0,0 +1,51 @@ +import copy + +import torch + +from ...manager import BaseManager + + +def attach_freerider_to_client(cls, mu, sigma): + class FreeRiderClientWrapper(cls): + """Implementation of Free Rider Attack (https://arxiv.org/abs/1911.12560)""" + + def __init__(self, *args, **kwargs): + super(FreeRiderClientWrapper, self).__init__(*args, **kwargs) + self.prev_parameters_to_generate_fake_gradients = None + + def upload_gradients(self): + """Upload the local gradients""" + gradients = [] + if self.prev_parameters_to_generate_fake_gradients is not None: + for param, prev_param in zip( + self.model.parameters(), + self.prev_parameters_to_generate_fake_gradients, + ): + gradients.append( + (prev_param - param) / self.lr + + (sigma * torch.randn_like(param) + mu) + ) + else: + for param in self.model.parameters(): + gradients.append(sigma * torch.randn_like(param) + mu) + return gradients + + def download(self, new_global_model): + """Download the new global model""" + self.prev_parameters_to_generate_fake_gradients = copy.deepcopy( + list(self.model.parameters()) + ) + super().download(new_global_model) + + return FreeRiderClientWrapper + + +class FreeRiderClientManager(BaseManager): + """Implementation of Free Rider Attack (https://arxiv.org/abs/1911.12560)""" + + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + def attach(self, cls): + return attach_freerider_to_client(cls, *self.args, **self.kwargs) diff --git a/src/aijack/collaborative/core/api.py b/src/aijack/collaborative/core/api.py index e1b6b9b0..c931827a 100644 --- a/src/aijack/collaborative/core/api.py +++ b/src/aijack/collaborative/core/api.py @@ -95,23 +95,29 @@ def train_client(self, public=True): def run(self): pass - def score(self, dataloader): + def score(self, dataloader, only_local=False): """Returns the performance on the given dataset. Args: dataloader (torch.utils.data.DataLoader): a dataloader + only_local (bool): show only the local results Returns: Dict[str, int]: performance of global model and local models """ - server_score = accuracy_torch_dataloader( - self.server, dataloader, device=self.device - ) + clients_score = [ accuracy_torch_dataloader(client, dataloader, device=self.device) for client in self.clients ] - return {"server_score": server_score, "clients_score": clients_score} + + if only_local: + return {"clients_score": clients_score} + else: + server_score = accuracy_torch_dataloader( + self.server, dataloader, device=self.device + ) + return {"server_score": server_score, "clients_score": clients_score} def local_score(self): """Returns the local performance of each clients. diff --git a/src/aijack/collaborative/core/server.py b/src/aijack/collaborative/core/server.py index 7675d913..b51b4eae 100644 --- a/src/aijack/collaborative/core/server.py +++ b/src/aijack/collaborative/core/server.py @@ -29,7 +29,7 @@ def update(self): """Update the global model.""" pass - def distribtue(self): + def distribute(self): """Distribute the global model to each client.""" pass diff --git a/src/aijack/collaborative/core/utils.py b/src/aijack/collaborative/core/utils.py index 3ff3f616..a21a9c66 100644 --- a/src/aijack/collaborative/core/utils.py +++ b/src/aijack/collaborative/core/utils.py @@ -1,5 +1,6 @@ GRADIENTS_TAG = 1 -SPARSEGRADIENTS_TAG = 2 -PARAMETERS_TAG = 3 +LOCAL_LOGIT_TAG = 2 +GLOBAL_LOGIT_TAG = 3 +PARAMETERS_TAG = 4 RECEIVE_NAN_CODE = 11 SEND_NAN_CODE = 12 diff --git a/src/aijack/collaborative/fedavg/__init__.py b/src/aijack/collaborative/fedavg/__init__.py index dac296b2..ff033311 100644 --- a/src/aijack/collaborative/fedavg/__init__.py +++ b/src/aijack/collaborative/fedavg/__init__.py @@ -1,3 +1,3 @@ from .api import FedAVGAPI, MPIFedAVGAPI # noqa: F401 -from .client import FedAvgClient, MPIFedAVGClient # noqa: F401 -from .server import FedAvgServer, MPIFedAVGServer # noqa: F401 +from .client import FedAvgClient, MPIFedAvgClient # noqa: F401 +from .server import FedAvgServer, MPIFedAvgServer # noqa: F401 diff --git a/src/aijack/collaborative/fedavg/api.py b/src/aijack/collaborative/fedavg/api.py index e1418b46..c4e245c4 100644 --- a/src/aijack/collaborative/fedavg/api.py +++ b/src/aijack/collaborative/fedavg/api.py @@ -94,7 +94,7 @@ def run(self): self.server.updata_from_gradients(weight=self.clients_weight) else: self.server.update_from_parameters(weight=self.clients_weight) - self.server.distribtue() + self.server.distribute() self.custom_action(self) @@ -125,28 +125,21 @@ def __init__( self.device = device def run(self): - if self.is_server: - self.party.send_parameters() - else: - self.party.download() + self.party.mpi_initialize() self.comm.Barrier() for _ in range(self.num_communication): - if self.is_server: - self.party.action() - else: + if not self.is_server: self.local_train() - self.party.upload() - self.party.model.zero_grad() - self.party.download() + self.party.action() self.custom_action(self) self.comm.Barrier() def local_train(self): self.party.prev_parameters = [] - for param in self.party.model.parameters(): - self.party.prev_parameters.append(copy.deepcopy(param)) + for param in self.party.client.model.parameters(): + self.party.client.prev_parameters.append(copy.deepcopy(param)) for _ in range(self.local_epoch): running_loss = 0 @@ -154,7 +147,7 @@ def local_train(self): self.local_optimizer.zero_grad() data = data.to(self.device) target = target.to(self.device) - output = self.party.model(data) + output = self.party.client.model(data) loss = self.criterion(output, target) loss.backward() self.local_optimizer.step() diff --git a/src/aijack/collaborative/fedavg/client.py b/src/aijack/collaborative/fedavg/client.py index 7206d9dd..00f573f8 100644 --- a/src/aijack/collaborative/fedavg/client.py +++ b/src/aijack/collaborative/fedavg/client.py @@ -1,7 +1,5 @@ import copy -import torch - from ..core import BaseClient from ..core.utils import GRADIENTS_TAG, PARAMETERS_TAG from ..optimizer import AdamFLOptimizer, SGDFLOptimizer @@ -107,29 +105,29 @@ def download(self, new_global_model): self.prev_parameters.append(copy.deepcopy(param)) -class MPIFedAVGClient(BaseClient): - """Client of FedAVG for mpi-backend simulation""" - - def __init__(self, comm, model, user_id=0, lr=0.1, device="cpu"): - super(MPIFedAVGClient, self).__init__(model, user_id=user_id) +class MPIFedAvgClient: + def __init__(self, comm, client): self.comm = comm - self.lr = lr - self.device = device + self.client = client - self.prev_parameters = [] - for param in self.model.parameters(): - self.prev_parameters.append(copy.deepcopy(param)) + def __call__(self, *args, **kwargs): + return self.client(*args, **kwargs) - def upload(self): - self.upload_gradient() + def action(self): + self.mpi_upload() + self.client.model.zero_grad() + self.mpi_download() - def upload_gradient(self, destination_id=0): - self.gradients = [] - for param, prev_param in zip(self.model.parameters(), self.prev_parameters): - self.gradients.append((prev_param - param) / self.lr) - self.comm.send(self.gradients, dest=destination_id, tag=GRADIENTS_TAG) + def mpi_upload(self): + self.mpi_upload_gradient() + + def mpi_upload_gradient(self, destination_id=0): + self.comm.send( + self.client.upload_gradients(), dest=destination_id, tag=GRADIENTS_TAG + ) + + def mpi_download(self): + self.client.download(self.comm.recv(tag=PARAMETERS_TAG)) - def download(self): - new_parameters = self.comm.recv(tag=PARAMETERS_TAG) - for params, new_params in zip(self.model.parameters(), new_parameters): - params.data = torch.Tensor(new_params).reshape(params.shape).to(self.device) + def mpi_initialize(self): + self.mpi_download() diff --git a/src/aijack/collaborative/fedavg/server.py b/src/aijack/collaborative/fedavg/server.py index 43fded04..5be44915 100644 --- a/src/aijack/collaborative/fedavg/server.py +++ b/src/aijack/collaborative/fedavg/server.py @@ -1,11 +1,8 @@ -import copy - import numpy as np import torch -from mpi4py import MPI from ..core import BaseServer -from ..core.utils import GRADIENTS_TAG, PARAMETERS_TAG, RECEIVE_NAN_CODE +from ..core.utils import GRADIENTS_TAG, PARAMETERS_TAG from ..optimizer import AdamFLOptimizer, SGDFLOptimizer @@ -13,8 +10,8 @@ class FedAvgServer(BaseServer): """Server of FedAVG for single process simulation Args: - clients ([FedAvgClient]): a list of FedAVG clients - global_model (torch.nn.Module): global model + clients ([FedAvgClient] | [int]): a list of FedAVG clients or their ids. + global_model (torch.nn.Module): global model. server_id (int, optional): id of this server. Defaults to 0. lr (float, optional): learning rate. Defaults to 0.1. optimizer_type (str, optional): optimizer for the update of global model . Defaults to "sgd". @@ -31,12 +28,16 @@ def __init__( optimizer_type="sgd", server_side_update=True, optimizer_kwargs={}, + device="cpu", ): super(FedAvgServer, self).__init__(clients, global_model, server_id=server_id) self.lr = lr self._setup_optimizer(optimizer_type, **optimizer_kwargs) self.server_side_update = server_side_update - self.distribtue(force_send_model_state_dict=True) + self.distribute(force_send_model_state_dict=True) + self.device = device + + self.uploaded_gradients = [] def _setup_optimizer(self, optimizer_type, **kwargs): if optimizer_type == "sgd": @@ -57,7 +58,7 @@ def _setup_optimizer(self, optimizer_type, **kwargs): def action(self, use_gradients=True): self.receive(use_gradients) self.update(use_gradients) - self.distribtue() + self.distribute() def receive(self, use_gradients=True): """Receive the local models @@ -77,19 +78,24 @@ def update(self, use_gradients=True): use_gradients (bool, optional): If True, update the global model with aggregated local gradients. Defaults to True. """ if use_gradients: - self.updata_from_gradients() + self.update_from_gradients() else: self.update_from_parameters() + def _preprocess_local_gradients(self, uploaded_grad): + return uploaded_grad + def receive_local_gradients(self): """Receive local gradients""" - self.uploaded_gradients = [c.upload_gradients() for c in self.clients] + self.uploaded_gradients = [ + self._preprocess_local_gradients(c.upload_gradients()) for c in self.clients + ] def receive_local_parameters(self): """Receive local parameters""" self.uploaded_parameters = [c.upload_parameters() for c in self.clients] - def updata_from_gradients(self, weight=None): + def update_from_gradients(self, weight=None): """Update the global model with the local gradients. Args: @@ -136,113 +142,66 @@ def update_from_parameters(self, weight=None): self.server_model.load_state_dict(averaged_params) - def distribtue(self, force_send_model_state_dict=False): + def distribute(self, force_send_model_state_dict=False): """Distribute the current global model to each client. Args: force_send_model_state_dict (bool, optional): If True, send the global model as the dictionary of model state regardless of other parameters. Defaults to False. """ for client in self.clients: - if self.server_side_update or force_send_model_state_dict: - client.download(self.server_model.state_dict()) - else: - client.download(self.aggregated_gradients) + if type(client) != int: + if self.server_side_update or force_send_model_state_dict: + client.download(self.server_model.state_dict()) + else: + client.download(self.aggregated_gradients) -class MPIFedAVGServer(BaseServer): - def __init__( - self, - comm, - global_model, - myid, - client_ids, - server_id=0, - lr=0.1, - optimizer_type="sgd", - optimizer_kwargs={}, - server_side_update=True, - device="cpu", - ): - super(MPIFedAVGServer, self).__init__( - client_ids, global_model, server_id=server_id - ) - self.comm = comm - self.myid = myid - self.client_ids = client_ids +class MPIFedAvgServer: + """MPI Wrapper for FedAvgServer - self.lr = lr + Args: + comm: MPI.COMM_WORLD + server: the instance of FedAvgServer. The `clients` member variable shoud be the list of id. + """ + def __init__(self, comm, server): + self.comm = comm + self.server = server + self.num_clients = len(self.server.clients) self.round = 0 - self.num_clients = len(client_ids) - self.server_side_update = server_side_update - self.device = device - - self._setup_optimizer(optimizer_type, **optimizer_kwargs) - - def _setup_optimizer(self, optimizer_type, **kwargs): - if optimizer_type == "sgd": - self.optimizer = SGDFLOptimizer( - self.server_model.parameters(), lr=self.lr, **kwargs - ) - elif optimizer_type == "adam": - self.optimizer = AdamFLOptimizer( - self.server_model.parameters(), lr=self.lr, **kwargs - ) - elif optimizer_type == "none": - self.optimizer = None - else: - raise NotImplementedError( - f"{optimizer_type} is not supported. You can specify `sgd`, `adam`, or `none`." - ) - - def send_parameters(self): - global_parameters = [] - for params in self.server_model.parameters(): - global_parameters.append(copy.copy(params).reshape(-1).tolist()) - for client_id in self.client_ids: - self.comm.send(global_parameters, dest=client_id, tag=PARAMETERS_TAG) + def __call__(self, *args, **kwargs): + return self.server(*args, **kwargs) def action(self): - self.receive() - self.update() - self.send_parameters() + self.mpi_receive() + self.server.update() + self.mpi_distribute() self.round += 1 - def receive(self): - self.receive_local_gradients() + def mpi_receive(self): + self.mpi_receive_local_gradients() - def receive_local_gradients(self): - self.received_gradients = [] + def mpi_receive_local_gradients(self): + self.server.uploaded_gradients = [] - while len(self.received_gradients) < self.num_clients: + while len(self.server.uploaded_gradients) < self.num_clients: gradients_received = self.comm.recv(tag=GRADIENTS_TAG) - gradients_processed = [] - for grad in gradients_received: - gradients_processed.append(grad.to(self.device)) - if torch.sum(torch.isnan(gradients_processed[-1])): - print("the received gradients contains nan") - MPI.COMM_WORLD.Abort(RECEIVE_NAN_CODE) - - self.received_gradients.append(gradients_processed) - - def update(self): - self.updata_from_gradients() + self.server.uploaded_gradients.append( + self.server._preprocess_local_gradients(gradients_received) + ) - def _aggregate(self): - self.aggregated_gradients = [ - torch.zeros_like(params) for params in self.server_model.parameters() - ] - len_gradients = len(self.aggregated_gradients) + def mpi_distribute(self): + # global_parameters = [] + # for params in self.server.server_model.parameters(): + # global_parameters.append(copy.copy(params).reshape(-1).tolist()) - for gradients in self.received_gradients: - for gradient_id in range(len_gradients): - self.aggregated_gradients[gradient_id] = ( - gradients[gradient_id] * (1 / self.num_clients) - + self.aggregated_gradients[gradient_id] - ) + for client_id in self.server.clients: + self.comm.send( + self.server.server_model.state_dict(), + dest=client_id, + tag=PARAMETERS_TAG, + ) - def updata_from_gradients(self): - self._aggregate() - if self.server_side_update: - self.optimizer.step(self.aggregated_gradients) + def mpi_initialize(self): + self.mpi_distribute() diff --git a/src/aijack/collaborative/fedgems/server.py b/src/aijack/collaborative/fedgems/server.py index e3b18a32..6812543b 100644 --- a/src/aijack/collaborative/fedgems/server.py +++ b/src/aijack/collaborative/fedgems/server.py @@ -37,13 +37,13 @@ def __init__( ) * float("inf") def action(self): - self.distribtue() + self.distribute() def update(self, idxs, x): """Register the predicted logits to self.predicted_values""" self.predicted_values[idxs] = self.server_model(x).detach().to(self.device) - def distribtue(self): + def distribute(self): """Distribute the logits of public dataset to each client.""" for client in self.clients: client.download(self.predicted_values) diff --git a/src/aijack/collaborative/fedmd/__init__.py b/src/aijack/collaborative/fedmd/__init__.py index 8abf14c1..a0d52778 100644 --- a/src/aijack/collaborative/fedmd/__init__.py +++ b/src/aijack/collaborative/fedmd/__init__.py @@ -1,8 +1,8 @@ -from .api import FedMDAPI # noqa : F401 -from .client import FedMDClient # noqa : F401 +from .api import FedMDAPI, MPIFedMDAPI # noqa : F401 +from .client import FedMDClient, MPIFedMDClient # noqa : F401 from .nfdp import ( # noqa : F401 get_delta_of_fedmd_nfdp, get_epsilon_of_fedmd_nfdp, get_k_of_fedmd_nfdp, ) -from .server import FedMDServer # noqa : F401 +from .server import FedMDServer, MPIFedMDServer # noqa : F401 diff --git a/src/aijack/collaborative/fedmd/api.py b/src/aijack/collaborative/fedmd/api.py index ebc4546f..a59af296 100644 --- a/src/aijack/collaborative/fedmd/api.py +++ b/src/aijack/collaborative/fedmd/api.py @@ -2,7 +2,7 @@ import torch -from ..core.api import BaseFLKnowledgeDistillationAPI +from ..core.api import BaseFedAPI, BaseFLKnowledgeDistillationAPI class FedMDAPI(BaseFLKnowledgeDistillationAPI): @@ -49,7 +49,8 @@ def __init__( def train_server(self): if self.server_optimizer is None: - raise ValueError("server_optimzier does not exist") + return 0.0 + running_loss = 0.0 for data in self.public_dataloader: _, x, y = data @@ -86,7 +87,9 @@ def run(self): logging["loss_client_public_dataset_transfer"].append(loss_public) if self.validation_dataloader is not None: - acc_val = self.score(self.validation_dataloader) + acc_val = self.score( + self.validation_dataloader, self.server_optimizer is None + ) print("acc on validation dataset: ", acc_val) logging["acc_val"].append(copy.deepcopy(acc_val)) @@ -96,7 +99,9 @@ def run(self): logging["loss_client_local_dataset_transfer"].append(loss_local) if self.validation_dataloader is not None: - acc_val = self.score(self.validation_dataloader) + acc_val = self.score( + self.validation_dataloader, self.server_optimizer is None + ) print("acc on validation dataset: ", acc_val) logging["acc_val"].append(copy.deepcopy(acc_val)) @@ -104,8 +109,7 @@ def run(self): self.epoch = i - self.server.update() - self.server.distribute() + self.server.action() # Digest temp_consensus_loss = [] @@ -132,15 +136,94 @@ def run(self): acc_on_local_dataset = self.local_score() print(f"epoch={i} acc on local datasets: ", acc_on_local_dataset) logging["acc_local"].append(acc_on_local_dataset) - acc_pub = self.score(self.public_dataloader) + acc_pub = self.score(self.public_dataloader, self.server_optimizer is None) print(f"epoch={i} acc on public dataset: ", acc_pub) logging["acc_pub"].append(copy.deepcopy(acc_pub)) # evaluation if self.validation_dataloader is not None: - acc_val = self.score(self.validation_dataloader) + acc_val = self.score( + self.validation_dataloader, self.server_optimizer is None + ) print(f"epoch={i} acc on validation dataset: ", acc_val) logging["acc_val"].append(copy.deepcopy(acc_val)) self.custom_action(self) return logging + + +class MPIFedMDAPI(BaseFedAPI): + def __init__( + self, + comm, + party, + is_server, + criterion, + local_optimizer=None, + local_dataloader=None, + public_dataloader=None, + num_communication=1, + local_epoch=1, + consensus_epoch=1, + revisit_epoch=1, + transfer_epoch_public=1, + transfer_epoch_private=1, + custom_action=lambda x: x, + device="cpu", + ): + self.comm = comm + self.party = party + self.is_server = is_server + self.criterion = criterion + self.local_optimizer = local_optimizer + self.local_dataloader = local_dataloader + self.public_dataloader = public_dataloader + self.num_communication = num_communication + self.local_epoch = local_epoch + self.consensus_epoch = consensus_epoch + self.revisit_epoch = revisit_epoch + self.transfer_epoch_public = transfer_epoch_public + self.transfer_epoch_private = transfer_epoch_private + self.custom_action = custom_action + self.device = device + + def run(self): + for _ in range(self.num_communication): + + # Transfer phase + if not self.is_server: + for _ in range(1, self.transfer_epoch_public + 1): + self.train_client(public=True) + for _ in range(1, self.transfer_epoch_private + 1): + self.train_client(public=False) + + # Updata global logits + self.party.action() + self.comm.Barrier() + + # Digest phase + if not self.is_server: + for _ in range(1, self.consensus_epoch): + self.party.client.approach_consensus(self.local_optimizer) + + # Revisit phase + if not self.is_server: + for _ in range(self.revisit_epoch): + self.train_client(public=False) + + self.custom_action(self) + self.comm.Barrier() + + def train_client(self, public=True): + for _ in range(self.local_epoch): + running_loss = 0 + trainloader = self.public_dataloader if public else self.local_dataloader + for (_, data, target) in trainloader: + self.local_optimizer.zero_grad() + data = data.to(self.device) + target = target.to(self.device) + output = self.party.client.model(data) + loss = self.criterion(output, target) + loss.backward() + self.local_optimizer.step() + running_loss += loss.item() diff --git a/src/aijack/collaborative/fedmd/client.py b/src/aijack/collaborative/fedmd/client.py index 2f216bf3..8a59c39d 100644 --- a/src/aijack/collaborative/fedmd/client.py +++ b/src/aijack/collaborative/fedmd/client.py @@ -3,6 +3,7 @@ from ...utils.utils import torch_round_x_decimal from ..core import BaseClient +from ..core.utils import GLOBAL_LOGIT_TAG, LOCAL_LOGIT_TAG class FedMDClient(BaseClient): @@ -63,3 +64,29 @@ def approach_consensus(self, consensus_optimizer): running_loss += loss.item() return running_loss + + +class MPIFedMDClient: + def __init__(self, comm, client): + self.comm = comm + self.client = client + + def __call__(self, *args, **kwargs): + return self.client(*args, **kwargs) + + def action(self): + self.mpi_upload() + self.client.model.zero_grad() + self.mpi_download() + + def mpi_upload(self): + self.mpi_upload_logits() + + def mpi_upload_logits(self, destination_id=0): + self.comm.send(self.client.upload(), dest=destination_id, tag=LOCAL_LOGIT_TAG) + + def mpi_download(self): + self.client.download(self.comm.recv(tag=GLOBAL_LOGIT_TAG)) + + def mpi_initialize(self): + self.mpi_download() diff --git a/src/aijack/collaborative/fedmd/server.py b/src/aijack/collaborative/fedmd/server.py index fbf06871..f4d3eeb4 100644 --- a/src/aijack/collaborative/fedmd/server.py +++ b/src/aijack/collaborative/fedmd/server.py @@ -1,4 +1,5 @@ from ..core import BaseServer +from ..core.utils import GLOBAL_LOGIT_TAG, LOCAL_LOGIT_TAG class FedMDServer(BaseServer): @@ -12,6 +13,8 @@ def __init__( super(FedMDServer, self).__init__(clients, server_model, server_id=server_id) self.device = device + self.uploaded_logits = [] + def forward(self, x): if self.server_model is not None: return self.server_model(x) @@ -19,15 +22,58 @@ def forward(self, x): return None def action(self): + self.receive() self.update() - self.distribtue() + self.distribute() + + def receive(self): + self.uploaded_logits = [client.upload() for client in self.clients] def update(self): - self.consensus = self.clients[0].upload() / len(self.clients) - for client in self.clients[1:]: - self.consensus += client.upload() / len(self.clients) + self.consensus = self.uploaded_logits[0] + len_clients = len(self.clients) + for logit in self.uploaded_logits[1:]: + self.consensus += logit / len_clients def distribute(self): """Distribute the logits of public dataset to each client.""" for client in self.clients: client.download(self.consensus) + + +class MPIFedMDServer: + """MPI Wrapper for FedMDServer + + Args: + comm: MPI.COMM_WORLD + server: the instance of FedAvgServer. The `clients` member variable shoud be the list of id. + """ + + def __init__(self, comm, server): + self.comm = comm + self.server = server + self.num_clients = len(self.server.clients) + self.round = 0 + + def __call__(self, *args, **kwargs): + return self.server(*args, **kwargs) + + def action(self): + self.mpi_receive() + self.server.update() + self.mpi_distribute() + self.round += 1 + + def mpi_receive(self): + self.mpi_receive_local_logits() + + def mpi_receive_local_logits(self): + self.server.uploaded_logits = [] + + while len(self.server.uploaded_logits) < self.num_clients: + received_logits = self.comm.recv(tag=LOCAL_LOGIT_TAG) + self.server.uploaded_logits.append(received_logits) + + def mpi_distribute(self): + for client_id in self.server.clients: + self.comm.send(self.server.consensus, dest=client_id, tag=GLOBAL_LOGIT_TAG) diff --git a/src/aijack/defense/mid/loss.py b/src/aijack/defense/mid/loss.py index 92933184..50a73c46 100644 --- a/src/aijack/defense/mid/loss.py +++ b/src/aijack/defense/mid/loss.py @@ -32,7 +32,7 @@ def mib_loss( ) ) - loss = torch.nn.CrossEntropyLoss(reduce=False) + loss = torch.nn.CrossEntropyLoss(reduction="none") cross_entropy_loss = loss( sampled_y_pred, y[:, None].expand(-1, sampled_y_pred.size()[-1]) ) diff --git a/src/aijack/defense/sparse/__init__.py b/src/aijack/defense/sparse/__init__.py new file mode 100644 index 00000000..403d435d --- /dev/null +++ b/src/aijack/defense/sparse/__init__.py @@ -0,0 +1 @@ +from .topk import SparseGradientClientManager, SparseGradientServerManager # noqa: F401 diff --git a/src/aijack/defense/sparse/topk.py b/src/aijack/defense/sparse/topk.py new file mode 100644 index 00000000..5b512385 --- /dev/null +++ b/src/aijack/defense/sparse/topk.py @@ -0,0 +1,84 @@ +import torch + +from ...manager import BaseManager + + +def attach_sparse_gradient_to_client(cls, k): + """Make the client class communicate the sparse gradients. + + Args: + cls: client class + k (int): strength of sparcity + """ + + class SparseGradientClientWrapper(cls): + def __init__(self, *args, **kwargs): + super(SparseGradientClientWrapper, self).__init__(*args, **kwargs) + + def upload_gradients(self): + """Upload sparse gradients""" + vanila_gradients = super().upload_gradients() + sparse_gradients = [] + sparse_indices = [] + for vanila_grad in vanila_gradients: + temp_grad = vanila_grad.reshape(-1) + # only send top-k gradients + topk_indices = torch.topk( + torch.abs(temp_grad), k=int(len(temp_grad) * k) + ).indices + sparse_gradients.append(temp_grad[topk_indices].tolist()) + sparse_indices.append(topk_indices.tolist()) + + return sparse_gradients, sparse_indices + + return SparseGradientClientWrapper + + +def attach_sparse_gradient_to_server(cls): + """Make the server class communicate the sparse gradients. + + Args: + cls: server class + """ + + class SparseGradientServerWrapper(cls): + def __init__(self, *args, **kwargs): + super(SparseGradientServerWrapper, self).__init__(*args, **kwargs) + + def _preprocess_local_gradients(self, uploaded_grad): + sparse_gradients_flattend, sparse_indices = uploaded_grad + gradients_reshaped = [] + for params, grad, idx in zip( + self.server_model.parameters(), + sparse_gradients_flattend, + sparse_indices, + ): + temp_grad = torch.zeros_like(params).reshape(-1) + temp_grad[idx] = torch.Tensor(grad).to(self.device) + gradients_reshaped.append(temp_grad.reshape(params.shape)) + + return gradients_reshaped + + return SparseGradientServerWrapper + + +class SparseGradientClientManager(BaseManager): + """Client-side Manager for sparse gradients.""" + + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + def attach(self, cls): + return attach_sparse_gradient_to_client(cls, *self.args, **self.kwargs) + + +class SparseGradientServerManager(BaseManager): + """Server-side Manager for sparse gradients.""" + + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + def attach(self, cls): + return attach_sparse_gradient_to_server(cls, *self.args, **self.kwargs) diff --git a/test/attack/freerider/test_freerider.py b/test/attack/freerider/test_freerider.py new file mode 100644 index 00000000..22214938 --- /dev/null +++ b/test/attack/freerider/test_freerider.py @@ -0,0 +1,62 @@ +def test_fedavg_delta_weight(): + import torch + import torch.nn as nn + import torch.optim as optim + + from aijack.attack.freerider import FreeRiderClientManager + from aijack.collaborative import FedAvgClient, FedAvgServer + + torch.manual_seed(0) + + lr = 0.01 + epochs = 2 + client_num = 2 + + class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.lin = nn.Sequential(nn.Linear(28 * 28, 10)) + + def forward(self, x): + x = x.reshape((-1, 28 * 28)) + x = self.lin(x) + return x + + x = torch.load("test/demodata/demo_mnist_x.pt") + x.requires_grad = True + y = torch.load("test/demodata/demo_mnist_y.pt") + + manager = FreeRiderClientManager(mu=0, sigma=1.0) + FreeRiderFedAvgClient = manager.attach(FedAvgClient) + + clients = [ + FreeRiderFedAvgClient(Net(), user_id=i, lr=lr, server_side_update=False) + for i in range(client_num) + ] + optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients] + + global_model = Net() + server = FedAvgServer(clients, global_model, lr=lr, server_side_update=False) + + criterion = nn.CrossEntropyLoss() + + loss_log = [] + for _ in range(epochs): + temp_loss = 0 + for client_idx in range(client_num): + client = clients[client_idx] + optimizer = optimizers[client_idx] + + optimizer.zero_grad() + client.zero_grad() + + outputs = client(x) + loss = criterion(outputs, y.to(torch.int64)) + client.backward(loss) + temp_loss = loss.item() / client_num + + optimizer.step() + + loss_log.append(temp_loss) + + server.action(use_gradients=True) diff --git a/test/attack/inversion/test_ganattack.py b/test/attack/inversion/test_ganattack.py index aa5ba5f3..44627421 100644 --- a/test/attack/inversion/test_ganattack.py +++ b/test/attack/inversion/test_ganattack.py @@ -55,7 +55,10 @@ def __init__(self): ) self.lin = nn.Sequential( - nn.Linear(256, 200), nn.Tanh(), nn.Linear(200, 11), nn.LogSoftmax() + nn.Linear(256, 200), + nn.Tanh(), + nn.Linear(200, 11), + nn.LogSoftmax(dim=-1), ) def forward(self, x): diff --git a/test/defense/sparse/test_sparse.py b/test/defense/sparse/test_sparse.py new file mode 100644 index 00000000..7be40ee5 --- /dev/null +++ b/test/defense/sparse/test_sparse.py @@ -0,0 +1,72 @@ +def test_fedavg_sparse_gradient(): + import torch + import torch.nn as nn + import torch.optim as optim + + from aijack.collaborative import FedAvgClient, FedAvgServer + from aijack.defense.sparse import ( + SparseGradientClientManager, + SparseGradientServerManager, + ) + + torch.manual_seed(0) + + lr = 0.01 + epochs = 2 + client_num = 2 + + class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.lin = nn.Sequential(nn.Linear(28 * 28, 10)) + + def forward(self, x): + x = x.reshape((-1, 28 * 28)) + x = self.lin(x) + return x + + x = torch.load("test/demodata/demo_mnist_x.pt") + x.requires_grad = True + y = torch.load("test/demodata/demo_mnist_y.pt") + + client_manager = SparseGradientClientManager(k=0.3) + SparseGradientFedAvgClient = client_manager.attach(FedAvgClient) + + server_manager = SparseGradientServerManager() + SparseGradientFedAvgServer = server_manager.attach(FedAvgServer) + + clients = [ + SparseGradientFedAvgClient(Net(), user_id=i, lr=lr, server_side_update=False) + for i in range(client_num) + ] + optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients] + + global_model = Net() + server = SparseGradientFedAvgServer( + clients, global_model, lr=lr, server_side_update=False + ) + + criterion = nn.CrossEntropyLoss() + + loss_log = [] + for _ in range(epochs): + temp_loss = 0 + for client_idx in range(client_num): + client = clients[client_idx] + optimizer = optimizers[client_idx] + + optimizer.zero_grad() + client.zero_grad() + + outputs = client(x) + loss = criterion(outputs, y.to(torch.int64)) + client.backward(loss) + temp_loss = loss.item() / client_num + + optimizer.step() + + loss_log.append(temp_loss) + + server.action(use_gradients=True) + + assert loss_log[0] > loss_log[1]