Skip to content

Commit dafb968

Browse files
authored
Fix DiLoCo with DTensor (#197)
1 parent c4f0e72 commit dafb968

File tree

3 files changed

+59
-17
lines changed

3 files changed

+59
-17
lines changed

torchft/local_sgd.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,11 @@ def _perform_sync(self) -> None:
132132
# we averaged the local version of the tensor so need to copy it back as a DTensor
133133
param.data.copy_(
134134
DTensor.from_local(
135-
avg_param, param.device_mesh, param.placements
135+
avg_param,
136+
param.device_mesh,
137+
param.placements,
138+
shape=param.shape,
139+
stride=param.stride(),
136140
)
137141
)
138142
else:
@@ -249,7 +253,11 @@ def _restore_parameters(self) -> None:
249253
# we averaged the local version of the tensor so need to copy it back as a DTensor
250254
p.data.copy_(
251255
DTensor.from_local(
252-
self.original_parameters[name], p.device_mesh, p.placements
256+
self.original_parameters[name],
257+
p.device_mesh,
258+
p.placements,
259+
shape=p.shape,
260+
stride=p.stride(),
253261
),
254262
non_blocking=False,
255263
)

torchft/local_sgd_integ_test.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import copy
22
import logging
3+
import os
34
import re
45
import traceback
56
from concurrent.futures import ThreadPoolExecutor, as_completed
@@ -11,8 +12,10 @@
1112
import torch
1213
from parameterized import parameterized
1314
from torch import nn, optim
15+
from torch.distributed.tensor import DTensor, Replicate
1416

1517
from torchft._torchft import LighthouseServer
18+
from torchft.device_mesh import ft_init_device_mesh
1619
from torchft.local_sgd import DiLoCo, LocalSGD
1720
from torchft.manager import Manager
1821
from torchft.manager_integ_test import FailureInjector, MyModel, Runner
@@ -64,6 +67,7 @@ def state_dict() -> Dict[str, Dict[str, object]]:
6467
stack.callback(lambda: manager.shutdown(wait=False))
6568

6669
m: nn.Module = MyModel().to(device)
70+
6771
optimizer: optim.Optimizer = optim.Adam(m.parameters())
6872
criterion = nn.CrossEntropyLoss()
6973

@@ -156,6 +160,29 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
156160
**runner.manager_args,
157161
)
158162
stack.callback(manager.shutdown)
163+
# initialize default group for device mesh to work
164+
if not torch.distributed.is_initialized():
165+
torch.distributed.init_process_group(
166+
init_method=f"tcp://localhost:0",
167+
rank=rank,
168+
world_size=runner.world_size,
169+
)
170+
171+
device_type = device.type
172+
ft_device_mesh = ft_init_device_mesh(
173+
device_type=device_type,
174+
mesh_shape=(runner.world_size, 1),
175+
mesh_dim_names=("replicate", "none"),
176+
replicate_dim=0,
177+
manager=manager,
178+
)
179+
for layer in m.layers:
180+
if isinstance(layer, nn.Linear):
181+
for param in layer.parameters():
182+
param = DTensor.from_local(
183+
param,
184+
device_mesh=ft_device_mesh,
185+
)
159186

160187
criterion = nn.CrossEntropyLoss()
161188
all_state_dicts = {}
@@ -170,13 +197,10 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
170197
while True:
171198
manager_curr_step = manager.current_step()
172199
if manager_curr_step not in all_state_dicts:
173-
print(
174-
f"{manager_curr_step=} {diloco._local_step=} {runner.replica_id=} {state_dict()=}"
175-
)
176200
all_state_dicts[manager_curr_step] = copy.deepcopy(state_dict())
177201
batch_size = 1
178-
inputs = m.get_rand_inputs(batch_size).to(device)
179-
labels = m.get_rand_labels(batch_size).to(device)
202+
inputs = m.get_rand_inputs(batch_size, device=device)
203+
labels = m.get_rand_labels(batch_size, device=device)
180204

181205
out = m(inputs)
182206
loss = criterion(out, labels)

torchft/manager_integ_test.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,29 @@ def __init__(self, in_dim: int = 3, out_dim: int = 4) -> None:
3333
super().__init__()
3434
self.in_dim = in_dim
3535
self.out_dim = out_dim
36-
self.model = nn.Sequential(
37-
nn.Linear(in_dim, out_dim),
38-
nn.Sigmoid(),
36+
self.layers = nn.ModuleList(
37+
[
38+
nn.Linear(in_dim, 8),
39+
nn.ReLU(),
40+
nn.Linear(8, out_dim),
41+
nn.ReLU(),
42+
]
3943
)
4044

4145
def forward(self, x: torch.Tensor) -> torch.Tensor:
42-
return self.model(x)
43-
44-
def get_rand_inputs(self, batch_size: int) -> torch.Tensor:
45-
return torch.rand(batch_size, self.in_dim)
46-
47-
def get_rand_labels(self, batch_size: int) -> torch.Tensor:
48-
return torch.randint(3, (batch_size,))
46+
for layer in self.layers:
47+
x = layer(x)
48+
return x
49+
50+
def get_rand_inputs(
51+
self, batch_size: int, device: torch.device = torch.device("cpu")
52+
) -> torch.Tensor:
53+
return torch.rand(batch_size, self.in_dim, device=device)
54+
55+
def get_rand_labels(
56+
self, batch_size: int, device: torch.device = torch.device("cpu")
57+
) -> torch.Tensor:
58+
return torch.randint(3, (batch_size,), device=device)
4959

5060

5161
class InjectedFailure(Exception):

0 commit comments

Comments
 (0)