Skip to content

Commit 998f96f

Browse files
Guangyu Wangfacebook-github-bot
authored andcommitted
support MTIA in DMPCollection (#3100)
Summary: Pull Request resolved: #3100 tsia Reviewed By: jvandebon Differential Revision: D76608091 fbshipit-source-id: 579f97c614bd2997150a7fdc315c191341c9056a
1 parent 820bc7a commit 998f96f

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

torchrec/distributed/model_parallel.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,9 @@ def __init__(
776776
use_inter_host_allreduce: bool = False,
777777
custom_all_reduce: Optional[Callable[[List[torch.Tensor]], None]] = None,
778778
) -> None:
779-
assert device.type == "cuda", "DMPCollection only supports CUDA"
779+
assert (
780+
device.type == "cuda" or device.type == "mtia"
781+
), "DMPCollection only supports CUDA or MTIA"
780782
self._device = device
781783
self._pg: dist.ProcessGroup = global_pg
782784
self._plan: ShardingPlan = plan
@@ -1013,7 +1015,7 @@ def _remap_sharding_plan(
10131015
else:
10141016
shard_rank = shard.placement._rank * step + group_start
10151017
shard.placement = _remote_device(
1016-
f"rank:{shard_rank}/cuda:{shard_rank % get_local_size()}"
1018+
f"rank:{shard_rank}/{self._device.type}:{shard_rank % get_local_size()}"
10171019
)
10181020
return
10191021

0 commit comments

Comments
 (0)