-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add demo code for distributed training with ddp
- Loading branch information
1 parent
564a0a1
commit 9ef0c5f
Showing
1 changed file
with
153 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import os | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.distributed as dist | ||
from torch.nn.parallel import DistributedDataParallel as DDP | ||
import torch.optim as optim | ||
from torchvision import datasets, transforms | ||
from transformers import Trainer, TrainingArguments | ||
|
||
|
||
class BasicNet(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.conv1 = nn.Conv2d(1, 32, 3, 1) | ||
self.conv2 = nn.Conv2d(32, 64, 3, 1) | ||
self.dropout1 = nn.Dropout(0.25) | ||
self.dropout2 = nn.Dropout(0.5) | ||
self.fc1 = nn.Linear(9216, 128) | ||
self.fc2 = nn.Linear(128, 10) | ||
self.act = F.relu | ||
|
||
def forward(self, x): | ||
x = self.act(self.conv1(x)) | ||
x = self.act(self.conv2(x)) | ||
x = F.max_pool2d(x, 2) | ||
x = self.dropout1(x) | ||
x = torch.flatten(x, 1) | ||
x = self.act(self.fc1(x)) | ||
x = self.dropout2(x) | ||
x = self.fc2(x) | ||
output = F.log_softmax(x, dim=1) | ||
return output | ||
|
||
|
||
def collate_fn(examples): | ||
pixel_values = torch.stack([example[0] for example in examples]) | ||
labels = torch.tensor([example[1] for example in examples]) | ||
return {"x":pixel_values, "labels":labels} | ||
|
||
|
||
class MyTrainer(Trainer): | ||
def compute_loss(self, model, inputs, return_outputs=False): | ||
outputs = model(inputs["x"]) | ||
target = inputs["labels"] | ||
loss = F.nll_loss(outputs, target) | ||
return (loss, outputs) if return_outputs else loss | ||
|
||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
model = BasicNet().to(device) | ||
optimizer = optim.AdamW(model.parameters(), lr=1e-3) | ||
|
||
|
||
def setup(rank, world_size): | ||
"Sets up the process group and configuration for PyTorch Distributed Data Parallelism" | ||
os.environ["MASTER_ADDR"] = 'localhost' | ||
os.environ["MASTER_PORT"] = "12355" | ||
|
||
# Initialize the process group | ||
dist.init_process_group("gloo", rank=rank, world_size=world_size) | ||
|
||
def cleanup(): | ||
"Cleans up the distributed environment" | ||
dist.destroy_process_group() | ||
|
||
|
||
def train_ddp(rank, world_size): | ||
setup(rank, world_size) | ||
# Build DataLoaders | ||
transform = transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.1307), (0.3081)) | ||
]) | ||
|
||
train_dset = datasets.MNIST('data', train=True, download=True, transform=transform) | ||
test_dset = datasets.MNIST('data', train=False, transform=transform) | ||
|
||
train_loader = torch.utils.data.DataLoader(train_dset, shuffle=True, batch_size=64) | ||
test_loader = torch.utils.data.DataLoader(test_dset, shuffle=False, batch_size=64) | ||
|
||
# Build model | ||
model = model.to(rank) | ||
ddp_model = DDP(model, device_ids=[rank]) | ||
|
||
# Build optimizer | ||
optimizer = optim.AdamW(ddp_model.parameters(), lr=1e-3) | ||
|
||
# Train for a single epoch | ||
model.train() | ||
for batch_idx, (data, target) in enumerate(train_loader): | ||
data, target = data.to(device), target.to(device) | ||
output = model(data) | ||
loss = F.nll_loss(output, target) | ||
loss.backward() | ||
optimizer.step() | ||
optimizer.zero_grad() | ||
|
||
# Evaluate | ||
model.eval() | ||
correct = 0 | ||
with torch.no_grad(): | ||
for data, target in test_loader: | ||
data, target = data.to(device), target.to(device) | ||
output = model(data) | ||
pred = output.argmax(dim=1, keepdim=True) | ||
correct += pred.eq(target.view_as(pred)).sum().item() | ||
print(f'Accuracy: {100. * correct / len(test_loader.dataset)}') | ||
|
||
|
||
|
||
def train_ddp_with_trainer(): | ||
training_args = TrainingArguments( | ||
"basic-trainer", | ||
per_device_train_batch_size=64, | ||
per_device_eval_batch_size=64, | ||
num_train_epochs=1, | ||
evaluation_strategy="epoch", | ||
remove_unused_columns=False | ||
) | ||
|
||
# Build DataLoaders | ||
transform = transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.1307), (0.3081)) | ||
]) | ||
|
||
train_dset = datasets.MNIST('data', train=True, download=True, transform=transform) | ||
test_dset = datasets.MNIST('data', train=False, transform=transform) | ||
|
||
trainer = MyTrainer( | ||
model, | ||
training_args, | ||
train_dataset=train_dset, | ||
eval_dataset=test_dset, | ||
data_collator=collate_fn, | ||
) | ||
trainer.train() | ||
|
||
|
||
if __name__ == "__main__": | ||
import sys | ||
if len(sys.argv) > 1: | ||
target = sys.argv[1] | ||
if target == "ddp": | ||
train_ddp(0, 1) | ||
elif target == "ddp_with_trainer": | ||
train_ddp_with_trainer() | ||
else: | ||
raise ValueError("Unknown target") | ||
else: | ||
train_ddp(0, 1) |