Skip to content

Commit 0301f50

Browse files
yupadhyayfacebook-github-bot
authored andcommitted
DLRM examples to create prediction using TorchRec library. This shows basic example of how to use TorchRec library quickly locally. (#3043)
Summary: Adding end to end example of using TorchRec library with DLRM to create recommendation. You can use this example in local machine to run and test how TorchRec is used for simple use case. Differential Revision: D75989524
1 parent 76a0826 commit 0301f50

File tree

4 files changed

+815
-0
lines changed

4 files changed

+815
-0
lines changed

examples/prediction/README.md

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# DLRM Prediction Example
2+
3+
This example demonstrates how to use a Deep Learning Recommendation Model (DLRM) for making predictions. The code includes:
4+
5+
1. A custom DLRM implementation
6+
2. Training with random data
7+
3. Evaluation
8+
4. Making sample predictions
9+
10+
## Cross-Platform Compatibility
11+
12+
This implementation has been specifically designed to work on all platforms, including:
13+
- Linux
14+
- macOS
15+
- Windows
16+
17+
Unlike the original torchrec implementation, this version uses a custom SimpleDLRM class that doesn't depend on torchrec or fbgemm_gpu, avoiding compatibility issues on macOS and other platforms.
18+
19+
## Dependencies
20+
21+
Install the required dependencies:
22+
23+
```bash
24+
# Install PyTorch
25+
pip install torch torchvision
26+
27+
# Install NumPy
28+
pip install numpy
29+
```
30+
31+
**Important**: torchrec is NOT required or used in this implementation. The code has been completely rewritten to avoid any dependencies on torchrec or fbgemm_gpu.
32+
33+
## Running the Example Locally
34+
35+
1. Download the `predict_using_torchrec.py` file to your local machine.
36+
37+
2. Run the example:
38+
39+
```bash
40+
python3 predict_using_torchrec.py
41+
```
42+
43+
3. If you're using a different Python environment:
44+
45+
```bash
46+
# For conda environments
47+
conda activate your_environment_name
48+
python predict_using_torchrec.py
49+
50+
# For virtual environments
51+
source your_venv/bin/activate
52+
python predict_using_torchrec.py
53+
```
54+
55+
## What to Expect
56+
57+
When you run the example, you'll see:
58+
59+
1. Training progress for 10 epochs with loss and learning rate information
60+
2. Evaluation results showing MSE and RMSE metrics
61+
3. Sample predictions for a specific user on multiple items
62+
63+
## Implementation Details
64+
65+
This example uses a custom SimpleDLRM implementation that:
66+
67+
- Takes dense features and categorical features as input
68+
- Processes dense features through a bottom MLP
69+
- Processes categorical features through embedding tables
70+
- Computes feature interactions using dot products
71+
- Processes the interactions through a top MLP
72+
- Outputs rating predictions on a 0-5 scale
73+
74+
The implementation is designed to be simple and easy to understand, while still capturing the key components of a DLRM model.
75+
76+
## Troubleshooting
77+
78+
If you encounter any issues:
79+
80+
1. **Python version**: This code has been tested with Python 3.8+. Make sure you're using a compatible version.
81+
82+
2. **PyTorch installation**: If you have issues with PyTorch, try installing a specific version:
83+
```bash
84+
pip install torch==2.0.0 torchvision==0.15.0
85+
```
86+
87+
3. **Memory issues**: If you run out of memory, try reducing the batch size by modifying this line in the code:
88+
```python
89+
batch_size = 256 # Try a smaller value like 64 or 32
90+
```
91+
92+
4. **CPU vs GPU**: The code automatically uses CUDA if available. To force CPU usage, modify:
93+
```python
94+
device = torch.device("cpu")
95+
```

examples/prediction/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
11+
def main() -> None:
12+
"""DOC_STRING"""
13+
14+
15+
if __name__ == "__main__":
16+
main()

0 commit comments

Comments
 (0)