From 03ca37bc01fd7113b3abd1561ef563ae710c294c Mon Sep 17 00:00:00 2001 From: Jeet Kanjani Date: Wed, 21 May 2025 13:03:54 -0700 Subject: [PATCH] Enable SDD for standalone CINT expert model Summary: This diff applies Sparse Data Distribution (SDD) to the CINT expert model. SDD distributes the feature to the right trainers before running lookup. We make the modules fx tracable to enable SDD which remove s the communication/computation overlap. Local run logs: https://www.internalfb.com/intern/everpaste/?handle=GPCynx2gWWqpyV4DAFLy0frFtTpPbsIXAAAz&phabricator_paste_number=1811578826 Mast job: fire-linjianma-20250516-1506-3ae5c1eb (peak qps: 1.17) Baseline: fire-linjianma-20250426-2216-8530f5d8 (peak qps: 1.11) Verified from the trace that SDD is applied correctly: {F1978480235} Differential Revision: D74751782 --- torchrec/sparse/tensor_dict.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchrec/sparse/tensor_dict.py b/torchrec/sparse/tensor_dict.py index 3f00d5275..23b7ebdd0 100644 --- a/torchrec/sparse/tensor_dict.py +++ b/torchrec/sparse/tensor_dict.py @@ -15,6 +15,7 @@ from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +@torch.fx.wrap def maybe_td_to_kjt( features: KeyedJaggedTensor, keys: Optional[List[str]] = None ) -> KeyedJaggedTensor: