Skip to content

Commit 078b6c0

Browse files
committed
add tensorboard to training script
Summary: - add tensorboard integration and separate the metrics by run id and replica id - have an output folder per replica id
1 parent 7898bfd commit 078b6c0

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

train_diloco.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from torch.distributed.pipelining import SplitPoint, pipeline
2525
from torch.export import export
2626
from torchdata.stateful_dataloader import StatefulDataLoader
27+
from torch.utils.tensorboard import SummaryWriter
2728

2829
from torchft import (
2930
DistributedSampler,
@@ -41,7 +42,11 @@
4142
@record
4243
def main() -> None:
4344
REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0))
44-
NUM_REPLICA_GROUPS = int(os.environ.get("NUM_REPLICA_GROUPS", 2))
45+
RUN = int(os.environ.get("RUN", 0))
46+
47+
output_folder = f"output/replica-{REPLICA_GROUP_ID}"
48+
49+
writer = SummaryWriter(f"{output_folder}/tensorboard", max_queue=1000)
4550

4651
def load_state_dict(state_dict):
4752
m.load_state_dict(state_dict["model"])
@@ -175,7 +180,7 @@ def forward(self, x):
175180

176181
def trace_handler(p):
177182
p.export_chrome_trace(
178-
f"/home/tushar00jain/trace_{p.step_num}_{REPLICA_GROUP_ID}.json"
183+
f"{output_folder}/profiles/step-{p.step_num}.json"
179184
)
180185

181186
# You can use an epoch based training but with faults it's easier to use step
@@ -188,6 +193,7 @@ def trace_handler(p):
188193
)
189194

190195
prof.start()
196+
tensorboard_key_prefix = f"Run:{RUN}"
191197
with DiLoCo(
192198
manager,
193199
module_partitions if USE_STREAMING else [m],
@@ -210,16 +216,21 @@ def trace_handler(p):
210216
out = m(inputs)
211217
loss = criterion(out, labels)
212218

219+
writer.add_scalar(f"{tensorboard_key_prefix}/loss", loss, i)
220+
213221
loss.backward()
214222

215223
inner_optimizer.step()
216224

225+
writer.add_scalar(f"{tensorboard_key_prefix}/num_participants", manager.num_participants(), i)
226+
writer.add_scalar(f"{tensorboard_key_prefix}/current_step", manager.current_step(), i)
217227
if manager.current_step() % 100 == 0:
218228
print(f"[{manager.current_step()}] loss = {loss.item()}")
219229

220230
if manager.current_step() >= 15:
221231
# complete training
222232
prof.stop()
233+
writer.flush()
223234
exit()
224235

225236

0 commit comments

Comments
 (0)