Skip to content

Commit c91e7da

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Added model config for supporting SparseNN, TowerSparseNN, TowerCollectionSparseNN models. (#3104)
Summary: Pull Request resolved: #3104 Added a model configuration that supports SparseNN, TowerSparseNN, TowerCollectionSparseNN models. Future commits will add support for DeepFM and DLRM models. Reviewed By: aliafzal Differential Revision: D76833673
1 parent f659b6a commit c91e7da

File tree

2 files changed

+243
-19
lines changed

2 files changed

+243
-19
lines changed

torchrec/distributed/benchmark/benchmark_pipeline_utils.py

Lines changed: 166 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
# pyre-strict
99

1010
import copy
11-
from dataclasses import dataclass
11+
from abc import ABC, abstractmethod
12+
from dataclasses import dataclass, fields
1213
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union
1314

1415
import torch
@@ -24,8 +25,9 @@
2425
from torchrec.distributed.test_utils.test_input import ModelInput
2526
from torchrec.distributed.test_utils.test_model import (
2627
TestEBCSharder,
27-
TestOverArchLarge,
2828
TestSparseNN,
29+
TestTowerCollectionSparseNN,
30+
TestTowerSparseNN,
2931
)
3032
from torchrec.distributed.train_pipeline import (
3133
TrainPipelineBase,
@@ -41,16 +43,56 @@
4143

4244

4345
@dataclass
44-
class ModelConfig:
45-
batch_size: int = 8192
46-
num_float_features: int = 10
47-
feature_pooling_avg: int = 10
48-
use_offsets: bool = False
49-
dev_str: str = ""
50-
long_kjt_indices: bool = True
51-
long_kjt_offsets: bool = True
52-
long_kjt_lengths: bool = True
53-
pin_memory: bool = True
46+
class BaseModelConfig(ABC):
47+
"""
48+
Abstract base class for model configurations.
49+
50+
This class defines the common parameters shared across all model types
51+
and requires each concrete implementation to provide its own generate_model method.
52+
"""
53+
54+
# Common parameters for all model types
55+
batch_size: int
56+
num_float_features: int
57+
feature_pooling_avg: int
58+
use_offsets: bool
59+
dev_str: str
60+
long_kjt_indices: bool
61+
long_kjt_offsets: bool
62+
long_kjt_lengths: bool
63+
pin_memory: bool
64+
65+
@abstractmethod
66+
def generate_model(
67+
self,
68+
tables: List[EmbeddingBagConfig],
69+
weighted_tables: List[EmbeddingBagConfig],
70+
dense_device: torch.device,
71+
) -> nn.Module:
72+
"""
73+
Generate a model instance based on the configuration.
74+
75+
Args:
76+
tables: List of unweighted embedding tables
77+
weighted_tables: List of weighted embedding tables
78+
dense_device: Device to place dense layers on
79+
80+
Returns:
81+
A neural network module instance
82+
"""
83+
pass
84+
85+
86+
@dataclass
87+
class TestSparseNNConfig(BaseModelConfig):
88+
"""Configuration for TestSparseNN model."""
89+
90+
embedding_groups: Optional[Dict[str, List[str]]]
91+
feature_processor_modules: Optional[Dict[str, torch.nn.Module]]
92+
max_feature_lengths: Optional[Dict[str, int]]
93+
over_arch_clazz: Type[nn.Module]
94+
postproc_module: Optional[nn.Module]
95+
zch: bool
5496

5597
def generate_model(
5698
self,
@@ -60,13 +102,123 @@ def generate_model(
60102
) -> nn.Module:
61103
return TestSparseNN(
62104
tables=tables,
105+
num_float_features=self.num_float_features,
63106
weighted_tables=weighted_tables,
64107
dense_device=dense_device,
65108
sparse_device=torch.device("meta"),
66-
over_arch_clazz=TestOverArchLarge,
109+
max_feature_lengths=self.max_feature_lengths,
110+
feature_processor_modules=self.feature_processor_modules,
111+
over_arch_clazz=self.over_arch_clazz,
112+
postproc_module=self.postproc_module,
113+
embedding_groups=self.embedding_groups,
114+
zch=self.zch,
67115
)
68116

69117

118+
@dataclass
119+
class TestTowerSparseNNConfig(BaseModelConfig):
120+
"""Configuration for TestTowerSparseNN model."""
121+
122+
embedding_groups: Optional[Dict[str, List[str]]] = None
123+
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None
124+
125+
def generate_model(
126+
self,
127+
tables: List[EmbeddingBagConfig],
128+
weighted_tables: List[EmbeddingBagConfig],
129+
dense_device: torch.device,
130+
) -> nn.Module:
131+
return TestTowerSparseNN(
132+
num_float_features=self.num_float_features,
133+
tables=tables,
134+
weighted_tables=weighted_tables,
135+
dense_device=dense_device,
136+
sparse_device=torch.device("meta"),
137+
embedding_groups=self.embedding_groups,
138+
feature_processor_modules=self.feature_processor_modules,
139+
)
140+
141+
142+
@dataclass
143+
class TestTowerCollectionSparseNNConfig(BaseModelConfig):
144+
"""Configuration for TestTowerCollectionSparseNN model."""
145+
146+
embedding_groups: Optional[Dict[str, List[str]]] = None
147+
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None
148+
149+
def generate_model(
150+
self,
151+
tables: List[EmbeddingBagConfig],
152+
weighted_tables: List[EmbeddingBagConfig],
153+
dense_device: torch.device,
154+
) -> nn.Module:
155+
return TestTowerCollectionSparseNN(
156+
tables=tables,
157+
weighted_tables=weighted_tables,
158+
dense_device=dense_device,
159+
sparse_device=torch.device("meta"),
160+
num_float_features=self.num_float_features,
161+
embedding_groups=self.embedding_groups,
162+
feature_processor_modules=self.feature_processor_modules,
163+
)
164+
165+
166+
@dataclass
167+
class DeepFMConfig(BaseModelConfig):
168+
"""Configuration for DeepFM model."""
169+
170+
hidden_layer_size: int
171+
deep_fm_dimension: int
172+
173+
def generate_model(
174+
self,
175+
tables: List[EmbeddingBagConfig],
176+
weighted_tables: List[EmbeddingBagConfig],
177+
dense_device: torch.device,
178+
) -> nn.Module:
179+
# TODO: Implement DeepFM model generation
180+
raise NotImplementedError("DeepFM model generation not yet implemented")
181+
182+
183+
@dataclass
184+
class DLRMConfig(BaseModelConfig):
185+
"""Configuration for DLRM model."""
186+
187+
dense_arch_layer_sizes: List[int]
188+
over_arch_layer_sizes: List[int]
189+
190+
def generate_model(
191+
self,
192+
tables: List[EmbeddingBagConfig],
193+
weighted_tables: List[EmbeddingBagConfig],
194+
dense_device: torch.device,
195+
) -> nn.Module:
196+
# TODO: Implement DLRM model generation
197+
raise NotImplementedError("DLRM model generation not yet implemented")
198+
199+
200+
# pyre-ignore[2]: Missing parameter annotation
201+
def create_model_config(model_name: str, **kwargs) -> BaseModelConfig:
202+
203+
model_configs = {
204+
"test_sparse_nn": TestSparseNNConfig,
205+
"test_tower_sparse_nn": TestTowerSparseNNConfig,
206+
"test_tower_collection_sparse_nn": TestTowerCollectionSparseNNConfig,
207+
"deepfm": DeepFMConfig,
208+
"dlrm": DLRMConfig,
209+
}
210+
211+
if model_name not in model_configs:
212+
raise ValueError(f"Unknown model name: {model_name}")
213+
214+
# Filter kwargs to only include valid parameters for the specific model config class
215+
model_class = model_configs[model_name]
216+
valid_field_names = {field.name for field in fields(model_class)}
217+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_field_names}
218+
219+
return model_class(**filtered_kwargs)
220+
221+
70222
def generate_tables(
71223
num_unweighted_features: int,
72224
num_weighted_features: int,
@@ -319,7 +471,7 @@ def generate_sharded_model_and_optimizer(
319471
def generate_data(
320472
tables: List[EmbeddingBagConfig],
321473
weighted_tables: List[EmbeddingBagConfig],
322-
model_config: ModelConfig,
474+
model_config: BaseModelConfig,
323475
num_batches: int,
324476
) -> List[ModelInput]:
325477
"""

torchrec/distributed/benchmark/benchmark_train_pipeline.py

Lines changed: 77 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,25 @@
99

1010
#!/usr/bin/env python3
1111

12-
from dataclasses import dataclass
13-
from typing import List, Optional
12+
from dataclasses import dataclass, field
13+
from typing import Dict, List, Optional, Type, Union
1414

1515
import torch
1616
from fbgemm_gpu.split_embedding_configs import EmbOptimType
1717
from torch import nn
1818
from torchrec.distributed.benchmark.benchmark_pipeline_utils import (
19+
BaseModelConfig,
20+
create_model_config,
21+
DeepFMConfig,
22+
DLRMConfig,
1923
generate_data,
2024
generate_pipeline,
2125
generate_planner,
2226
generate_sharded_model_and_optimizer,
2327
generate_tables,
24-
ModelConfig,
28+
TestSparseNNConfig,
29+
TestTowerCollectionSparseNNConfig,
30+
TestTowerSparseNNConfig,
2531
)
2632
from torchrec.distributed.benchmark.benchmark_utils import benchmark_func, cmd_conf
2733
from torchrec.distributed.comm import get_local_size
@@ -33,6 +39,7 @@
3339
run_multi_process_func,
3440
)
3541
from torchrec.distributed.test_utils.test_input import ModelInput
42+
from torchrec.distributed.test_utils.test_model import TestOverArchLarge
3643
from torchrec.distributed.train_pipeline import TrainPipeline
3744
from torchrec.distributed.types import ShardingType
3845
from torchrec.modules.embedding_configs import EmbeddingBagConfig
@@ -124,19 +131,84 @@ class PipelineConfig:
124131
emb_lookup_stream: str = "data_dist"
125132

126133

134+
@dataclass
135+
class ModelSelectionConfig:
136+
model_name: str = "test_sparse_nn"
137+
138+
# Common config for all model types
139+
batch_size: int = 8192
140+
num_float_features: int = 10
141+
feature_pooling_avg: int = 10
142+
use_offsets: bool = False
143+
dev_str: str = ""
144+
long_kjt_indices: bool = True
145+
long_kjt_offsets: bool = True
146+
long_kjt_lengths: bool = True
147+
pin_memory: bool = True
148+
149+
# TestSparseNN specific config
150+
embedding_groups: Optional[Dict[str, List[str]]] = None
151+
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None
152+
max_feature_lengths: Optional[Dict[str, int]] = None
153+
over_arch_clazz: Type[nn.Module] = TestOverArchLarge
154+
postproc_module: Optional[nn.Module] = None
155+
zch: bool = False
156+
157+
# DeepFM specific config
158+
hidden_layer_size: int = 20
159+
deep_fm_dimension: int = 5
160+
161+
# DLRM specific config
162+
dense_arch_layer_sizes: List[int] = field(default_factory=lambda: [20, 10])
163+
over_arch_layer_sizes: List[int] = field(default_factory=lambda: [5, 3])
164+
165+
127166
@cmd_conf
128167
def main(
129168
run_option: RunOptions,
130169
table_config: EmbeddingTablesConfig,
131-
model_config: ModelConfig,
170+
model_selection: ModelSelectionConfig,
132171
pipeline_config: PipelineConfig,
172+
model_config: Optional[
173+
Union[
174+
TestSparseNNConfig,
175+
TestTowerCollectionSparseNNConfig,
176+
TestTowerSparseNNConfig,
177+
DeepFMConfig,
178+
DLRMConfig,
179+
]
180+
] = None,
133181
) -> None:
134182
tables, weighted_tables = generate_tables(
135183
num_unweighted_features=table_config.num_unweighted_features,
136184
num_weighted_features=table_config.num_weighted_features,
137185
embedding_feature_dim=table_config.embedding_feature_dim,
138186
)
139187

188+
if model_config is None:
189+
model_config = create_model_config(
190+
model_name=model_selection.model_name,
191+
batch_size=model_selection.batch_size,
192+
num_float_features=model_selection.num_float_features,
193+
feature_pooling_avg=model_selection.feature_pooling_avg,
194+
use_offsets=model_selection.use_offsets,
195+
dev_str=model_selection.dev_str,
196+
long_kjt_indices=model_selection.long_kjt_indices,
197+
long_kjt_offsets=model_selection.long_kjt_offsets,
198+
long_kjt_lengths=model_selection.long_kjt_lengths,
199+
pin_memory=model_selection.pin_memory,
200+
embedding_groups=model_selection.embedding_groups,
201+
feature_processor_modules=model_selection.feature_processor_modules,
202+
max_feature_lengths=model_selection.max_feature_lengths,
203+
over_arch_clazz=model_selection.over_arch_clazz,
204+
postproc_module=model_selection.postproc_module,
205+
zch=model_selection.zch,
206+
hidden_layer_size=model_selection.hidden_layer_size,
207+
deep_fm_dimension=model_selection.deep_fm_dimension,
208+
dense_arch_layer_sizes=model_selection.dense_arch_layer_sizes,
209+
over_arch_layer_sizes=model_selection.over_arch_layer_sizes,
210+
)
211+
140212
# launch trainers
141213
run_multi_process_func(
142214
func=runner,
@@ -155,7 +227,7 @@ def runner(
155227
tables: List[EmbeddingBagConfig],
156228
weighted_tables: List[EmbeddingBagConfig],
157229
run_option: RunOptions,
158-
model_config: ModelConfig,
230+
model_config: BaseModelConfig,
159231
pipeline_config: PipelineConfig,
160232
) -> None:
161233
# Ensure GPUs are available and we have enough of them

0 commit comments

Comments
 (0)