Skip to content

Commit a86ed0d

Browse files
committed
add tensorboard to training script
Summary: - add tensorboard integration and separate the metrics by run id and replica id - have an output folder per replica id
1 parent 7898bfd commit a86ed0d

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

train_diloco.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from torch.distributed.elastic.multiprocessing.errors import record
2424
from torch.distributed.pipelining import SplitPoint, pipeline
2525
from torch.export import export
26+
from torch.utils.tensorboard import SummaryWriter
2627
from torchdata.stateful_dataloader import StatefulDataLoader
2728

2829
from torchft import (
@@ -41,7 +42,11 @@
4142
@record
4243
def main() -> None:
4344
REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0))
44-
NUM_REPLICA_GROUPS = int(os.environ.get("NUM_REPLICA_GROUPS", 2))
45+
RUN = int(os.environ.get("RUN", 0))
46+
47+
output_folder = f"output/replica-{REPLICA_GROUP_ID}"
48+
49+
writer = SummaryWriter(f"{output_folder}/tensorboard", max_queue=1000)
4550

4651
def load_state_dict(state_dict):
4752
m.load_state_dict(state_dict["model"])
@@ -174,9 +179,7 @@ def forward(self, x):
174179
sort_by_keyword = "self_" + device + "_time_total"
175180

176181
def trace_handler(p):
177-
p.export_chrome_trace(
178-
f"/home/tushar00jain/trace_{p.step_num}_{REPLICA_GROUP_ID}.json"
179-
)
182+
p.export_chrome_trace(f"{output_folder}/profiles/step-{p.step_num}.json")
180183

181184
# You can use an epoch based training but with faults it's easier to use step
182185
# based training.
@@ -188,6 +191,7 @@ def trace_handler(p):
188191
)
189192

190193
prof.start()
194+
tensorboard_key_prefix = f"Run:{RUN}"
191195
with DiLoCo(
192196
manager,
193197
module_partitions if USE_STREAMING else [m],
@@ -210,16 +214,27 @@ def trace_handler(p):
210214
out = m(inputs)
211215
loss = criterion(out, labels)
212216

217+
writer.add_scalar(f"{tensorboard_key_prefix}/loss", loss, i)
218+
213219
loss.backward()
214220

215221
inner_optimizer.step()
216222

223+
writer.add_scalar(
224+
f"{tensorboard_key_prefix}/num_participants",
225+
manager.num_participants(),
226+
i,
227+
)
228+
writer.add_scalar(
229+
f"{tensorboard_key_prefix}/current_step", manager.current_step(), i
230+
)
217231
if manager.current_step() % 100 == 0:
218232
print(f"[{manager.current_step()}] loss = {loss.item()}")
219233

220234
if manager.current_step() >= 15:
221235
# complete training
222236
prof.stop()
237+
writer.flush()
223238
exit()
224239

225240

0 commit comments

Comments
 (0)