Skip to content

Commit 9e4bc3c

Browse files
authored
manager: gracefully handle errors from configure+checkpoint (#182)
1 parent 0a5bc89 commit 9e4bc3c

File tree

2 files changed

+150
-63
lines changed

2 files changed

+150
-63
lines changed

torchft/manager.py

Lines changed: 76 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -508,12 +508,16 @@ def _async_quorum(
508508

509509
self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}")
510510
# We use the replica rank and world as we want all replicas in the PG.
511-
# TODO: handle configure errors
512-
with torch.profiler.record_function("torchft::manager::_pg.configure"):
513-
self._pg.configure(
514-
store_prefixed_addr, replica_rank, replica_world_size
515-
)
516-
self._quorum_id = quorum_id
511+
try:
512+
with torch.profiler.record_function("torchft::manager::_pg.configure"):
513+
self._pg.configure(
514+
store_prefixed_addr, replica_rank, replica_world_size
515+
)
516+
self._quorum_id = quorum_id
517+
except Exception as e:
518+
self._logger.exception(f"got exception in pg configure: {e}")
519+
self.report_error(e)
520+
return
517521

518522
if allow_heal:
519523
# run recovery on the recovery stream if available
@@ -523,62 +527,67 @@ def _async_quorum(
523527
if recovery_stream is not None
524528
else nullcontext()
525529
):
526-
if quorum.recover_dst_ranks:
527-
self._logger.info(
528-
f"peers need recovery from us {quorum.recover_dst_ranks}"
529-
)
530-
with torch.profiler.record_function(
531-
"torchft::manager::_checkpoint_transport::send_checkpoint"
532-
):
533-
self._checkpoint_transport.send_checkpoint(
534-
dst_ranks=quorum.recover_dst_ranks,
535-
step=max_step,
536-
state_dict=self._manager_state_dict(),
537-
timeout=self._timeout,
530+
try:
531+
if quorum.recover_dst_ranks:
532+
self._logger.info(
533+
f"peers need recovery from us {quorum.recover_dst_ranks}"
538534
)
539-
540-
# See manager.rs for healing conditions
541-
if heal:
542-
self._healing = True
543-
self._logger.info(
544-
f"healing required, fetching checkpoint metadata from {recover_src_manager_address=} {max_step=}"
545-
)
546-
primary_client = ManagerClient(
547-
recover_src_manager_address,
548-
connect_timeout=self._connect_timeout,
549-
)
550-
checkpoint_metadata = primary_client._checkpoint_metadata(
551-
self._rank, timeout=self._timeout
552-
)
553-
recover_src_rank = quorum.recover_src_rank
554-
assert (
555-
recover_src_rank is not None
556-
), "must have a recover rank when healing"
557-
558-
self._logger.info(
559-
f"fetching checkpoint from {recover_src_rank=} with {checkpoint_metadata=}"
560-
)
561-
562-
# we apply the user state dict only when safe from the main thread
563-
# save it for now
564-
with torch.profiler.record_function(
565-
"torchft::manager::_checkpoint_transport::recv_checkpoint"
566-
):
567-
self._pending_state_dict = (
568-
self._checkpoint_transport.recv_checkpoint(
569-
src_rank=recover_src_rank,
570-
metadata=checkpoint_metadata,
535+
with torch.profiler.record_function(
536+
"torchft::manager::_checkpoint_transport::send_checkpoint"
537+
):
538+
self._checkpoint_transport.send_checkpoint(
539+
dst_ranks=quorum.recover_dst_ranks,
571540
step=max_step,
541+
state_dict=self._manager_state_dict(),
572542
timeout=self._timeout,
573543
)
544+
545+
# See manager.rs for healing conditions
546+
if heal:
547+
self._healing = True
548+
self._logger.info(
549+
f"healing required, fetching checkpoint metadata from {recover_src_manager_address=} {max_step=}"
574550
)
551+
primary_client = ManagerClient(
552+
recover_src_manager_address,
553+
connect_timeout=self._connect_timeout,
554+
)
555+
checkpoint_metadata = primary_client._checkpoint_metadata(
556+
self._rank, timeout=self._timeout
557+
)
558+
recover_src_rank = quorum.recover_src_rank
559+
assert (
560+
recover_src_rank is not None
561+
), "must have a recover rank when healing"
575562

576-
# pyre-fixme[6]: got object
577-
self.load_state_dict(self._pending_state_dict["torchft"])
563+
self._logger.info(
564+
f"fetching checkpoint from {recover_src_rank=} with {checkpoint_metadata=}"
565+
)
578566

579-
# This isn't strictly needed as loading the state_dict above should
580-
# restore the correct step but it makes writing tests simpler.
581-
self._step = max_step
567+
# we apply the user state dict only when safe from the main thread
568+
# save it for now
569+
with torch.profiler.record_function(
570+
"torchft::manager::_checkpoint_transport::recv_checkpoint"
571+
):
572+
self._pending_state_dict = (
573+
self._checkpoint_transport.recv_checkpoint(
574+
src_rank=recover_src_rank,
575+
metadata=checkpoint_metadata,
576+
step=max_step,
577+
timeout=self._timeout,
578+
)
579+
)
580+
581+
# pyre-fixme[6]: got object
582+
self.load_state_dict(self._pending_state_dict["torchft"])
583+
584+
# This isn't strictly needed as loading the state_dict above should
585+
# restore the correct step but it makes writing tests simpler.
586+
self._step = max_step
587+
except Exception as e:
588+
self._logger.exception(f"got exception in recovery: {e}")
589+
self.report_error(e)
590+
return
582591

583592
def _apply_pending_state_dict(self) -> None:
584593
assert self._healing, "must be in healing state"
@@ -587,15 +596,19 @@ def _apply_pending_state_dict(self) -> None:
587596
assert self._quorum_future is not None, "must call step before should_commit"
588597
self._quorum_future.result()
589598

590-
self._logger.info("applying pending state dict")
599+
pending_state_dict = self._pending_state_dict
591600

592-
assert self._pending_state_dict is not None, "checkpoint was not staged"
593-
assert (
594-
self._load_state_dict is not None
595-
), "user load_state_dict is not initialized."
596-
self._load_state_dict(self._pending_state_dict["user"])
597-
self._pending_state_dict = None
598-
self._logger.info("Loaded state dict.")
601+
if pending_state_dict is None:
602+
assert self.errored(), "checkpoint was not staged and no error occured"
603+
else:
604+
self._logger.info("applying pending state dict")
605+
606+
assert (
607+
self._load_state_dict is not None
608+
), "user load_state_dict is not initialized."
609+
self._load_state_dict(pending_state_dict["user"])
610+
self._pending_state_dict = None
611+
self._logger.info("Loaded state dict.")
599612

600613
@torch.profiler.record_function("torchft::manager::should_commit")
601614
def should_commit(self, timeout: Optional[timedelta] = None) -> bool:

torchft/manager_test.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torch.distributed import TCPStore
1515

1616
from torchft._torchft import QuorumResult
17+
from torchft.checkpointing.transport import CheckpointTransport
1718
from torchft.manager import MANAGER_ADDR_KEY, REPLICA_ID_KEY, Manager, WorldSizeMode
1819
from torchft.process_group import ProcessGroup, _DummyWork
1920

@@ -648,6 +649,79 @@ def test_quorum_skip_init(self, client_mock: MagicMock) -> None:
648649
manager.start_quorum()
649650
self.assertEqual(client_mock()._quorum.call_args.kwargs["init_sync"], True)
650651

652+
@patch("torchft.manager.ManagerClient", autospec=True)
653+
def test_quorum_checkpoint_errors(self, client_mock: MagicMock) -> None:
654+
manager = self._create_manager(use_async_quorum=True)
655+
client_mock().should_commit = MagicMock(return_value=False)
656+
657+
transport = MagicMock(spec=CheckpointTransport)
658+
transport.send_checkpoint.side_effect = RuntimeError("send failure")
659+
transport.recv_checkpoint.side_effect = RuntimeError("recv failure")
660+
manager._checkpoint_transport = transport
661+
662+
quorum = QuorumResult()
663+
quorum.quorum_id = 123
664+
quorum.replica_rank = 1
665+
quorum.replica_world_size = 2
666+
quorum.recover_src_manager_address = "manager address"
667+
quorum.recover_src_rank = 0
668+
quorum.store_address = f"localhost:{self.store.port}"
669+
quorum.max_step = 20
670+
quorum.max_rank = None
671+
quorum.max_world_size = 2
672+
quorum.heal = True
673+
674+
client_mock()._quorum.return_value = quorum
675+
676+
manager.start_quorum()
677+
manager.wait_quorum()
678+
self.assertFalse(manager.should_commit())
679+
680+
error = manager.errored()
681+
self.assertIsNotNone(error)
682+
with self.assertRaisesRegex(RuntimeError, "recv failure"):
683+
raise error
684+
685+
quorum.recover_dst_ranks = [0]
686+
manager.start_quorum()
687+
manager.wait_quorum()
688+
self.assertFalse(manager.should_commit())
689+
690+
error = manager.errored()
691+
self.assertIsNotNone(error)
692+
with self.assertRaisesRegex(RuntimeError, "send failure"):
693+
raise error
694+
695+
@patch("torchft.manager.ManagerClient", autospec=True)
696+
def test_quorum_configure_errors(self, client_mock: MagicMock) -> None:
697+
manager = self._create_manager(use_async_quorum=True)
698+
client_mock().should_commit = MagicMock(return_value=False)
699+
700+
# pyre-ignore[16]: mock
701+
manager._pg.configure.side_effect = RuntimeError("configure failure")
702+
703+
quorum = QuorumResult()
704+
quorum.quorum_id = 123
705+
quorum.replica_rank = 1
706+
quorum.replica_world_size = 2
707+
quorum.recover_src_manager_address = "manager address"
708+
quorum.recover_src_rank = 0
709+
quorum.store_address = f"localhost:{self.store.port}"
710+
quorum.max_step = 20
711+
quorum.max_rank = None
712+
quorum.max_world_size = 2
713+
714+
client_mock()._quorum.return_value = quorum
715+
716+
manager.start_quorum()
717+
manager.wait_quorum()
718+
self.assertFalse(manager.should_commit())
719+
720+
error = manager.errored()
721+
self.assertIsNotNone(error)
722+
with self.assertRaisesRegex(RuntimeError, "configure failure"):
723+
raise error
724+
651725
@patch("torchft.manager.ManagerClient", autospec=True)
652726
def test_max_retries(self, client_mock: MagicMock) -> None:
653727
# Create a manager with max_retries=2

0 commit comments

Comments
 (0)