Skip to content

Commit 10686e2

Browse files
Fix compatibility between data.py and train_ddp.py with replica rank terminology (meta-pytorch#187)
1 parent 4f5837b commit 10686e2

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

torchft/data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class DistributedSampler(data.distributed.DistributedSampler):
4646
def __init__(
4747
self,
4848
dataset: data.Dataset,
49-
replica_rank: int,
49+
replica_group_id: int,
5050
num_replica_groups: int,
5151
group_rank: Optional[int] = None,
5252
num_replicas: Optional[int] = None,
@@ -65,7 +65,7 @@ def __init__(
6565
if num_replicas is None:
6666
num_replicas = dist.get_world_size()
6767

68-
self.global_rank: int = group_rank + num_replicas * replica_rank
68+
self.global_rank: int = group_rank + num_replicas * replica_group_id
6969
self.global_world_size: int = num_replicas * num_replica_groups
7070

7171
super().__init__(

train_ddp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def main() -> None:
5151
# majority of groups will be available so few batches will be dropped.
5252
sampler = DistributedSampler(
5353
trainset,
54-
replica_group=REPLICA_GROUP_ID,
54+
replica_group_id=REPLICA_GROUP_ID,
5555
num_replica_groups=NUM_REPLICA_GROUPS,
5656
group_rank=0,
5757
# for DDP we can use replica groups of size 1, FSDP/PP/CP would need more.

0 commit comments

Comments
 (0)