Skip to content

Commit

Permalink
Add demo code for distributed training with ddp
Browse files Browse the repository at this point in the history
  • Loading branch information
YeonwooSung committed Feb 26, 2023
1 parent 564a0a1 commit 9ef0c5f
Showing 1 changed file with 153 additions and 0 deletions.
153 changes: 153 additions & 0 deletions Experiments/CV/src/ddp_mnist_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
from torchvision import datasets, transforms
from transformers import Trainer, TrainingArguments


class BasicNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
self.act = F.relu

def forward(self, x):
x = self.act(self.conv1(x))
x = self.act(self.conv2(x))
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.act(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output


def collate_fn(examples):
pixel_values = torch.stack([example[0] for example in examples])
labels = torch.tensor([example[1] for example in examples])
return {"x":pixel_values, "labels":labels}


class MyTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
outputs = model(inputs["x"])
target = inputs["labels"]
loss = F.nll_loss(outputs, target)
return (loss, outputs) if return_outputs else loss


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = BasicNet().to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-3)


def setup(rank, world_size):
"Sets up the process group and configuration for PyTorch Distributed Data Parallelism"
os.environ["MASTER_ADDR"] = 'localhost'
os.environ["MASTER_PORT"] = "12355"

# Initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
"Cleans up the distributed environment"
dist.destroy_process_group()


def train_ddp(rank, world_size):
setup(rank, world_size)
# Build DataLoaders
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307), (0.3081))
])

train_dset = datasets.MNIST('data', train=True, download=True, transform=transform)
test_dset = datasets.MNIST('data', train=False, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dset, shuffle=True, batch_size=64)
test_loader = torch.utils.data.DataLoader(test_dset, shuffle=False, batch_size=64)

# Build model
model = model.to(rank)
ddp_model = DDP(model, device_ids=[rank])

# Build optimizer
optimizer = optim.AdamW(ddp_model.parameters(), lr=1e-3)

# Train for a single epoch
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()

# Evaluate
model.eval()
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
print(f'Accuracy: {100. * correct / len(test_loader.dataset)}')



def train_ddp_with_trainer():
training_args = TrainingArguments(
"basic-trainer",
per_device_train_batch_size=64,
per_device_eval_batch_size=64,
num_train_epochs=1,
evaluation_strategy="epoch",
remove_unused_columns=False
)

# Build DataLoaders
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307), (0.3081))
])

train_dset = datasets.MNIST('data', train=True, download=True, transform=transform)
test_dset = datasets.MNIST('data', train=False, transform=transform)

trainer = MyTrainer(
model,
training_args,
train_dataset=train_dset,
eval_dataset=test_dset,
data_collator=collate_fn,
)
trainer.train()


if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
target = sys.argv[1]
if target == "ddp":
train_ddp(0, 1)
elif target == "ddp_with_trainer":
train_ddp_with_trainer()
else:
raise ValueError("Unknown target")
else:
train_ddp(0, 1)

0 comments on commit 9ef0c5f

Please sign in to comment.