Skip to content

Commit d5a014c

Browse files
emlinfacebook-github-bot
authored andcommitted
support get state dict and apply state dict (pytorch#2976)
Summary: Pull Request resolved: pytorch#2976 X-link: pytorch/FBGEMM#4145 X-link: facebookresearch/FBGEMM#1226 # Functions **Saving State Dict** When saving state dict, we convert IDs from local to global. This allows us to avoid shifting IDs based on sharding decisions when tables are resharded. **Checkpoint Loading Mode** We have enabled load state dict mode for checkpoint loading, which allows us to cache all ID, weight, and bucket tensors in memory before applying them to the backend. This approach ensures that all data is loaded correctly, even though the checkpoint client does not support ordered tensor loading. # Current Solution The current solution involves caching all data in Python tensors, following these steps: - Set self.local_weight_counts based on checkpoint bucket tensor size. - Enable load state dict mode to initialize local cache tensors. - Call state_dict to get empty tensors for the checkpoint loader. - Write checkpoint data to cached tensors from persisted checkpoint by the checkpoint loader. - Call apply_state_dict to write all cached tensors to the backend. **Apply State Dict Flow** During the apply_state_dict step, we perform the following operations: - If optimizer offloading is enabled: - Loop through chunks of weight and optimizer. - Concatenate weight and optimizer together. - Write to backend using KVTensorWrapper interface. - If optimizer offloading is disabled: - Set optimizer to device tensor based on ID. - Write ID weight to backend for each table. # Limitations The current solution has two limitations: - Memory overhead: - When writing data to the backend, the Python tensor's memory cannot be released until the whole tensor data is duplicated in the backend. This can lead to high memory usage, especially when dealing with single large tables. - Performance regression: - With optimizer offloading, we need to concatenate weight and optimizer together before writing to the backend. To avoid triple one large tensor's memory, we loop through smaller chunks during writing, which can cause performance regression. # Future Improvements After the first version e2e is ready, we plan to support unordered loading from the backend to improve performance and reduce memory overhead. Reviewed By: bobbyliujb Differential Revision: D74790154
1 parent a2b1ee6 commit d5a014c

File tree

1 file changed

+81
-58
lines changed

1 file changed

+81
-58
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 81 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929

3030
import torch
3131
import torch.distributed as dist
32+
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
33+
BackendType,
34+
KVZCHParams,
35+
)
3236
from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
3337
IntNBitTableBatchedEmbeddingBagsCodegen,
3438
)
@@ -42,7 +46,6 @@
4246
SplitTableBatchedEmbeddingBagsCodegen,
4347
)
4448
from fbgemm_gpu.tbe.ssd import ASSOC, SSDTableBatchedEmbeddingBags
45-
from fbgemm_gpu.tbe.ssd.training import BackendType, KVZCHParams
4649
from fbgemm_gpu.tbe.ssd.utils.partially_materialized_tensor import (
4750
PartiallyMaterializedTensor,
4851
)
@@ -861,6 +864,7 @@ def _gen_named_parameters_by_table_fused(
861864
table_count = table_name_to_count.pop(table_name)
862865
if emb_module.weights_precision == SparseType.INT8:
863866
dim += emb_module.int8_emb_row_dim_offset
867+
# pyre-ignore [29]
864868
offset = emb_module.weights_physical_offsets[t_idx]
865869
weights: torch.Tensor
866870
if location == EmbeddingLocation.DEVICE.value:
@@ -1253,6 +1257,16 @@ def __init__(
12531257
compute_kernel = config.embedding_tables[0].compute_kernel
12541258
embedding_location = compute_kernel_to_embedding_location(compute_kernel)
12551259

1260+
# every split_embeding_weights call is expensive, since it iterates over all the elements in the backend kv db
1261+
# use split weights result cache so that multiple calls in the same train iteration will only trigger once
1262+
self._split_weights_res: Optional[
1263+
Tuple[
1264+
List[ShardedTensor],
1265+
List[ShardedTensor],
1266+
List[ShardedTensor],
1267+
]
1268+
] = None
1269+
12561270
self._emb_module: SSDTableBatchedEmbeddingBags = SSDTableBatchedEmbeddingBags(
12571271
embedding_specs=list(zip(self._num_embeddings, self._local_cols)),
12581272
feature_table_map=self._feature_table_map,
@@ -1265,11 +1279,18 @@ def __init__(
12651279
logger.info(
12661280
f"tbe_unique_id:{self._emb_module.tbe_unique_id} => table name to count dict:{self.table_name_to_count}"
12671281
)
1268-
1269-
self._optim: KeyValueEmbeddingFusedOptimizer = KeyValueEmbeddingFusedOptimizer(
1270-
config,
1271-
self._emb_module,
1272-
pg,
1282+
self._table_name_to_weight_count_per_rank: Dict[str, List[int]] = {}
1283+
self._init_sharded_split_embedding_weights() # this will populate self._split_weights_res
1284+
self._optim: ZeroCollisionKeyValueEmbeddingFusedOptimizer = (
1285+
ZeroCollisionKeyValueEmbeddingFusedOptimizer(
1286+
config,
1287+
self._emb_module,
1288+
# pyre-ignore[16]
1289+
sharded_embedding_weights_by_table=self._split_weights_res[0],
1290+
table_name_to_weight_count_per_rank=self._table_name_to_weight_count_per_rank,
1291+
sharded_embedding_weight_ids=self._split_weights_res[1],
1292+
pg=pg,
1293+
)
12731294
)
12741295
self._param_per_table: Dict[str, nn.Parameter] = dict(
12751296
_gen_named_parameters_by_table_ssd_pmt(
@@ -1281,16 +1302,6 @@ def __init__(
12811302
)
12821303
self.init_parameters()
12831304

1284-
# every split_embeding_weights call is expensive, since it iterates over all the elements in the backend kv db
1285-
# use split weights result cache so that multiple calls in the same train iteration will only trigger once
1286-
self._split_weights_res: Optional[
1287-
Tuple[
1288-
List[ShardedTensor],
1289-
List[ShardedTensor],
1290-
List[ShardedTensor],
1291-
]
1292-
] = None
1293-
12941305
def init_parameters(self) -> None:
12951306
"""
12961307
An advantage of KV TBE is that we don't need to init weights. Hence skipping.
@@ -1393,7 +1404,7 @@ def named_parameters(
13931404
# pyre-ignore [15]
13941405
def named_split_embedding_weights(
13951406
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
1396-
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
1407+
) -> Iterator[Tuple[str, Union[PartiallyMaterializedTensor, torch.Tensor]]]:
13971408
assert (
13981409
remove_duplicate
13991410
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
@@ -1404,50 +1415,41 @@ def named_split_embedding_weights(
14041415
key = append_prefix(prefix, f"{config.name}.weight")
14051416
yield key, tensor
14061417

1407-
def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterator[
1408-
Tuple[
1409-
str,
1410-
Union[ShardedTensor, PartiallyMaterializedTensor],
1411-
Optional[ShardedTensor],
1412-
Optional[ShardedTensor],
1413-
]
1414-
]:
1415-
"""
1416-
Return an iterator over embedding tables, for each table yielding
1417-
table name,
1418-
PMT for embedding table with a valid RocksDB snapshot to support tensor IO
1419-
optional ShardedTensor for weight_id
1420-
optional ShardedTensor for bucket_cnt
1421-
"""
1422-
if self._split_weights_res is not None:
1423-
pmt_sharded_t_list = self._split_weights_res[0]
1424-
# pyre-ignore
1425-
weight_id_sharded_t_list = self._split_weights_res[1]
1426-
bucket_cnt_sharded_t_list = self._split_weights_res[2]
1427-
for table_idx, pmt_sharded_t in enumerate(pmt_sharded_t_list):
1428-
table_config = self._config.embedding_tables[table_idx]
1429-
key = append_prefix(prefix, f"{table_config.name}")
1430-
1431-
yield key, pmt_sharded_t, weight_id_sharded_t_list[
1432-
table_idx
1433-
], bucket_cnt_sharded_t_list[table_idx]
1418+
# initialize sharded _split_weights_res if it's None
1419+
# this method is used to generate sharded embedding weights once for all following state_dict
1420+
# calls in checkpointing and publishing.
1421+
# When training is resumed, the cached value will be reset to None and the value needs to be
1422+
# rebuilt for next checkpointing and publishing, as the weight id, weight embedding will be updated
1423+
# during training in backend k/v store.
1424+
def _init_sharded_split_embedding_weights(
1425+
self, prefix: str = "", force_regenerate: bool = False
1426+
) -> None:
1427+
if not force_regenerate and self._split_weights_res is not None:
14341428
return
14351429

14361430
pmt_list, weight_ids_list, bucket_cnt_list = self.split_embedding_weights(
1437-
no_snapshot=False, should_flush=True
1431+
no_snapshot=False,
1432+
should_flush=True,
14381433
)
14391434
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
14401435
for emb_table in emb_table_config_copy:
14411436
emb_table.local_metadata.placement._device = torch.device("cpu")
14421437

14431438
pmt_sharded_t_list = create_virtual_sharded_tensors(
1444-
emb_table_config_copy, pmt_list, self._pg, prefix
1439+
emb_table_config_copy,
1440+
pmt_list,
1441+
self._pg,
1442+
prefix,
14451443
)
14461444
weight_id_sharded_t_list = create_virtual_sharded_tensors(
14471445
emb_table_config_copy, weight_ids_list, self._pg, prefix # pyre-ignore
14481446
)
14491447
bucket_cnt_sharded_t_list = create_virtual_sharded_tensors(
1450-
emb_table_config_copy, bucket_cnt_list, self._pg, prefix # pyre-ignore
1448+
emb_table_config_copy,
1449+
# pyre-ignore [6]
1450+
bucket_cnt_list,
1451+
self._pg,
1452+
prefix,
14511453
)
14521454
# pyre-ignore
14531455
assert len(pmt_list) == len(weight_ids_list) == len(bucket_cnt_list)
@@ -1456,6 +1458,34 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
14561458
== len(weight_id_sharded_t_list)
14571459
== len(bucket_cnt_sharded_t_list)
14581460
)
1461+
self._split_weights_res = (
1462+
pmt_sharded_t_list,
1463+
weight_id_sharded_t_list,
1464+
bucket_cnt_sharded_t_list,
1465+
)
1466+
1467+
def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterator[
1468+
Tuple[
1469+
str,
1470+
Union[ShardedTensor, PartiallyMaterializedTensor],
1471+
Optional[ShardedTensor],
1472+
Optional[ShardedTensor],
1473+
]
1474+
]:
1475+
"""
1476+
Return an iterator over embedding tables, for each table yielding
1477+
table name,
1478+
PMT for embedding table with a valid RocksDB snapshot to support tensor IO
1479+
optional ShardedTensor for weight_id
1480+
optional ShardedTensor for bucket_cnt
1481+
"""
1482+
self._init_sharded_split_embedding_weights()
1483+
# pyre-ignore[16]
1484+
self._optim.set_sharded_embedding_weight_ids(self._split_weights_res[1])
1485+
1486+
pmt_sharded_t_list = self._split_weights_res[0]
1487+
weight_id_sharded_t_list = self._split_weights_res[1]
1488+
bucket_cnt_sharded_t_list = self._split_weights_res[2]
14591489
for table_idx, pmt_sharded_t in enumerate(pmt_sharded_t_list):
14601490
table_config = self._config.embedding_tables[table_idx]
14611491
key = append_prefix(prefix, f"{table_config.name}")
@@ -1464,12 +1494,6 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
14641494
table_idx
14651495
], bucket_cnt_sharded_t_list[table_idx]
14661496

1467-
self._split_weights_res = (
1468-
pmt_sharded_t_list,
1469-
weight_id_sharded_t_list,
1470-
bucket_cnt_sharded_t_list,
1471-
)
1472-
14731497
def flush(self) -> None:
14741498
"""
14751499
Flush the embeddings in cache back to SSD. Should be pretty expensive.
@@ -1486,19 +1510,18 @@ def purge(self) -> None:
14861510

14871511
# pyre-ignore [15]
14881512
def split_embedding_weights(
1489-
self, no_snapshot: bool = True, should_flush: bool = False
1513+
self, no_snapshot: bool = True, should_flush: bool = True
14901514
) -> Tuple[
1491-
List[PartiallyMaterializedTensor],
1515+
Union[List[PartiallyMaterializedTensor], List[torch.Tensor]],
14921516
Optional[List[torch.Tensor]],
14931517
Optional[List[torch.Tensor]],
14941518
]:
1495-
return self.emb_module.split_embedding_weights(
1496-
no_snapshot, should_flush=should_flush
1497-
)
1519+
return self.emb_module.split_embedding_weights(no_snapshot, should_flush)
14981520

14991521
def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
15001522
# reset split weights during training
15011523
self._split_weights_res = None
1524+
self._optim.set_sharded_embedding_weight_ids(sharded_embedding_weight_ids=None)
15021525

15031526
return self.emb_module(
15041527
indices=features.values().long(),

0 commit comments

Comments
 (0)