Skip to content

Commit fad3f3e

Browse files
committed
fix infinite recovery
Summary: - we don't increase the max_step when a node is catching up because we don't call should_commit - this can lead the node always being behind and get into an infinite recovery loop - note, this can result in the global parameters falling out of sync, the diff includes an RFC on how to fix that if we need to - document another case where `should_commit` can return `True` but it shouldn't because allreduce failed (this is also relvant only to the case when we can have pending inflight allreduce) - make an assert based on the fragment sync schedule to make sure we don't run into this Test Plan: - tested on a cluster of 3 nodes by removing and adding a node - the `max_step` and `local_step` increase in the manager's state dict after both failure and recovery metrics from the healthy node <img width="1103" alt="Screenshot 2025-06-15 at 10 53 28 PM copy" src="https://github.com/user-attachments/assets/8640780c-fd20-4266-aa3c-3116776a9c68" /> metrics from the failed and recovered node <img width="1101" alt="Screenshot 2025-06-15 at 10 56 49 PM copy" src="https://github.com/user-attachments/assets/cc2a1c57-715f-4e0a-8e00-7c62da525dc3" />
1 parent 9241a8b commit fad3f3e

File tree

2 files changed

+46
-3
lines changed

2 files changed

+46
-3
lines changed

torchft/local_sgd.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,14 +356,47 @@ def perform_sync(self) -> bool:
356356
Overrides the sync method to wait for the scheduled allreduce to finish and
357357
steps using the outer optimizer.
358358
"""
359-
if len(self._allreduce_futures) == 0:
360-
return True
359+
# Waiting for an allreduce is currently not supported. Please make
360+
# sure to not do this to avoid running into inconsistencies.
361+
#
362+
# This can happen when using large values of `fragment_sync_delay`.
363+
# The node might not have participated in syncing of this fragment.
364+
#
365+
# The allreduce for other nodes who did might actually
366+
# succeed and in that case, we shouldn't allow recovery
367+
# from this node.
368+
#
369+
# We do need to increase the `max_step` here so we
370+
# don't end up in an infinite loop of needing to recover.
371+
#
372+
# We can add a `is_catching_up` flag to the state_dict
373+
# to disallow recoveries from this node. Such nodes can
374+
# be excluded from `max_step` calculation unless all
375+
# nodes are catching up. This approach makes the replica state
376+
# of global parameters diverge though. So we could add recovery
377+
# for a particular fragment from a peer node as a part of the
378+
# `should_commit` when a node is catching up.
379+
assert len(self._allreduce_futures) > 0
361380

362381
self.wait()
363382

364383
# Restore the parameters back to the previous state
365384
self.restore_parameters()
366385

386+
# For large values of `fragment_sync_delay`, this call can be
387+
# a problem.
388+
#
389+
# This can return success even if the allreduce failed. Because
390+
# the process group could have been reconfigured while the
391+
# allreduce was inflight. The inflight allreduce may or may
392+
# not have been aborted.
393+
#
394+
# We consider it successful anyway.
395+
#
396+
# We can track errors per allreduce to
397+
# let the commit fail here. But this has the downside of
398+
# reconfiguring the pg too many times resulting in
399+
# more aborts and more commit failures.
367400
should_commit = self._manager.should_commit()
368401

369402
if should_commit:
@@ -708,6 +741,16 @@ def _step_post_hook(
708741
# waste after recovery
709742
self._quorum_loop()
710743

744+
# TODO: Since we do quorum after commit, there might be a big gap until
745+
# the next allreduce. This increases the chances of nodes failing
746+
# and so the allreduce to fail.
747+
# - We could maybe do a quorum again right before preparing for a fragment
748+
# using `shrink_only`. This might make it tricky for new nodes to join
749+
# though.
750+
# - Maintain a sequence number in the state dict that gets bumped at every
751+
# quorum call. Then we can do a quorum right before allreduce and avoid
752+
# doing quorums after commit.
753+
711754
# We need to set make sure `_local_step` is still
712755
# the same across all replicas if `quorum_id` changed.
713756
#

train_diloco.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def trace_handler(p):
201201
outer_optimizer,
202202
backup_device=device,
203203
sync_every=20 if USE_STREAMING else 20,
204-
fragment_sync_delay=10 if USE_STREAMING else 0,
204+
fragment_sync_delay=5 if USE_STREAMING else 0,
205205
should_quantize=False,
206206
) as diloco:
207207
while True:

0 commit comments

Comments
 (0)