From a978020ecb0060bb03ffde59727ebfc6582ac7a3 Mon Sep 17 00:00:00 2001 From: Yernar Sadybekov Date: Tue, 24 Jun 2025 13:18:00 -0700 Subject: [PATCH] Model wrapper for DeepFM (#3115) Summary: * Added model wrapper for DeepFM. The wrapper will take ModelInput as an only parameter in the forward method. The forward method will return just the prediction if it's in inference mode and losses with prediction if it's in training mode. (Because training pipeline expects loss and prediction. See https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/train_pipeline/train_pipelines.py#L670) * Added the parameterized unit tests to cover the model's wrapper Differential Revision: D76916471 --- torchrec/models/deepfm.py | 34 +++++++- torchrec/models/tests/test_deepfm.py | 114 ++++++++++++++++++++++++++- 2 files changed, 146 insertions(+), 2 deletions(-) diff --git a/torchrec/models/deepfm.py b/torchrec/models/deepfm.py index 322992f89..99bae0711 100644 --- a/torchrec/models/deepfm.py +++ b/torchrec/models/deepfm.py @@ -7,10 +7,11 @@ # pyre-strict -from typing import List +from typing import List, Tuple, Union import torch from torch import nn +from torchrec.distributed.test_utils.test_input import ModelInput from torchrec.modules.deepfm import DeepFM, FactorizationMachine from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor @@ -350,3 +351,34 @@ def forward( ) logits = self.over_arch(concatenated_dense) return logits + + +class SimpleDeepFMNNWrapper(SimpleDeepFMNN): + # pyre-ignore[14, 15] + def forward( + self, model_input: ModelInput + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Forward pass for the SimpleDeepFMNNWrapper. + + Args: + model_input (ModelInput): Contains dense and sparse features. + + Returns: + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + If training, returns (loss, prediction). Otherwise, returns prediction. + """ + pred = super().forward( + dense_features=model_input.float_features, + sparse_features=model_input.idlist_features, # pyre-ignore[6] + ) + + if self.training: + # Calculate loss and return both loss and prediction + loss = torch.nn.functional.binary_cross_entropy_with_logits( + pred.squeeze(), model_input.label + ) + return (loss, pred) + else: + # Return just the prediction + return pred diff --git a/torchrec/models/tests/test_deepfm.py b/torchrec/models/tests/test_deepfm.py index 154b4c1c3..cabe26b77 100644 --- a/torchrec/models/tests/test_deepfm.py +++ b/torchrec/models/tests/test_deepfm.py @@ -8,11 +8,20 @@ # pyre-strict import unittest +from dataclasses import dataclass +from typing import List import torch +from parameterized import parameterized from torch.testing import FileCheck # @manual +from torchrec.distributed.test_utils.test_input import ModelInput from torchrec.fx import symbolic_trace, Tracer -from torchrec.models.deepfm import DenseArch, FMInteractionArch, SimpleDeepFMNN +from torchrec.models.deepfm import ( + DenseArch, + FMInteractionArch, + SimpleDeepFMNN, + SimpleDeepFMNNWrapper, +) from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor @@ -210,5 +219,108 @@ def test_fx_script(self) -> None: self.assertEqual(logits.size(), (B, 1)) +class SimpleDeepFMNNWrapperTest(unittest.TestCase): + @dataclass + class WrapperTestParams: + # input parameters + embedding_configs: List[EmbeddingBagConfig] + sparse_feature_keys: List[str] + sparse_feature_values: List[int] + sparse_feature_offsets: List[int] + # expected output parameters + expected_output_size: tuple[int, ...] + + @parameterized.expand( + [ + ( + "basic_with_multiple_features", + WrapperTestParams( + embedding_configs=[ + EmbeddingBagConfig( + name="t1", + embedding_dim=8, + num_embeddings=100, + feature_names=["f1", "f3"], + ), + EmbeddingBagConfig( + name="t2", + embedding_dim=8, + num_embeddings=100, + feature_names=["f2"], + ), + ], + sparse_feature_keys=["f1", "f3", "f2"], + sparse_feature_values=[1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3], + sparse_feature_offsets=[0, 2, 4, 6, 8, 10, 11], + expected_output_size=(2, 1), + ), + ), + ( + "empty_sparse_features", + WrapperTestParams( + embedding_configs=[ + EmbeddingBagConfig( + name="t1", + embedding_dim=8, + num_embeddings=100, + feature_names=["f1"], + ), + ], + sparse_feature_keys=["f1"], + sparse_feature_values=[], + sparse_feature_offsets=[0, 0, 0], + expected_output_size=(2, 1), + ), + ), + ] + ) + def test_wrapper_functionality( + self, _test_name: str, test_params: WrapperTestParams + ) -> None: + B = 2 + num_dense_features = 100 + + ebc = EmbeddingBagCollection(tables=test_params.embedding_configs) + + deepfm_wrapper = SimpleDeepFMNNWrapper( + num_dense_features=num_dense_features, + embedding_bag_collection=ebc, + hidden_layer_size=20, + deep_fm_dimension=5, + ) + + # Create ModelInput + dense_features = torch.rand((B, num_dense_features)) + sparse_features = KeyedJaggedTensor.from_offsets_sync( + keys=test_params.sparse_feature_keys, + values=torch.tensor(test_params.sparse_feature_values, dtype=torch.long), + offsets=torch.tensor(test_params.sparse_feature_offsets, dtype=torch.long), + ) + + model_input = ModelInput( + float_features=dense_features, + idlist_features=sparse_features, + idscore_features=None, + label=torch.rand((B,)), + ) + + # Test eval mode - should return just logits + deepfm_wrapper.eval() + logits = deepfm_wrapper(model_input) + self.assertIsInstance(logits, torch.Tensor) + self.assertEqual(logits.size(), test_params.expected_output_size) + + # Test training mode - should return (loss, logits) tuple + deepfm_wrapper.train() + result = deepfm_wrapper(model_input) + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 2) + loss, pred = result + self.assertIsInstance(loss, torch.Tensor) + self.assertIsInstance(pred, torch.Tensor) + self.assertEqual(loss.size(), ()) # scalar loss + self.assertEqual(pred.size(), test_params.expected_output_size) + + if __name__ == "__main__": unittest.main()