Skip to content

Commit 5c405a9

Browse files
emlinfacebook-github-bot
authored andcommitted
support get state dict and apply state dict (pytorch#2976)
Summary: X-link: pytorch/FBGEMM#4145 X-link: facebookresearch/FBGEMM#1226 # Functions **Saving State Dict** When saving state dict, we convert IDs from local to global. This allows us to avoid shifting IDs based on sharding decisions when tables are resharded. **Checkpoint Loading Mode** We have enabled load state dict mode for checkpoint loading, which allows us to cache all ID, weight, and bucket tensors in memory before applying them to the backend. This approach ensures that all data is loaded correctly, even though the checkpoint client does not support ordered tensor loading. # Current Solution The current solution involves caching all data in Python tensors, following these steps: - Set self.local_weight_counts based on checkpoint bucket tensor size. - Enable load state dict mode to initialize local cache tensors. - Call state_dict to get empty tensors for the checkpoint loader. - Write checkpoint data to cached tensors from persisted checkpoint by the checkpoint loader. - Call apply_state_dict to write all cached tensors to the backend. **Apply State Dict Flow** During the apply_state_dict step, we perform the following operations: - If optimizer offloading is enabled: - Loop through chunks of weight and optimizer. - Concatenate weight and optimizer together. - Write to backend using KVTensorWrapper interface. - If optimizer offloading is disabled: - Set optimizer to device tensor based on ID. - Write ID weight to backend for each table. # Limitations The current solution has two limitations: - Memory overhead: - When writing data to the backend, the Python tensor's memory cannot be released until the whole tensor data is duplicated in the backend. This can lead to high memory usage, especially when dealing with single large tables. - Performance regression: - With optimizer offloading, we need to concatenate weight and optimizer together before writing to the backend. To avoid triple one large tensor's memory, we loop through smaller chunks during writing, which can cause performance regression. # Future Improvements After the first version e2e is ready, we plan to support unordered loading from the backend to improve performance and reduce memory overhead. Reviewed By: bobbyliujb Differential Revision: D74790154
1 parent 4410a29 commit 5c405a9

File tree

1 file changed

+78
-58
lines changed

1 file changed

+78
-58
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 78 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929

3030
import torch
3131
import torch.distributed as dist
32+
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
33+
BackendType,
34+
KVZCHParams,
35+
)
3236
from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
3337
IntNBitTableBatchedEmbeddingBagsCodegen,
3438
)
@@ -42,7 +46,6 @@
4246
SplitTableBatchedEmbeddingBagsCodegen,
4347
)
4448
from fbgemm_gpu.tbe.ssd import ASSOC, SSDTableBatchedEmbeddingBags
45-
from fbgemm_gpu.tbe.ssd.training import BackendType, KVZCHParams
4649
from fbgemm_gpu.tbe.ssd.utils.partially_materialized_tensor import (
4750
PartiallyMaterializedTensor,
4851
)
@@ -1254,6 +1257,16 @@ def __init__(
12541257
compute_kernel = config.embedding_tables[0].compute_kernel
12551258
embedding_location = compute_kernel_to_embedding_location(compute_kernel)
12561259

1260+
# every split_embeding_weights call is expensive, since it iterates over all the elements in the backend kv db
1261+
# use split weights result cache so that multiple calls in the same train iteration will only trigger once
1262+
self._split_weights_res: Optional[
1263+
Tuple[
1264+
List[ShardedTensor],
1265+
List[ShardedTensor],
1266+
List[ShardedTensor],
1267+
]
1268+
] = None
1269+
12571270
self._emb_module: SSDTableBatchedEmbeddingBags = SSDTableBatchedEmbeddingBags(
12581271
embedding_specs=list(zip(self._num_embeddings, self._local_cols)),
12591272
feature_table_map=self._feature_table_map,
@@ -1266,11 +1279,17 @@ def __init__(
12661279
logger.info(
12671280
f"tbe_unique_id:{self._emb_module.tbe_unique_id} => table name to count dict:{self.table_name_to_count}"
12681281
)
1269-
1270-
self._optim: KeyValueEmbeddingFusedOptimizer = KeyValueEmbeddingFusedOptimizer(
1271-
config,
1272-
self._emb_module,
1273-
pg,
1282+
self._table_name_to_weight_count_per_rank: Dict[str, List[int]] = {}
1283+
self._init_sharded_split_embedding_weights() # this will populate self._split_weights_res
1284+
self._optim: ZeroCollisionKeyValueEmbeddingFusedOptimizer = (
1285+
ZeroCollisionKeyValueEmbeddingFusedOptimizer(
1286+
config,
1287+
self._emb_module,
1288+
sharded_embedding_weights_by_table=self._split_weights_res[0], # pyre-ignore[16]
1289+
table_name_to_weight_count_per_rank=self._table_name_to_weight_count_per_rank,
1290+
sharded_embedding_weight_ids=self._split_weights_res[1],
1291+
pg=pg,
1292+
)
12741293
)
12751294
self._param_per_table: Dict[str, nn.Parameter] = dict(
12761295
_gen_named_parameters_by_table_ssd_pmt(
@@ -1282,16 +1301,6 @@ def __init__(
12821301
)
12831302
self.init_parameters()
12841303

1285-
# every split_embeding_weights call is expensive, since it iterates over all the elements in the backend kv db
1286-
# use split weights result cache so that multiple calls in the same train iteration will only trigger once
1287-
self._split_weights_res: Optional[
1288-
Tuple[
1289-
List[ShardedTensor],
1290-
List[ShardedTensor],
1291-
List[ShardedTensor],
1292-
]
1293-
] = None
1294-
12951304
def init_parameters(self) -> None:
12961305
"""
12971306
An advantage of KV TBE is that we don't need to init weights. Hence skipping.
@@ -1394,7 +1403,7 @@ def named_parameters(
13941403
# pyre-ignore [15]
13951404
def named_split_embedding_weights(
13961405
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
1397-
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
1406+
) -> Iterator[Tuple[str, PartiallyMaterializedTensor | torch.Tensor]]:
13981407
assert (
13991408
remove_duplicate
14001409
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
@@ -1405,50 +1414,41 @@ def named_split_embedding_weights(
14051414
key = append_prefix(prefix, f"{config.name}.weight")
14061415
yield key, tensor
14071416

1408-
def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterator[
1409-
Tuple[
1410-
str,
1411-
Union[ShardedTensor, PartiallyMaterializedTensor],
1412-
Optional[ShardedTensor],
1413-
Optional[ShardedTensor],
1414-
]
1415-
]:
1416-
"""
1417-
Return an iterator over embedding tables, for each table yielding
1418-
table name,
1419-
PMT for embedding table with a valid RocksDB snapshot to support tensor IO
1420-
optional ShardedTensor for weight_id
1421-
optional ShardedTensor for bucket_cnt
1422-
"""
1423-
if self._split_weights_res is not None:
1424-
pmt_sharded_t_list = self._split_weights_res[0]
1425-
# pyre-ignore
1426-
weight_id_sharded_t_list = self._split_weights_res[1]
1427-
bucket_cnt_sharded_t_list = self._split_weights_res[2]
1428-
for table_idx, pmt_sharded_t in enumerate(pmt_sharded_t_list):
1429-
table_config = self._config.embedding_tables[table_idx]
1430-
key = append_prefix(prefix, f"{table_config.name}")
1431-
1432-
yield key, pmt_sharded_t, weight_id_sharded_t_list[
1433-
table_idx
1434-
], bucket_cnt_sharded_t_list[table_idx]
1417+
# initialize sharded _split_weights_res if it's None
1418+
# this method is used to generate sharded embedding weights once for all following state_dict
1419+
# calls in checkpointing and publishing.
1420+
# When training is resumed, the cached value will be reset to None and the value needs to be
1421+
# rebuilt for next checkpointing and publishing, as the weight id, weight embedding will be updated
1422+
# during training in backend k/v store.
1423+
def _init_sharded_split_embedding_weights(
1424+
self, prefix: str = "", force_regenerate: bool = False
1425+
) -> None:
1426+
if not force_regenerate and self._split_weights_res is not None:
14351427
return
14361428

14371429
pmt_list, weight_ids_list, bucket_cnt_list = self.split_embedding_weights(
1438-
no_snapshot=False, should_flush=True
1430+
no_snapshot=False,
1431+
should_flush=True,
14391432
)
14401433
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
14411434
for emb_table in emb_table_config_copy:
14421435
emb_table.local_metadata.placement._device = torch.device("cpu")
14431436

14441437
pmt_sharded_t_list = create_virtual_sharded_tensors(
1445-
emb_table_config_copy, pmt_list, self._pg, prefix
1438+
emb_table_config_copy,
1439+
pmt_list,
1440+
self._pg,
1441+
prefix,
14461442
)
14471443
weight_id_sharded_t_list = create_virtual_sharded_tensors(
14481444
emb_table_config_copy, weight_ids_list, self._pg, prefix # pyre-ignore
14491445
)
14501446
bucket_cnt_sharded_t_list = create_virtual_sharded_tensors(
1451-
emb_table_config_copy, bucket_cnt_list, self._pg, prefix # pyre-ignore
1447+
emb_table_config_copy,
1448+
# pyre-ignore [6]
1449+
bucket_cnt_list,
1450+
self._pg,
1451+
prefix,
14521452
)
14531453
# pyre-ignore
14541454
assert len(pmt_list) == len(weight_ids_list) == len(bucket_cnt_list)
@@ -1457,6 +1457,33 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
14571457
== len(weight_id_sharded_t_list)
14581458
== len(bucket_cnt_sharded_t_list)
14591459
)
1460+
self._split_weights_res = (
1461+
pmt_sharded_t_list,
1462+
weight_id_sharded_t_list,
1463+
bucket_cnt_sharded_t_list,
1464+
)
1465+
1466+
def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterator[
1467+
Tuple[
1468+
str,
1469+
Union[ShardedTensor, PartiallyMaterializedTensor],
1470+
Optional[ShardedTensor],
1471+
Optional[ShardedTensor],
1472+
]
1473+
]:
1474+
"""
1475+
Return an iterator over embedding tables, for each table yielding
1476+
table name,
1477+
PMT for embedding table with a valid RocksDB snapshot to support tensor IO
1478+
optional ShardedTensor for weight_id
1479+
optional ShardedTensor for bucket_cnt
1480+
"""
1481+
self._init_sharded_split_embedding_weights()
1482+
self._optim.set_sharded_embedding_weight_ids(self._split_weights_res[1]) # pyre-ignore[16]
1483+
1484+
pmt_sharded_t_list = self._split_weights_res[0]
1485+
weight_id_sharded_t_list = self._split_weights_res[1]
1486+
bucket_cnt_sharded_t_list = self._split_weights_res[2]
14601487
for table_idx, pmt_sharded_t in enumerate(pmt_sharded_t_list):
14611488
table_config = self._config.embedding_tables[table_idx]
14621489
key = append_prefix(prefix, f"{table_config.name}")
@@ -1465,12 +1492,6 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
14651492
table_idx
14661493
], bucket_cnt_sharded_t_list[table_idx]
14671494

1468-
self._split_weights_res = (
1469-
pmt_sharded_t_list,
1470-
weight_id_sharded_t_list,
1471-
bucket_cnt_sharded_t_list,
1472-
)
1473-
14741495
def flush(self) -> None:
14751496
"""
14761497
Flush the embeddings in cache back to SSD. Should be pretty expensive.
@@ -1487,19 +1508,18 @@ def purge(self) -> None:
14871508

14881509
# pyre-ignore [15]
14891510
def split_embedding_weights(
1490-
self, no_snapshot: bool = True, should_flush: bool = False
1511+
self, no_snapshot: bool = True, should_flush: bool = True
14911512
) -> Tuple[
1492-
List[PartiallyMaterializedTensor],
1513+
List[PartiallyMaterializedTensor] | List[torch.Tensor],
14931514
Optional[List[torch.Tensor]],
14941515
Optional[List[torch.Tensor]],
14951516
]:
1496-
return self.emb_module.split_embedding_weights(
1497-
no_snapshot, should_flush=should_flush
1498-
)
1517+
return self.emb_module.split_embedding_weights(no_snapshot, should_flush)
14991518

15001519
def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
15011520
# reset split weights during training
15021521
self._split_weights_res = None
1522+
self._optim.set_sharded_embedding_weight_ids(sharded_embedding_weight_ids=None)
15031523

15041524
return self.emb_module(
15051525
indices=features.values().long(),

0 commit comments

Comments
 (0)