|
8 | 8 | # pyre-strict |
9 | 9 |
|
10 | 10 | import abc |
| 11 | +import hashlib |
| 12 | +import pickle |
11 | 13 | from copy import deepcopy |
12 | 14 | from dataclasses import dataclass, field |
13 | 15 | from enum import Enum |
@@ -248,6 +250,10 @@ def get_bw( |
248 | 250 |
|
249 | 251 |
|
250 | 252 | class Topology: |
| 253 | + """ |
| 254 | + Representation of a network of devices in a cluster. |
| 255 | + """ |
| 256 | + |
251 | 257 | def __init__( |
252 | 258 | self, |
253 | 259 | world_size: int, |
@@ -396,6 +402,40 @@ def __repr__(self) -> str: |
396 | 402 | topology_repr += str(self._comms_bandwidths) + "\n" |
397 | 403 | return topology_repr |
398 | 404 |
|
| 405 | + def _hash(self) -> str: |
| 406 | + """ |
| 407 | + Compute a consistent hash value for this Topology instance. |
| 408 | +
|
| 409 | + Returns: |
| 410 | + str: A hash value for this Topology instance. |
| 411 | + """ |
| 412 | + |
| 413 | + # Compute hbms and ddrs from the decives |
| 414 | + hbms = [device.storage.hbm for device in self._devices] |
| 415 | + ddrs = [device.storage.ddr for device in self._devices] |
| 416 | + |
| 417 | + # Combine all attributes into a hashable tuple |
| 418 | + hashable_list = [ |
| 419 | + self._world_size, |
| 420 | + self._compute_device, |
| 421 | + hbms, |
| 422 | + ddrs, |
| 423 | + self._local_world_size, |
| 424 | + self._hbm_mem_bw, |
| 425 | + self._ddr_mem_bw, |
| 426 | + self._hbm_to_ddr_mem_bw, |
| 427 | + self._comms_bandwidths.intra_host_bw, |
| 428 | + self._comms_bandwidths.inter_host_bw, |
| 429 | + self._bwd_compute_multiplier, |
| 430 | + self._weighted_feature_bwd_compute_multiplier, |
| 431 | + self._uneven_sharding_perf_multiplier, |
| 432 | + ] |
| 433 | + |
| 434 | + serialized_list = str(hashable_list).encode("utf-8") |
| 435 | + hash_object = hashlib.sha256(serialized_list) |
| 436 | + hash_digest = hash_object.hexdigest() |
| 437 | + return hash_digest |
| 438 | + |
399 | 439 |
|
400 | 440 | # ---- INPUT / OUTPUT ----- # |
401 | 441 |
|
|
0 commit comments