23
23
from torch .distributed .elastic .multiprocessing .errors import record
24
24
from torch .distributed .pipelining import SplitPoint , pipeline
25
25
from torch .export import export
26
+ from torch .utils .tensorboard import SummaryWriter
26
27
from torchdata .stateful_dataloader import StatefulDataLoader
27
28
28
29
from torchft import (
41
42
@record
42
43
def main () -> None :
43
44
REPLICA_GROUP_ID = int (os .environ .get ("REPLICA_GROUP_ID" , 0 ))
44
- NUM_REPLICA_GROUPS = int (os .environ .get ("NUM_REPLICA_GROUPS" , 2 ))
45
+ RUN = int (os .environ .get ("RUN" , 0 ))
46
+
47
+ output_folder = f"output/replica-{ REPLICA_GROUP_ID } "
48
+
49
+ writer = SummaryWriter (f"{ output_folder } /tensorboard" , max_queue = 1000 )
45
50
46
51
def load_state_dict (state_dict ):
47
52
m .load_state_dict (state_dict ["model" ])
@@ -174,9 +179,7 @@ def forward(self, x):
174
179
sort_by_keyword = "self_" + device + "_time_total"
175
180
176
181
def trace_handler (p ):
177
- p .export_chrome_trace (
178
- f"/home/tushar00jain/trace_{ p .step_num } _{ REPLICA_GROUP_ID } .json"
179
- )
182
+ p .export_chrome_trace (f"{ output_folder } /profiles/step-{ p .step_num } .json" )
180
183
181
184
# You can use an epoch based training but with faults it's easier to use step
182
185
# based training.
@@ -188,6 +191,7 @@ def trace_handler(p):
188
191
)
189
192
190
193
prof .start ()
194
+ tensorboard_key_prefix = f"Run:{ RUN } "
191
195
with DiLoCo (
192
196
manager ,
193
197
module_partitions if USE_STREAMING else [m ],
@@ -210,16 +214,27 @@ def trace_handler(p):
210
214
out = m (inputs )
211
215
loss = criterion (out , labels )
212
216
217
+ writer .add_scalar (f"{ tensorboard_key_prefix } /loss" , loss , i )
218
+
213
219
loss .backward ()
214
220
215
221
inner_optimizer .step ()
216
222
223
+ writer .add_scalar (
224
+ f"{ tensorboard_key_prefix } /num_participants" ,
225
+ manager .num_participants (),
226
+ i ,
227
+ )
228
+ writer .add_scalar (
229
+ f"{ tensorboard_key_prefix } /current_step" , manager .current_step (), i
230
+ )
217
231
if manager .current_step () % 100 == 0 :
218
232
print (f"[{ manager .current_step ()} ] loss = { loss .item ()} " )
219
233
220
234
if manager .current_step () >= 15 :
221
235
# complete training
222
236
prof .stop ()
237
+ writer .flush ()
223
238
exit ()
224
239
225
240
0 commit comments