diff --git a/torchrec/distributed/planner/shard_estimators.py b/torchrec/distributed/planner/shard_estimators.py index 7f7856cad..dcb2dfc33 100644 --- a/torchrec/distributed/planner/shard_estimators.py +++ b/torchrec/distributed/planner/shard_estimators.py @@ -1245,7 +1245,8 @@ def calculate_shard_storages( for _ in hbm_specific_sizes ] ddr_specific_sizes = [ - (key_value_params.l2_cache_size or 0) * 1024 * 1024 * 1024 + # TODO: revisit the logic for SSD virtual table + 0 for _ in ddr_specific_sizes ] diff --git a/torchrec/distributed/planner/tests/test_planners.py b/torchrec/distributed/planner/tests/test_planners.py index 3d3b233db..60c1ffdc6 100644 --- a/torchrec/distributed/planner/tests/test_planners.py +++ b/torchrec/distributed/planner/tests/test_planners.py @@ -618,7 +618,7 @@ def test_planner_with_virtual_table(self) -> None: # L1 cache size is 64GB per shard and L2 cache size is 128MB per shard per table self.assertTrue( any( - "dram_virtual_table: HBM: 0.501 GB, DDR: 256.0 GB" in line + "dram_virtual_table: HBM: 0.501 GB, DDR: 0.0 GB" in line for line in stats ) ) @@ -748,7 +748,7 @@ def test_planner_with_virtual_table(self) -> None: # L2 cache size is 128MB per shard per table self.assertTrue( any( - "dram_virtual_table: HBM: 0.002 GB, DDR: 256.0 GB" in line + "dram_virtual_table: HBM: 0.002 GB, DDR: 0.0 GB" in line for line in stats ) ) @@ -800,7 +800,7 @@ def test_planner_with_virtual_table(self) -> None: # L2 cache size is 128MB per shard per table self.assertTrue( any( - "dram_virtual_table: HBM: 0.005 GB, DDR: 256.0 GB" in line + "dram_virtual_table: HBM: 0.005 GB, DDR: 0.0 GB" in line for line in stats ) )