References:
- PyTorch gradient checkpointing - API reference
- PyTorch native ZeRO - FullyShardedDataParallel
- GPipe (one good implementation of pipelining) - arxiv
- Megatron-LM - one honking great implementation of large-scale training for transformers - repo
- DeepSpeed (a library of many tricks) - repo
- Alpa (automated parallelism in Jax - https://github.com/alpa-projects/alpa
- ICML'22 tutorial: https://sites.google.com/view/icml-2022-big-model
- FairScale - sharded DDP and pipeline from Meta - repo
tensor_parallel
- automated tensor parallelism in PyTorch
During the in-class practice, we also had several PyTorch code examples that could come in handy when training large models:
Automatic tensor parallelism:
%pip install tensor_parallel
import tensor_parallel as tp
model = create_a_regular_pytorch_model()
model = tp.tensor_parallel(model, ['cuda:0', 'cuda:1'])
outputs_as_usual = model(input_as_usual)
Gradient checkpointing:
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
class Checkpoint(nn.Sequential):
def forward(self, *inputs):
return checkpoint(super().forward, *inputs)
class Echo(nn.Module):
def __init__(self, msg: str):
super().__init__()
self.msg = msg # print this message during forward (for debugging)
def forward(self, x):
print("forward", self.msg)
return x
model = nn.Sequential(
Checkpoint(nn.Linear(1000, 1000), nn.ReLU(), Echo("layer1 done"),
nn.Linear(1000, 1000), nn.ReLU(), Echo("layer2 done")),
Checkpoint(nn.Linear(1000, 1000), nn.ReLU(), Echo("layer3 done"),
nn.Linear(1000, 1000), nn.ReLU(), Echo("layer4 done")),
nn.Linear(1000, 1000), nn.ReLU(), Echo("layer5 done"),
)
inputs = torch.randn(16, 1000, requires_grad=True)
# note: we must set inptus requires_grad=True because checkpoints require at least one input with grad for backprop
outputs = model(inputs)
outputs.norm().backward() # Echo layers will print in the following order: 1 2 3 4 5 3 4 1 2