Skip to content

Commit b1aa49c

Browse files
emasapfacebook-github-bot
authored andcommitted
Add MTIA info to sharder (#3032)
Summary: Pull Request resolved: #3032 Reviewed By: kausv Differential Revision: D74064134 fbshipit-source-id: 3a56ced167b2cf0a0559ef106a902c68bd241eae
1 parent 45223e0 commit b1aa49c

File tree

6 files changed

+12
-12
lines changed

6 files changed

+12
-12
lines changed

torchrec/distributed/embedding_types.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -519,8 +519,7 @@ def storage_usage(
519519
storage_map = {
520520
"cuda": ParameterStorage.HBM,
521521
"cpu": ParameterStorage.DDR,
522-
# TODO: Update it later. Setting for MTIA is same as CPU's for now.
523-
"mtia": ParameterStorage.DDR,
522+
"mtia": ParameterStorage.HBM,
524523
}
525524
return {
526525
storage_map[compute_device_type].value: get_tensor_size_bytes(tensor)

torchrec/distributed/planner/enumerators.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ def __init__(
8080
self._use_exact_enumerate_order: bool = (
8181
use_exact_enumerate_order if use_exact_enumerate_order else False
8282
)
83-
memory_type = "hbm_cap" if topology.compute_device == "cuda" else "ddr_cap"
83+
memory_type = (
84+
"hbm_cap" if topology.compute_device in {"cuda", "mtia"} else "ddr_cap"
85+
)
8486
self._device_memory_sizes: Optional[
8587
List[int]
8688
] = ( # only used with custom topology where memory is different within a topology

torchrec/distributed/planner/shard_estimators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1261,7 +1261,7 @@ def calculate_shard_storages(
12611261
count_ephemeral_storage_cost=count_ephemeral_storage_cost,
12621262
is_inference=is_inference,
12631263
)
1264-
if compute_device == "cuda"
1264+
if compute_device in {"cuda", "mtia"}
12651265
else 0
12661266
)
12671267
for input_size, output_size, hbm_specific_size in zip(
@@ -1273,7 +1273,7 @@ def calculate_shard_storages(
12731273
ddr_sizes: List[int] = [
12741274
(
12751275
input_size + output_size + ddr_specific_size
1276-
if compute_device in {"cpu", "mtia"} and not is_inference
1276+
if compute_device == "cpu" and not is_inference
12771277
else ddr_specific_size
12781278
)
12791279
for input_size, output_size, ddr_specific_size in zip(

torchrec/distributed/planner/storage_reservations.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ def _reserve_dense_storage(
7373
dense_tensor_size = dense_tensor_estimate
7474

7575
dense_tensor_storage = Storage(
76-
hbm=dense_tensor_size if topology.compute_device == "cuda" else 0,
77-
ddr=dense_tensor_size if topology.compute_device in {"cpu", "mtia"} else 0,
76+
hbm=dense_tensor_size if topology.compute_device in {"cuda", "mtia"} else 0,
77+
ddr=dense_tensor_size if topology.compute_device == "cpu" else 0,
7878
)
7979

8080
for device in topology.devices:
@@ -93,8 +93,8 @@ def _reserve_kjt_storage(
9393
kjt_size = math.ceil(sum(batch_inputs) * float(input_data_type_size)) * multiplier
9494

9595
kjt_storage = Storage(
96-
hbm=kjt_size if topology.compute_device == "cuda" else 0,
97-
ddr=kjt_size if topology.compute_device in {"cpu", "mtia"} else 0,
96+
hbm=kjt_size if topology.compute_device in {"cuda", "mtia"} else 0,
97+
ddr=kjt_size if topology.compute_device == "cpu" else 0,
9898
)
9999

100100
for device in topology.devices:

torchrec/distributed/planner/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def __init__(
284284
self._world_size = world_size
285285

286286
hbm_per_device = [0] * world_size
287-
if self._compute_device == "cuda":
287+
if self._compute_device == "cuda" or self._compute_device == "mtia":
288288
hbm_per_device = [hbm_cap if hbm_cap else HBM_CAP] * world_size
289289
ddr_cap_per_rank = [ddr_cap if ddr_cap else DDR_CAP] * world_size
290290

torchrec/distributed/types.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,8 +1197,7 @@ def storage_usage(
11971197
storage_map = {
11981198
"cuda": ParameterStorage.HBM,
11991199
"cpu": ParameterStorage.DDR,
1200-
# TODO: Update it later. Setting for MTIA is same as CPU's for now.
1201-
"mtia": ParameterStorage.DDR,
1200+
"mtia": ParameterStorage.HBM,
12021201
}
12031202
return {storage_map[compute_device_type].value: get_tensor_size_bytes(tensor)}
12041203

0 commit comments

Comments
 (0)