Skip to content

Commit 86f9615

Browse files
committed
DLRM examples to create prediction using TorchRec library. This shows basic example of how to use TorchRec library quickly locally. (#3043)
Summary: Pull Request resolved: #3043 Adding end to end example of using TorchRec library with DLRM to create recommendation. You can use this example in local machine to run and test how TorchRec is used for simple use case. Differential Revision: D75989524
1 parent e62add5 commit 86f9615

File tree

4 files changed

+1142
-0
lines changed

4 files changed

+1142
-0
lines changed

examples/prediction/README.md

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# DLRM Prediction Example
2+
3+
This example demonstrates how to use a Deep Learning Recommendation Model (DLRM) for making predictions using TorchRec capabilities. The code includes:
4+
5+
1. A DLRM implementation using TorchRec's EmbeddingBagCollection and KeyedJaggedTensor
6+
2. Training with random data
7+
3. Evaluation
8+
4. Making sample predictions
9+
10+
## TorchRec Integration
11+
12+
This implementation has been updated to use TorchRec's capabilities:
13+
- Uses `KeyedJaggedTensor` for sparse features
14+
- Uses `EmbeddingBagCollection` for embedding tables
15+
- Follows the DLRM architecture as described in the paper: https://arxiv.org/abs/1906.00091
16+
17+
The example demonstrates how to leverage TorchRec's efficient sparse feature handling for recommendation models.
18+
19+
## Dependencies
20+
21+
Install the required dependencies:
22+
23+
```bash
24+
# Install PyTorch
25+
pip install torch torchvision
26+
27+
# Install NumPy
28+
pip install numpy
29+
30+
# Install TorchRec
31+
pip install torchrec
32+
```
33+
34+
**Important**: This implementation now requires torchrec to run, as it uses TorchRec's specialized modules for recommendation systems.
35+
36+
## Running the Example Locally
37+
38+
1. Download the `predict_using_torchrec.py` file to your local machine.
39+
40+
2. Run the example:
41+
42+
```bash
43+
python3 predict_using_torchrec.py
44+
```
45+
46+
3. If you're using a different Python environment:
47+
48+
```bash
49+
# For conda environments
50+
conda activate your_environment_name
51+
python predict_using_torchrec.py
52+
53+
# For virtual environments
54+
source your_venv/bin/activate
55+
python predict_using_torchrec.py
56+
```
57+
58+
## What to Expect
59+
60+
When you run the example, you'll see:
61+
62+
1. Training progress for 10 epochs with loss and learning rate information
63+
2. Evaluation results showing MSE and RMSE metrics
64+
3. Sample predictions for a specific user on multiple items
65+
66+
## Implementation Details
67+
68+
This example uses TorchRec's capabilities to implement a DLRM model that:
69+
70+
- Takes dense features and sparse features (as KeyedJaggedTensor) as input
71+
- Processes dense features through a bottom MLP
72+
- Processes sparse features through EmbeddingBagCollection
73+
- Computes feature interactions using dot products
74+
- Processes the interactions through a top MLP
75+
- Outputs rating predictions on a 0-5 scale
76+
77+
The implementation demonstrates how to use TorchRec's specialized modules for recommendation systems, making it more efficient and scalable than a custom implementation.
78+
79+
## Key TorchRec Components Used
80+
81+
1. **KeyedJaggedTensor**: Efficiently represents sparse features with variable lengths
82+
2. **EmbeddingBagConfig**: Configures embedding tables with parameters like dimensions and feature names
83+
3. **EmbeddingBagCollection**: Manages multiple embedding tables for different categorical features
84+
85+
## Troubleshooting
86+
87+
If you encounter any issues:
88+
89+
1. **Python version**: This code has been tested with Python 3.8+. Make sure you're using a compatible version.
90+
91+
2. **PyTorch and TorchRec installation**: If you have issues with PyTorch or TorchRec, try installing specific versions:
92+
```bash
93+
pip install torch==2.0.0 torchvision==0.15.0
94+
pip install torchrec==0.5.0
95+
```
96+
97+
3. **Memory issues**: If you run out of memory, try reducing the batch size by modifying this line in the code:
98+
```python
99+
batch_size = 256 # Try a smaller value like 64 or 32
100+
```
101+
102+
4. **CPU vs GPU**: The code automatically uses CUDA if available. To force CPU usage, modify:
103+
```python
104+
device = torch.device("cpu")
105+
```
106+
107+
5. **TorchRec compatibility**: If you encounter compatibility issues with TorchRec, make sure you're using compatible versions of PyTorch and TorchRec.

examples/prediction/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
11+
def main() -> None:
12+
"""DOC_STRING"""
13+
14+
15+
if __name__ == "__main__":
16+
main()

0 commit comments

Comments
 (0)