diff --git a/docs/aijack_fedavg.ipynb b/docs/aijack_fedavg.ipynb index c926174f..2cd27c8a 100644 --- a/docs/aijack_fedavg.ipynb +++ b/docs/aijack_fedavg.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "SROAZ9dO80s8" @@ -38,8 +39,7 @@ "from mpi4py import MPI\n", "from torchvision import datasets, transforms\n", "\n", - "from aijack.collaborative import FedAvgClient, FedAvgServer\n", - "from aijack.collaborative.fedavg import FedAVGAPI" + "from aijack.collaborative import FedAVGClient, FedAVGServer, FedAVGAPI" ] }, { @@ -300,10 +300,10 @@ } ], "source": [ - "clients = [FedAvgClient(Net().to(device), user_id=c) for c in range(client_size)]\n", + "clients = [FedAVGClient(Net().to(device), user_id=c) for c in range(client_size)]\n", "local_optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients]\n", "\n", - "server = FedAvgServer(clients, Net().to(device))\n", + "server = FedAVGServer(clients, Net().to(device))\n", "\n", "for round in range(1, num_rounds + 1):\n", " for client, local_trainloader, local_optimizer in zip(\n", @@ -386,15 +386,15 @@ "pk, sk = keygenerator.generate_keypair()\n", "\n", "manager = PaillierGradientClientManager(pk, sk)\n", - "PaillierGradFedAvgClient = manager.attach(FedAvgClient)\n", + "PaillierGradFedAVGClient = manager.attach(FedAVGClient)\n", "\n", "clients = [\n", - " PaillierGradFedAvgClient(Net().to(device), user_id=c, server_side_update=False)\n", + " PaillierGradFedAVGClient(Net().to(device), user_id=c, server_side_update=False)\n", " for c in range(client_size)\n", "]\n", "local_optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients]\n", "\n", - "server = FedAvgServer(clients, Net().to(device), server_side_update=False)\n", + "server = FedAVGServer(clients, Net().to(device), server_side_update=False)\n", "\n", "for round in range(1, num_rounds + 1):\n", " for client, local_trainloader, local_optimizer in zip(\n", @@ -498,7 +498,7 @@ } ], "source": [ - "%%writefile mpi_fedavg.py\n", + "%%writefile mpi_FedAVG.py\n", "import random\n", "from logging import getLogger\n", "\n", @@ -510,8 +510,7 @@ "from mpi4py import MPI\n", "from torchvision import datasets, transforms\n", "\n", - "from aijack.collaborative import FedAvgClient, FedAvgServer\n", - "from aijack.collaborative.fedavg import MPIFedAVGAPI, MPIFedAvgClient, MPIFedAvgServer\n", + "from aijack.collaborative import FedAVGClient, FedAVGServer, MPIFedAVGAPI, MPIFedAVGClient, MPIFedAVGServer\n", "\n", "logger = getLogger(__name__)\n", "\n", @@ -602,7 +601,7 @@ " if myid == 0:\n", " dataloader = prepare_dataloader(size - 1, myid, train=False)\n", " client_ids = list(range(1, size))\n", - " server = MPIFedAvgServer(comm, FedAvgServer([1, 2], model))\n", + " server = MPIFedAVGServer(comm, FedAVGServer([1, 2], model))\n", " api = MPIFedAVGAPI(\n", " comm,\n", " server,\n", @@ -617,7 +616,7 @@ " )\n", " else:\n", " dataloader = prepare_dataloader(size - 1, myid, train=True)\n", - " client = MPIFedAvgClient(comm, FedAvgClient(model, user_id=myid))\n", + " client = MPIFedAVGClient(comm, FedAVGClient(model, user_id=myid))\n", " api = MPIFedAVGAPI(\n", " comm,\n", " client,\n", @@ -661,7 +660,7 @@ } ], "source": [ - "!mpiexec -np 3 --allow-run-as-root python /content/mpi_fedavg.py" + "!mpiexec -np 3 --allow-run-as-root python /content/mpi_FedAVG.py" ] }, { @@ -693,7 +692,7 @@ } ], "source": [ - "%%writefile mpi_fedavg_sparse.py\n", + "%%writefile mpi_FedAVG_sparse.py\n", "import random\n", "from logging import getLogger\n", "\n", @@ -705,8 +704,7 @@ "from mpi4py import MPI\n", "from torchvision import datasets, transforms\n", "\n", - "from aijack.collaborative import FedAvgClient, FedAvgServer\n", - "from aijack.collaborative.fedavg import MPIFedAVGAPI, MPIFedAvgClient, MPIFedAvgServer\n", + "from aijack.collaborative import FedAVGClient, FedAVGServer, MPIFedAVGAPI, MPIFedAVGClient, MPIFedAVGServer\n", "from aijack.defense.sparse import (\n", " SparseGradientClientManager,\n", " SparseGradientServerManager,\n", @@ -799,15 +797,15 @@ " optimizer = optim.SGD(model.parameters(), lr=lr)\n", "\n", " client_manager = SparseGradientClientManager(k=0.03)\n", - " SparseGradientFedAvgClient = client_manager.attach(FedAvgClient)\n", + " SparseGradientFedAVGClient = client_manager.attach(FedAVGClient)\n", "\n", " server_manager = SparseGradientServerManager()\n", - " SparseGradientFedAvgServer = server_manager.attach(FedAvgServer)\n", + " SparseGradientFedAVGServer = server_manager.attach(FedAVGServer)\n", "\n", " if myid == 0:\n", " dataloader = prepare_dataloader(size - 1, myid, train=False)\n", " client_ids = list(range(1, size))\n", - " server = MPIFedAvgServer(comm, SparseGradientFedAvgServer([1, 2], model))\n", + " server = MPIFedAVGServer(comm, SparseGradientFedAVGServer([1, 2], model))\n", " api = MPIFedAVGAPI(\n", " comm,\n", " server,\n", @@ -822,7 +820,7 @@ " )\n", " else:\n", " dataloader = prepare_dataloader(size - 1, myid, train=True)\n", - " client = MPIFedAvgClient(comm, SparseGradientFedAvgClient(model, user_id=myid))\n", + " client = MPIFedAVGClient(comm, SparseGradientFedAVGClient(model, user_id=myid))\n", " api = MPIFedAVGAPI(\n", " comm,\n", " client,\n", @@ -866,7 +864,7 @@ } ], "source": [ - "!mpiexec -np 3 --allow-run-as-root python /content/mpi_fedavg_sparse.py" + "!mpiexec -np 3 --allow-run-as-root python /content/mpi_FedAVG_sparse.py" ] } ], @@ -894,4 +892,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} \ No newline at end of file +} diff --git a/docs/aijack_soteria.ipynb b/docs/aijack_soteria.ipynb index dc2e7430..083ecdcf 100644 --- a/docs/aijack_soteria.ipynb +++ b/docs/aijack_soteria.ipynb @@ -9,10 +9,21 @@ } ], "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, "language_info": { - "name": "python" + "name": "python", + "version": "3.9.1 (tags/v3.9.1:1e5d33e, Dec 7 2020, 17:08:21) [MSC v.1927 64 bit (AMD64)]" }, - "orig_nbformat": 4 + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "caa2b01f75ba60e629eaa9e4dabde0c46b243c9a0484934eeb17ad8b3fc9c91a" + } + } }, "nbformat": 4, "nbformat_minor": 2 diff --git a/src/aijack/attack/backdoor/__init__.py b/src/aijack/attack/backdoor/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/aijack/attack/backdoor/dba.py b/src/aijack/attack/backdoor/dba.py new file mode 100644 index 00000000..e69de29b diff --git a/src/aijack/collaborative/fedavg/__init__.py b/src/aijack/collaborative/fedavg/__init__.py index ff033311..ce4b2588 100644 --- a/src/aijack/collaborative/fedavg/__init__.py +++ b/src/aijack/collaborative/fedavg/__init__.py @@ -1,3 +1,3 @@ from .api import FedAVGAPI, MPIFedAVGAPI # noqa: F401 -from .client import FedAvgClient, MPIFedAvgClient # noqa: F401 -from .server import FedAvgServer, MPIFedAvgServer # noqa: F401 +from .client import FedAVGClient, MPIFedAVGClient # noqa: F401 +from .server import FedAVGServer, MPIFedAVGServer # noqa: F401 diff --git a/src/aijack/collaborative/fedavg/api.py b/src/aijack/collaborative/fedavg/api.py index c4e245c4..79d74912 100644 --- a/src/aijack/collaborative/fedavg/api.py +++ b/src/aijack/collaborative/fedavg/api.py @@ -54,41 +54,19 @@ def __init__( for dataset_size in local_dataset_sizes ] - def local_train(self, com): + def local_train(self, i): for client_idx in range(self.client_num): - client = self.clients[client_idx] - trainloader = self.local_dataloaders[client_idx] - optimizer = self.local_optimizers[client_idx] - - for i in range(self.local_epoch): - running_loss = 0.0 - running_data_num = 0 - for _, data in enumerate(trainloader, 0): - inputs, labels = data - inputs = inputs.to(self.device) - inputs.requires_grad = True - labels = labels.to(self.device) - - optimizer.zero_grad() - client.zero_grad() - - outputs = client(inputs) - loss = self.criterion(outputs, labels) - - loss.backward() - optimizer.step() - - running_loss += loss.item() - running_data_num += inputs.shape[0] - - print( - f"communication {com}, epoch {i}: client-{client_idx+1}", - running_loss / running_data_num, - ) + self.clients[client_idx].local_train( + self.local_epoch, + self.criterion, + self.local_dataloaders[client_idx], + self.local_optimizers[client_idx], + communication_id=i, + ) def run(self): - for com in range(self.num_communication): - self.local_train(com) + for i in range(self.num_communication): + self.local_train(i) self.server.receive(use_gradients=self.use_gradients) if self.use_gradients: self.server.updata_from_gradients(weight=self.clients_weight) @@ -128,27 +106,23 @@ def run(self): self.party.mpi_initialize() self.comm.Barrier() - for _ in range(self.num_communication): + for i in range(self.num_communication): if not self.is_server: - self.local_train() + self.local_train(i) self.party.action() self.custom_action(self) self.comm.Barrier() - def local_train(self): + def local_train(self, com_cnt): self.party.prev_parameters = [] for param in self.party.client.model.parameters(): self.party.client.prev_parameters.append(copy.deepcopy(param)) - for _ in range(self.local_epoch): - running_loss = 0 - for (data, target) in self.local_dataloader: - self.local_optimizer.zero_grad() - data = data.to(self.device) - target = target.to(self.device) - output = self.party.client.model(data) - loss = self.criterion(output, target) - loss.backward() - self.local_optimizer.step() - running_loss += loss.item() + self.party.client.local_train( + self.local_epoch, + self.criterion, + self.local_dataloader, + self.local_optimizer, + communication_id=com_cnt, + ) diff --git a/src/aijack/collaborative/fedavg/client.py b/src/aijack/collaborative/fedavg/client.py index 00f573f8..51bc568e 100644 --- a/src/aijack/collaborative/fedavg/client.py +++ b/src/aijack/collaborative/fedavg/client.py @@ -5,7 +5,7 @@ from ..optimizer import AdamFLOptimizer, SGDFLOptimizer -class FedAvgClient(BaseClient): +class FedAVGClient(BaseClient): """Client of FedAVG for single process simulation Args: @@ -30,7 +30,7 @@ def __init__( optimizer_kwargs_for_global_grad={}, device="cpu", ): - super(FedAvgClient, self).__init__(model, user_id=user_id) + super(FedAVGClient, self).__init__(model, user_id=user_id) self.lr = lr self.send_gradient = send_gradient self.server_side_update = server_side_update @@ -104,8 +104,37 @@ def download(self, new_global_model): for param in self.model.parameters(): self.prev_parameters.append(copy.deepcopy(param)) + def local_train( + self, local_epoch, criterion, trainloader, optimizer, communication_id=0 + ): + for i in range(local_epoch): + running_loss = 0.0 + running_data_num = 0 + for _, data in enumerate(trainloader, 0): + inputs, labels = data + inputs = inputs.to(self.device) + inputs.requires_grad = True + labels = labels.to(self.device) + + optimizer.zero_grad() + self.zero_grad() + + outputs = self(inputs) + loss = criterion(outputs, labels) + + loss.backward() + optimizer.step() + + running_loss += loss.item() + running_data_num += inputs.shape[0] + + print( + f"communication {communication_id}, epoch {i}: client-{self.user_id+1}", + running_loss / running_data_num, + ) -class MPIFedAvgClient: + +class MPIFedAVGClient(BaseClient): def __init__(self, comm, client): self.comm = comm self.client = client @@ -114,20 +143,20 @@ def __call__(self, *args, **kwargs): return self.client(*args, **kwargs) def action(self): - self.mpi_upload() + self.upload() self.client.model.zero_grad() - self.mpi_download() + self.download() - def mpi_upload(self): - self.mpi_upload_gradient() + def upload(self): + self.upload_gradient() - def mpi_upload_gradient(self, destination_id=0): + def upload_gradient(self, destination_id=0): self.comm.send( self.client.upload_gradients(), dest=destination_id, tag=GRADIENTS_TAG ) - def mpi_download(self): + def download(self): self.client.download(self.comm.recv(tag=PARAMETERS_TAG)) def mpi_initialize(self): - self.mpi_download() + self.download() diff --git a/src/aijack/collaborative/fedavg/server.py b/src/aijack/collaborative/fedavg/server.py index 5be44915..cc85cbcc 100644 --- a/src/aijack/collaborative/fedavg/server.py +++ b/src/aijack/collaborative/fedavg/server.py @@ -6,7 +6,7 @@ from ..optimizer import AdamFLOptimizer, SGDFLOptimizer -class FedAvgServer(BaseServer): +class FedAVGServer(BaseServer): """Server of FedAVG for single process simulation Args: @@ -30,7 +30,7 @@ def __init__( optimizer_kwargs={}, device="cpu", ): - super(FedAvgServer, self).__init__(clients, global_model, server_id=server_id) + super(FedAVGServer, self).__init__(clients, global_model, server_id=server_id) self.lr = lr self._setup_optimizer(optimizer_type, **optimizer_kwargs) self.server_side_update = server_side_update @@ -156,12 +156,12 @@ def distribute(self, force_send_model_state_dict=False): client.download(self.aggregated_gradients) -class MPIFedAvgServer: - """MPI Wrapper for FedAvgServer +class MPIFedAVGServer(BaseServer): + """MPI Wrapper for FedAVGServer Args: comm: MPI.COMM_WORLD - server: the instance of FedAvgServer. The `clients` member variable shoud be the list of id. + server: the instance of FedAVGServer. The `clients` member variable shoud be the list of id. """ def __init__(self, comm, server): @@ -174,15 +174,18 @@ def __call__(self, *args, **kwargs): return self.server(*args, **kwargs) def action(self): - self.mpi_receive() - self.server.update() - self.mpi_distribute() + self.receive() + self.update() + self.distribute() self.round += 1 - def mpi_receive(self): - self.mpi_receive_local_gradients() + def receive(self): + self.receive_local_gradients() - def mpi_receive_local_gradients(self): + def update(self): + self.server.update() + + def receive_local_gradients(self): self.server.uploaded_gradients = [] while len(self.server.uploaded_gradients) < self.num_clients: @@ -191,7 +194,7 @@ def mpi_receive_local_gradients(self): self.server._preprocess_local_gradients(gradients_received) ) - def mpi_distribute(self): + def distribute(self): # global_parameters = [] # for params in self.server.server_model.parameters(): # global_parameters.append(copy.copy(params).reshape(-1).tolist()) diff --git a/src/aijack/collaborative/fedkd/client.py b/src/aijack/collaborative/fedkd/client.py index 255a4966..e3491f3a 100644 --- a/src/aijack/collaborative/fedkd/client.py +++ b/src/aijack/collaborative/fedkd/client.py @@ -1,9 +1,9 @@ import torch.nn.functional as F -from ..fedavg import FedAvgClient +from ..fedavg import FedAVGClient -class FedKDClient(FedAvgClient): +class FedKDClient(FedAVGClient): """Implementation of FedKD (https://arxiv.org/abs/2108.13323)""" def __init__( diff --git a/src/aijack/collaborative/fedprox/api.py b/src/aijack/collaborative/fedprox/api.py index 16206dbf..2fec3df1 100644 --- a/src/aijack/collaborative/fedprox/api.py +++ b/src/aijack/collaborative/fedprox/api.py @@ -1,7 +1,5 @@ import copy -import torch - from ..fedavg import FedAVGAPI, MPIFedAVGAPI @@ -12,50 +10,20 @@ def __init__(self, *args, mu=0.01, **kwargs): super().__init__(*args, **kwargs) self.mu = mu - def run(self): - for com in range(self.num_communication): - for client_idx in range(self.client_num): - client = self.clients[client_idx] - trainloader = self.local_dataloaders[client_idx] - optimizer = self.local_optimizers[client_idx] - - for i in range(self.local_epoch): - running_loss = 0.0 - running_data_num = 0 - for _, data in enumerate(trainloader, 0): - inputs, labels = data - inputs = inputs.to(self.device) - inputs.requires_grad = True - labels = labels.to(self.device) - - optimizer.zero_grad() - client.zero_grad() - - outputs = client(inputs) - loss = self.criterion(outputs, labels) - loss.backward() - - for local_param, global_param in zip( - client.parameters(), self.server.parameters() - ): - loss += ( - self.mu - / 2 - * torch.norm(local_param.data - global_param.data) - ) - local_param.grad.data += self.mu * ( - local_param.data - global_param.data - ) + def local_train(self, i): + for client_idx in range(self.client_num): + self.clients[client_idx].local_train( + self.server.parameters(), + self.local_epoch, + self.criterion, + self.local_dataloaders[client_idx], + self.local_optimizers[client_idx], + communication_id=i, + ) - optimizer.step() - - running_loss += loss.item() - running_data_num += inputs.shape[0] - - print( - f"communication {com}, epoch {i}: client-{client_idx+1}", - running_loss / running_data_num, - ) + def run(self): + for i in range(self.num_communication): + self.local_train(i) self.server.receive(use_gradients=self.use_gradients) if self.use_gradients: @@ -71,31 +39,16 @@ def __init__(self, *args, mu=0.01, **kwargs): super().__init__(*args, **kwargs) self.mu = mu - def local_train(self): + def local_train(self, com_cnt): self.party.prev_parameters = [] for param in self.party.model.parameters(): self.party.prev_parameters.append(copy.deepcopy(param)) - for _ in range(self.local_epoch): - running_loss = 0 - for (data, target) in self.local_dataloader: - self.local_optimizer.zero_grad() - data = data.to(self.device) - target = target.to(self.device) - output = self.party.model(data) - - loss = self.criterion(output, target) - loss.backward() - - for local_param, global_param in zip( - self.party.model.parameters(), self.party.prev_parameters - ): - loss += ( - self.mu / 2 * torch.norm(local_param.data - global_param.data) - ) - local_param.grad.data += self.mu * ( - local_param.data - global_param.data - ) - - self.local_optimizer.step() - running_loss += loss.item() + self.party.client.local_train( + self.party.prev_parameters, + self.local_epoch, + self.criterion, + self.local_dataloader, + self.local_optimizer, + communication_id=com_cnt, + ) diff --git a/src/aijack/collaborative/fedprox/client.py b/src/aijack/collaborative/fedprox/client.py new file mode 100644 index 00000000..3d52987e --- /dev/null +++ b/src/aijack/collaborative/fedprox/client.py @@ -0,0 +1,50 @@ +import torch + +from ..fedavg import FedAVGClient + + +class FedProxClient(FedAVGClient): + def local_train( + self, + server_parameters, + local_epoch, + criterion, + trainloader, + optimizer, + communication_id=0, + ): + for i in range(local_epoch): + running_loss = 0.0 + running_data_num = 0 + for _, data in enumerate(trainloader, 0): + inputs, labels = data + inputs = inputs.to(self.device) + inputs.requires_grad = True + labels = labels.to(self.device) + + optimizer.zero_grad() + self.zero_grad() + + outputs = self(inputs) + loss = criterion(outputs, labels) + loss.backward() + + for local_param, global_param in zip( + self.parameters(), server_parameters + ): + loss += ( + self.mu / 2 * torch.norm(local_param.data - global_param.data) + ) + local_param.grad.data += self.mu * ( + local_param.data - global_param.data + ) + + optimizer.step() + + running_loss += loss.item() + running_data_num += inputs.shape[0] + + print( + f"communication {communication_id}, epoch {i}: client-{self.user_id+1}", + running_loss / running_data_num, + ) diff --git a/test/attack/freerider/test_freerider.py b/test/attack/freerider/test_freerider.py index 22214938..07520f1a 100644 --- a/test/attack/freerider/test_freerider.py +++ b/test/attack/freerider/test_freerider.py @@ -1,10 +1,10 @@ -def test_fedavg_delta_weight(): +def test_FedAVG_delta_weight(): import torch import torch.nn as nn import torch.optim as optim from aijack.attack.freerider import FreeRiderClientManager - from aijack.collaborative import FedAvgClient, FedAvgServer + from aijack.collaborative import FedAVGClient, FedAVGServer torch.manual_seed(0) @@ -27,16 +27,16 @@ def forward(self, x): y = torch.load("test/demodata/demo_mnist_y.pt") manager = FreeRiderClientManager(mu=0, sigma=1.0) - FreeRiderFedAvgClient = manager.attach(FedAvgClient) + FreeRiderFedAVGClient = manager.attach(FedAVGClient) clients = [ - FreeRiderFedAvgClient(Net(), user_id=i, lr=lr, server_side_update=False) + FreeRiderFedAVGClient(Net(), user_id=i, lr=lr, server_side_update=False) for i in range(client_num) ] optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients] global_model = Net() - server = FedAvgServer(clients, global_model, lr=lr, server_side_update=False) + server = FedAVGServer(clients, global_model, lr=lr, server_side_update=False) criterion = nn.CrossEntropyLoss() diff --git a/test/attack/inversion/test_ganattack.py b/test/attack/inversion/test_ganattack.py index 44627421..425c24ad 100644 --- a/test/attack/inversion/test_ganattack.py +++ b/test/attack/inversion/test_ganattack.py @@ -4,7 +4,7 @@ def test_ganattack(): import torch.optim as optim from aijack.attack import GANAttackManager - from aijack.collaborative import FedAvgClient, FedAvgServer + from aijack.collaborative import FedAVGClient, FedAVGServer nc = 1 nz = 100 @@ -68,7 +68,7 @@ def forward(self, x): return x net_1 = Net() - client_1 = FedAvgClient(net_1, user_id=0) + client_1 = FedAVGClient(net_1, user_id=0) optimizer_1 = optim.SGD( client_1.parameters(), lr=0.02, weight_decay=1e-7, momentum=0.9 ) @@ -87,10 +87,10 @@ def forward(self, x): criterion, nz=nz, ) - GANAttackFedAvgClient = manager.attach(FedAvgClient) + GANAttackFedAVGClient = manager.attach(FedAVGClient) net_2 = Net() - client_2 = GANAttackFedAvgClient(net_2, user_id=1) + client_2 = GANAttackFedAVGClient(net_2, user_id=1) optimizer_2 = optim.SGD( client_2.parameters(), lr=0.02, weight_decay=1e-7, momentum=0.9 ) @@ -99,7 +99,7 @@ def forward(self, x): optimizers = [optimizer_1, optimizer_2] global_model = Net() - server = FedAvgServer(clients, global_model) + server = FedAVGServer(clients, global_model) inputs = torch.load("test/demodata/demo_mnist_x.pt") labels = torch.load("test/demodata/demo_mnist_y.pt") diff --git a/test/collaborative/fedkd/test_fedkd.py b/test/collaborative/fedkd/test_fedkd.py index 2d6bf45e..7be95d75 100644 --- a/test/collaborative/fedkd/test_fedkd.py +++ b/test/collaborative/fedkd/test_fedkd.py @@ -3,7 +3,7 @@ def test_fedkd(): import torch.nn as nn import torch.optim as optim - from aijack.collaborative import FedAvgServer, FedKDClient + from aijack.collaborative import FedAVGServer, FedKDClient torch.manual_seed(0) @@ -58,7 +58,7 @@ def get_hidden_states(self): ] global_model = Net() - server = FedAvgServer(clients, global_model, lr=lr) + server = FedAVGServer(clients, global_model, lr=lr) teacher_loss_log = [] student_loss_log = [] diff --git a/test/defense/paillier/test_paillier.py b/test/defense/paillier/test_paillier.py index 639bcc1d..2e8efbe0 100644 --- a/test/defense/paillier/test_paillier.py +++ b/test/defense/paillier/test_paillier.py @@ -61,12 +61,12 @@ def test_paillier_torch(): ) -def test_pailier_fedavg(): +def test_pailier_FedAVG(): import torch import torch.nn as nn import torch.optim as optim - from aijack.collaborative import FedAvgClient, FedAvgServer + from aijack.collaborative import FedAVGClient, FedAVGServer from aijack.defense import PaillierGradientClientManager, PaillierKeyGenerator torch.manual_seed(0) @@ -93,16 +93,16 @@ def forward(self, x): y = torch.load("test/demodata/demo_mnist_y.pt") manager = PaillierGradientClientManager(pk, sk) - PaillierGradFedAvgClient = manager.attach(FedAvgClient) + PaillierGradFedAVGClient = manager.attach(FedAVGClient) clients = [ - PaillierGradFedAvgClient(Net(), user_id=i, lr=lr, server_side_update=False) + PaillierGradFedAVGClient(Net(), user_id=i, lr=lr, server_side_update=False) for i in range(client_num) ] optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients] global_model = Net() - server = FedAvgServer(clients, global_model, lr=lr, server_side_update=False) + server = FedAVGServer(clients, global_model, lr=lr, server_side_update=False) criterion = nn.CrossEntropyLoss() diff --git a/test/defense/soteria/test_soteria.py b/test/defense/soteria/test_soteria.py index b331a7d3..aa525c04 100644 --- a/test/defense/soteria/test_soteria.py +++ b/test/defense/soteria/test_soteria.py @@ -3,7 +3,7 @@ def test_soteria(): import torch.nn as nn import torch.optim as optim - from aijack.collaborative import FedAvgClient, FedAvgServer + from aijack.collaborative import FedAVGClient, FedAVGServer from aijack.defense import SoteriaManager torch.manual_seed(0) @@ -37,10 +37,10 @@ def forward(self, x): y = torch.load("test/demodata/demo_mnist_y.pt") manager = SoteriaManager("conv", "lin", target_layer_name="lin.0.weight") - SoteriaFedAvgClient = manager.attach(FedAvgClient) + SoteriaFedAVGClient = manager.attach(FedAVGClient) clients = [ - SoteriaFedAvgClient( + SoteriaFedAVGClient( Net(), user_id=i, lr=lr, @@ -50,7 +50,7 @@ def forward(self, x): optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients] global_model = Net() - server = FedAvgServer(clients, global_model, lr=lr) + server = FedAVGServer(clients, global_model, lr=lr) criterion = nn.CrossEntropyLoss() diff --git a/test/defense/sparse/test_sparse.py b/test/defense/sparse/test_sparse.py index 7be40ee5..2b3e9cd2 100644 --- a/test/defense/sparse/test_sparse.py +++ b/test/defense/sparse/test_sparse.py @@ -1,9 +1,9 @@ -def test_fedavg_sparse_gradient(): +def test_FedAVG_sparse_gradient(): import torch import torch.nn as nn import torch.optim as optim - from aijack.collaborative import FedAvgClient, FedAvgServer + from aijack.collaborative import FedAVGClient, FedAVGServer from aijack.defense.sparse import ( SparseGradientClientManager, SparseGradientServerManager, @@ -30,19 +30,19 @@ def forward(self, x): y = torch.load("test/demodata/demo_mnist_y.pt") client_manager = SparseGradientClientManager(k=0.3) - SparseGradientFedAvgClient = client_manager.attach(FedAvgClient) + SparseGradientFedAVGClient = client_manager.attach(FedAVGClient) server_manager = SparseGradientServerManager() - SparseGradientFedAvgServer = server_manager.attach(FedAvgServer) + SparseGradientFedAVGServer = server_manager.attach(FedAVGServer) clients = [ - SparseGradientFedAvgClient(Net(), user_id=i, lr=lr, server_side_update=False) + SparseGradientFedAVGClient(Net(), user_id=i, lr=lr, server_side_update=False) for i in range(client_num) ] optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients] global_model = Net() - server = SparseGradientFedAvgServer( + server = SparseGradientFedAVGServer( clients, global_model, lr=lr, server_side_update=False )