24
24
from torch .distributed .pipelining import SplitPoint , pipeline
25
25
from torch .export import export
26
26
from torchdata .stateful_dataloader import StatefulDataLoader
27
+ from torch .utils .tensorboard import SummaryWriter
27
28
28
29
from torchft import (
29
30
DistributedSampler ,
41
42
@record
42
43
def main () -> None :
43
44
REPLICA_GROUP_ID = int (os .environ .get ("REPLICA_GROUP_ID" , 0 ))
44
- NUM_REPLICA_GROUPS = int (os .environ .get ("NUM_REPLICA_GROUPS" , 2 ))
45
+ RUN = int (os .environ .get ("RUN" , 0 ))
46
+
47
+ output_folder = f"output/replica-{ REPLICA_GROUP_ID } "
48
+
49
+ writer = SummaryWriter (f"{ output_folder } /tensorboard" , max_queue = 1000 )
45
50
46
51
def load_state_dict (state_dict ):
47
52
m .load_state_dict (state_dict ["model" ])
@@ -175,7 +180,7 @@ def forward(self, x):
175
180
176
181
def trace_handler (p ):
177
182
p .export_chrome_trace (
178
- f"/home/tushar00jain/trace_ { p .step_num } _ { REPLICA_GROUP_ID } .json"
183
+ f"{ output_folder } /profiles/step- { p .step_num } .json"
179
184
)
180
185
181
186
# You can use an epoch based training but with faults it's easier to use step
@@ -188,6 +193,7 @@ def trace_handler(p):
188
193
)
189
194
190
195
prof .start ()
196
+ tensorboard_key_prefix = f"Run:{ RUN } "
191
197
with DiLoCo (
192
198
manager ,
193
199
module_partitions if USE_STREAMING else [m ],
@@ -210,16 +216,21 @@ def trace_handler(p):
210
216
out = m (inputs )
211
217
loss = criterion (out , labels )
212
218
219
+ writer .add_scalar (f"{ tensorboard_key_prefix } /loss" , loss , i )
220
+
213
221
loss .backward ()
214
222
215
223
inner_optimizer .step ()
216
224
225
+ writer .add_scalar (f"{ tensorboard_key_prefix } /num_participants" , manager .num_participants (), i )
226
+ writer .add_scalar (f"{ tensorboard_key_prefix } /current_step" , manager .current_step (), i )
217
227
if manager .current_step () % 100 == 0 :
218
228
print (f"[{ manager .current_step ()} ] loss = { loss .item ()} " )
219
229
220
230
if manager .current_step () >= 15 :
221
231
# complete training
222
232
prof .stop ()
233
+ writer .flush ()
223
234
exit ()
224
235
225
236
0 commit comments