diff --git a/torchrec/models/deepfm.py b/torchrec/models/deepfm.py index 322992f89..056e60369 100644 --- a/torchrec/models/deepfm.py +++ b/torchrec/models/deepfm.py @@ -11,6 +11,7 @@ import torch from torch import nn +from torchrec.distributed.test_utils.test_input import ModelInput from torchrec.modules.deepfm import DeepFM, FactorizationMachine from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor @@ -350,3 +351,21 @@ def forward( ) logits = self.over_arch(concatenated_dense) return logits + + +class SimpleDeepFMNNWrapper(SimpleDeepFMNN): + # pyre-ignore[14] + def forward(self, model_input: ModelInput) -> torch.Tensor: + """ + Forward pass for the SimpleDeepFMNNWrapper. + + Args: + model_input (ModelInput): Contains dense and sparse features. + + Returns: + torch.Tensor: Output tensor from the SimpleDeepFMNN model. + """ + return super().forward( + dense_features=model_input.float_features, + sparse_features=model_input.idlist_features, # pyre-ignore[6] + ) diff --git a/torchrec/models/dlrm.py b/torchrec/models/dlrm.py index ad1975eba..396dd5bd1 100644 --- a/torchrec/models/dlrm.py +++ b/torchrec/models/dlrm.py @@ -12,6 +12,7 @@ import torch from torch import nn from torchrec.datasets.utils import Batch +from torchrec.distributed.test_utils.test_input import ModelInput from torchrec.modules.crossnet import LowRankCrossNet from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.modules.mlp import MLP @@ -899,3 +900,21 @@ def forward( loss = self.loss_fn(logits, batch.labels.float()) return loss, (loss.detach(), logits.detach(), batch.labels.detach()) + + +class DLRMWrapper(DLRM): + # pyre-ignore[14] + def forward(self, model_input: ModelInput) -> torch.Tensor: + """ + Forward pass for the DLRMWrapper. + + Args: + model_input (ModelInput): Contains dense and sparse features. + + Returns: + torch.Tensor: Output tensor from the DLRM model. + """ + return super().forward( + dense_features=model_input.float_features, + sparse_features=model_input.idlist_features, # pyre-ignore[6] + ) diff --git a/torchrec/models/tests/test_deepfm.py b/torchrec/models/tests/test_deepfm.py index 154b4c1c3..6df27a41a 100644 --- a/torchrec/models/tests/test_deepfm.py +++ b/torchrec/models/tests/test_deepfm.py @@ -11,8 +11,14 @@ import torch from torch.testing import FileCheck # @manual +from torchrec.distributed.test_utils.test_input import ModelInput from torchrec.fx import symbolic_trace, Tracer -from torchrec.models.deepfm import DenseArch, FMInteractionArch, SimpleDeepFMNN +from torchrec.models.deepfm import ( + DenseArch, + FMInteractionArch, + SimpleDeepFMNN, + SimpleDeepFMNNWrapper, +) from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor @@ -210,5 +216,118 @@ def test_fx_script(self) -> None: self.assertEqual(logits.size(), (B, 1)) +class SimpleDeepFMNNWrapperTest(unittest.TestCase): + def test_basic(self) -> None: + B = 2 + D = 8 + num_dense_features = 100 + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=D, + num_embeddings=100, + feature_names=["f2"], + ) + + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + + deepfm_wrapper = SimpleDeepFMNNWrapper( + num_dense_features=num_dense_features, + embedding_bag_collection=ebc, + hidden_layer_size=20, + deep_fm_dimension=5, + ) + + # Create ModelInput with both dense and sparse features + dense_features = torch.rand((B, num_dense_features)) + sparse_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f3", "f2"], + values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3]), + offsets=torch.tensor([0, 2, 4, 6, 8, 10, 11]), + ) + + model_input = ModelInput( + float_features=dense_features, + idlist_features=sparse_features, + idscore_features=None, + label=torch.rand((B,)), + ) + + logits = deepfm_wrapper(model_input) + self.assertEqual(logits.size(), (B, 1)) + + def test_no_sparse_features(self) -> None: + B = 2 + D = 8 + num_dense_features = 100 + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1"] + ) + + ebc = EmbeddingBagCollection(tables=[eb1_config]) + + deepfm_wrapper = SimpleDeepFMNNWrapper( + num_dense_features=num_dense_features, + embedding_bag_collection=ebc, + hidden_layer_size=20, + deep_fm_dimension=5, + ) + + # Create ModelInput with empty sparse features that match expected feature names + dense_features = torch.rand((B, num_dense_features)) + empty_sparse_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1"], + values=torch.tensor([], dtype=torch.long), + offsets=torch.tensor([0] * (B + 1), dtype=torch.long), + ) + + model_input = ModelInput( + float_features=dense_features, + idlist_features=empty_sparse_features, + idscore_features=None, + label=torch.rand((B,)), + ) + + logits = deepfm_wrapper(model_input) + self.assertEqual(logits.size(), (B, 1)) + + def test_empty_sparse_features(self) -> None: + B = 2 + D = 8 + num_dense_features = 100 + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1"] + ) + + ebc = EmbeddingBagCollection(tables=[eb1_config]) + + deepfm_wrapper = SimpleDeepFMNNWrapper( + num_dense_features=num_dense_features, + embedding_bag_collection=ebc, + hidden_layer_size=20, + deep_fm_dimension=5, + ) + + # Create ModelInput with empty sparse features that match expected feature names + dense_features = torch.rand((B, num_dense_features)) + empty_sparse_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1"], + values=torch.tensor([], dtype=torch.long), + offsets=torch.tensor([0] * (B + 1), dtype=torch.long), + ) + + model_input = ModelInput( + float_features=dense_features, + idlist_features=empty_sparse_features, + idscore_features=None, + label=torch.rand((B,)), + ) + + logits = deepfm_wrapper(model_input) + self.assertEqual(logits.size(), (B, 1)) + + if __name__ == "__main__": unittest.main() diff --git a/torchrec/models/tests/test_dlrm.py b/torchrec/models/tests/test_dlrm.py index e01976404..e72fbc12c 100644 --- a/torchrec/models/tests/test_dlrm.py +++ b/torchrec/models/tests/test_dlrm.py @@ -13,6 +13,7 @@ from torch import nn from torch.testing import FileCheck # @manual from torchrec.datasets.utils import Batch +from torchrec.distributed.test_utils.test_input import ModelInput from torchrec.fx import symbolic_trace from torchrec.ir.serializer import JsonSerializer from torchrec.ir.utils import decapsulate_ir_modules, encapsulate_ir_modules @@ -23,6 +24,7 @@ DLRM_DCN, DLRM_Projection, DLRMTrain, + DLRMWrapper, InteractionArch, InteractionDCNArch, InteractionProjectionArch, @@ -1283,3 +1285,116 @@ def test_export_serialization(self) -> None: deserialized_logits = deserialized_model(features, sparse_features) self.assertEqual(deserialized_logits.size(), (B, 1)) + + +class DLRMWrapperTest(unittest.TestCase): + def test_basic(self) -> None: + B = 2 + D = 8 + dense_in_features = 100 + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=D, + num_embeddings=100, + feature_names=["f2"], + ) + + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + + dlrm_wrapper = DLRMWrapper( + embedding_bag_collection=ebc, + dense_in_features=dense_in_features, + dense_arch_layer_sizes=[20, D], + over_arch_layer_sizes=[5, 1], + ) + + # Create ModelInput with both dense and sparse features + dense_features = torch.rand((B, dense_in_features)) + sparse_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f3", "f2"], + values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3]), + offsets=torch.tensor([0, 2, 4, 6, 8, 10, 11]), + ) + + model_input = ModelInput( + float_features=dense_features, + idlist_features=sparse_features, + idscore_features=None, + label=torch.rand((B,)), + ) + + logits = dlrm_wrapper(model_input) + self.assertEqual(logits.size(), (B, 1)) + + def test_no_sparse_features(self) -> None: + B = 2 + D = 8 + dense_in_features = 100 + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1"] + ) + + ebc = EmbeddingBagCollection(tables=[eb1_config]) + + dlrm_wrapper = DLRMWrapper( + embedding_bag_collection=ebc, + dense_in_features=dense_in_features, + dense_arch_layer_sizes=[20, D], + over_arch_layer_sizes=[5, 1], + ) + + # Create ModelInput with empty sparse features that match expected feature names + dense_features = torch.rand((B, dense_in_features)) + empty_sparse_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1"], + values=torch.tensor([], dtype=torch.long), + offsets=torch.tensor([0] * (B + 1), dtype=torch.long), + ) + + model_input = ModelInput( + float_features=dense_features, + idlist_features=empty_sparse_features, + idscore_features=None, + label=torch.rand((B,)), + ) + + logits = dlrm_wrapper(model_input) + self.assertEqual(logits.size(), (B, 1)) + + def test_empty_sparse_features(self) -> None: + B = 2 + D = 8 + dense_in_features = 100 + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1"] + ) + + ebc = EmbeddingBagCollection(tables=[eb1_config]) + + dlrm_wrapper = DLRMWrapper( + embedding_bag_collection=ebc, + dense_in_features=dense_in_features, + dense_arch_layer_sizes=[20, D], + over_arch_layer_sizes=[5, 1], + ) + + # Create ModelInput with empty sparse features that match expected feature names + dense_features = torch.rand((B, dense_in_features)) + empty_sparse_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1"], + values=torch.tensor([], dtype=torch.long), + offsets=torch.tensor([0] * (B + 1), dtype=torch.long), + ) + + model_input = ModelInput( + float_features=dense_features, + idlist_features=empty_sparse_features, + idscore_features=None, + label=torch.rand((B,)), + ) + + logits = dlrm_wrapper(model_input) + self.assertEqual(logits.size(), (B, 1))