Skip to content

Commit e0310fa

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Added model config for supporting SparseNN, TowerSparseNN, TowerCollectionSparseNN models.
Summary: Added a model configuration that supports SparseNN, TowerSparseNN, TowerCollectionSparseNN models. Future commits will add support for DeepFM and DLRM models. Differential Revision: D76833673
1 parent 7133194 commit e0310fa

2 files changed

Lines changed: 61 additions & 29 deletions

File tree

torchrec/distributed/benchmark/benchmark_pipeline_utils.py

Lines changed: 60 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
TestEBCSharder,
2727
TestOverArchLarge,
2828
TestSparseNN,
29+
TestTowerCollectionSparseNN,
30+
TestTowerSparseNN,
2931
)
3032
from torchrec.distributed.train_pipeline import (
3133
TrainPipelineBase,
@@ -42,6 +44,8 @@
4244

4345
@dataclass
4446
class ModelConfig:
47+
model_name: str = "test_sparsenn"
48+
4549
batch_size: int = 8192
4650
num_float_features: int = 10
4751
feature_pooling_avg: int = 10
@@ -58,13 +62,32 @@ def generate_model(
5862
weighted_tables: List[EmbeddingBagConfig],
5963
dense_device: torch.device,
6064
) -> nn.Module:
61-
return TestSparseNN(
62-
tables=tables,
63-
weighted_tables=weighted_tables,
64-
dense_device=dense_device,
65-
sparse_device=torch.device("meta"),
66-
over_arch_clazz=TestOverArchLarge,
67-
)
65+
if self.model_name == "test_sparsenn":
66+
return TestSparseNN(
67+
tables=tables,
68+
weighted_tables=weighted_tables,
69+
dense_device=dense_device,
70+
sparse_device=torch.device("meta"),
71+
over_arch_clazz=TestOverArchLarge,
72+
)
73+
elif self.model_name == "test_tower_sparsenn":
74+
return TestTowerSparseNN(
75+
tables=tables,
76+
weighted_tables=weighted_tables,
77+
dense_device=dense_device,
78+
sparse_device=torch.device("meta"),
79+
num_float_features=self.num_float_features,
80+
)
81+
elif self.model_name == "test_tower_collection_sparsenn":
82+
return TestTowerCollectionSparseNN(
83+
tables=tables,
84+
weighted_tables=weighted_tables,
85+
dense_device=dense_device,
86+
sparse_device=torch.device("meta"),
87+
num_float_features=self.num_float_features,
88+
)
89+
else:
90+
raise RuntimeError(f"Unknown model name: {self.model_name}")
6891

6992

7093
def generate_tables(
@@ -317,6 +340,7 @@ def generate_sharded_model_and_optimizer(
317340

318341

319342
def generate_data(
343+
model_class_name: str,
320344
tables: List[EmbeddingBagConfig],
321345
weighted_tables: List[EmbeddingBagConfig],
322346
model_config: ModelConfig,
@@ -336,25 +360,32 @@ def generate_data(
336360
"""
337361
device = torch.device(model_config.dev_str) if model_config.dev_str else None
338362

339-
return [
340-
ModelInput.generate(
341-
batch_size=model_config.batch_size,
342-
tables=tables,
343-
weighted_tables=weighted_tables,
344-
num_float_features=model_config.num_float_features,
345-
pooling_avg=model_config.feature_pooling_avg,
346-
use_offsets=model_config.use_offsets,
347-
device=device,
348-
indices_dtype=(
349-
torch.int64 if model_config.long_kjt_indices else torch.int32
350-
),
351-
offsets_dtype=(
352-
torch.int64 if model_config.long_kjt_offsets else torch.int32
353-
),
354-
lengths_dtype=(
355-
torch.int64 if model_config.long_kjt_lengths else torch.int32
356-
),
357-
pin_memory=model_config.pin_memory,
358-
)
359-
for _ in range(num_batches)
360-
]
363+
if (
364+
model_class_name == "TestSparseNN"
365+
or model_class_name == "TestTowerSparseNN"
366+
or model_class_name == "TestTowerCollectionSparseNN"
367+
):
368+
return [
369+
ModelInput.generate(
370+
batch_size=model_config.batch_size,
371+
tables=tables,
372+
weighted_tables=weighted_tables,
373+
num_float_features=model_config.num_float_features,
374+
pooling_avg=model_config.feature_pooling_avg,
375+
use_offsets=model_config.use_offsets,
376+
device=device,
377+
indices_dtype=(
378+
torch.int64 if model_config.long_kjt_indices else torch.int32
379+
),
380+
offsets_dtype=(
381+
torch.int64 if model_config.long_kjt_offsets else torch.int32
382+
),
383+
lengths_dtype=(
384+
torch.int64 if model_config.long_kjt_lengths else torch.int32
385+
),
386+
pin_memory=model_config.pin_memory,
387+
)
388+
for _ in range(num_batches)
389+
]
390+
else:
391+
raise RuntimeError(f"Unknown model name: {model_config.model_name}")

torchrec/distributed/benchmark/benchmark_train_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def runner(
197197
num_poolings=run_option.num_poolings,
198198
)
199199
bench_inputs = generate_data(
200+
model_class_name=unsharded_model.__class__.__name__,
200201
tables=tables,
201202
weighted_tables=weighted_tables,
202203
model_config=model_config,

0 commit comments

Comments
 (0)