Skip to content

Commit 04b3a37

Browse files
committed
Streaming DiLoCo prototype
1 parent dafb968 commit 04b3a37

File tree

4 files changed

+309
-1
lines changed

4 files changed

+309
-1
lines changed
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[cli:run]
2+
component=../torchft/torchx.py:hsdp
3+
scheduler=local_cwd
4+
5+
[local_cwd]
6+
auto_set_cuda_visible_devices=True
7+
8+
[component:../torchft/torchx.py:hsdp]
9+
script=train.py
10+
gpu=1

streaming_diloco_prototype/README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
Requirements:
2+
3+
torchx
4+
5+
1. Start lighthouse
6+
7+
`RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000`
8+
9+
2. Start replica groups (see torchft/torchx.py)
10+
11+
`torchx run`

streaming_diloco_prototype/train.py

Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
import os
9+
from datetime import timedelta
10+
import torch
11+
import torch.utils.data
12+
from torch import nn, optim
13+
from torch.distributed import ReduceOp
14+
from torch.distributed.elastic.multiprocessing.errors import record
15+
from torch.distributed.pipelining import SplitPoint, pipeline
16+
17+
from torchft import Manager, ProcessGroupGloo, ProcessGroupNCCL
18+
from torchft.checkpointing.pg_transport import PGTransport
19+
from torchft.local_sgd import DiLoCo
20+
21+
from torchft.collectives import allreduce_quantized
22+
23+
logging.basicConfig(level=logging.INFO)
24+
25+
class DummyDataset(torch.utils.data.Dataset):
26+
def __init__(self, size=10000, feature_dim=128, num_classes=10):
27+
"""
28+
Create a dummy dataset suitable for MLP models.
29+
30+
Args:
31+
size: Number of samples in the dataset
32+
feature_dim: Dimension of the feature vector (should match d_hid in MultiMLP)
33+
num_classes: Number of output classes
34+
"""
35+
self.size = size
36+
self.feature_dim = feature_dim
37+
self.num_classes = num_classes
38+
39+
def __len__(self):
40+
return self.size
41+
42+
def __getitem__(self, idx):
43+
# Generate random feature vector (1D) instead of image (3D)
44+
features = torch.rand(self.feature_dim)
45+
label = torch.randint(0, self.num_classes, (1,)).item()
46+
return features, label
47+
48+
# MLP Layer
49+
class MLPModule(torch.nn.Module):
50+
def __init__(self, d_hid: int):
51+
super().__init__()
52+
self.net1 = torch.nn.Linear(d_hid, d_hid)
53+
self.relu = torch.nn.ReLU()
54+
self.net2 = torch.nn.Linear(d_hid, d_hid)
55+
56+
def forward(self, x):
57+
x = self.net1(x)
58+
x = self.relu(x)
59+
x = self.net2(x)
60+
return x
61+
62+
class MultiMLP(torch.nn.Module):
63+
def __init__(self, d_hid: int, n_layers: int = 2, num_classes: int = 10):
64+
super().__init__()
65+
self.layers = torch.nn.ModuleList([MLPModule(d_hid) for _ in range(n_layers)])
66+
# Add a final classification layer
67+
self.classifier = torch.nn.Linear(d_hid, num_classes)
68+
# For demonstration purposes only, this should be defined by user
69+
self.split_spec = {
70+
f"layers.{i}": SplitPoint.BEGINNING for i in range(1, n_layers)
71+
}
72+
73+
def forward(self, x):
74+
for layer in self.layers:
75+
x = layer(x)
76+
# Apply the classification layer to get logits
77+
x = self.classifier(x)
78+
return x
79+
80+
81+
REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0))
82+
NUM_REPLICA_GROUPS = int(os.environ.get("NUM_REPLICA_GROUPS", 2))
83+
CUDA_VISIBLE_DEVICES = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
84+
85+
print(f"{CUDA_VISIBLE_DEVICES=}, REPLICA_GROUP_ID: {REPLICA_GROUP_ID}")
86+
print(f"{NUM_REPLICA_GROUPS=}")
87+
torch.cuda.set_device(0)
88+
89+
# Get number of classes from the dataset
90+
d_hid = 128 # Feature dimension for the MLP
91+
n_layers = 8 # Number of MLP layers
92+
93+
# Create dummy dataset with random data matching the model's input dimension
94+
dataset_size = 10000
95+
trainset = DummyDataset(size=dataset_size, feature_dim=d_hid)
96+
trainloader = torch.utils.data.DataLoader(
97+
trainset, batch_size=64, num_workers=2, shuffle=True
98+
)
99+
100+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
101+
pg = (
102+
ProcessGroupNCCL(
103+
timeout=timedelta(seconds=30),
104+
)
105+
if torch.cuda.is_available()
106+
else ProcessGroupGloo(timeout=timedelta(seconds=5))
107+
)
108+
print(f"{device=} {pg=}")
109+
110+
transport = PGTransport(
111+
pg,
112+
timeout=timedelta(seconds=10),
113+
device=device,
114+
)
115+
116+
num_classes = trainset.num_classes
117+
m = MultiMLP(d_hid=d_hid, n_layers=n_layers, num_classes=num_classes).to(device)
118+
inner_optimizer = optim.AdamW(
119+
m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
120+
)
121+
outer_optimizer = optim.SGD(m.parameters(), lr=0.7, momentum=0.9, nesterov=True)
122+
criterion = nn.CrossEntropyLoss()
123+
124+
print(m)
125+
num_params = sum(p.numel() for p in m.parameters())
126+
print(f"DiLoCo: Total number of parameters: {num_params}")
127+
128+
@record
129+
def regular_diloco() -> None:
130+
def load_state_dict(state_dict):
131+
m.load_state_dict(state_dict["model"])
132+
m.to(device)
133+
diloco.original_parameters = state_dict["original_params"]
134+
for name in diloco.original_parameters.keys():
135+
diloco.original_parameters[name] = diloco.original_parameters[name].to(
136+
device
137+
)
138+
inner_optimizer.load_state_dict(state_dict["inner_optim"])
139+
outer_optimizer.load_state_dict(state_dict["outer_optim"])
140+
141+
def state_dict():
142+
return {
143+
"model": m.state_dict(),
144+
"original_params": diloco.original_parameters,
145+
"inner_optim": inner_optimizer.state_dict(),
146+
"outer_optim": outer_optimizer.state_dict(),
147+
}
148+
149+
150+
manager = Manager(
151+
pg=pg,
152+
min_replica_size=1,
153+
load_state_dict=load_state_dict,
154+
state_dict=state_dict,
155+
replica_id=f"regular_diloco_{REPLICA_GROUP_ID}",
156+
timeout=timedelta(seconds=30),
157+
checkpoint_transport=transport,
158+
use_async_quorum=False,
159+
)
160+
161+
num_local_steps = 0
162+
sync_every = 100
163+
max_outer_steps = 10
164+
165+
with DiLoCo(
166+
manager,
167+
m,
168+
inner_optimizer,
169+
outer_optimizer,
170+
backup_device=device,
171+
sync_every=sync_every,
172+
) as diloco:
173+
while True:
174+
for i, (inputs, labels) in enumerate(trainloader):
175+
inputs = inputs.to(device)
176+
labels = labels.to(device)
177+
inner_optimizer.zero_grad()
178+
179+
out = m(inputs)
180+
loss = criterion(out, labels)
181+
loss.backward()
182+
183+
inner_optimizer.step()
184+
num_local_steps += 1
185+
186+
if num_local_steps % sync_every == 0:
187+
print(
188+
f"DiLoCo: Number of inner optimizer steps completed: {num_local_steps}"
189+
)
190+
print(
191+
f"DiLoCo: Number of outer optimizer steps completed: {manager.current_step()} loss = {loss.item()}"
192+
)
193+
194+
if manager.current_step() >= max_outer_steps:
195+
exit()
196+
197+
@record
198+
def streaming_diloco() -> None:
199+
def load_state_dict(state_dict):
200+
m.load_state_dict(state_dict["model"])
201+
m.to(device)
202+
inner_optimizer.load_state_dict(state_dict["inner_optim"])
203+
outer_optimizer.load_state_dict(state_dict["outer_optim"])
204+
205+
def state_dict():
206+
return {
207+
"model": m.state_dict(),
208+
"inner_optim": inner_optimizer.state_dict(),
209+
"outer_optim": outer_optimizer.state_dict(),
210+
}
211+
212+
manager = Manager(
213+
pg=pg,
214+
min_replica_size=1,
215+
load_state_dict=load_state_dict,
216+
state_dict=state_dict,
217+
replica_id=f"streaming_diloco_{REPLICA_GROUP_ID}",
218+
timeout=timedelta(seconds=30),
219+
checkpoint_transport=transport,
220+
use_async_quorum=False,
221+
)
222+
223+
# Part 1, more easily specify model partitions using pipeline APIs?
224+
# TODO: how to map partition back to original model
225+
example_input, _ = next(iter(trainloader))
226+
pipe = pipeline(module=m, mb_args=(example_input.to(device),), split_spec=m.split_spec)
227+
module_partitions = [pipe.get_stage_module(idx) for idx in range(n_layers)]
228+
# for module in module_partitions:
229+
# print(f"DiLoCo: {module=}, params: {[p for p in module.parameters()]}")
230+
231+
# Part 2, run DiLoCo as usual
232+
num_local_steps = 0
233+
sync_every = 100
234+
max_outer_steps = 5
235+
236+
for i, (inputs, labels) in enumerate(trainloader):
237+
inputs = inputs.to(device)
238+
labels = labels.to(device)
239+
inner_optimizer.zero_grad()
240+
241+
out = m(inputs)
242+
loss = criterion(out, labels)
243+
loss.backward()
244+
245+
inner_optimizer.step()
246+
num_local_steps += 1
247+
248+
if num_local_steps % sync_every == 0:
249+
print(
250+
f"DiLoCo: Number of inner optimizer steps completed: {num_local_steps}"
251+
)
252+
print(
253+
f"DiLoCo: Number of outer optimizer steps completed: {manager.current_step()} loss = {loss.item()}"
254+
)
255+
manager.start_quorum()
256+
# On sync step, we need to sync the model weights across the manager (we only do part of it)
257+
params_data = []
258+
for p in module_partitions[0].parameters():
259+
tensor = p.data
260+
# TODO: only 2D tensors supported for quantization?
261+
# replica_0/0 File "/data/users/howardhuang/torchft/torchft/quantization.py", line 450, in _prepare_quantize_fp8
262+
# replica_0/0 assert len(inputs[i].shape) == 2, "Only 2D tensors are supported"
263+
# replica_0/0 AssertionError: Only 2D tensors are supported
264+
if tensor.dim() == 1:
265+
# Convert 1D tensors to 2D by adding a dimension
266+
tensor = tensor.unsqueeze(0)
267+
params_data.append(tensor)
268+
# print(f"Transfering {params_data=} tensors")
269+
print(f"param shapes {[(p.shape) for p in params_data]}")
270+
# TODO: error blocking
271+
# replica_1/0 File "/data/users/howardhuang/torchft/torchft/quantization.py", line 531, in fused_quantize_into_fp8
272+
# replica_1/0 _fused_kernel_quantize_into_fp8[grid](
273+
# replica_1/0 File "/home/howardhuang/.conda/envs/torchft/lib/python3.10/site-packages/triton/runtime/jit.py", line 499, in run
274+
# replica_1/0 if key not in self.cache[device]:
275+
# replica_1/0 TypeError: unhashable type: 'constexpr'
276+
fut = allreduce_quantized(params_data, ReduceOp.AVG, pg)
277+
# TODO: add allreduce_quantized as a manager collective option
278+
fut.wait()
279+
print("finished")
280+
281+
if manager.current_step() >= max_outer_steps:
282+
print("exiting")
283+
exit()
284+
285+
if __name__ == "__main__":
286+
# regular_diloco()
287+
streaming_diloco()

torchft/torchx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def hsdp(
1212
*script_args: str,
1313
replicas: int = 2,
1414
workers_per_replica: int = 1,
15-
max_restarts: int = 10,
15+
max_restarts: int = 0,
1616
script: str = "train_ddp.py",
1717
env: Optional[Dict[str, str]] = None,
1818
image: str = "",

0 commit comments

Comments
 (0)