Skip to content

Commit 45223e0

Browse files
duduyi2013facebook-github-bot
authored andcommitted
chunking opt split and fix duplicate flush (#3042)
Summary: Pull Request resolved: #3042 X-link: pytorch/FBGEMM#4260 X-link: facebookresearch/FBGEMM#1338 changesets 1. in ZeroCollisionKeyValueEmbedding, we force flush when calling split_embedding_weights, remove that to utilize the cached weights on the same global step. 2. on split embedding optimizer, rocksdb has to read the whole value part(embedding + optimizer) out into dram, without chunking we essentially read everything into dram at once(temporarily huge mem spike), with chunk loading, we could keep mem spike low. Reviewed By: steven1327, emlin Differential Revision: D75988991 fbshipit-source-id: 414fab2aad45e05e1da12f95a7ab99fb82c4f8aa
1 parent e21e2ed commit 45223e0

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,7 +1435,6 @@ def _init_sharded_split_embedding_weights(
14351435

14361436
pmt_list, weight_ids_list, bucket_cnt_list = self.split_embedding_weights(
14371437
no_snapshot=False,
1438-
should_flush=True,
14391438
)
14401439
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
14411440
for emb_table in emb_table_config_copy:
@@ -1528,7 +1527,7 @@ def purge(self) -> None:
15281527

15291528
# pyre-ignore [15]
15301529
def split_embedding_weights(
1531-
self, no_snapshot: bool = True, should_flush: bool = True
1530+
self, no_snapshot: bool = True, should_flush: bool = False
15321531
) -> Tuple[
15331532
Union[List[PartiallyMaterializedTensor], List[torch.Tensor]],
15341533
Optional[List[torch.Tensor]],

0 commit comments

Comments
 (0)