File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -776,7 +776,9 @@ def __init__(
776
776
use_inter_host_allreduce : bool = False ,
777
777
custom_all_reduce : Optional [Callable [[List [torch .Tensor ]], None ]] = None ,
778
778
) -> None :
779
- assert device .type == "cuda" , "DMPCollection only supports CUDA"
779
+ assert (
780
+ device .type == "cuda" or device .type == "mtia"
781
+ ), "DMPCollection only supports CUDA or MTIA"
780
782
self ._device = device
781
783
self ._pg : dist .ProcessGroup = global_pg
782
784
self ._plan : ShardingPlan = plan
@@ -1013,7 +1015,7 @@ def _remap_sharding_plan(
1013
1015
else :
1014
1016
shard_rank = shard .placement ._rank * step + group_start
1015
1017
shard .placement = _remote_device (
1016
- f"rank:{ shard_rank } /cuda :{ shard_rank % get_local_size ()} "
1018
+ f"rank:{ shard_rank } /{ self . _device . type } :{ shard_rank % get_local_size ()} "
1017
1019
)
1018
1020
return
1019
1021
You can’t perform that action at this time.
0 commit comments