Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
Koukyosyumei committed Dec 27, 2022
1 parent 8e6830d commit c131f98
Show file tree
Hide file tree
Showing 17 changed files with 210 additions and 192 deletions.
42 changes: 20 additions & 22 deletions docs/aijack_fedavg.ipynb
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "SROAZ9dO80s8"
Expand Down Expand Up @@ -38,8 +39,7 @@
"from mpi4py import MPI\n",
"from torchvision import datasets, transforms\n",
"\n",
"from aijack.collaborative import FedAvgClient, FedAvgServer\n",
"from aijack.collaborative.fedavg import FedAVGAPI"
"from aijack.collaborative import FedAVGClient, FedAVGServer, FedAVGAPI"
]
},
{
Expand Down Expand Up @@ -300,10 +300,10 @@
}
],
"source": [
"clients = [FedAvgClient(Net().to(device), user_id=c) for c in range(client_size)]\n",
"clients = [FedAVGClient(Net().to(device), user_id=c) for c in range(client_size)]\n",
"local_optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients]\n",
"\n",
"server = FedAvgServer(clients, Net().to(device))\n",
"server = FedAVGServer(clients, Net().to(device))\n",
"\n",
"for round in range(1, num_rounds + 1):\n",
" for client, local_trainloader, local_optimizer in zip(\n",
Expand Down Expand Up @@ -386,15 +386,15 @@
"pk, sk = keygenerator.generate_keypair()\n",
"\n",
"manager = PaillierGradientClientManager(pk, sk)\n",
"PaillierGradFedAvgClient = manager.attach(FedAvgClient)\n",
"PaillierGradFedAVGClient = manager.attach(FedAVGClient)\n",
"\n",
"clients = [\n",
" PaillierGradFedAvgClient(Net().to(device), user_id=c, server_side_update=False)\n",
" PaillierGradFedAVGClient(Net().to(device), user_id=c, server_side_update=False)\n",
" for c in range(client_size)\n",
"]\n",
"local_optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients]\n",
"\n",
"server = FedAvgServer(clients, Net().to(device), server_side_update=False)\n",
"server = FedAVGServer(clients, Net().to(device), server_side_update=False)\n",
"\n",
"for round in range(1, num_rounds + 1):\n",
" for client, local_trainloader, local_optimizer in zip(\n",
Expand Down Expand Up @@ -498,7 +498,7 @@
}
],
"source": [
"%%writefile mpi_fedavg.py\n",
"%%writefile mpi_FedAVG.py\n",
"import random\n",
"from logging import getLogger\n",
"\n",
Expand All @@ -510,8 +510,7 @@
"from mpi4py import MPI\n",
"from torchvision import datasets, transforms\n",
"\n",
"from aijack.collaborative import FedAvgClient, FedAvgServer\n",
"from aijack.collaborative.fedavg import MPIFedAVGAPI, MPIFedAvgClient, MPIFedAvgServer\n",
"from aijack.collaborative import FedAVGClient, FedAVGServer, MPIFedAVGAPI, MPIFedAVGClient, MPIFedAVGServer\n",
"\n",
"logger = getLogger(__name__)\n",
"\n",
Expand Down Expand Up @@ -602,7 +601,7 @@
" if myid == 0:\n",
" dataloader = prepare_dataloader(size - 1, myid, train=False)\n",
" client_ids = list(range(1, size))\n",
" server = MPIFedAvgServer(comm, FedAvgServer([1, 2], model))\n",
" server = MPIFedAVGServer(comm, FedAVGServer([1, 2], model))\n",
" api = MPIFedAVGAPI(\n",
" comm,\n",
" server,\n",
Expand All @@ -617,7 +616,7 @@
" )\n",
" else:\n",
" dataloader = prepare_dataloader(size - 1, myid, train=True)\n",
" client = MPIFedAvgClient(comm, FedAvgClient(model, user_id=myid))\n",
" client = MPIFedAVGClient(comm, FedAVGClient(model, user_id=myid))\n",
" api = MPIFedAVGAPI(\n",
" comm,\n",
" client,\n",
Expand Down Expand Up @@ -661,7 +660,7 @@
}
],
"source": [
"!mpiexec -np 3 --allow-run-as-root python /content/mpi_fedavg.py"
"!mpiexec -np 3 --allow-run-as-root python /content/mpi_FedAVG.py"
]
},
{
Expand Down Expand Up @@ -693,7 +692,7 @@
}
],
"source": [
"%%writefile mpi_fedavg_sparse.py\n",
"%%writefile mpi_FedAVG_sparse.py\n",
"import random\n",
"from logging import getLogger\n",
"\n",
Expand All @@ -705,8 +704,7 @@
"from mpi4py import MPI\n",
"from torchvision import datasets, transforms\n",
"\n",
"from aijack.collaborative import FedAvgClient, FedAvgServer\n",
"from aijack.collaborative.fedavg import MPIFedAVGAPI, MPIFedAvgClient, MPIFedAvgServer\n",
"from aijack.collaborative import FedAVGClient, FedAVGServer, MPIFedAVGAPI, MPIFedAVGClient, MPIFedAVGServer\n",
"from aijack.defense.sparse import (\n",
" SparseGradientClientManager,\n",
" SparseGradientServerManager,\n",
Expand Down Expand Up @@ -799,15 +797,15 @@
" optimizer = optim.SGD(model.parameters(), lr=lr)\n",
"\n",
" client_manager = SparseGradientClientManager(k=0.03)\n",
" SparseGradientFedAvgClient = client_manager.attach(FedAvgClient)\n",
" SparseGradientFedAVGClient = client_manager.attach(FedAVGClient)\n",
"\n",
" server_manager = SparseGradientServerManager()\n",
" SparseGradientFedAvgServer = server_manager.attach(FedAvgServer)\n",
" SparseGradientFedAVGServer = server_manager.attach(FedAVGServer)\n",
"\n",
" if myid == 0:\n",
" dataloader = prepare_dataloader(size - 1, myid, train=False)\n",
" client_ids = list(range(1, size))\n",
" server = MPIFedAvgServer(comm, SparseGradientFedAvgServer([1, 2], model))\n",
" server = MPIFedAVGServer(comm, SparseGradientFedAVGServer([1, 2], model))\n",
" api = MPIFedAVGAPI(\n",
" comm,\n",
" server,\n",
Expand All @@ -822,7 +820,7 @@
" )\n",
" else:\n",
" dataloader = prepare_dataloader(size - 1, myid, train=True)\n",
" client = MPIFedAvgClient(comm, SparseGradientFedAvgClient(model, user_id=myid))\n",
" client = MPIFedAVGClient(comm, SparseGradientFedAVGClient(model, user_id=myid))\n",
" api = MPIFedAVGAPI(\n",
" comm,\n",
" client,\n",
Expand Down Expand Up @@ -866,7 +864,7 @@
}
],
"source": [
"!mpiexec -np 3 --allow-run-as-root python /content/mpi_fedavg_sparse.py"
"!mpiexec -np 3 --allow-run-as-root python /content/mpi_FedAVG_sparse.py"
]
}
],
Expand Down Expand Up @@ -894,4 +892,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
15 changes: 13 additions & 2 deletions docs/aijack_soteria.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,21 @@
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python"
"name": "python",
"version": "3.9.1 (tags/v3.9.1:1e5d33e, Dec 7 2020, 17:08:21) [MSC v.1927 64 bit (AMD64)]"
},
"orig_nbformat": 4
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "caa2b01f75ba60e629eaa9e4dabde0c46b243c9a0484934eeb17ad8b3fc9c91a"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
Expand Down
Empty file.
Empty file.
4 changes: 2 additions & 2 deletions src/aijack/collaborative/fedavg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .api import FedAVGAPI, MPIFedAVGAPI # noqa: F401
from .client import FedAvgClient, MPIFedAvgClient # noqa: F401
from .server import FedAvgServer, MPIFedAvgServer # noqa: F401
from .client import FedAVGClient, MPIFedAVGClient # noqa: F401
from .server import FedAVGServer, MPIFedAVGServer # noqa: F401
66 changes: 20 additions & 46 deletions src/aijack/collaborative/fedavg/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,41 +54,19 @@ def __init__(
for dataset_size in local_dataset_sizes
]

def local_train(self, com):
def local_train(self, i):
for client_idx in range(self.client_num):
client = self.clients[client_idx]
trainloader = self.local_dataloaders[client_idx]
optimizer = self.local_optimizers[client_idx]

for i in range(self.local_epoch):
running_loss = 0.0
running_data_num = 0
for _, data in enumerate(trainloader, 0):
inputs, labels = data
inputs = inputs.to(self.device)
inputs.requires_grad = True
labels = labels.to(self.device)

optimizer.zero_grad()
client.zero_grad()

outputs = client(inputs)
loss = self.criterion(outputs, labels)

loss.backward()
optimizer.step()

running_loss += loss.item()
running_data_num += inputs.shape[0]

print(
f"communication {com}, epoch {i}: client-{client_idx+1}",
running_loss / running_data_num,
)
self.clients[client_idx].local_train(
self.local_epoch,
self.criterion,
self.local_dataloaders[client_idx],
self.local_optimizers[client_idx],
communication_id=i,
)

def run(self):
for com in range(self.num_communication):
self.local_train(com)
for i in range(self.num_communication):
self.local_train(i)
self.server.receive(use_gradients=self.use_gradients)
if self.use_gradients:
self.server.updata_from_gradients(weight=self.clients_weight)
Expand Down Expand Up @@ -128,27 +106,23 @@ def run(self):
self.party.mpi_initialize()
self.comm.Barrier()

for _ in range(self.num_communication):
for i in range(self.num_communication):
if not self.is_server:
self.local_train()
self.local_train(i)
self.party.action()

self.custom_action(self)
self.comm.Barrier()

def local_train(self):
def local_train(self, com_cnt):
self.party.prev_parameters = []
for param in self.party.client.model.parameters():
self.party.client.prev_parameters.append(copy.deepcopy(param))

for _ in range(self.local_epoch):
running_loss = 0
for (data, target) in self.local_dataloader:
self.local_optimizer.zero_grad()
data = data.to(self.device)
target = target.to(self.device)
output = self.party.client.model(data)
loss = self.criterion(output, target)
loss.backward()
self.local_optimizer.step()
running_loss += loss.item()
self.party.client.local_train(
self.local_epoch,
self.criterion,
self.local_dataloader,
self.local_optimizer,
communication_id=com_cnt,
)
49 changes: 39 additions & 10 deletions src/aijack/collaborative/fedavg/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..optimizer import AdamFLOptimizer, SGDFLOptimizer


class FedAvgClient(BaseClient):
class FedAVGClient(BaseClient):
"""Client of FedAVG for single process simulation
Args:
Expand All @@ -30,7 +30,7 @@ def __init__(
optimizer_kwargs_for_global_grad={},
device="cpu",
):
super(FedAvgClient, self).__init__(model, user_id=user_id)
super(FedAVGClient, self).__init__(model, user_id=user_id)
self.lr = lr
self.send_gradient = send_gradient
self.server_side_update = server_side_update
Expand Down Expand Up @@ -104,8 +104,37 @@ def download(self, new_global_model):
for param in self.model.parameters():
self.prev_parameters.append(copy.deepcopy(param))

def local_train(
self, local_epoch, criterion, trainloader, optimizer, communication_id=0
):
for i in range(local_epoch):
running_loss = 0.0
running_data_num = 0
for _, data in enumerate(trainloader, 0):
inputs, labels = data
inputs = inputs.to(self.device)
inputs.requires_grad = True
labels = labels.to(self.device)

optimizer.zero_grad()
self.zero_grad()

outputs = self(inputs)
loss = criterion(outputs, labels)

loss.backward()
optimizer.step()

running_loss += loss.item()
running_data_num += inputs.shape[0]

print(
f"communication {communication_id}, epoch {i}: client-{self.user_id+1}",
running_loss / running_data_num,
)

class MPIFedAvgClient:

class MPIFedAVGClient(BaseClient):
def __init__(self, comm, client):
self.comm = comm
self.client = client
Expand All @@ -114,20 +143,20 @@ def __call__(self, *args, **kwargs):
return self.client(*args, **kwargs)

def action(self):
self.mpi_upload()
self.upload()
self.client.model.zero_grad()
self.mpi_download()
self.download()

def mpi_upload(self):
self.mpi_upload_gradient()
def upload(self):
self.upload_gradient()

def mpi_upload_gradient(self, destination_id=0):
def upload_gradient(self, destination_id=0):
self.comm.send(
self.client.upload_gradients(), dest=destination_id, tag=GRADIENTS_TAG
)

def mpi_download(self):
def download(self):
self.client.download(self.comm.recv(tag=PARAMETERS_TAG))

def mpi_initialize(self):
self.mpi_download()
self.download()
Loading

0 comments on commit c131f98

Please sign in to comment.