Skip to content

Commit 3f7c960

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Added model wrappers for DeepFM and DLRM (#3115)
Summary: Pull Request resolved: #3115 * Added model wrappers for DeepFM and DLRM. The wrapper will take ModelInput as an only parameter in the forward method. * Added the unit tests to cover the models' wrappers Differential Revision: D76916471
1 parent 8a4378f commit 3f7c960

File tree

4 files changed

+273
-1
lines changed

4 files changed

+273
-1
lines changed

torchrec/models/deepfm.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import torch
1313
from torch import nn
14+
from torchrec.distributed.test_utils.test_input import ModelInput
1415
from torchrec.modules.deepfm import DeepFM, FactorizationMachine
1516
from torchrec.modules.embedding_modules import EmbeddingBagCollection
1617
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
@@ -350,3 +351,21 @@ def forward(
350351
)
351352
logits = self.over_arch(concatenated_dense)
352353
return logits
354+
355+
356+
class SimpleDeepFMNNWrapper(SimpleDeepFMNN):
357+
# pyre-ignore[14]
358+
def forward(self, model_input: ModelInput) -> torch.Tensor:
359+
"""
360+
Forward pass for the SimpleDeepFMNNWrapper.
361+
362+
Args:
363+
model_input (ModelInput): Contains dense and sparse features.
364+
365+
Returns:
366+
torch.Tensor: Output tensor from the SimpleDeepFMNN model.
367+
"""
368+
return super().forward(
369+
dense_features=model_input.float_features,
370+
sparse_features=model_input.idlist_features, # pyre-ignore[6]
371+
)

torchrec/models/dlrm.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313
from torch import nn
1414
from torchrec.datasets.utils import Batch
15+
from torchrec.distributed.test_utils.test_input import ModelInput
1516
from torchrec.modules.crossnet import LowRankCrossNet
1617
from torchrec.modules.embedding_modules import EmbeddingBagCollection
1718
from torchrec.modules.mlp import MLP
@@ -899,3 +900,21 @@ def forward(
899900
loss = self.loss_fn(logits, batch.labels.float())
900901

901902
return loss, (loss.detach(), logits.detach(), batch.labels.detach())
903+
904+
905+
class DLRMWrapper(DLRM):
906+
# pyre-ignore[14]
907+
def forward(self, model_input: ModelInput) -> torch.Tensor:
908+
"""
909+
Forward pass for the DLRMWrapper.
910+
911+
Args:
912+
model_input (ModelInput): Contains dense and sparse features.
913+
914+
Returns:
915+
torch.Tensor: Output tensor from the DLRM model.
916+
"""
917+
return super().forward(
918+
dense_features=model_input.float_features,
919+
sparse_features=model_input.idlist_features, # pyre-ignore[6]
920+
)

torchrec/models/tests/test_deepfm.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,14 @@
1111

1212
import torch
1313
from torch.testing import FileCheck # @manual
14+
from torchrec.distributed.test_utils.test_input import ModelInput
1415
from torchrec.fx import symbolic_trace, Tracer
15-
from torchrec.models.deepfm import DenseArch, FMInteractionArch, SimpleDeepFMNN
16+
from torchrec.models.deepfm import (
17+
DenseArch,
18+
FMInteractionArch,
19+
SimpleDeepFMNN,
20+
SimpleDeepFMNNWrapper,
21+
)
1622
from torchrec.modules.embedding_configs import EmbeddingBagConfig
1723
from torchrec.modules.embedding_modules import EmbeddingBagCollection
1824
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
@@ -210,5 +216,118 @@ def test_fx_script(self) -> None:
210216
self.assertEqual(logits.size(), (B, 1))
211217

212218

219+
class SimpleDeepFMNNWrapperTest(unittest.TestCase):
220+
def test_basic(self) -> None:
221+
B = 2
222+
D = 8
223+
num_dense_features = 100
224+
eb1_config = EmbeddingBagConfig(
225+
name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"]
226+
)
227+
eb2_config = EmbeddingBagConfig(
228+
name="t2",
229+
embedding_dim=D,
230+
num_embeddings=100,
231+
feature_names=["f2"],
232+
)
233+
234+
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
235+
236+
deepfm_wrapper = SimpleDeepFMNNWrapper(
237+
num_dense_features=num_dense_features,
238+
embedding_bag_collection=ebc,
239+
hidden_layer_size=20,
240+
deep_fm_dimension=5,
241+
)
242+
243+
# Create ModelInput with both dense and sparse features
244+
dense_features = torch.rand((B, num_dense_features))
245+
sparse_features = KeyedJaggedTensor.from_offsets_sync(
246+
keys=["f1", "f3", "f2"],
247+
values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3]),
248+
offsets=torch.tensor([0, 2, 4, 6, 8, 10, 11]),
249+
)
250+
251+
model_input = ModelInput(
252+
float_features=dense_features,
253+
idlist_features=sparse_features,
254+
idscore_features=None,
255+
label=torch.rand((B,)),
256+
)
257+
258+
logits = deepfm_wrapper(model_input)
259+
self.assertEqual(logits.size(), (B, 1))
260+
261+
def test_no_sparse_features(self) -> None:
262+
B = 2
263+
D = 8
264+
num_dense_features = 100
265+
eb1_config = EmbeddingBagConfig(
266+
name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1"]
267+
)
268+
269+
ebc = EmbeddingBagCollection(tables=[eb1_config])
270+
271+
deepfm_wrapper = SimpleDeepFMNNWrapper(
272+
num_dense_features=num_dense_features,
273+
embedding_bag_collection=ebc,
274+
hidden_layer_size=20,
275+
deep_fm_dimension=5,
276+
)
277+
278+
# Create ModelInput with empty sparse features that match expected feature names
279+
dense_features = torch.rand((B, num_dense_features))
280+
empty_sparse_features = KeyedJaggedTensor.from_offsets_sync(
281+
keys=["f1"],
282+
values=torch.tensor([], dtype=torch.long),
283+
offsets=torch.tensor([0] * (B + 1), dtype=torch.long),
284+
)
285+
286+
model_input = ModelInput(
287+
float_features=dense_features,
288+
idlist_features=empty_sparse_features,
289+
idscore_features=None,
290+
label=torch.rand((B,)),
291+
)
292+
293+
logits = deepfm_wrapper(model_input)
294+
self.assertEqual(logits.size(), (B, 1))
295+
296+
def test_empty_sparse_features(self) -> None:
297+
B = 2
298+
D = 8
299+
num_dense_features = 100
300+
eb1_config = EmbeddingBagConfig(
301+
name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1"]
302+
)
303+
304+
ebc = EmbeddingBagCollection(tables=[eb1_config])
305+
306+
deepfm_wrapper = SimpleDeepFMNNWrapper(
307+
num_dense_features=num_dense_features,
308+
embedding_bag_collection=ebc,
309+
hidden_layer_size=20,
310+
deep_fm_dimension=5,
311+
)
312+
313+
# Create ModelInput with empty sparse features that match expected feature names
314+
dense_features = torch.rand((B, num_dense_features))
315+
empty_sparse_features = KeyedJaggedTensor.from_offsets_sync(
316+
keys=["f1"],
317+
values=torch.tensor([], dtype=torch.long),
318+
offsets=torch.tensor([0] * (B + 1), dtype=torch.long),
319+
)
320+
321+
model_input = ModelInput(
322+
float_features=dense_features,
323+
idlist_features=empty_sparse_features,
324+
idscore_features=None,
325+
label=torch.rand((B,)),
326+
)
327+
328+
logits = deepfm_wrapper(model_input)
329+
self.assertEqual(logits.size(), (B, 1))
330+
331+
213332
if __name__ == "__main__":
214333
unittest.main()

torchrec/models/tests/test_dlrm.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torch import nn
1414
from torch.testing import FileCheck # @manual
1515
from torchrec.datasets.utils import Batch
16+
from torchrec.distributed.test_utils.test_input import ModelInput
1617
from torchrec.fx import symbolic_trace
1718
from torchrec.ir.serializer import JsonSerializer
1819
from torchrec.ir.utils import decapsulate_ir_modules, encapsulate_ir_modules
@@ -23,6 +24,7 @@
2324
DLRM_DCN,
2425
DLRM_Projection,
2526
DLRMTrain,
27+
DLRMWrapper,
2628
InteractionArch,
2729
InteractionDCNArch,
2830
InteractionProjectionArch,
@@ -1283,3 +1285,116 @@ def test_export_serialization(self) -> None:
12831285
deserialized_logits = deserialized_model(features, sparse_features)
12841286

12851287
self.assertEqual(deserialized_logits.size(), (B, 1))
1288+
1289+
1290+
class DLRMWrapperTest(unittest.TestCase):
1291+
def test_basic(self) -> None:
1292+
B = 2
1293+
D = 8
1294+
dense_in_features = 100
1295+
eb1_config = EmbeddingBagConfig(
1296+
name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"]
1297+
)
1298+
eb2_config = EmbeddingBagConfig(
1299+
name="t2",
1300+
embedding_dim=D,
1301+
num_embeddings=100,
1302+
feature_names=["f2"],
1303+
)
1304+
1305+
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
1306+
1307+
dlrm_wrapper = DLRMWrapper(
1308+
embedding_bag_collection=ebc,
1309+
dense_in_features=dense_in_features,
1310+
dense_arch_layer_sizes=[20, D],
1311+
over_arch_layer_sizes=[5, 1],
1312+
)
1313+
1314+
# Create ModelInput with both dense and sparse features
1315+
dense_features = torch.rand((B, dense_in_features))
1316+
sparse_features = KeyedJaggedTensor.from_offsets_sync(
1317+
keys=["f1", "f3", "f2"],
1318+
values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3]),
1319+
offsets=torch.tensor([0, 2, 4, 6, 8, 10, 11]),
1320+
)
1321+
1322+
model_input = ModelInput(
1323+
float_features=dense_features,
1324+
idlist_features=sparse_features,
1325+
idscore_features=None,
1326+
label=torch.rand((B,)),
1327+
)
1328+
1329+
logits = dlrm_wrapper(model_input)
1330+
self.assertEqual(logits.size(), (B, 1))
1331+
1332+
def test_no_sparse_features(self) -> None:
1333+
B = 2
1334+
D = 8
1335+
dense_in_features = 100
1336+
eb1_config = EmbeddingBagConfig(
1337+
name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1"]
1338+
)
1339+
1340+
ebc = EmbeddingBagCollection(tables=[eb1_config])
1341+
1342+
dlrm_wrapper = DLRMWrapper(
1343+
embedding_bag_collection=ebc,
1344+
dense_in_features=dense_in_features,
1345+
dense_arch_layer_sizes=[20, D],
1346+
over_arch_layer_sizes=[5, 1],
1347+
)
1348+
1349+
# Create ModelInput with empty sparse features that match expected feature names
1350+
dense_features = torch.rand((B, dense_in_features))
1351+
empty_sparse_features = KeyedJaggedTensor.from_offsets_sync(
1352+
keys=["f1"],
1353+
values=torch.tensor([], dtype=torch.long),
1354+
offsets=torch.tensor([0] * (B + 1), dtype=torch.long),
1355+
)
1356+
1357+
model_input = ModelInput(
1358+
float_features=dense_features,
1359+
idlist_features=empty_sparse_features,
1360+
idscore_features=None,
1361+
label=torch.rand((B,)),
1362+
)
1363+
1364+
logits = dlrm_wrapper(model_input)
1365+
self.assertEqual(logits.size(), (B, 1))
1366+
1367+
def test_empty_sparse_features(self) -> None:
1368+
B = 2
1369+
D = 8
1370+
dense_in_features = 100
1371+
eb1_config = EmbeddingBagConfig(
1372+
name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1"]
1373+
)
1374+
1375+
ebc = EmbeddingBagCollection(tables=[eb1_config])
1376+
1377+
dlrm_wrapper = DLRMWrapper(
1378+
embedding_bag_collection=ebc,
1379+
dense_in_features=dense_in_features,
1380+
dense_arch_layer_sizes=[20, D],
1381+
over_arch_layer_sizes=[5, 1],
1382+
)
1383+
1384+
# Create ModelInput with empty sparse features that match expected feature names
1385+
dense_features = torch.rand((B, dense_in_features))
1386+
empty_sparse_features = KeyedJaggedTensor.from_offsets_sync(
1387+
keys=["f1"],
1388+
values=torch.tensor([], dtype=torch.long),
1389+
offsets=torch.tensor([0] * (B + 1), dtype=torch.long),
1390+
)
1391+
1392+
model_input = ModelInput(
1393+
float_features=dense_features,
1394+
idlist_features=empty_sparse_features,
1395+
idscore_features=None,
1396+
label=torch.rand((B,)),
1397+
)
1398+
1399+
logits = dlrm_wrapper(model_input)
1400+
self.assertEqual(logits.size(), (B, 1))

0 commit comments

Comments
 (0)