diff --git a/Experiments/CV/src/ddp_mnist_example.py b/Experiments/CV/src/ddp_mnist_example.py new file mode 100644 index 0000000..d5d1c3e --- /dev/null +++ b/Experiments/CV/src/ddp_mnist_example.py @@ -0,0 +1,153 @@ +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.optim as optim +from torchvision import datasets, transforms +from transformers import Trainer, TrainingArguments + + +class BasicNet(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + self.act = F.relu + + def forward(self, x): + x = self.act(self.conv1(x)) + x = self.act(self.conv2(x)) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.act(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + +def collate_fn(examples): + pixel_values = torch.stack([example[0] for example in examples]) + labels = torch.tensor([example[1] for example in examples]) + return {"x":pixel_values, "labels":labels} + + +class MyTrainer(Trainer): + def compute_loss(self, model, inputs, return_outputs=False): + outputs = model(inputs["x"]) + target = inputs["labels"] + loss = F.nll_loss(outputs, target) + return (loss, outputs) if return_outputs else loss + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +model = BasicNet().to(device) +optimizer = optim.AdamW(model.parameters(), lr=1e-3) + + +def setup(rank, world_size): + "Sets up the process group and configuration for PyTorch Distributed Data Parallelism" + os.environ["MASTER_ADDR"] = 'localhost' + os.environ["MASTER_PORT"] = "12355" + + # Initialize the process group + dist.init_process_group("gloo", rank=rank, world_size=world_size) + +def cleanup(): + "Cleans up the distributed environment" + dist.destroy_process_group() + + +def train_ddp(rank, world_size): + setup(rank, world_size) + # Build DataLoaders + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307), (0.3081)) + ]) + + train_dset = datasets.MNIST('data', train=True, download=True, transform=transform) + test_dset = datasets.MNIST('data', train=False, transform=transform) + + train_loader = torch.utils.data.DataLoader(train_dset, shuffle=True, batch_size=64) + test_loader = torch.utils.data.DataLoader(test_dset, shuffle=False, batch_size=64) + + # Build model + model = model.to(rank) + ddp_model = DDP(model, device_ids=[rank]) + + # Build optimizer + optimizer = optim.AdamW(ddp_model.parameters(), lr=1e-3) + + # Train for a single epoch + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + optimizer.zero_grad() + + # Evaluate + model.eval() + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + print(f'Accuracy: {100. * correct / len(test_loader.dataset)}') + + + +def train_ddp_with_trainer(): + training_args = TrainingArguments( + "basic-trainer", + per_device_train_batch_size=64, + per_device_eval_batch_size=64, + num_train_epochs=1, + evaluation_strategy="epoch", + remove_unused_columns=False + ) + + # Build DataLoaders + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307), (0.3081)) + ]) + + train_dset = datasets.MNIST('data', train=True, download=True, transform=transform) + test_dset = datasets.MNIST('data', train=False, transform=transform) + + trainer = MyTrainer( + model, + training_args, + train_dataset=train_dset, + eval_dataset=test_dset, + data_collator=collate_fn, + ) + trainer.train() + + +if __name__ == "__main__": + import sys + if len(sys.argv) > 1: + target = sys.argv[1] + if target == "ddp": + train_ddp(0, 1) + elif target == "ddp_with_trainer": + train_ddp_with_trainer() + else: + raise ValueError("Unknown target") + else: + train_ddp(0, 1)