2626 TestEBCSharder ,
2727 TestOverArchLarge ,
2828 TestSparseNN ,
29+ TestTowerCollectionSparseNN ,
30+ TestTowerSparseNN ,
2931)
3032from torchrec .distributed .train_pipeline import (
3133 TrainPipelineBase ,
4244
4345@dataclass
4446class ModelConfig :
47+ model_name : str = "test_sparsenn"
48+
4549 batch_size : int = 8192
4650 num_float_features : int = 10
4751 feature_pooling_avg : int = 10
@@ -58,13 +62,32 @@ def generate_model(
5862 weighted_tables : List [EmbeddingBagConfig ],
5963 dense_device : torch .device ,
6064 ) -> nn .Module :
61- return TestSparseNN (
62- tables = tables ,
63- weighted_tables = weighted_tables ,
64- dense_device = dense_device ,
65- sparse_device = torch .device ("meta" ),
66- over_arch_clazz = TestOverArchLarge ,
67- )
65+ if self .model_name == "test_sparsenn" :
66+ return TestSparseNN (
67+ tables = tables ,
68+ weighted_tables = weighted_tables ,
69+ dense_device = dense_device ,
70+ sparse_device = torch .device ("meta" ),
71+ over_arch_clazz = TestOverArchLarge ,
72+ )
73+ elif self .model_name == "test_tower_sparsenn" :
74+ return TestTowerSparseNN (
75+ tables = tables ,
76+ weighted_tables = weighted_tables ,
77+ dense_device = dense_device ,
78+ sparse_device = torch .device ("meta" ),
79+ num_float_features = self .num_float_features ,
80+ )
81+ elif self .model_name == "test_tower_collection_sparsenn" :
82+ return TestTowerCollectionSparseNN (
83+ tables = tables ,
84+ weighted_tables = weighted_tables ,
85+ dense_device = dense_device ,
86+ sparse_device = torch .device ("meta" ),
87+ num_float_features = self .num_float_features ,
88+ )
89+ else :
90+ raise RuntimeError (f"Unknown model name: { self .model_name } " )
6891
6992
7093def generate_tables (
@@ -317,6 +340,7 @@ def generate_sharded_model_and_optimizer(
317340
318341
319342def generate_data (
343+ model_class_name : str ,
320344 tables : List [EmbeddingBagConfig ],
321345 weighted_tables : List [EmbeddingBagConfig ],
322346 model_config : ModelConfig ,
@@ -336,25 +360,32 @@ def generate_data(
336360 """
337361 device = torch .device (model_config .dev_str ) if model_config .dev_str else None
338362
339- return [
340- ModelInput .generate (
341- batch_size = model_config .batch_size ,
342- tables = tables ,
343- weighted_tables = weighted_tables ,
344- num_float_features = model_config .num_float_features ,
345- pooling_avg = model_config .feature_pooling_avg ,
346- use_offsets = model_config .use_offsets ,
347- device = device ,
348- indices_dtype = (
349- torch .int64 if model_config .long_kjt_indices else torch .int32
350- ),
351- offsets_dtype = (
352- torch .int64 if model_config .long_kjt_offsets else torch .int32
353- ),
354- lengths_dtype = (
355- torch .int64 if model_config .long_kjt_lengths else torch .int32
356- ),
357- pin_memory = model_config .pin_memory ,
358- )
359- for _ in range (num_batches )
360- ]
363+ if (
364+ model_class_name == "TestSparseNN"
365+ or model_class_name == "TestTowerSparseNN"
366+ or model_class_name == "TestTowerCollectionSparseNN"
367+ ):
368+ return [
369+ ModelInput .generate (
370+ batch_size = model_config .batch_size ,
371+ tables = tables ,
372+ weighted_tables = weighted_tables ,
373+ num_float_features = model_config .num_float_features ,
374+ pooling_avg = model_config .feature_pooling_avg ,
375+ use_offsets = model_config .use_offsets ,
376+ device = device ,
377+ indices_dtype = (
378+ torch .int64 if model_config .long_kjt_indices else torch .int32
379+ ),
380+ offsets_dtype = (
381+ torch .int64 if model_config .long_kjt_offsets else torch .int32
382+ ),
383+ lengths_dtype = (
384+ torch .int64 if model_config .long_kjt_lengths else torch .int32
385+ ),
386+ pin_memory = model_config .pin_memory ,
387+ )
388+ for _ in range (num_batches )
389+ ]
390+ else :
391+ raise RuntimeError (f"Unknown model name: { model_config .model_name } " )
0 commit comments