Skip to content

Commit 9240bf0

Browse files
committed
Update planner to use consistent hashing
Differential Revision: D76303748
1 parent 3b6b537 commit 9240bf0

File tree

5 files changed

+213
-34
lines changed

5 files changed

+213
-34
lines changed

torchrec/distributed/planner/enumerators.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ def __init__(
102102
EmbeddingStorageEstimator(topology=topology, constraints=constraints),
103103
]
104104

105+
# Initializing caching for enumerate
106+
self._last_stored_search_space: Optional[List[ShardingOption]] = None
107+
self._last_stored_module: Optional[nn.Module] = None
108+
self._last_stored_sharders: Optional[List[ModuleSharder[nn.Module]]] = None
109+
105110
def enumerate(
106111
self,
107112
module: nn.Module,
@@ -118,6 +123,12 @@ def enumerate(
118123
List[ShardingOption]: valid sharding options with values populated.
119124
"""
120125

126+
if (
127+
self._last_stored_module == module
128+
and self._last_stored_sharders == sharders
129+
):
130+
return self._last_stored_search_space # pyre-ignore
131+
121132
self._sharder_map = {
122133
sharder_name(sharder.module_type): sharder for sharder in sharders
123134
}
@@ -230,8 +241,18 @@ def enumerate(
230241

231242
self.populate_estimates(sharding_options)
232243

244+
self._last_stored_module = module
245+
self._last_stored_sharders = sharders
246+
self._last_stored_search_space = sharding_options
233247
return sharding_options
234248

249+
@property
250+
def last_stored_search_space(self) -> Optional[List[ShardingOption]]:
251+
# NOTE: This is the last search space stored by enumerate(...), do not use
252+
# this field in place of actually calling enumerate(...) as it will varie for each
253+
# module/sharders passed in.
254+
return self._last_stored_search_space
255+
235256
def populate_estimates(self, sharding_options: List[ShardingOption]) -> None:
236257
for estimator in self._estimators:
237258
estimator.estimate(sharding_options, self._sharder_map)

torchrec/distributed/planner/planners.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
)
4040
from torchrec.distributed.planner.types import (
4141
Enumerator,
42+
hash_planner_context_inputs,
4243
ParameterConstraints,
4344
Partitioner,
4445
PerfModel,
@@ -280,25 +281,21 @@ def collective_plan(
280281
sharders,
281282
)
282283

283-
def hash_planner_context_inputs(self) -> str:
284+
def hash_planner_context_inputs(self) -> int:
284285
"""
285286
Generates a hash for all planner inputs except for partitioner, proposer, performance model, and stats.
286287
These are all the inputs needed to verify whether a previously generated sharding plan is still valid in a new context.
287288
288289
Returns:
289290
Generates a hash capturing topology, batch size, enumerator, storage reservation, stats and constraints.
290291
"""
291-
hashable_list = [
292+
return hash_planner_context_inputs(
292293
self._topology,
293294
self._batch_size,
294295
self._enumerator,
295296
self._storage_reservation,
296-
frozenset(self._constraints.items()) if self._constraints else None,
297-
]
298-
serialized_list = str(hashable_list).encode("utf-8")
299-
hash_object = hashlib.sha256(serialized_list)
300-
hash_digest = hash_object.hexdigest()
301-
return hash_digest
297+
self._constraints,
298+
)
302299

303300
def plan(
304301
self,

torchrec/distributed/planner/storage_reservations.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ class FixedPercentageStorageReservation(StorageReservation):
163163
def __init__(self, percentage: float) -> None:
164164
assert percentage >= 0 and percentage <= 1
165165
self._percentage: float = percentage
166+
self._last_reserved_toplogy: Optional[Topology] = None
166167

167168
def reserve(
168169
self,
@@ -174,8 +175,14 @@ def reserve(
174175
) -> Topology:
175176
reserved_topology = copy.deepcopy(topology)
176177
_reserve_storage_percentage(reserved_topology, self._percentage)
178+
self._last_reserved_toplogy = reserved_topology
177179
return reserved_topology
178180

181+
@property
182+
def last_reserved_toplogy(self) -> Optional[Topology]:
183+
"Cached value of the most recent output from the reserve() method."
184+
return self._last_reserved_toplogy
185+
179186

180187
class HeuristicalStorageReservation(StorageReservation):
181188
"""
@@ -206,6 +213,7 @@ def __init__(
206213

207214
self._dense_storage: Optional[Storage] = None
208215
self._kjt_storage: Optional[Storage] = None
216+
self._last_reserved_toplogy: Optional[Topology] = None
209217

210218
def reserve(
211219
self,
@@ -215,6 +223,7 @@ def reserve(
215223
sharders: List[ModuleSharder[nn.Module]],
216224
constraints: Optional[Dict[str, ParameterConstraints]] = None,
217225
) -> Topology:
226+
# TODO: enable proper caching of topology values through _last_reserved_toplogy
218227
reserved_topology = copy.deepcopy(topology)
219228

220229
batch_inputs, shardable_modules = _get_batch_inputs_and_shardable_parameters(
@@ -262,8 +271,14 @@ def reserve(
262271
message=negative_storage_solution,
263272
)
264273

274+
self._last_reserved_toplogy = reserved_topology
265275
return reserved_topology
266276

277+
@property
278+
def last_reserved_toplogy(self) -> Optional[Topology]:
279+
"Cached value of the most recent output from the reserve() method."
280+
return self._last_reserved_toplogy
281+
267282

268283
class InferenceStorageReservation(StorageReservation):
269284
"""

torchrec/distributed/planner/tests/test_types.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,30 @@
88
# pyre-strict
99

1010
import unittest
11-
from typing import cast
11+
from typing import cast, Dict, Optional
1212
from unittest.mock import MagicMock
1313

1414
import torch
15+
from torch import multiprocessing
1516
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
17+
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
18+
from torchrec.distributed.planner import EmbeddingShardingPlanner
19+
from torchrec.distributed.planner.enumerators import EmbeddingEnumerator
20+
from torchrec.distributed.planner.perf_models import NoopPerfModel
21+
from torchrec.distributed.planner.storage_reservations import (
22+
HeuristicalStorageReservation,
23+
)
1624

1725
from torchrec.distributed.planner.types import (
1826
ParameterConstraints,
1927
Shard,
2028
ShardingOption,
2129
Topology,
2230
)
31+
from torchrec.distributed.test_utils.multi_process import (
32+
MultiProcessContext,
33+
MultiProcessTestBase,
34+
)
2335
from torchrec.distributed.types import (
2436
BoundsCheckMode,
2537
CacheAlgorithm,
@@ -348,3 +360,75 @@ def test_hash_inequality(self) -> None:
348360
self.assertNotEqual(
349361
hash(pc1), hash(pc2), "Hashes should be different for different instances"
350362
)
363+
364+
365+
def _test_hashing_consistency(
366+
rank: int,
367+
world_size: int,
368+
backend: str,
369+
return_hash_dict: Dict[str, int],
370+
local_size: Optional[int] = None,
371+
) -> None:
372+
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
373+
topology = Topology(
374+
local_world_size=8,
375+
world_size=1,
376+
compute_device="cuda",
377+
)
378+
batch_size = 128
379+
enumerator = EmbeddingEnumerator(topology=topology, batch_size=batch_size)
380+
eb_config = EmbeddingBagConfig(
381+
name="table_0",
382+
embedding_dim=160,
383+
num_embeddings=10000,
384+
feature_names=["f1"],
385+
data_type=DataType.FP16,
386+
)
387+
module = EmbeddingBagCollection(
388+
tables=[eb_config],
389+
is_weighted=False,
390+
device=torch.device(
391+
"meta"
392+
), # Using meta device for now since only getting search space
393+
)
394+
sharders = [EmbeddingBagCollectionSharder()]
395+
enumerator.enumerate(module, sharders) # pyre-ignore
396+
storage_reservation = HeuristicalStorageReservation(percentage=0.15)
397+
constraints = {"table1": ParameterConstraints()}
398+
399+
storage_reservation.reserve(
400+
topology=topology,
401+
batch_size=batch_size,
402+
module=module,
403+
sharders=sharders, # pyre-ignore
404+
constraints=constraints,
405+
)
406+
perf_model = NoopPerfModel(topology=topology)
407+
408+
planner1 = EmbeddingShardingPlanner(
409+
topology=topology,
410+
batch_size=batch_size,
411+
enumerator=enumerator,
412+
storage_reservation=storage_reservation,
413+
performance_model=perf_model,
414+
constraints=constraints,
415+
)
416+
417+
h = planner1.hash_planner_context_inputs()
418+
return_hash_dict[str(rank)] = h
419+
420+
421+
class TestConsistentHashingBetweenProcesses(MultiProcessTestBase):
422+
423+
def test_hash_consistency(self) -> None:
424+
# planner
425+
world_size = 2
426+
return_hash_dict = multiprocessing.Manager().dict()
427+
self._run_multi_process_test(
428+
callable=_test_hashing_consistency,
429+
world_size=world_size,
430+
backend="nccl" if torch.cuda.is_available() else "gloo",
431+
return_hash_dict=return_hash_dict,
432+
)
433+
hashes = return_hash_dict.values()
434+
assert hashes[0] == hashes[1], "hash values are different."

torchrec/distributed/planner/types.py

Lines changed: 87 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from copy import deepcopy
1313
from dataclasses import dataclass, field
1414
from enum import Enum
15-
from typing import cast, Dict, List, Optional, Tuple, Union
15+
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
1616

1717
import torch
1818
from torch import nn
@@ -401,12 +401,15 @@ def __repr__(self) -> str:
401401
topology_repr += str(self._comms_bandwidths) + "\n"
402402
return topology_repr
403403

404-
def _hash(self) -> str:
404+
def _hash(self) -> int:
405405
"""
406406
Compute a consistent hash value for this Topology instance.
407407
408408
Returns:
409409
str: A hash value for this Topology instance.
410+
411+
NOTE: Not overriding the __hash__ method here to account for other
412+
potential variables that may be unchecked by the following list
410413
"""
411414

412415
# Compute hbms and ddrs from the decives
@@ -430,10 +433,7 @@ def _hash(self) -> str:
430433
self._uneven_sharding_perf_multiplier,
431434
]
432435

433-
serialized_list = str(hashable_list).encode("utf-8")
434-
hash_object = hashlib.sha256(serialized_list)
435-
hash_digest = hash_object.hexdigest()
436-
return hash_digest
436+
return hash_sha256_to_int(hashable_list)
437437

438438

439439
# ---- INPUT / OUTPUT ----- #
@@ -743,25 +743,25 @@ class ParameterConstraints:
743743
key_value_params: Optional[KeyValueParams] = None
744744

745745
def __hash__(self) -> int:
746-
return hash(
747-
(
748-
tuple(self.sharding_types) if self.sharding_types else None,
749-
tuple(self.compute_kernels) if self.compute_kernels else None,
750-
self.min_partition,
751-
tuple(self.pooling_factors),
752-
tuple(self.num_poolings) if self.num_poolings else None,
753-
tuple(self.batch_sizes) if self.batch_sizes else None,
754-
self.is_weighted,
755-
self.cache_params,
756-
self.enforce_hbm,
757-
self.stochastic_rounding,
758-
self.bounds_check_mode,
759-
tuple(self.feature_names) if self.feature_names else None,
760-
self.output_dtype,
761-
self.device_group,
762-
self.key_value_params,
763-
)
764-
)
746+
hashable_list = [
747+
tuple(self.sharding_types) if self.sharding_types else None,
748+
tuple(self.compute_kernels) if self.compute_kernels else None,
749+
self.min_partition,
750+
tuple(self.pooling_factors),
751+
tuple(self.num_poolings) if self.num_poolings else None,
752+
tuple(self.batch_sizes) if self.batch_sizes else None,
753+
self.is_weighted,
754+
self.cache_params,
755+
self.enforce_hbm,
756+
self.stochastic_rounding,
757+
self.bounds_check_mode,
758+
tuple(self.feature_names) if self.feature_names else None,
759+
self.output_dtype,
760+
self.device_group,
761+
self.key_value_params,
762+
]
763+
764+
return hash_sha256_to_int(hashable_list)
765765

766766

767767
class PlannerErrorType(Enum):
@@ -967,3 +967,65 @@ class CriticalPathEstimate:
967967

968968
def total(self) -> float:
969969
return self.comms_estimate + self.comp_estimate
970+
971+
972+
# ---- Types Utils ---- #
973+
def hash_sha256_to_int(hashable_list: List[Any]) -> int: # pyre-ignore
974+
"""
975+
Hashes the given data using SHA256 and returns the hash as an integer
976+
"""
977+
serialized_list = str(hashable_list).encode("utf-8")
978+
hash_object = hashlib.sha256(serialized_list)
979+
hash_digest = hash_object.hexdigest()
980+
return int(hash_digest, 16)
981+
982+
983+
def hash_planner_context_inputs(
984+
topology: Topology,
985+
batch_size: int,
986+
enumerator: Enumerator,
987+
storage_reservation: StorageReservation,
988+
constraints: Optional[Dict[str, ParameterConstraints]],
989+
# pyre-ignore
990+
hash_function: Callable[[List[Any]], int] = hash_sha256_to_int,
991+
) -> int:
992+
assert hasattr(
993+
enumerator, "last_stored_search_space"
994+
), "This enumerator is not compatible with hashing"
995+
assert (
996+
enumerator.last_stored_search_space is not None # pyre-ignore
997+
), "Unable to hash planner context without an enumerator that has a precomputed search space"
998+
search_space = enumerator.last_stored_search_space
999+
storage_reservation_policy = type(storage_reservation).__name__
1000+
1001+
# TODO: Not the best code, but will result in circular dependency if we import the actaul
1002+
# storage reservation classes - will come back here to refactor more cleanly
1003+
assert storage_reservation_policy in [
1004+
"HeuristicalStorageReservation",
1005+
"FixedPercentageStorageReservation",
1006+
]
1007+
assert hasattr(
1008+
storage_reservation, "last_reserved_toplogy"
1009+
), "This storage reservation is not compatible with hashing"
1010+
assert (
1011+
storage_reservation.last_reserved_toplogy is not None # pyre-ignore
1012+
), "Unable to hash planner context without a storage reservation that has a precomputed topology"
1013+
1014+
hashable_list = [
1015+
topology,
1016+
batch_size,
1017+
[
1018+
[
1019+
shard_option.fqn,
1020+
shard_option.sharding_type,
1021+
shard_option.compute_kernel,
1022+
tuple(shard_option.shards),
1023+
shard_option.cache_params,
1024+
]
1025+
for shard_option in search_space
1026+
],
1027+
storage_reservation_policy,
1028+
storage_reservation.last_reserved_toplogy,
1029+
frozenset(constraints.items()) if constraints else None,
1030+
]
1031+
return hash_function(hashable_list)

0 commit comments

Comments
 (0)