8
8
# pyre-strict
9
9
10
10
import copy
11
- from dataclasses import dataclass
11
+ from abc import ABC , abstractmethod
12
+ from dataclasses import dataclass , fields
12
13
from typing import Any , cast , Dict , List , Optional , Tuple , Type , Union
13
14
14
15
import torch
24
25
from torchrec .distributed .test_utils .test_input import ModelInput
25
26
from torchrec .distributed .test_utils .test_model import (
26
27
TestEBCSharder ,
27
- TestOverArchLarge ,
28
28
TestSparseNN ,
29
+ TestTowerCollectionSparseNN ,
30
+ TestTowerSparseNN ,
29
31
)
30
32
from torchrec .distributed .train_pipeline import (
31
33
TrainPipelineBase ,
41
43
42
44
43
45
@dataclass
44
- class ModelConfig :
45
- batch_size : int = 8192
46
- num_float_features : int = 10
47
- feature_pooling_avg : int = 10
48
- use_offsets : bool = False
49
- dev_str : str = ""
50
- long_kjt_indices : bool = True
51
- long_kjt_offsets : bool = True
52
- long_kjt_lengths : bool = True
53
- pin_memory : bool = True
46
+ class BaseModelConfig (ABC ):
47
+ """
48
+ Abstract base class for model configurations.
49
+
50
+ This class defines the common parameters shared across all model types
51
+ and requires each concrete implementation to provide its own generate_model method.
52
+ """
53
+
54
+ # Common parameters for all model types
55
+ batch_size : int
56
+ num_float_features : int
57
+ feature_pooling_avg : int
58
+ use_offsets : bool
59
+ dev_str : str
60
+ long_kjt_indices : bool
61
+ long_kjt_offsets : bool
62
+ long_kjt_lengths : bool
63
+ pin_memory : bool
64
+
65
+ @abstractmethod
66
+ def generate_model (
67
+ self ,
68
+ tables : List [EmbeddingBagConfig ],
69
+ weighted_tables : List [EmbeddingBagConfig ],
70
+ dense_device : torch .device ,
71
+ ) -> nn .Module :
72
+ """
73
+ Generate a model instance based on the configuration.
74
+
75
+ Args:
76
+ tables: List of unweighted embedding tables
77
+ weighted_tables: List of weighted embedding tables
78
+ dense_device: Device to place dense layers on
79
+
80
+ Returns:
81
+ A neural network module instance
82
+ """
83
+ pass
84
+
85
+
86
+ @dataclass
87
+ class TestSparseNNConfig (BaseModelConfig ):
88
+ """Configuration for TestSparseNN model."""
89
+
90
+ embedding_groups : Optional [Dict [str , List [str ]]]
91
+ feature_processor_modules : Optional [Dict [str , torch .nn .Module ]]
92
+ max_feature_lengths : Optional [Dict [str , int ]]
93
+ over_arch_clazz : Type [nn .Module ]
94
+ postproc_module : Optional [nn .Module ]
95
+ zch : bool
54
96
55
97
def generate_model (
56
98
self ,
@@ -60,13 +102,123 @@ def generate_model(
60
102
) -> nn .Module :
61
103
return TestSparseNN (
62
104
tables = tables ,
105
+ num_float_features = self .num_float_features ,
63
106
weighted_tables = weighted_tables ,
64
107
dense_device = dense_device ,
65
108
sparse_device = torch .device ("meta" ),
66
- over_arch_clazz = TestOverArchLarge ,
109
+ max_feature_lengths = self .max_feature_lengths ,
110
+ feature_processor_modules = self .feature_processor_modules ,
111
+ over_arch_clazz = self .over_arch_clazz ,
112
+ postproc_module = self .postproc_module ,
113
+ embedding_groups = self .embedding_groups ,
114
+ zch = self .zch ,
67
115
)
68
116
69
117
118
+ @dataclass
119
+ class TestTowerSparseNNConfig (BaseModelConfig ):
120
+ """Configuration for TestTowerSparseNN model."""
121
+
122
+ embedding_groups : Optional [Dict [str , List [str ]]] = None
123
+ feature_processor_modules : Optional [Dict [str , torch .nn .Module ]] = None
124
+
125
+ def generate_model (
126
+ self ,
127
+ tables : List [EmbeddingBagConfig ],
128
+ weighted_tables : List [EmbeddingBagConfig ],
129
+ dense_device : torch .device ,
130
+ ) -> nn .Module :
131
+ return TestTowerSparseNN (
132
+ num_float_features = self .num_float_features ,
133
+ tables = tables ,
134
+ weighted_tables = weighted_tables ,
135
+ dense_device = dense_device ,
136
+ sparse_device = torch .device ("meta" ),
137
+ embedding_groups = self .embedding_groups ,
138
+ feature_processor_modules = self .feature_processor_modules ,
139
+ )
140
+
141
+
142
+ @dataclass
143
+ class TestTowerCollectionSparseNNConfig (BaseModelConfig ):
144
+ """Configuration for TestTowerCollectionSparseNN model."""
145
+
146
+ embedding_groups : Optional [Dict [str , List [str ]]] = None
147
+ feature_processor_modules : Optional [Dict [str , torch .nn .Module ]] = None
148
+
149
+ def generate_model (
150
+ self ,
151
+ tables : List [EmbeddingBagConfig ],
152
+ weighted_tables : List [EmbeddingBagConfig ],
153
+ dense_device : torch .device ,
154
+ ) -> nn .Module :
155
+ return TestTowerCollectionSparseNN (
156
+ tables = tables ,
157
+ weighted_tables = weighted_tables ,
158
+ dense_device = dense_device ,
159
+ sparse_device = torch .device ("meta" ),
160
+ num_float_features = self .num_float_features ,
161
+ embedding_groups = self .embedding_groups ,
162
+ feature_processor_modules = self .feature_processor_modules ,
163
+ )
164
+
165
+
166
+ @dataclass
167
+ class DeepFMConfig (BaseModelConfig ):
168
+ """Configuration for DeepFM model."""
169
+
170
+ hidden_layer_size : int
171
+ deep_fm_dimension : int
172
+
173
+ def generate_model (
174
+ self ,
175
+ tables : List [EmbeddingBagConfig ],
176
+ weighted_tables : List [EmbeddingBagConfig ],
177
+ dense_device : torch .device ,
178
+ ) -> nn .Module :
179
+ # TODO: Implement DeepFM model generation
180
+ raise NotImplementedError ("DeepFM model generation not yet implemented" )
181
+
182
+
183
+ @dataclass
184
+ class DLRMConfig (BaseModelConfig ):
185
+ """Configuration for DLRM model."""
186
+
187
+ dense_arch_layer_sizes : List [int ]
188
+ over_arch_layer_sizes : List [int ]
189
+
190
+ def generate_model (
191
+ self ,
192
+ tables : List [EmbeddingBagConfig ],
193
+ weighted_tables : List [EmbeddingBagConfig ],
194
+ dense_device : torch .device ,
195
+ ) -> nn .Module :
196
+ # TODO: Implement DLRM model generation
197
+ raise NotImplementedError ("DLRM model generation not yet implemented" )
198
+
199
+
200
+ # pyre-ignore[2]: Missing parameter annotation
201
+ def create_model_config (model_name : str , ** kwargs ) -> BaseModelConfig :
202
+
203
+ model_configs = {
204
+ "test_sparse_nn" : TestSparseNNConfig ,
205
+ "test_tower_sparse_nn" : TestTowerSparseNNConfig ,
206
+ "test_tower_collection_sparse_nn" : TestTowerCollectionSparseNNConfig ,
207
+ "deepfm" : DeepFMConfig ,
208
+ "dlrm" : DLRMConfig ,
209
+ }
210
+
211
+ if model_name not in model_configs :
212
+ raise ValueError (f"Unknown model name: { model_name } " )
213
+
214
+ # Filter kwargs to only include valid parameters for the specific model config class
215
+ model_class = model_configs [model_name ]
216
+ valid_field_names = {field .name for field in fields (model_class )}
217
+ filtered_kwargs = {k : v for k , v in kwargs .items () if k in valid_field_names }
218
+
219
+ return model_class (** filtered_kwargs )
220
+
221
+
70
222
def generate_tables (
71
223
num_unweighted_features : int ,
72
224
num_weighted_features : int ,
@@ -319,7 +471,7 @@ def generate_sharded_model_and_optimizer(
319
471
def generate_data (
320
472
tables : List [EmbeddingBagConfig ],
321
473
weighted_tables : List [EmbeddingBagConfig ],
322
- model_config : ModelConfig ,
474
+ model_config : BaseModelConfig ,
323
475
num_batches : int ,
324
476
) -> List [ModelInput ]:
325
477
"""
0 commit comments