File tree Expand file tree Collapse file tree 6 files changed +12
-12
lines changed Expand file tree Collapse file tree 6 files changed +12
-12
lines changed Original file line number Diff line number Diff line change @@ -519,8 +519,7 @@ def storage_usage(
519
519
storage_map = {
520
520
"cuda" : ParameterStorage .HBM ,
521
521
"cpu" : ParameterStorage .DDR ,
522
- # TODO: Update it later. Setting for MTIA is same as CPU's for now.
523
- "mtia" : ParameterStorage .DDR ,
522
+ "mtia" : ParameterStorage .HBM ,
524
523
}
525
524
return {
526
525
storage_map [compute_device_type ].value : get_tensor_size_bytes (tensor )
Original file line number Diff line number Diff line change @@ -80,7 +80,9 @@ def __init__(
80
80
self ._use_exact_enumerate_order : bool = (
81
81
use_exact_enumerate_order if use_exact_enumerate_order else False
82
82
)
83
- memory_type = "hbm_cap" if topology .compute_device == "cuda" else "ddr_cap"
83
+ memory_type = (
84
+ "hbm_cap" if topology .compute_device in {"cuda" , "mtia" } else "ddr_cap"
85
+ )
84
86
self ._device_memory_sizes : Optional [
85
87
List [int ]
86
88
] = ( # only used with custom topology where memory is different within a topology
Original file line number Diff line number Diff line change @@ -1261,7 +1261,7 @@ def calculate_shard_storages(
1261
1261
count_ephemeral_storage_cost = count_ephemeral_storage_cost ,
1262
1262
is_inference = is_inference ,
1263
1263
)
1264
- if compute_device == "cuda"
1264
+ if compute_device in { "cuda" , "mtia" }
1265
1265
else 0
1266
1266
)
1267
1267
for input_size , output_size , hbm_specific_size in zip (
@@ -1273,7 +1273,7 @@ def calculate_shard_storages(
1273
1273
ddr_sizes : List [int ] = [
1274
1274
(
1275
1275
input_size + output_size + ddr_specific_size
1276
- if compute_device in { "cpu" , "mtia" } and not is_inference
1276
+ if compute_device == "cpu" and not is_inference
1277
1277
else ddr_specific_size
1278
1278
)
1279
1279
for input_size , output_size , ddr_specific_size in zip (
Original file line number Diff line number Diff line change @@ -73,8 +73,8 @@ def _reserve_dense_storage(
73
73
dense_tensor_size = dense_tensor_estimate
74
74
75
75
dense_tensor_storage = Storage (
76
- hbm = dense_tensor_size if topology .compute_device == "cuda" else 0 ,
77
- ddr = dense_tensor_size if topology .compute_device in { "cpu" , "mtia" } else 0 ,
76
+ hbm = dense_tensor_size if topology .compute_device in { "cuda" , "mtia" } else 0 ,
77
+ ddr = dense_tensor_size if topology .compute_device == "cpu" else 0 ,
78
78
)
79
79
80
80
for device in topology .devices :
@@ -93,8 +93,8 @@ def _reserve_kjt_storage(
93
93
kjt_size = math .ceil (sum (batch_inputs ) * float (input_data_type_size )) * multiplier
94
94
95
95
kjt_storage = Storage (
96
- hbm = kjt_size if topology .compute_device == "cuda" else 0 ,
97
- ddr = kjt_size if topology .compute_device in { "cpu" , "mtia" } else 0 ,
96
+ hbm = kjt_size if topology .compute_device in { "cuda" , "mtia" } else 0 ,
97
+ ddr = kjt_size if topology .compute_device == "cpu" else 0 ,
98
98
)
99
99
100
100
for device in topology .devices :
Original file line number Diff line number Diff line change @@ -284,7 +284,7 @@ def __init__(
284
284
self ._world_size = world_size
285
285
286
286
hbm_per_device = [0 ] * world_size
287
- if self ._compute_device == "cuda" :
287
+ if self ._compute_device == "cuda" or self . _compute_device == "mtia" :
288
288
hbm_per_device = [hbm_cap if hbm_cap else HBM_CAP ] * world_size
289
289
ddr_cap_per_rank = [ddr_cap if ddr_cap else DDR_CAP ] * world_size
290
290
Original file line number Diff line number Diff line change @@ -1197,8 +1197,7 @@ def storage_usage(
1197
1197
storage_map = {
1198
1198
"cuda" : ParameterStorage .HBM ,
1199
1199
"cpu" : ParameterStorage .DDR ,
1200
- # TODO: Update it later. Setting for MTIA is same as CPU's for now.
1201
- "mtia" : ParameterStorage .DDR ,
1200
+ "mtia" : ParameterStorage .HBM ,
1202
1201
}
1203
1202
return {storage_map [compute_device_type ].value : get_tensor_size_bytes (tensor )}
1204
1203
You can’t perform that action at this time.
0 commit comments