|
14 | 14 | import torch
|
15 | 15 | from torchrec.distributed.embedding_types import EmbeddingComputeKernel
|
16 | 16 |
|
17 |
| -from torchrec.distributed.planner.types import Shard, ShardingOption |
| 17 | +from torchrec.distributed.planner.types import ( |
| 18 | + ParameterConstraints, |
| 19 | + Shard, |
| 20 | + ShardingOption, |
| 21 | +) |
18 | 22 | from torchrec.distributed.types import (
|
19 | 23 | BoundsCheckMode,
|
20 | 24 | CacheAlgorithm,
|
21 | 25 | CacheParams,
|
22 | 26 | DataType,
|
| 27 | + KeyValueParams, |
23 | 28 | ShardingType,
|
24 | 29 | )
|
25 | 30 | from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig
|
@@ -207,3 +212,90 @@ def test_module_pooled_mch_ec(self) -> None:
|
207 | 212 | shards=[Shard(size=shard_size, offset=offset) for offset in shard_offsets],
|
208 | 213 | )
|
209 | 214 | self.assertEqual(sharding_option.is_pooled, False)
|
| 215 | + |
| 216 | + |
| 217 | +class TestParameterConstraintsHash(unittest.TestCase): |
| 218 | + |
| 219 | + def test_hash_equality(self) -> None: |
| 220 | + # Create two identical instances |
| 221 | + pc1 = ParameterConstraints( |
| 222 | + sharding_types=["type1", "type2"], |
| 223 | + compute_kernels=["kernel1"], |
| 224 | + min_partition=4, |
| 225 | + pooling_factors=[1.0, 2.0], |
| 226 | + num_poolings=[1.0], |
| 227 | + batch_sizes=[32], |
| 228 | + is_weighted=True, |
| 229 | + cache_params=CacheParams(), |
| 230 | + enforce_hbm=True, |
| 231 | + stochastic_rounding=False, |
| 232 | + bounds_check_mode=BoundsCheckMode(1), |
| 233 | + feature_names=["feature1", "feature2"], |
| 234 | + output_dtype=DataType.FP32, |
| 235 | + device_group="cuda", |
| 236 | + key_value_params=KeyValueParams(), |
| 237 | + ) |
| 238 | + |
| 239 | + pc2 = ParameterConstraints( |
| 240 | + sharding_types=["type1", "type2"], |
| 241 | + compute_kernels=["kernel1"], |
| 242 | + min_partition=4, |
| 243 | + pooling_factors=[1.0, 2.0], |
| 244 | + num_poolings=[1.0], |
| 245 | + batch_sizes=[32], |
| 246 | + is_weighted=True, |
| 247 | + cache_params=CacheParams(), |
| 248 | + enforce_hbm=True, |
| 249 | + stochastic_rounding=False, |
| 250 | + bounds_check_mode=BoundsCheckMode(1), |
| 251 | + feature_names=["feature1", "feature2"], |
| 252 | + output_dtype=DataType.FP32, |
| 253 | + device_group="cuda", |
| 254 | + key_value_params=KeyValueParams(), |
| 255 | + ) |
| 256 | + |
| 257 | + self.assertEqual( |
| 258 | + hash(pc1), hash(pc2), "Hashes should be equal for identical instances" |
| 259 | + ) |
| 260 | + |
| 261 | + def test_hash_inequality(self) -> None: |
| 262 | + # Create two different instances |
| 263 | + pc1 = ParameterConstraints( |
| 264 | + sharding_types=["type1"], |
| 265 | + compute_kernels=["kernel1"], |
| 266 | + min_partition=4, |
| 267 | + pooling_factors=[1.0], |
| 268 | + num_poolings=[1.0], |
| 269 | + batch_sizes=[32], |
| 270 | + is_weighted=True, |
| 271 | + cache_params=CacheParams(), |
| 272 | + enforce_hbm=True, |
| 273 | + stochastic_rounding=False, |
| 274 | + bounds_check_mode=BoundsCheckMode(1), |
| 275 | + feature_names=["feature1"], |
| 276 | + output_dtype=DataType.FP32, |
| 277 | + device_group="cuda", |
| 278 | + key_value_params=KeyValueParams(), |
| 279 | + ) |
| 280 | + |
| 281 | + pc2 = ParameterConstraints( |
| 282 | + sharding_types=["type2"], |
| 283 | + compute_kernels=["kernel2"], |
| 284 | + min_partition=8, |
| 285 | + pooling_factors=[2.0], |
| 286 | + num_poolings=[2.0], |
| 287 | + batch_sizes=[64], |
| 288 | + is_weighted=False, |
| 289 | + cache_params=CacheParams(), |
| 290 | + enforce_hbm=False, |
| 291 | + stochastic_rounding=True, |
| 292 | + bounds_check_mode=BoundsCheckMode(1), |
| 293 | + feature_names=["feature2"], |
| 294 | + output_dtype=DataType.FP16, |
| 295 | + device_group="cpu", |
| 296 | + key_value_params=KeyValueParams(), |
| 297 | + ) |
| 298 | + |
| 299 | + self.assertNotEqual( |
| 300 | + hash(pc1), hash(pc2), "Hashes should be different for different instances" |
| 301 | + ) |
0 commit comments