From 9ab316dd3baf603e55baa794751827465181d52c Mon Sep 17 00:00:00 2001 From: Fucheng Warren Zhu Date: Sun, 4 May 2025 07:25:55 +0800 Subject: [PATCH 1/2] Distinguishing between replica rank and group rank across the project (#181) --- proto/torchft.proto | 10 ++-- src/lib.rs | 26 +++++----- src/manager.rs | 108 +++++++++++++++++++++------------------- torchft/_torchft.pyi | 10 ++-- torchft/data.py | 16 +++--- torchft/data_test.py | 4 +- torchft/manager.py | 92 +++++++++++++++++----------------- torchft/manager_test.py | 44 ++++++++-------- train_ddp.py | 2 +- 9 files changed, 158 insertions(+), 154 deletions(-) diff --git a/proto/torchft.proto b/proto/torchft.proto index 1ed754c..7c086eb 100644 --- a/proto/torchft.proto +++ b/proto/torchft.proto @@ -73,7 +73,7 @@ service LighthouseService { } message ManagerQuorumRequest { - int64 rank = 1; + int64 group_rank = 1; int64 step = 2; string checkpoint_metadata = 3; bool shrink_only = 4; @@ -84,12 +84,12 @@ message ManagerQuorumRequest { message ManagerQuorumResponse { int64 quorum_id = 1; string recover_src_manager_address = 2; - optional int64 recover_src_rank = 3; - repeated int64 recover_dst_ranks = 4; + optional int64 recover_src_replica_rank = 3; + repeated int64 recover_dst_replica_ranks = 4; string store_address = 5; // These are information for the replicas which are at the max step. int64 max_step = 6; - optional int64 max_rank = 7; + optional int64 max_replica_rank = 7; int64 max_world_size = 8; // These are information for all replicas including behind replicas. int64 replica_rank = 9; @@ -108,7 +108,7 @@ message CheckpointMetadataResponse { message ShouldCommitRequest { bool should_commit = 1; - int64 rank = 2; + int64 group_rank = 2; int64 step = 3; } message ShouldCommitResponse { diff --git a/src/lib.rs b/src/lib.rs index e21d414..32a7a37 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -172,7 +172,7 @@ impl ManagerClient { fn _quorum( &self, py: Python<'_>, - rank: i64, + group_rank: i64, step: i64, checkpoint_metadata: String, shrink_only: bool, @@ -182,7 +182,7 @@ impl ManagerClient { ) -> Result { py.allow_threads(move || { let mut request = tonic::Request::new(ManagerQuorumRequest { - rank: rank, + group_rank: group_rank, step: step, checkpoint_metadata: checkpoint_metadata, shrink_only: shrink_only, @@ -201,11 +201,11 @@ impl ManagerClient { replica_rank: resp.replica_rank, replica_world_size: resp.replica_world_size, recover_src_manager_address: resp.recover_src_manager_address, - recover_src_rank: resp.recover_src_rank, - recover_dst_ranks: resp.recover_dst_ranks, + recover_src_replica_rank: resp.recover_src_replica_rank, + recover_dst_replica_ranks: resp.recover_dst_replica_ranks, store_address: resp.store_address, max_step: resp.max_step, - max_rank: resp.max_rank, + max_replica_rank: resp.max_replica_rank, max_world_size: resp.max_world_size, heal: resp.heal, }) @@ -250,14 +250,14 @@ impl ManagerClient { fn should_commit( &self, py: Python<'_>, - rank: i64, + group_rank: i64, step: i64, should_commit: bool, timeout: Duration, ) -> Result { py.allow_threads(move || { let mut request = tonic::Request::new(ShouldCommitRequest { - rank: rank, + group_rank: group_rank, step: step, should_commit: should_commit, }); @@ -281,11 +281,11 @@ struct QuorumResult { replica_rank: i64, replica_world_size: i64, recover_src_manager_address: String, - recover_src_rank: Option, - recover_dst_ranks: Vec, + recover_src_replica_rank: Option, + recover_dst_replica_ranks: Vec, store_address: String, max_step: i64, - max_rank: Option, + max_replica_rank: Option, max_world_size: i64, heal: bool, } @@ -299,11 +299,11 @@ impl QuorumResult { replica_rank: 0, replica_world_size: 1, recover_src_manager_address: "".to_string(), - recover_src_rank: None, - recover_dst_ranks: Vec::new(), + recover_src_replica_rank: None, + recover_dst_replica_ranks: Vec::new(), store_address: "".to_string(), max_step: 0, - max_rank: None, + max_replica_rank: None, max_world_size: 1, heal: false, } diff --git a/src/manager.rs b/src/manager.rs index f0cc026..5cf6577 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -235,9 +235,9 @@ impl ManagerService for Arc { request: Request, ) -> Result, Status> { let req = request.get_ref(); - let rank = req.rank; + let group_rank = req.group_rank; - info_with_replica!(self.replica_id, "Start quorum for rank {}", rank); + info_with_replica!(self.replica_id, "Start quorum for group_rank {}", group_rank); let timeout = try_parse_grpc_timeout(&request.metadata()) .map_err(|e| { @@ -255,7 +255,7 @@ impl ManagerService for Arc { // TODO: make separate call to set? state .checkpoint_metadata - .insert(req.rank, req.checkpoint_metadata.clone()); + .insert(req.group_rank, req.checkpoint_metadata.clone()); let member = QuorumMember { replica_id: self.replica_id.clone(), @@ -268,7 +268,7 @@ impl ManagerService for Arc { commit_failures: req.commit_failures, }; // TODO check step - state.participants.insert(rank, member.clone()); + state.participants.insert(group_rank, member.clone()); let rx = state.channel.subscribe(); self._run_quorum(&mut state, member, timeout).await?; @@ -281,9 +281,13 @@ impl ManagerService for Arc { .await .map_err(|e| Status::internal(e.to_string()))?; - info_with_replica!(self.replica_id, "Finished quorum for rank {}", rank); + info_with_replica!( + self.replica_id, + "Finished quorum for group_rank {}", + group_rank + ); - let reply = compute_quorum_results(&self.replica_id, rank, &quorum, req.init_sync)?; + let reply = compute_quorum_results(&self.replica_id, group_rank, &quorum, req.init_sync)?; Ok(Response::new(reply)) } @@ -312,12 +316,12 @@ impl ManagerService for Arc { request: Request, ) -> Result, Status> { let req = request.into_inner(); - let rank = req.rank; + let group_rank = req.group_rank; info_with_replica!( self.replica_id, "should_commit request from {} should_commit={}", - rank, + group_rank, req.should_commit ); @@ -327,9 +331,9 @@ impl ManagerService for Arc { let mut state = self.state.lock().await; if !req.should_commit { - state.should_commit_failures.insert(rank); + state.should_commit_failures.insert(group_rank); } - state.should_commit_count.insert(rank); + state.should_commit_count.insert(group_rank); let rx = state.should_commit_channel.subscribe(); @@ -377,7 +381,7 @@ impl ManagerService for Arc { fn compute_quorum_results( replica_id: &str, - rank: i64, + group_rank: i64, quorum: &Quorum, init_sync: bool, ) -> Result { @@ -408,7 +412,7 @@ fn compute_quorum_results( let max_step = participants.iter().map(|p| p.step).max().unwrap(); let max_participants: Vec<&QuorumMember> = participants.iter().filter(|p| p.step == max_step).collect(); - let max_rank = max_participants.iter().enumerate().find_map(|(i, p)| { + let max_replica_rank = max_participants.iter().enumerate().find_map(|(i, p)| { if p.replica_id == replica_id { Some(i as i64) } else { @@ -417,8 +421,9 @@ fn compute_quorum_results( }); // The primary TCPStore to use for this rank. - let primary_rank = rank as usize % max_participants.len(); - let primary = max_participants[primary_rank]; + // There is one TCPStore per replica. + let primary_replica_rank = group_rank as usize % max_participants.len(); + let primary = max_participants[primary_replica_rank]; // Compute recovery assignments @@ -427,7 +432,7 @@ fn compute_quorum_results( // Nodes are recovering if // 1. not at the max step (init_sync) // 2. max_step == 0 and not the primary replica - let all_recover_dst_ranks: Vec = participants + let all_recover_dst_replica_ranks: Vec = participants .iter() .enumerate() .filter_map(|(i, p)| { @@ -439,12 +444,13 @@ fn compute_quorum_results( }) .collect(); - let all_recover_dst_ranks_set = all_recover_dst_ranks.iter().collect::>(); + let all_recover_dst_replica_ranks_set = + all_recover_dst_replica_ranks.iter().collect::>(); let up_to_date_ranks: Vec = participants .iter() .enumerate() .filter_map(|(i, _p)| { - if !all_recover_dst_ranks_set.contains(&i) { + if !all_recover_dst_replica_ranks_set.contains(&i) { Some(i) } else { None @@ -455,34 +461,34 @@ fn compute_quorum_results( // This is a map of rank to the ranks that are recovering from that node. let mut recovery_assignments: HashMap> = HashMap::new(); // The rank of the node that this rank is recovering from. - let mut recover_src_rank: Option = None; - for (i, recovering_rank) in all_recover_dst_ranks.iter().enumerate() { - let up_to_date_idx = (i + rank as usize) % up_to_date_ranks.len(); - let recovering_recover_src_rank = up_to_date_ranks[up_to_date_idx]; - if !recovery_assignments.contains_key(&recovering_recover_src_rank) { - recovery_assignments.insert(recovering_recover_src_rank, Vec::new()); + let mut recover_src_replica_rank: Option = None; + for (i, recovering_rank) in all_recover_dst_replica_ranks.iter().enumerate() { + let up_to_date_idx = (i + group_rank as usize) % up_to_date_ranks.len(); + let recovering_recover_src_replica_rank = up_to_date_ranks[up_to_date_idx]; + if !recovery_assignments.contains_key(&recovering_recover_src_replica_rank) { + recovery_assignments.insert(recovering_recover_src_replica_rank, Vec::new()); } recovery_assignments - .get_mut(&recovering_recover_src_rank) + .get_mut(&recovering_recover_src_replica_rank) .unwrap() .push(*recovering_rank as i64); if *recovering_rank == replica_rank { - recover_src_rank = Some(recovering_recover_src_rank as i64); + recover_src_replica_rank = Some(recovering_recover_src_replica_rank as i64); } } - let heal = recover_src_rank.is_some(); + let heal = recover_src_replica_rank.is_some(); if heal { info_with_replica!( replica_id, - "healing is required step={}, max_step={}, recover_src_rank={}", + "healing is required step={}, max_step={}, recover_src_replica_rank={}", step, max_step, - recover_src_rank.unwrap() + recover_src_replica_rank.unwrap() ); } - let recover_src_manager_address = match recover_src_rank { + let recover_src_manager_address = match recover_src_replica_rank { Some(r) => participants[r as usize].address.clone(), None => "".to_string(), }; @@ -491,13 +497,13 @@ fn compute_quorum_results( quorum_id: quorum.quorum_id, // address is used for looking up the checkpoint server address. recover_src_manager_address: recover_src_manager_address, - recover_src_rank: recover_src_rank, - recover_dst_ranks: recovery_assignments + recover_src_replica_rank: recover_src_replica_rank, + recover_dst_replica_ranks: recovery_assignments .get(&replica_rank) .map_or_else(Vec::new, |v| v.clone()), store_address: primary.store_address.clone(), max_step: max_step, - max_rank: max_rank, + max_replica_rank: max_replica_rank, max_world_size: max_participants.len() as i64, replica_rank: replica_rank as i64, replica_world_size: participants.len() as i64, @@ -515,7 +521,7 @@ mod tests { use super::*; use crate::lighthouse::{Lighthouse, LighthouseOpt}; - async fn should_commit(rank: i64, should_commit: bool) -> Result { + async fn should_commit(group_rank: i64, should_commit: bool) -> Result { let mut client = manager_client_new( "http://localhost:29531".to_string(), Duration::from_secs(10), @@ -523,7 +529,7 @@ mod tests { .await?; let request = tonic::Request::new(ShouldCommitRequest { - rank: rank, + group_rank: group_rank, step: 1, should_commit: should_commit, }); @@ -607,7 +613,7 @@ mod tests { let mut client = manager_client_new(manager.address(), Duration::from_secs(10)).await?; let mut request = tonic::Request::new(ManagerQuorumRequest { - rank: 0, + group_rank: 0, step: 123, checkpoint_metadata: "addr".to_string(), shrink_only: false, @@ -624,7 +630,7 @@ mod tests { assert_eq!(resp.recover_src_manager_address, "".to_string()); assert_eq!(resp.store_address, "store_addr".to_string()); assert_eq!(resp.max_step, 123); - assert_eq!(resp.max_rank, Some(0)); + assert_eq!(resp.max_replica_rank, Some(0)); assert_eq!(resp.max_world_size, 1); assert_eq!(resp.replica_rank, 0); assert_eq!(resp.replica_world_size, 1); @@ -669,7 +675,7 @@ mod tests { manager_client_new(manager.address(), Duration::from_secs(10)).await?; let mut request = tonic::Request::new(ManagerQuorumRequest { - rank: 0, + group_rank: 0, step: 0, checkpoint_metadata: "addr".to_string(), shrink_only: false, @@ -787,22 +793,22 @@ mod tests { let results = compute_quorum_results("replica_0", 0, &quorum, true)?; assert!(!results.heal); assert_eq!(results.replica_rank, 0); - assert_eq!(results.recover_src_rank, None); - assert_eq!(results.recover_dst_ranks, vec![1]); + assert_eq!(results.recover_src_replica_rank, None); + assert_eq!(results.recover_dst_replica_ranks, vec![1]); let results = compute_quorum_results("replica_1", 0, &quorum, true)?; assert!(results.heal); assert_eq!(results.replica_rank, 1); - assert_eq!(results.recover_src_rank, Some(0)); - assert_eq!(results.recover_dst_ranks, Vec::::new()); + assert_eq!(results.recover_src_replica_rank, Some(0)); + assert_eq!(results.recover_dst_replica_ranks, Vec::::new()); // rank 1 assignments should be offset from rank 0 above and the primary let results = compute_quorum_results("replica_1", 1, &quorum, true)?; assert!(!results.heal); assert_eq!(results.replica_rank, 1); - assert_eq!(results.recover_src_rank, None); - assert_eq!(results.recover_dst_ranks, vec![0]); + assert_eq!(results.recover_src_replica_rank, None); + assert_eq!(results.recover_dst_replica_ranks, vec![0]); Ok(()) } @@ -872,29 +878,29 @@ mod tests { assert!(results.heal); assert_eq!(results.recover_src_manager_address, "addr_1".to_string()); assert_eq!(results.replica_rank, 0); - assert_eq!(results.recover_src_rank, Some(1)); - assert!(results.recover_dst_ranks.is_empty()); + assert_eq!(results.recover_src_replica_rank, Some(1)); + assert!(results.recover_dst_replica_ranks.is_empty()); let results = compute_quorum_results("replica_1", 0, &quorum, true)?; assert!(!results.heal); assert_eq!(results.recover_src_manager_address, "".to_string()); assert_eq!(results.replica_rank, 1); - assert_eq!(results.recover_src_rank, None); - assert_eq!(results.recover_dst_ranks, vec![0, 4]); + assert_eq!(results.recover_src_replica_rank, None); + assert_eq!(results.recover_dst_replica_ranks, vec![0, 4]); let results = compute_quorum_results("replica_3", 0, &quorum, true)?; assert!(!results.heal); assert_eq!(results.replica_rank, 3); - assert_eq!(results.recover_src_rank, None); - assert_eq!(results.recover_dst_ranks, vec![2]); + assert_eq!(results.recover_src_replica_rank, None); + assert_eq!(results.recover_dst_replica_ranks, vec![2]); // rank 1 assignments should be offset from rank 0 above let results = compute_quorum_results("replica_1", 1, &quorum, true)?; assert!(!results.heal); assert_eq!(results.replica_rank, 1); - assert_eq!(results.recover_src_rank, None); - assert_eq!(results.recover_dst_ranks, vec![2]); + assert_eq!(results.recover_src_replica_rank, None); + assert_eq!(results.recover_dst_replica_ranks, vec![2]); Ok(()) } diff --git a/torchft/_torchft.pyi b/torchft/_torchft.pyi index faf9ffa..9614d1b 100644 --- a/torchft/_torchft.pyi +++ b/torchft/_torchft.pyi @@ -6,7 +6,7 @@ class ManagerClient: def __init__(self, addr: str, connect_timeout: timedelta) -> None: ... def _quorum( self, - rank: int, + group_rank: int, step: int, checkpoint_metadata: str, shrink_only: bool, @@ -17,7 +17,7 @@ class ManagerClient: def _checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ... def should_commit( self, - rank: int, + group_rank: int, step: int, should_commit: bool, timeout: timedelta, @@ -28,11 +28,11 @@ class QuorumResult: replica_rank: int replica_world_size: int recover_src_manager_address: str - recover_src_rank: Optional[int] - recover_dst_ranks: List[int] + recover_src_replica_rank: Optional[int] + recover_dst_replica_ranks: List[int] store_address: str max_step: int - max_rank: Optional[int] + max_replica_rank: Optional[int] max_world_size: int heal: bool commit_failures: int diff --git a/torchft/data.py b/torchft/data.py index 337292c..02e5b3b 100644 --- a/torchft/data.py +++ b/torchft/data.py @@ -38,34 +38,34 @@ class DistributedSampler(data.distributed.DistributedSampler): This will shard the input dataset into ``num_replicas*num_replica_group`` number of shards. - Each shard rank is calculated via: ``rank + num_replicas*replica_group`` + Each shard rank is calculated via: ``rank + num_replicas*replica_rank`` - num_replicas and replica_group must be the same on all workers. + num_replicas and replica_rank must be the same on all workers. """ def __init__( self, dataset: data.Dataset, - replica_group: int, + replica_rank: int, num_replica_groups: int, - rank: Optional[int] = None, + group_rank: Optional[int] = None, num_replicas: Optional[int] = None, **kwargs: object, ) -> None: """ Args: data: the dataset to use - replica_group: the group ID (0-num_replica_groups) to use for this shard of data. + replica_rank: the group ID (0-num_replica_groups) to use for this shard of data. num_replica_groups: the max number of global replica groups rank: the local group rank num_replicas: the local group world size """ - if rank is None: - rank = dist.get_rank() + if group_rank is None: + group_rank = dist.get_rank() if num_replicas is None: num_replicas = dist.get_world_size() - self.global_rank: int = rank + num_replicas * replica_group + self.global_rank: int = group_rank + num_replicas * replica_rank self.global_world_size: int = num_replicas * num_replica_groups super().__init__( diff --git a/torchft/data_test.py b/torchft/data_test.py index 206dd4b..8dae190 100644 --- a/torchft/data_test.py +++ b/torchft/data_test.py @@ -27,9 +27,9 @@ def test_distributed_sampler(self) -> None: dataset = DummyDataset(1000) sampler = DistributedSampler( dataset, - replica_group=1, + replica_rank=1, num_replica_groups=2, - rank=3, + group_rank=3, num_replicas=4, ) self.assertEqual(sampler.global_rank, 3 + 1 * 4) diff --git a/torchft/manager.py b/torchft/manager.py index 55cf821..37acfb0 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -136,8 +136,8 @@ def __init__( depending on how frequently the syncs occur. connect_timeout: the timeout used for establishing rpc connections to ManagerServer and Lighthouse - rank: the replica group local rank - world_size: the replica group local world size + rank: the replica group local rank, referred to as group_rank in manager.py for clarity + world_size: the replica group local world size, referred to as group_world_size in manager.py for clarity store_addr: TCPStore address for this replica group store_port: TCPStore port for this replica group lighthouse_addr: if rank==0, the address of the lighthouse server @@ -158,16 +158,16 @@ def __init__( self._timeout = timeout self._quorum_timeout = quorum_timeout self._connect_timeout = connect_timeout - self._world_size_mode = world_size_mode + self._replica_world_size_mode = world_size_mode self._init_sync = init_sync self._max_retries = max_retries self._commit_failures = 0 store_addr = store_addr or os.environ["MASTER_ADDR"] store_port = store_port or int(os.environ["MASTER_PORT"]) - self._rank: int = rank if rank is not None else int(os.environ["RANK"]) - rank = self._rank - world_size = world_size or int(os.environ["WORLD_SIZE"]) + self._group_rank: int = rank if rank is not None else int(os.environ["RANK"]) + group_rank = self._group_rank + group_world_size = world_size or int(os.environ["WORLD_SIZE"]) self._min_replica_size = min_replica_size if checkpoint_transport is None: @@ -197,7 +197,7 @@ def __init__( torch.cuda.Stream() if torch.cuda.is_available() else None ) - if rank == 0: + if self._group_rank == 0: if port is None: port = int(os.environ.get(MANAGER_PORT_ENV, 0)) @@ -217,7 +217,7 @@ def __init__( hostname=hostname, bind=bind, store_addr=f"{store_addr}:{store_port}", - world_size=world_size, + world_size=group_world_size, heartbeat_interval=heartbeat_interval, connect_timeout=connect_timeout, ) @@ -230,7 +230,7 @@ def __init__( replica_id = self._store.get(REPLICA_ID_KEY).decode("utf-8") self._logger = _ManagerLogger( - manager=self, replica_id=replica_id or "", rank=rank + manager=self, replica_id=replica_id or "", group_rank=group_rank ) self._step = 0 @@ -241,8 +241,8 @@ def __init__( self._batches_committed = 0 # first step is 1 - self._participating_rank: Optional[int] = None - self._participating_world_size: int = 0 + self._participating_replica_rank: Optional[int] = None + self._participating_replica_world_size: int = 0 def set_state_dict_fns( self, load_state_dict: Callable[[T], None], state_dict: Callable[[], T] @@ -464,7 +464,7 @@ def _async_quorum( quorum = None with torch.profiler.record_function("torchft::manager::_client::_quorum"): quorum = self._client._quorum( - rank=self._rank, + group_rank=self._group_rank, step=self._step, checkpoint_metadata=self._checkpoint_transport.metadata(), shrink_only=shrink_only, @@ -479,33 +479,35 @@ def _async_quorum( recover_src_manager_address = quorum.recover_src_manager_address store_address = quorum.store_address max_step = quorum.max_step - max_rank = quorum.max_rank - max_world_size = quorum.max_world_size + max_replica_rank = quorum.max_replica_rank + max_replica_world_size = quorum.max_world_size heal = quorum.heal # When using async quorum we need to take the recovered workers. # When not using async quorum we need to take the max world size as all # workers will be healthy. - self._participating_rank, self._participating_world_size = ( - (max_rank, max_world_size) + self._participating_replica_rank, self._participating_replica_world_size = ( + (max_replica_rank, max_replica_world_size) if self._use_async_quorum or not allow_heal else (replica_rank, replica_world_size) ) # For fixed with spares we need to ensure that we don't have more # participating replicas than the min replica size. - if self._world_size_mode == WorldSizeMode.FIXED_WITH_SPARES: - self._participating_world_size = min( - self._participating_world_size, self._min_replica_size + if self._replica_world_size_mode == WorldSizeMode.FIXED_WITH_SPARES: + self._participating_replica_world_size = min( + self._participating_replica_world_size, self._min_replica_size ) if ( - self._participating_rank is not None - and self._participating_rank >= self._min_replica_size + self._participating_replica_rank is not None + and self._participating_replica_rank >= self._min_replica_size ): - self._participating_rank = None + self._participating_replica_rank = None if quorum_id != self._quorum_id: - store_prefixed_addr = f"{store_address}/torchft/{quorum_id}/{self._rank}" + store_prefixed_addr = ( + f"{store_address}/torchft/{quorum_id}/{self._group_rank}" + ) self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}") # We use the replica rank and world as we want all replicas in the PG. @@ -529,15 +531,15 @@ def _async_quorum( else nullcontext() ): try: - if quorum.recover_dst_ranks: + if quorum.recover_dst_replica_ranks: self._logger.info( - f"peers need recovery from us {quorum.recover_dst_ranks}" + f"peers need recovery from us {quorum.recover_dst_replica_ranks}" ) with torch.profiler.record_function( "torchft::manager::_checkpoint_transport::send_checkpoint" ): self._checkpoint_transport.send_checkpoint( - dst_ranks=quorum.recover_dst_ranks, + dst_ranks=quorum.recover_dst_replica_ranks, step=max_step, state_dict=self._manager_state_dict(), timeout=self._timeout, @@ -554,15 +556,15 @@ def _async_quorum( connect_timeout=self._connect_timeout, ) checkpoint_metadata = primary_client._checkpoint_metadata( - self._rank, timeout=self._timeout + self._group_rank, timeout=self._timeout ) - recover_src_rank = quorum.recover_src_rank + recover_src_replica_rank = quorum.recover_src_replica_rank assert ( - recover_src_rank is not None + recover_src_replica_rank is not None ), "must have a recover rank when healing" self._logger.info( - f"fetching checkpoint from {recover_src_rank=} with {checkpoint_metadata=}" + f"fetching checkpoint from {recover_src_replica_rank=} with {checkpoint_metadata=}" ) # we apply the user state dict only when safe from the main thread @@ -570,13 +572,11 @@ def _async_quorum( with torch.profiler.record_function( "torchft::manager::_checkpoint_transport::recv_checkpoint" ): - self._pending_state_dict = ( - self._checkpoint_transport.recv_checkpoint( - src_rank=recover_src_rank, - metadata=checkpoint_metadata, - step=max_step, - timeout=self._timeout, - ) + self._pending_state_dict = self._checkpoint_transport.recv_checkpoint( + src_rank=recover_src_replica_rank, + metadata=checkpoint_metadata, # Depending on group rank + step=max_step, + timeout=self._timeout, ) # pyre-fixme[6]: got object @@ -661,7 +661,7 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool: enough_replicas = self.num_participants() >= self._min_replica_size local_should_commit = enough_replicas and self._errored is None should_commit = self._client.should_commit( - self._rank, + self._group_rank, self._step, local_should_commit, timeout=timeout or self._timeout, @@ -762,7 +762,7 @@ def participating_rank(self) -> Optional[int]: self.wait_quorum() - return self._participating_rank + return self._participating_replica_rank def num_participants(self) -> int: """ @@ -780,8 +780,8 @@ def num_participants(self) -> int: self.wait_quorum() - assert self._participating_world_size >= 0, "internal error" - return self._participating_world_size + assert self._participating_replica_world_size >= 0, "internal error" + return self._participating_replica_world_size def is_participating(self) -> bool: """ @@ -790,7 +790,7 @@ def is_participating(self) -> bool: Returns: whether this replica is participating in the current quorum """ - if self._participating_rank is None: + if self._participating_replica_rank is None: return False if self._healing: assert self._use_async_quorum @@ -799,16 +799,14 @@ def is_participating(self) -> bool: class _ManagerLogger: - def __init__(self, manager: Manager, replica_id: str, rank: int) -> None: + def __init__(self, manager: Manager, replica_id: str, group_rank: int) -> None: self._logger: logging.Logger = logging.getLogger(__name__) self._replica_id = replica_id - self._rank = rank + self._group_rank = group_rank self._manager = manager def prefix(self) -> str: - return ( - f"[{self._replica_id}/{self._rank} - step {self._manager.current_step()}]" - ) + return f"[{self._replica_id}/{self._group_rank} - step {self._manager.current_step()}]" def info(self, msg: str) -> None: self._logger.info(f"{self.prefix()} {msg}") diff --git a/torchft/manager_test.py b/torchft/manager_test.py index be2dec2..998a2df 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -146,7 +146,7 @@ def test_quorum_happy(self, client_mock: MagicMock) -> None: quorum.recover_src_manager_address = "manager address" quorum.store_address = f"localhost:{self.store.port}" quorum.max_step = 1 - quorum.max_rank = 1 + quorum.max_replica_rank = 1 quorum.max_world_size = 2 quorum.heal = False @@ -180,10 +180,10 @@ def test_quorum_heal_sync(self, client_mock: MagicMock) -> None: quorum.replica_rank = 1 quorum.replica_world_size = 2 quorum.recover_src_manager_address = "manager address" - quorum.recover_src_rank = 0 + quorum.recover_src_replica_rank = 0 quorum.store_address = f"localhost:{self.store.port}" quorum.max_step = 20 - quorum.max_rank = None + quorum.max_replica_rank = None quorum.max_world_size = 2 quorum.heal = True @@ -234,10 +234,10 @@ def test_quorum_heal_async_not_enough_participants( quorum.replica_rank = 1 quorum.replica_world_size = 2 quorum.recover_src_manager_address = "manager address" - quorum.recover_src_rank = 0 + quorum.recover_src_replica_rank = 0 quorum.store_address = f"localhost:{self.store.port}" quorum.max_step = 20 - quorum.max_rank = None + quorum.max_replica_rank = None quorum.max_world_size = 1 quorum.heal = True @@ -296,10 +296,10 @@ def test_quorum_heal_async_zero_grad(self, client_mock: MagicMock) -> None: quorum.replica_rank = 1 quorum.replica_world_size = 2 quorum.recover_src_manager_address = "manager address" - quorum.recover_src_rank = 0 + quorum.recover_src_replica_rank = 0 quorum.store_address = f"localhost:{self.store.port}" quorum.max_step = 20 - quorum.max_rank = None + quorum.max_replica_rank = None quorum.max_world_size = 1 quorum.heal = True @@ -358,7 +358,7 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None: quorum.recover_src_manager_address = "manager address" quorum.store_address = f"localhost:{self.store.port}" quorum.max_step = 1 - quorum.max_rank = 1 + quorum.max_replica_rank = 1 quorum.max_world_size = 2 quorum.heal = False @@ -427,7 +427,7 @@ def test_pg_errored(self, client_mock: MagicMock) -> None: quorum.recover_src_manager_address = "manager address" quorum.store_address = f"localhost:{self.store.port}" quorum.max_step = 1 - quorum.max_rank = 1 + quorum.max_replica_rank = 1 quorum.max_world_size = 2 quorum.heal = False @@ -465,7 +465,7 @@ def test_quorum_fixed_world_size(self, client_mock: MagicMock) -> None: quorum.recover_src_manager_address = "manager address" quorum.store_address = f"localhost:{self.store.port}" quorum.max_step = 1 - quorum.max_rank = rank + quorum.max_replica_rank = rank quorum.max_world_size = 3 quorum.heal = False @@ -497,10 +497,10 @@ def test_quorum_no_healing(self, client_mock: MagicMock) -> None: quorum.replica_rank = 0 quorum.replica_world_size = 3 quorum.recover_src_manager_address = "manager address" - quorum.recover_src_rank = 1 + quorum.recover_src_replica_rank = 1 quorum.store_address = f"localhost:{self.store.port}" quorum.max_step = 1 - quorum.max_rank = None + quorum.max_replica_rank = None quorum.max_world_size = 2 quorum.heal = True client_mock()._quorum.return_value = quorum @@ -568,8 +568,8 @@ def test_manager_numerics(self, client_mock: MagicMock) -> None: manager._quorum_future = quorum_future = MagicMock( spec=concurrent.futures.Future ) - manager._participating_rank = 1 - manager._participating_world_size = 5 + manager._participating_replica_rank = 1 + manager._participating_replica_world_size = 5 self.assertEqual(manager.num_participants(), 5) self.assertEqual(quorum_future.result.call_count, 1) self.assertEqual(manager.participating_rank(), 1) @@ -603,7 +603,7 @@ def test_quorum_happy_timeouts(self, client_mock: MagicMock) -> None: quorum.recover_src_manager_address = "manager address" quorum.store_address = f"localhost:{self.store.port}" quorum.max_step = 1 - quorum.max_rank = 1 + quorum.max_replica_rank = 1 quorum.max_world_size = 2 quorum.heal = False @@ -636,7 +636,7 @@ def test_quorum_skip_init(self, client_mock: MagicMock) -> None: quorum.recover_src_manager_address = "manager address" quorum.store_address = f"localhost:{self.store.port}" quorum.max_step = 1 - quorum.max_rank = 1 + quorum.max_replica_rank = 1 quorum.max_world_size = 2 quorum.heal = False @@ -664,10 +664,10 @@ def test_quorum_checkpoint_errors(self, client_mock: MagicMock) -> None: quorum.replica_rank = 1 quorum.replica_world_size = 2 quorum.recover_src_manager_address = "manager address" - quorum.recover_src_rank = 0 + quorum.recover_src_replica_rank = 0 quorum.store_address = f"localhost:{self.store.port}" quorum.max_step = 20 - quorum.max_rank = None + quorum.max_replica_rank = None quorum.max_world_size = 2 quorum.heal = True @@ -682,7 +682,7 @@ def test_quorum_checkpoint_errors(self, client_mock: MagicMock) -> None: with self.assertRaisesRegex(RuntimeError, "recv failure"): raise error - quorum.recover_dst_ranks = [0] + quorum.recover_dst_replica_ranks = [0] manager.start_quorum() manager.wait_quorum() self.assertFalse(manager.should_commit()) @@ -705,10 +705,10 @@ def test_quorum_configure_errors(self, client_mock: MagicMock) -> None: quorum.replica_rank = 1 quorum.replica_world_size = 2 quorum.recover_src_manager_address = "manager address" - quorum.recover_src_rank = 0 + quorum.recover_src_replica_rank = 0 quorum.store_address = f"localhost:{self.store.port}" quorum.max_step = 20 - quorum.max_rank = None + quorum.max_replica_rank = None quorum.max_world_size = 2 client_mock()._quorum.return_value = quorum @@ -735,7 +735,7 @@ def test_max_retries(self, client_mock: MagicMock) -> None: quorum.recover_src_manager_address = "manager address" quorum.store_address = f"localhost:{self.store.port}" quorum.max_step = 1 - quorum.max_rank = 1 + quorum.max_replica_rank = 1 quorum.max_world_size = 2 quorum.heal = False client_mock()._quorum.return_value = quorum diff --git a/train_ddp.py b/train_ddp.py index 1140e3b..fd79b8a 100644 --- a/train_ddp.py +++ b/train_ddp.py @@ -53,7 +53,7 @@ def main() -> None: trainset, replica_group=REPLICA_GROUP_ID, num_replica_groups=NUM_REPLICA_GROUPS, - rank=0, + group_rank=0, # for DDP we can use replica groups of size 1, FSDP/PP/CP would need more. num_replicas=1, shuffle=True, From 840cd31fa12a6f26d59ade9e961946e3da4356d3 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Mon, 5 May 2025 16:52:59 -0700 Subject: [PATCH 2/2] lint --- src/manager.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/manager.rs b/src/manager.rs index 5cf6577..e28cbeb 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -237,7 +237,11 @@ impl ManagerService for Arc { let req = request.get_ref(); let group_rank = req.group_rank; - info_with_replica!(self.replica_id, "Start quorum for group_rank {}", group_rank); + info_with_replica!( + self.replica_id, + "Start quorum for group_rank {}", + group_rank + ); let timeout = try_parse_grpc_timeout(&request.metadata()) .map_err(|e| {