1
1
import copy
2
2
import logging
3
+ import os
3
4
import re
4
5
import traceback
5
6
from concurrent .futures import ThreadPoolExecutor , as_completed
11
12
import torch
12
13
from parameterized import parameterized
13
14
from torch import nn , optim
15
+ from torch .distributed .tensor import DTensor , Replicate
14
16
15
17
from torchft ._torchft import LighthouseServer
18
+ from torchft .device_mesh import ft_init_device_mesh
16
19
from torchft .local_sgd import DiLoCo , LocalSGD
17
20
from torchft .manager import Manager
18
21
from torchft .manager_integ_test import FailureInjector , MyModel , Runner
@@ -64,6 +67,7 @@ def state_dict() -> Dict[str, Dict[str, object]]:
64
67
stack .callback (lambda : manager .shutdown (wait = False ))
65
68
66
69
m : nn .Module = MyModel ().to (device )
70
+
67
71
optimizer : optim .Optimizer = optim .Adam (m .parameters ())
68
72
criterion = nn .CrossEntropyLoss ()
69
73
@@ -156,6 +160,29 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
156
160
** runner .manager_args ,
157
161
)
158
162
stack .callback (manager .shutdown )
163
+ # initialize default group for device mesh to work
164
+ if not torch .distributed .is_initialized ():
165
+ torch .distributed .init_process_group (
166
+ init_method = f"tcp://localhost:0" ,
167
+ rank = rank ,
168
+ world_size = runner .world_size ,
169
+ )
170
+
171
+ device_type = device .type
172
+ ft_device_mesh = ft_init_device_mesh (
173
+ device_type = device_type ,
174
+ mesh_shape = (runner .world_size , 1 ),
175
+ mesh_dim_names = ("replicate" , "none" ),
176
+ replicate_dim = 0 ,
177
+ manager = manager ,
178
+ )
179
+ for layer in m .layers :
180
+ if isinstance (layer , nn .Linear ):
181
+ for param in layer .parameters ():
182
+ param = DTensor .from_local (
183
+ param ,
184
+ device_mesh = ft_device_mesh ,
185
+ )
159
186
160
187
criterion = nn .CrossEntropyLoss ()
161
188
all_state_dicts = {}
@@ -170,13 +197,10 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
170
197
while True :
171
198
manager_curr_step = manager .current_step ()
172
199
if manager_curr_step not in all_state_dicts :
173
- print (
174
- f"{ manager_curr_step = } { diloco ._local_step = } { runner .replica_id = } { state_dict ()= } "
175
- )
176
200
all_state_dicts [manager_curr_step ] = copy .deepcopy (state_dict ())
177
201
batch_size = 1
178
- inputs = m .get_rand_inputs (batch_size ). to ( device )
179
- labels = m .get_rand_labels (batch_size ). to ( device )
202
+ inputs = m .get_rand_inputs (batch_size , device = device )
203
+ labels = m .get_rand_labels (batch_size , device = device )
180
204
181
205
out = m (inputs )
182
206
loss = criterion (out , labels )
0 commit comments