@@ -508,12 +508,16 @@ def _async_quorum(
508508
509509 self ._logger .info (f"reconfiguring for { quorum_id = } { store_prefixed_addr = } " )
510510 # We use the replica rank and world as we want all replicas in the PG.
511- # TODO: handle configure errors
512- with torch .profiler .record_function ("torchft::manager::_pg.configure" ):
513- self ._pg .configure (
514- store_prefixed_addr , replica_rank , replica_world_size
515- )
516- self ._quorum_id = quorum_id
511+ try :
512+ with torch .profiler .record_function ("torchft::manager::_pg.configure" ):
513+ self ._pg .configure (
514+ store_prefixed_addr , replica_rank , replica_world_size
515+ )
516+ self ._quorum_id = quorum_id
517+ except Exception as e :
518+ self ._logger .exception (f"got exception in pg configure: { e } " )
519+ self .report_error (e )
520+ return
517521
518522 if allow_heal :
519523 # run recovery on the recovery stream if available
@@ -523,62 +527,67 @@ def _async_quorum(
523527 if recovery_stream is not None
524528 else nullcontext ()
525529 ):
526- if quorum .recover_dst_ranks :
527- self ._logger .info (
528- f"peers need recovery from us { quorum .recover_dst_ranks } "
529- )
530- with torch .profiler .record_function (
531- "torchft::manager::_checkpoint_transport::send_checkpoint"
532- ):
533- self ._checkpoint_transport .send_checkpoint (
534- dst_ranks = quorum .recover_dst_ranks ,
535- step = max_step ,
536- state_dict = self ._manager_state_dict (),
537- timeout = self ._timeout ,
530+ try :
531+ if quorum .recover_dst_ranks :
532+ self ._logger .info (
533+ f"peers need recovery from us { quorum .recover_dst_ranks } "
538534 )
539-
540- # See manager.rs for healing conditions
541- if heal :
542- self ._healing = True
543- self ._logger .info (
544- f"healing required, fetching checkpoint metadata from { recover_src_manager_address = } { max_step = } "
545- )
546- primary_client = ManagerClient (
547- recover_src_manager_address ,
548- connect_timeout = self ._connect_timeout ,
549- )
550- checkpoint_metadata = primary_client ._checkpoint_metadata (
551- self ._rank , timeout = self ._timeout
552- )
553- recover_src_rank = quorum .recover_src_rank
554- assert (
555- recover_src_rank is not None
556- ), "must have a recover rank when healing"
557-
558- self ._logger .info (
559- f"fetching checkpoint from { recover_src_rank = } with { checkpoint_metadata = } "
560- )
561-
562- # we apply the user state dict only when safe from the main thread
563- # save it for now
564- with torch .profiler .record_function (
565- "torchft::manager::_checkpoint_transport::recv_checkpoint"
566- ):
567- self ._pending_state_dict = (
568- self ._checkpoint_transport .recv_checkpoint (
569- src_rank = recover_src_rank ,
570- metadata = checkpoint_metadata ,
535+ with torch .profiler .record_function (
536+ "torchft::manager::_checkpoint_transport::send_checkpoint"
537+ ):
538+ self ._checkpoint_transport .send_checkpoint (
539+ dst_ranks = quorum .recover_dst_ranks ,
571540 step = max_step ,
541+ state_dict = self ._manager_state_dict (),
572542 timeout = self ._timeout ,
573543 )
544+
545+ # See manager.rs for healing conditions
546+ if heal :
547+ self ._healing = True
548+ self ._logger .info (
549+ f"healing required, fetching checkpoint metadata from { recover_src_manager_address = } { max_step = } "
574550 )
551+ primary_client = ManagerClient (
552+ recover_src_manager_address ,
553+ connect_timeout = self ._connect_timeout ,
554+ )
555+ checkpoint_metadata = primary_client ._checkpoint_metadata (
556+ self ._rank , timeout = self ._timeout
557+ )
558+ recover_src_rank = quorum .recover_src_rank
559+ assert (
560+ recover_src_rank is not None
561+ ), "must have a recover rank when healing"
575562
576- # pyre-fixme[6]: got object
577- self .load_state_dict (self ._pending_state_dict ["torchft" ])
563+ self ._logger .info (
564+ f"fetching checkpoint from { recover_src_rank = } with { checkpoint_metadata = } "
565+ )
578566
579- # This isn't strictly needed as loading the state_dict above should
580- # restore the correct step but it makes writing tests simpler.
581- self ._step = max_step
567+ # we apply the user state dict only when safe from the main thread
568+ # save it for now
569+ with torch .profiler .record_function (
570+ "torchft::manager::_checkpoint_transport::recv_checkpoint"
571+ ):
572+ self ._pending_state_dict = (
573+ self ._checkpoint_transport .recv_checkpoint (
574+ src_rank = recover_src_rank ,
575+ metadata = checkpoint_metadata ,
576+ step = max_step ,
577+ timeout = self ._timeout ,
578+ )
579+ )
580+
581+ # pyre-fixme[6]: got object
582+ self .load_state_dict (self ._pending_state_dict ["torchft" ])
583+
584+ # This isn't strictly needed as loading the state_dict above should
585+ # restore the correct step but it makes writing tests simpler.
586+ self ._step = max_step
587+ except Exception as e :
588+ self ._logger .exception (f"got exception in recovery: { e } " )
589+ self .report_error (e )
590+ return
582591
583592 def _apply_pending_state_dict (self ) -> None :
584593 assert self ._healing , "must be in healing state"
@@ -587,15 +596,19 @@ def _apply_pending_state_dict(self) -> None:
587596 assert self ._quorum_future is not None , "must call step before should_commit"
588597 self ._quorum_future .result ()
589598
590- self . _logger . info ( "applying pending state dict" )
599+ pending_state_dict = self . _pending_state_dict
591600
592- assert self ._pending_state_dict is not None , "checkpoint was not staged"
593- assert (
594- self ._load_state_dict is not None
595- ), "user load_state_dict is not initialized."
596- self ._load_state_dict (self ._pending_state_dict ["user" ])
597- self ._pending_state_dict = None
598- self ._logger .info ("Loaded state dict." )
601+ if pending_state_dict is None :
602+ assert self .errored (), "checkpoint was not staged and no error occured"
603+ else :
604+ self ._logger .info ("applying pending state dict" )
605+
606+ assert (
607+ self ._load_state_dict is not None
608+ ), "user load_state_dict is not initialized."
609+ self ._load_state_dict (pending_state_dict ["user" ])
610+ self ._pending_state_dict = None
611+ self ._logger .info ("Loaded state dict." )
599612
600613 @torch .profiler .record_function ("torchft::manager::should_commit" )
601614 def should_commit (self , timeout : Optional [timedelta ] = None ) -> bool :
0 commit comments