|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import logging |
| 8 | +import os |
| 9 | +from datetime import timedelta |
| 10 | +import torch |
| 11 | +import torch.utils.data |
| 12 | +from torch import nn, optim |
| 13 | +from torch.distributed import ReduceOp |
| 14 | +from torch.distributed.elastic.multiprocessing.errors import record |
| 15 | +from torch.distributed.pipelining import SplitPoint, pipeline |
| 16 | + |
| 17 | +from torchft import Manager, ProcessGroupGloo, ProcessGroupNCCL |
| 18 | +from torchft.checkpointing.pg_transport import PGTransport |
| 19 | +from torchft.local_sgd import DiLoCo |
| 20 | + |
| 21 | +from torchft.collectives import allreduce_quantized |
| 22 | + |
| 23 | +logging.basicConfig(level=logging.INFO) |
| 24 | + |
| 25 | +class DummyDataset(torch.utils.data.Dataset): |
| 26 | + def __init__(self, size=10000, feature_dim=128, num_classes=10): |
| 27 | + """ |
| 28 | + Create a dummy dataset suitable for MLP models. |
| 29 | + |
| 30 | + Args: |
| 31 | + size: Number of samples in the dataset |
| 32 | + feature_dim: Dimension of the feature vector (should match d_hid in MultiMLP) |
| 33 | + num_classes: Number of output classes |
| 34 | + """ |
| 35 | + self.size = size |
| 36 | + self.feature_dim = feature_dim |
| 37 | + self.num_classes = num_classes |
| 38 | + |
| 39 | + def __len__(self): |
| 40 | + return self.size |
| 41 | + |
| 42 | + def __getitem__(self, idx): |
| 43 | + # Generate random feature vector (1D) instead of image (3D) |
| 44 | + features = torch.rand(self.feature_dim) |
| 45 | + label = torch.randint(0, self.num_classes, (1,)).item() |
| 46 | + return features, label |
| 47 | + |
| 48 | +# MLP Layer |
| 49 | +class MLPModule(torch.nn.Module): |
| 50 | + def __init__(self, d_hid: int): |
| 51 | + super().__init__() |
| 52 | + self.net1 = torch.nn.Linear(d_hid, d_hid) |
| 53 | + self.relu = torch.nn.ReLU() |
| 54 | + self.net2 = torch.nn.Linear(d_hid, d_hid) |
| 55 | + |
| 56 | + def forward(self, x): |
| 57 | + x = self.net1(x) |
| 58 | + x = self.relu(x) |
| 59 | + x = self.net2(x) |
| 60 | + return x |
| 61 | + |
| 62 | +class MultiMLP(torch.nn.Module): |
| 63 | + def __init__(self, d_hid: int, n_layers: int = 2, num_classes: int = 10): |
| 64 | + super().__init__() |
| 65 | + self.layers = torch.nn.ModuleList([MLPModule(d_hid) for _ in range(n_layers)]) |
| 66 | + # Add a final classification layer |
| 67 | + self.classifier = torch.nn.Linear(d_hid, num_classes) |
| 68 | + # For demonstration purposes only, this should be defined by user |
| 69 | + self.split_spec = { |
| 70 | + f"layers.{i}": SplitPoint.BEGINNING for i in range(1, n_layers) |
| 71 | + } |
| 72 | + |
| 73 | + def forward(self, x): |
| 74 | + for layer in self.layers: |
| 75 | + x = layer(x) |
| 76 | + # Apply the classification layer to get logits |
| 77 | + x = self.classifier(x) |
| 78 | + return x |
| 79 | + |
| 80 | + |
| 81 | +REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0)) |
| 82 | +NUM_REPLICA_GROUPS = int(os.environ.get("NUM_REPLICA_GROUPS", 2)) |
| 83 | +CUDA_VISIBLE_DEVICES = os.environ.get("CUDA_VISIBLE_DEVICES", "0") |
| 84 | + |
| 85 | +print(f"{CUDA_VISIBLE_DEVICES=}, REPLICA_GROUP_ID: {REPLICA_GROUP_ID}") |
| 86 | +print(f"{NUM_REPLICA_GROUPS=}") |
| 87 | +torch.cuda.set_device(0) |
| 88 | + |
| 89 | +# Get number of classes from the dataset |
| 90 | +d_hid = 128 # Feature dimension for the MLP |
| 91 | +n_layers = 8 # Number of MLP layers |
| 92 | + |
| 93 | +# Create dummy dataset with random data matching the model's input dimension |
| 94 | +dataset_size = 10000 |
| 95 | +trainset = DummyDataset(size=dataset_size, feature_dim=d_hid) |
| 96 | +trainloader = torch.utils.data.DataLoader( |
| 97 | + trainset, batch_size=64, num_workers=2, shuffle=True |
| 98 | +) |
| 99 | + |
| 100 | +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| 101 | +pg = ( |
| 102 | + ProcessGroupNCCL( |
| 103 | + timeout=timedelta(seconds=30), |
| 104 | + ) |
| 105 | + if torch.cuda.is_available() |
| 106 | + else ProcessGroupGloo(timeout=timedelta(seconds=5)) |
| 107 | +) |
| 108 | +print(f"{device=} {pg=}") |
| 109 | + |
| 110 | +transport = PGTransport( |
| 111 | + pg, |
| 112 | + timeout=timedelta(seconds=10), |
| 113 | + device=device, |
| 114 | +) |
| 115 | + |
| 116 | +num_classes = trainset.num_classes |
| 117 | +m = MultiMLP(d_hid=d_hid, n_layers=n_layers, num_classes=num_classes).to(device) |
| 118 | +inner_optimizer = optim.AdamW( |
| 119 | + m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95) |
| 120 | +) |
| 121 | +outer_optimizer = optim.SGD(m.parameters(), lr=0.7, momentum=0.9, nesterov=True) |
| 122 | +criterion = nn.CrossEntropyLoss() |
| 123 | + |
| 124 | +print(m) |
| 125 | +num_params = sum(p.numel() for p in m.parameters()) |
| 126 | +print(f"DiLoCo: Total number of parameters: {num_params}") |
| 127 | + |
| 128 | +@record |
| 129 | +def regular_diloco() -> None: |
| 130 | + def load_state_dict(state_dict): |
| 131 | + m.load_state_dict(state_dict["model"]) |
| 132 | + m.to(device) |
| 133 | + diloco.original_parameters = state_dict["original_params"] |
| 134 | + for name in diloco.original_parameters.keys(): |
| 135 | + diloco.original_parameters[name] = diloco.original_parameters[name].to( |
| 136 | + device |
| 137 | + ) |
| 138 | + inner_optimizer.load_state_dict(state_dict["inner_optim"]) |
| 139 | + outer_optimizer.load_state_dict(state_dict["outer_optim"]) |
| 140 | + |
| 141 | + def state_dict(): |
| 142 | + return { |
| 143 | + "model": m.state_dict(), |
| 144 | + "original_params": diloco.original_parameters, |
| 145 | + "inner_optim": inner_optimizer.state_dict(), |
| 146 | + "outer_optim": outer_optimizer.state_dict(), |
| 147 | + } |
| 148 | + |
| 149 | + |
| 150 | + manager = Manager( |
| 151 | + pg=pg, |
| 152 | + min_replica_size=1, |
| 153 | + load_state_dict=load_state_dict, |
| 154 | + state_dict=state_dict, |
| 155 | + replica_id=f"regular_diloco_{REPLICA_GROUP_ID}", |
| 156 | + timeout=timedelta(seconds=30), |
| 157 | + checkpoint_transport=transport, |
| 158 | + use_async_quorum=False, |
| 159 | + ) |
| 160 | + |
| 161 | + num_local_steps = 0 |
| 162 | + sync_every = 100 |
| 163 | + max_outer_steps = 10 |
| 164 | + |
| 165 | + with DiLoCo( |
| 166 | + manager, |
| 167 | + m, |
| 168 | + inner_optimizer, |
| 169 | + outer_optimizer, |
| 170 | + backup_device=device, |
| 171 | + sync_every=sync_every, |
| 172 | + ) as diloco: |
| 173 | + while True: |
| 174 | + for i, (inputs, labels) in enumerate(trainloader): |
| 175 | + inputs = inputs.to(device) |
| 176 | + labels = labels.to(device) |
| 177 | + inner_optimizer.zero_grad() |
| 178 | + |
| 179 | + out = m(inputs) |
| 180 | + loss = criterion(out, labels) |
| 181 | + loss.backward() |
| 182 | + |
| 183 | + inner_optimizer.step() |
| 184 | + num_local_steps += 1 |
| 185 | + |
| 186 | + if num_local_steps % sync_every == 0: |
| 187 | + print( |
| 188 | + f"DiLoCo: Number of inner optimizer steps completed: {num_local_steps}" |
| 189 | + ) |
| 190 | + print( |
| 191 | + f"DiLoCo: Number of outer optimizer steps completed: {manager.current_step()} loss = {loss.item()}" |
| 192 | + ) |
| 193 | + |
| 194 | + if manager.current_step() >= max_outer_steps: |
| 195 | + exit() |
| 196 | + |
| 197 | +@record |
| 198 | +def streaming_diloco() -> None: |
| 199 | + def load_state_dict(state_dict): |
| 200 | + m.load_state_dict(state_dict["model"]) |
| 201 | + m.to(device) |
| 202 | + inner_optimizer.load_state_dict(state_dict["inner_optim"]) |
| 203 | + outer_optimizer.load_state_dict(state_dict["outer_optim"]) |
| 204 | + |
| 205 | + def state_dict(): |
| 206 | + return { |
| 207 | + "model": m.state_dict(), |
| 208 | + "inner_optim": inner_optimizer.state_dict(), |
| 209 | + "outer_optim": outer_optimizer.state_dict(), |
| 210 | + } |
| 211 | + |
| 212 | + manager = Manager( |
| 213 | + pg=pg, |
| 214 | + min_replica_size=1, |
| 215 | + load_state_dict=load_state_dict, |
| 216 | + state_dict=state_dict, |
| 217 | + replica_id=f"streaming_diloco_{REPLICA_GROUP_ID}", |
| 218 | + timeout=timedelta(seconds=30), |
| 219 | + checkpoint_transport=transport, |
| 220 | + use_async_quorum=False, |
| 221 | + ) |
| 222 | + |
| 223 | + # Part 1, more easily specify model partitions using pipeline APIs? |
| 224 | + # TODO: how to map partition back to original model |
| 225 | + example_input, _ = next(iter(trainloader)) |
| 226 | + pipe = pipeline(module=m, mb_args=(example_input.to(device),), split_spec=m.split_spec) |
| 227 | + module_partitions = [pipe.get_stage_module(idx) for idx in range(n_layers)] |
| 228 | + # for module in module_partitions: |
| 229 | + # print(f"DiLoCo: {module=}, params: {[p for p in module.parameters()]}") |
| 230 | + |
| 231 | + # Part 2, run DiLoCo as usual |
| 232 | + num_local_steps = 0 |
| 233 | + sync_every = 100 |
| 234 | + max_outer_steps = 5 |
| 235 | + |
| 236 | + for i, (inputs, labels) in enumerate(trainloader): |
| 237 | + inputs = inputs.to(device) |
| 238 | + labels = labels.to(device) |
| 239 | + inner_optimizer.zero_grad() |
| 240 | + |
| 241 | + out = m(inputs) |
| 242 | + loss = criterion(out, labels) |
| 243 | + loss.backward() |
| 244 | + |
| 245 | + inner_optimizer.step() |
| 246 | + num_local_steps += 1 |
| 247 | + |
| 248 | + if num_local_steps % sync_every == 0: |
| 249 | + print( |
| 250 | + f"DiLoCo: Number of inner optimizer steps completed: {num_local_steps}" |
| 251 | + ) |
| 252 | + print( |
| 253 | + f"DiLoCo: Number of outer optimizer steps completed: {manager.current_step()} loss = {loss.item()}" |
| 254 | + ) |
| 255 | + manager.start_quorum() |
| 256 | + # On sync step, we need to sync the model weights across the manager (we only do part of it) |
| 257 | + params_data = [] |
| 258 | + for p in module_partitions[0].parameters(): |
| 259 | + tensor = p.data |
| 260 | + # TODO: only 2D tensors supported for quantization? |
| 261 | + # replica_0/0 File "/data/users/howardhuang/torchft/torchft/quantization.py", line 450, in _prepare_quantize_fp8 |
| 262 | + # replica_0/0 assert len(inputs[i].shape) == 2, "Only 2D tensors are supported" |
| 263 | + # replica_0/0 AssertionError: Only 2D tensors are supported |
| 264 | + if tensor.dim() == 1: |
| 265 | + # Convert 1D tensors to 2D by adding a dimension |
| 266 | + tensor = tensor.unsqueeze(0) |
| 267 | + params_data.append(tensor) |
| 268 | + # print(f"Transfering {params_data=} tensors") |
| 269 | + print(f"param shapes {[(p.shape) for p in params_data]}") |
| 270 | + # TODO: error blocking |
| 271 | + # replica_1/0 File "/data/users/howardhuang/torchft/torchft/quantization.py", line 531, in fused_quantize_into_fp8 |
| 272 | + # replica_1/0 _fused_kernel_quantize_into_fp8[grid]( |
| 273 | + # replica_1/0 File "/home/howardhuang/.conda/envs/torchft/lib/python3.10/site-packages/triton/runtime/jit.py", line 499, in run |
| 274 | + # replica_1/0 if key not in self.cache[device]: |
| 275 | + # replica_1/0 TypeError: unhashable type: 'constexpr' |
| 276 | + fut = allreduce_quantized(params_data, ReduceOp.AVG, pg) |
| 277 | + # TODO: add allreduce_quantized as a manager collective option |
| 278 | + fut.wait() |
| 279 | + print("finished") |
| 280 | + |
| 281 | + if manager.current_step() >= max_outer_steps: |
| 282 | + print("exiting") |
| 283 | + exit() |
| 284 | + |
| 285 | +if __name__ == "__main__": |
| 286 | + # regular_diloco() |
| 287 | + streaming_diloco() |
0 commit comments