Skip to content

Commit e0d4b6c

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
make process group arg optional (#3003)
Summary: Pull Request resolved: #3003 keeping process group optional should a user require more control over the all gather Reviewed By: aliafzal Differential Revision: D75488211 fbshipit-source-id: 064e283bf3296d21a7e878c36981ad6aaf094e95
1 parent f1ab0e7 commit e0d4b6c

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

torchrec/metrics/metric_module.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def _get_metric_states(
395395
return state_aggregated
396396

397397
def get_pre_compute_states(
398-
self, pg: Union[dist.ProcessGroup, DeviceMesh]
398+
self, pg: Optional[Union[dist.ProcessGroup, DeviceMesh]] = None
399399
) -> Dict[str, Dict[str, Dict[str, Union[torch.Tensor, List[torch.Tensor]]]]]:
400400
"""
401401
This function returns the states per rank for each metric to be saved. The states are are aggregated by the state defined reduction_function.
@@ -411,16 +411,16 @@ def get_pre_compute_states(
411411
applied to them. Typical state dict exposes just the metric states that live on the rank it's called from.
412412
413413
Args:
414-
pg (Union[dist.ProcessGroup, DeviceMesh]): the process group to use for all gather.
415-
reduce_metrics (bool): whether to reduce the metrics or not. Default is True.
414+
pg (Optional[Union[dist.ProcessGroup, DeviceMesh]]): the process group to use for all gather, defaults to WORLD process group.
416415
417416
Returns:
418417
Dict[str, Dict[str, Dict[str, torch.Tensor]]]: the states for each metric to be saved
419418
"""
420-
if isinstance(pg, DeviceMesh):
421-
process_group: dist.ProcessGroup = pg.get_group(mesh_dim="shard")
422-
else:
423-
process_group: dist.ProcessGroup = pg
419+
pg = pg if pg is not None else dist.group.WORLD
420+
process_group: dist.ProcessGroup = ( # pyre-ignore[9]
421+
pg.get_group(mesh_dim="shard") if isinstance(pg, DeviceMesh) else pg
422+
)
423+
424424
aggregated_states = {}
425425
world_size = dist.get_world_size(
426426
process_group

0 commit comments

Comments
 (0)