Skip to content

Commit a19a0e3

Browse files
committed
fix infinite recovery
Summary: - we don't increase the max_step when a node is catching up because we don't call should_commit - this can lead the node always being behind and get into an infinite recovery loop - so simply call `should_commit` - note, this can result in the global parameters falling out of sync, the diff includes an RFC on how to fix that
1 parent 77b6330 commit a19a0e3

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

torchft/local_sgd.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,13 +351,39 @@ def perform_sync(self) -> bool:
351351
steps using the outer optimizer.
352352
"""
353353
if len(self._allreduce_futures) == 0:
354-
return True
354+
assert self._fragment_sync_delay > 0
355+
# This can happen when using `fragment_sync_delay`. The node
356+
# might not have participated in syncing of this fragment.
357+
#
358+
# The allreduce for other nodes who did might actually
359+
# succeed and in that case, we shouldn't allow recovery
360+
# from this node.
361+
#
362+
# We do need to increase the `max_step` here so we
363+
# don't end up in an infinite loop of needing to recover.
364+
#
365+
# TODO: We can add a `is_catching_up` flag to the state_dict
366+
# to disallow recoveries from this node. Such nodes can
367+
# be excluded from `max_step` calculation unless all
368+
# nodes are catching up.
369+
return self._manager.should_commit()
355370

356371
self.wait()
357372

358373
# Restore the parameters back to the previous state
359374
self.restore_parameters()
360375

376+
# This can return success even if the allreduce failed. Because
377+
# the process group could have been reconfigured while the
378+
# allreduce was inflight. The inflight allreduce may or may
379+
# not have been aborted.
380+
#
381+
# We consider it successful anyway.
382+
#
383+
# TODO: We can track errors per allreduce to
384+
# let the commit fail here. But this has the downside of
385+
# reconfiguring the pg too many times resulting in
386+
# more aborts and more commit failures.
361387
should_commit = self._manager.should_commit()
362388

363389
if should_commit:
@@ -702,6 +728,16 @@ def _step_post_hook(
702728
# waste after recovery
703729
self._quorum_loop()
704730

731+
# TODO: Since we do quorum after commit, there might be a big gap until
732+
# the next allreduce. This increases the chances of nodes failing
733+
# and so the allreduce to fail.
734+
# - We could maybe do a quorum again right before preparing for a fragment
735+
# using `shring_only`. This might make it tricky for new nodes to join
736+
# though.
737+
# - Maintain a sequence number in the state dict that gets bumped at every
738+
# quorum call. Then we can do a quorum right before allreduce and avoid
739+
# doing quorums after commit.
740+
705741
# We need to set make sure `_local_step` is still
706742
# the same across all replicas if `quorum_id` changed.
707743
#

0 commit comments

Comments
 (0)