@@ -351,13 +351,39 @@ def perform_sync(self) -> bool:
351
351
steps using the outer optimizer.
352
352
"""
353
353
if len (self ._allreduce_futures ) == 0 :
354
- return True
354
+ assert self ._fragment_sync_delay > 0
355
+ # This can happen when using `fragment_sync_delay`. The node
356
+ # might not have participated in syncing of this fragment.
357
+ #
358
+ # The allreduce for other nodes who did might actually
359
+ # succeed and in that case, we shouldn't allow recovery
360
+ # from this node.
361
+ #
362
+ # We do need to increase the `max_step` here so we
363
+ # don't end up in an infinite loop of needing to recover.
364
+ #
365
+ # TODO: We can add a `is_catching_up` flag to the state_dict
366
+ # to disallow recoveries from this node. Such nodes can
367
+ # be excluded from `max_step` calculation unless all
368
+ # nodes are catching up.
369
+ return self ._manager .should_commit ()
355
370
356
371
self .wait ()
357
372
358
373
# Restore the parameters back to the previous state
359
374
self .restore_parameters ()
360
375
376
+ # This can return success even if the allreduce failed. Because
377
+ # the process group could have been reconfigured while the
378
+ # allreduce was inflight. The inflight allreduce may or may
379
+ # not have been aborted.
380
+ #
381
+ # We consider it successful anyway.
382
+ #
383
+ # TODO: We can track errors per allreduce to
384
+ # let the commit fail here. But this has the downside of
385
+ # reconfiguring the pg too many times resulting in
386
+ # more aborts and more commit failures.
361
387
should_commit = self ._manager .should_commit ()
362
388
363
389
if should_commit :
@@ -702,6 +728,16 @@ def _step_post_hook(
702
728
# waste after recovery
703
729
self ._quorum_loop ()
704
730
731
+ # TODO: Since we do quorum after commit, there might be a big gap until
732
+ # the next allreduce. This increases the chances of nodes failing
733
+ # and so the allreduce to fail.
734
+ # - We could maybe do a quorum again right before preparing for a fragment
735
+ # using `shring_only`. This might make it tricky for new nodes to join
736
+ # though.
737
+ # - Maintain a sequence number in the state dict that gets bumped at every
738
+ # quorum call. Then we can do a quorum right before allreduce and avoid
739
+ # doing quorums after commit.
740
+
705
741
# We need to set make sure `_local_step` is still
706
742
# the same across all replicas if `quorum_id` changed.
707
743
#
0 commit comments