Skip to content

Commit 4533426

Browse files
committed
fix infinite recovery
Summary: - we don't increase the max_step when a node is catching up because we don't call should_commit - this can lead the node always being behind and get into an infinite recovery loop - so simply call `should_commit` - note, this can result in the global parameters falling out of sync, the diff includes an RFC on how to fix that Test Plan: - tested on a cluster of 3 nodes by removing and adding a node - the `max_step` and `local_step` increase in the manager's state dict after both failure and recovery metrics from the healthy node <img width="1103" alt="Screenshot 2025-06-15 at 10 53 28 PM copy" src="https://github.com/user-attachments/assets/8640780c-fd20-4266-aa3c-3116776a9c68" /> metrics from the failed and recovered node <img width="1101" alt="Screenshot 2025-06-15 at 10 56 49 PM copy" src="https://github.com/user-attachments/assets/cc2a1c57-715f-4e0a-8e00-7c62da525dc3" />
1 parent d607d2d commit 4533426

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

torchft/local_sgd.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,13 +354,39 @@ def perform_sync(self) -> bool:
354354
steps using the outer optimizer.
355355
"""
356356
if len(self._allreduce_futures) == 0:
357-
return True
357+
assert self._fragment_sync_delay > 0
358+
# This can happen when using `fragment_sync_delay`. The node
359+
# might not have participated in syncing of this fragment.
360+
#
361+
# The allreduce for other nodes who did might actually
362+
# succeed and in that case, we shouldn't allow recovery
363+
# from this node.
364+
#
365+
# We do need to increase the `max_step` here so we
366+
# don't end up in an infinite loop of needing to recover.
367+
#
368+
# TODO: We can add a `is_catching_up` flag to the state_dict
369+
# to disallow recoveries from this node. Such nodes can
370+
# be excluded from `max_step` calculation unless all
371+
# nodes are catching up.
372+
return self._manager.should_commit()
358373

359374
self.wait()
360375

361376
# Restore the parameters back to the previous state
362377
self.restore_parameters()
363378

379+
# This can return success even if the allreduce failed. Because
380+
# the process group could have been reconfigured while the
381+
# allreduce was inflight. The inflight allreduce may or may
382+
# not have been aborted.
383+
#
384+
# We consider it successful anyway.
385+
#
386+
# TODO: We can track errors per allreduce to
387+
# let the commit fail here. But this has the downside of
388+
# reconfiguring the pg too many times resulting in
389+
# more aborts and more commit failures.
364390
should_commit = self._manager.should_commit()
365391

366392
if should_commit:
@@ -705,6 +731,16 @@ def _step_post_hook(
705731
# waste after recovery
706732
self._quorum_loop()
707733

734+
# TODO: Since we do quorum after commit, there might be a big gap until
735+
# the next allreduce. This increases the chances of nodes failing
736+
# and so the allreduce to fail.
737+
# - We could maybe do a quorum again right before preparing for a fragment
738+
# using `shring_only`. This might make it tricky for new nodes to join
739+
# though.
740+
# - Maintain a sequence number in the state dict that gets bumped at every
741+
# quorum call. Then we can do a quorum right before allreduce and avoid
742+
# doing quorums after commit.
743+
708744
# We need to set make sure `_local_step` is still
709745
# the same across all replicas if `quorum_id` changed.
710746
#

0 commit comments

Comments
 (0)