Skip to content

Commit 714d996

Browse files
Caner Gocmenfacebook-github-bot
authored andcommitted
Add hashing for ParameterConstraints (#3009)
Summary: Pull Request resolved: #3009 We are adding a hashing method for ParameterConstraints. We will later use this to ensure that a previously generated sharding plan is still valid when we want to reuse it later on. Reviewed By: iamzainhuda Differential Revision: D75553542 fbshipit-source-id: 7192459c9f4f749e82fac844119d3035721877e8
1 parent e0d4b6c commit 714d996

File tree

2 files changed

+114
-1
lines changed

2 files changed

+114
-1
lines changed

torchrec/distributed/planner/tests/test_types.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,17 @@
1414
import torch
1515
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
1616

17-
from torchrec.distributed.planner.types import Shard, ShardingOption
17+
from torchrec.distributed.planner.types import (
18+
ParameterConstraints,
19+
Shard,
20+
ShardingOption,
21+
)
1822
from torchrec.distributed.types import (
1923
BoundsCheckMode,
2024
CacheAlgorithm,
2125
CacheParams,
2226
DataType,
27+
KeyValueParams,
2328
ShardingType,
2429
)
2530
from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig
@@ -207,3 +212,90 @@ def test_module_pooled_mch_ec(self) -> None:
207212
shards=[Shard(size=shard_size, offset=offset) for offset in shard_offsets],
208213
)
209214
self.assertEqual(sharding_option.is_pooled, False)
215+
216+
217+
class TestParameterConstraintsHash(unittest.TestCase):
218+
219+
def test_hash_equality(self) -> None:
220+
# Create two identical instances
221+
pc1 = ParameterConstraints(
222+
sharding_types=["type1", "type2"],
223+
compute_kernels=["kernel1"],
224+
min_partition=4,
225+
pooling_factors=[1.0, 2.0],
226+
num_poolings=[1.0],
227+
batch_sizes=[32],
228+
is_weighted=True,
229+
cache_params=CacheParams(),
230+
enforce_hbm=True,
231+
stochastic_rounding=False,
232+
bounds_check_mode=BoundsCheckMode(1),
233+
feature_names=["feature1", "feature2"],
234+
output_dtype=DataType.FP32,
235+
device_group="cuda",
236+
key_value_params=KeyValueParams(),
237+
)
238+
239+
pc2 = ParameterConstraints(
240+
sharding_types=["type1", "type2"],
241+
compute_kernels=["kernel1"],
242+
min_partition=4,
243+
pooling_factors=[1.0, 2.0],
244+
num_poolings=[1.0],
245+
batch_sizes=[32],
246+
is_weighted=True,
247+
cache_params=CacheParams(),
248+
enforce_hbm=True,
249+
stochastic_rounding=False,
250+
bounds_check_mode=BoundsCheckMode(1),
251+
feature_names=["feature1", "feature2"],
252+
output_dtype=DataType.FP32,
253+
device_group="cuda",
254+
key_value_params=KeyValueParams(),
255+
)
256+
257+
self.assertEqual(
258+
hash(pc1), hash(pc2), "Hashes should be equal for identical instances"
259+
)
260+
261+
def test_hash_inequality(self) -> None:
262+
# Create two different instances
263+
pc1 = ParameterConstraints(
264+
sharding_types=["type1"],
265+
compute_kernels=["kernel1"],
266+
min_partition=4,
267+
pooling_factors=[1.0],
268+
num_poolings=[1.0],
269+
batch_sizes=[32],
270+
is_weighted=True,
271+
cache_params=CacheParams(),
272+
enforce_hbm=True,
273+
stochastic_rounding=False,
274+
bounds_check_mode=BoundsCheckMode(1),
275+
feature_names=["feature1"],
276+
output_dtype=DataType.FP32,
277+
device_group="cuda",
278+
key_value_params=KeyValueParams(),
279+
)
280+
281+
pc2 = ParameterConstraints(
282+
sharding_types=["type2"],
283+
compute_kernels=["kernel2"],
284+
min_partition=8,
285+
pooling_factors=[2.0],
286+
num_poolings=[2.0],
287+
batch_sizes=[64],
288+
is_weighted=False,
289+
cache_params=CacheParams(),
290+
enforce_hbm=False,
291+
stochastic_rounding=True,
292+
bounds_check_mode=BoundsCheckMode(1),
293+
feature_names=["feature2"],
294+
output_dtype=DataType.FP16,
295+
device_group="cpu",
296+
key_value_params=KeyValueParams(),
297+
)
298+
299+
self.assertNotEqual(
300+
hash(pc1), hash(pc2), "Hashes should be different for different instances"
301+
)

torchrec/distributed/planner/types.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,27 @@ class ParameterConstraints:
703703
device_group: Optional[str] = None
704704
key_value_params: Optional[KeyValueParams] = None
705705

706+
def __hash__(self) -> int:
707+
return hash(
708+
(
709+
tuple(self.sharding_types) if self.sharding_types else None,
710+
tuple(self.compute_kernels) if self.compute_kernels else None,
711+
self.min_partition,
712+
tuple(self.pooling_factors),
713+
tuple(self.num_poolings) if self.num_poolings else None,
714+
tuple(self.batch_sizes) if self.batch_sizes else None,
715+
self.is_weighted,
716+
self.cache_params,
717+
self.enforce_hbm,
718+
self.stochastic_rounding,
719+
self.bounds_check_mode,
720+
tuple(self.feature_names) if self.feature_names else None,
721+
self.output_dtype,
722+
self.device_group,
723+
self.key_value_params,
724+
)
725+
)
726+
706727

707728
class PlannerErrorType(Enum):
708729
"""

0 commit comments

Comments
 (0)