diff --git a/torchrec/distributed/benchmark/benchmark_pipeline_utils.py b/torchrec/distributed/benchmark/benchmark_pipeline_utils.py index b0a968aa5..ccff6b2ed 100644 --- a/torchrec/distributed/benchmark/benchmark_pipeline_utils.py +++ b/torchrec/distributed/benchmark/benchmark_pipeline_utils.py @@ -8,7 +8,8 @@ # pyre-strict import copy -from dataclasses import dataclass +from abc import ABC, abstractmethod +from dataclasses import dataclass, fields from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union import torch @@ -24,8 +25,9 @@ from torchrec.distributed.test_utils.test_input import ModelInput from torchrec.distributed.test_utils.test_model import ( TestEBCSharder, - TestOverArchLarge, TestSparseNN, + TestTowerCollectionSparseNN, + TestTowerSparseNN, ) from torchrec.distributed.train_pipeline import ( TrainPipelineBase, @@ -41,16 +43,56 @@ @dataclass -class ModelConfig: - batch_size: int = 8192 - num_float_features: int = 10 - feature_pooling_avg: int = 10 - use_offsets: bool = False - dev_str: str = "" - long_kjt_indices: bool = True - long_kjt_offsets: bool = True - long_kjt_lengths: bool = True - pin_memory: bool = True +class BaseModelConfig(ABC): + """ + Abstract base class for model configurations. + + This class defines the common parameters shared across all model types + and requires each concrete implementation to provide its own generate_model method. + """ + + # Common parameters for all model types + batch_size: int + num_float_features: int + feature_pooling_avg: int + use_offsets: bool + dev_str: str + long_kjt_indices: bool + long_kjt_offsets: bool + long_kjt_lengths: bool + pin_memory: bool + + @abstractmethod + def generate_model( + self, + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + dense_device: torch.device, + ) -> nn.Module: + """ + Generate a model instance based on the configuration. + + Args: + tables: List of unweighted embedding tables + weighted_tables: List of weighted embedding tables + dense_device: Device to place dense layers on + + Returns: + A neural network module instance + """ + pass + + +@dataclass +class TestSparseNNConfig(BaseModelConfig): + """Configuration for TestSparseNN model.""" + + embedding_groups: Optional[Dict[str, List[str]]] + feature_processor_modules: Optional[Dict[str, torch.nn.Module]] + max_feature_lengths: Optional[Dict[str, int]] + over_arch_clazz: Type[nn.Module] + postproc_module: Optional[nn.Module] + zch: bool def generate_model( self, @@ -60,13 +102,123 @@ def generate_model( ) -> nn.Module: return TestSparseNN( tables=tables, + num_float_features=self.num_float_features, weighted_tables=weighted_tables, dense_device=dense_device, sparse_device=torch.device("meta"), - over_arch_clazz=TestOverArchLarge, + max_feature_lengths=self.max_feature_lengths, + feature_processor_modules=self.feature_processor_modules, + over_arch_clazz=self.over_arch_clazz, + postproc_module=self.postproc_module, + embedding_groups=self.embedding_groups, + zch=self.zch, ) +@dataclass +class TestTowerSparseNNConfig(BaseModelConfig): + """Configuration for TestTowerSparseNN model.""" + + embedding_groups: Optional[Dict[str, List[str]]] = None + feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None + + def generate_model( + self, + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + dense_device: torch.device, + ) -> nn.Module: + return TestTowerSparseNN( + num_float_features=self.num_float_features, + tables=tables, + weighted_tables=weighted_tables, + dense_device=dense_device, + sparse_device=torch.device("meta"), + embedding_groups=self.embedding_groups, + feature_processor_modules=self.feature_processor_modules, + ) + + +@dataclass +class TestTowerCollectionSparseNNConfig(BaseModelConfig): + """Configuration for TestTowerCollectionSparseNN model.""" + + embedding_groups: Optional[Dict[str, List[str]]] = None + feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None + + def generate_model( + self, + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + dense_device: torch.device, + ) -> nn.Module: + return TestTowerCollectionSparseNN( + tables=tables, + weighted_tables=weighted_tables, + dense_device=dense_device, + sparse_device=torch.device("meta"), + num_float_features=self.num_float_features, + embedding_groups=self.embedding_groups, + feature_processor_modules=self.feature_processor_modules, + ) + + +@dataclass +class DeepFMConfig(BaseModelConfig): + """Configuration for DeepFM model.""" + + hidden_layer_size: int + deep_fm_dimension: int + + def generate_model( + self, + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + dense_device: torch.device, + ) -> nn.Module: + # TODO: Implement DeepFM model generation + raise NotImplementedError("DeepFM model generation not yet implemented") + + +@dataclass +class DLRMConfig(BaseModelConfig): + """Configuration for DLRM model.""" + + dense_arch_layer_sizes: List[int] + over_arch_layer_sizes: List[int] + + def generate_model( + self, + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + dense_device: torch.device, + ) -> nn.Module: + # TODO: Implement DLRM model generation + raise NotImplementedError("DLRM model generation not yet implemented") + + +# pyre-ignore[2]: Missing parameter annotation +def create_model_config(model_name: str, **kwargs) -> BaseModelConfig: + + model_configs = { + "test_sparse_nn": TestSparseNNConfig, + "test_tower_sparse_nn": TestTowerSparseNNConfig, + "test_tower_collection_sparse_nn": TestTowerCollectionSparseNNConfig, + "deepfm": DeepFMConfig, + "dlrm": DLRMConfig, + } + + if model_name not in model_configs: + raise ValueError(f"Unknown model name: {model_name}") + + # Filter kwargs to only include valid parameters for the specific model config class + model_class = model_configs[model_name] + valid_field_names = {field.name for field in fields(model_class)} + filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_field_names} + + return model_class(**filtered_kwargs) + + def generate_tables( num_unweighted_features: int, num_weighted_features: int, @@ -319,7 +471,7 @@ def generate_sharded_model_and_optimizer( def generate_data( tables: List[EmbeddingBagConfig], weighted_tables: List[EmbeddingBagConfig], - model_config: ModelConfig, + model_config: BaseModelConfig, num_batches: int, ) -> List[ModelInput]: """ diff --git a/torchrec/distributed/benchmark/benchmark_train_pipeline.py b/torchrec/distributed/benchmark/benchmark_train_pipeline.py index 537163f6c..e8aa5fd2d 100644 --- a/torchrec/distributed/benchmark/benchmark_train_pipeline.py +++ b/torchrec/distributed/benchmark/benchmark_train_pipeline.py @@ -9,19 +9,25 @@ #!/usr/bin/env python3 -from dataclasses import dataclass -from typing import List, Optional +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Type, Union import torch from fbgemm_gpu.split_embedding_configs import EmbOptimType from torch import nn from torchrec.distributed.benchmark.benchmark_pipeline_utils import ( + BaseModelConfig, + create_model_config, + DeepFMConfig, + DLRMConfig, generate_data, generate_pipeline, generate_planner, generate_sharded_model_and_optimizer, generate_tables, - ModelConfig, + TestSparseNNConfig, + TestTowerCollectionSparseNNConfig, + TestTowerSparseNNConfig, ) from torchrec.distributed.benchmark.benchmark_utils import benchmark_func, cmd_conf from torchrec.distributed.comm import get_local_size @@ -33,6 +39,7 @@ run_multi_process_func, ) from torchrec.distributed.test_utils.test_input import ModelInput +from torchrec.distributed.test_utils.test_model import TestOverArchLarge from torchrec.distributed.train_pipeline import TrainPipeline from torchrec.distributed.types import ShardingType from torchrec.modules.embedding_configs import EmbeddingBagConfig @@ -124,12 +131,53 @@ class PipelineConfig: emb_lookup_stream: str = "data_dist" +@dataclass +class ModelSelectionConfig: + model_name: str = "test_sparse_nn" + + # Common config for all model types + batch_size: int = 8192 + num_float_features: int = 10 + feature_pooling_avg: int = 10 + use_offsets: bool = False + dev_str: str = "" + long_kjt_indices: bool = True + long_kjt_offsets: bool = True + long_kjt_lengths: bool = True + pin_memory: bool = True + + # TestSparseNN specific config + embedding_groups: Optional[Dict[str, List[str]]] = None + feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None + max_feature_lengths: Optional[Dict[str, int]] = None + over_arch_clazz: Type[nn.Module] = TestOverArchLarge + postproc_module: Optional[nn.Module] = None + zch: bool = False + + # DeepFM specific config + hidden_layer_size: int = 20 + deep_fm_dimension: int = 5 + + # DLRM specific config + dense_arch_layer_sizes: List[int] = field(default_factory=lambda: [20, 10]) + over_arch_layer_sizes: List[int] = field(default_factory=lambda: [5, 3]) + + @cmd_conf def main( run_option: RunOptions, table_config: EmbeddingTablesConfig, - model_config: ModelConfig, + model_selection: ModelSelectionConfig, pipeline_config: PipelineConfig, + model_config: Optional[ + Union[ + TestSparseNNConfig, + TestTowerCollectionSparseNNConfig, + TestTowerSparseNNConfig, + DeepFMConfig, + DLRMConfig, + ] + ] = None, ) -> None: tables, weighted_tables = generate_tables( num_unweighted_features=table_config.num_unweighted_features, @@ -137,6 +185,30 @@ def main( embedding_feature_dim=table_config.embedding_feature_dim, ) + if model_config is None: + model_config = create_model_config( + model_name=model_selection.model_name, + batch_size=model_selection.batch_size, + num_float_features=model_selection.num_float_features, + feature_pooling_avg=model_selection.feature_pooling_avg, + use_offsets=model_selection.use_offsets, + dev_str=model_selection.dev_str, + long_kjt_indices=model_selection.long_kjt_indices, + long_kjt_offsets=model_selection.long_kjt_offsets, + long_kjt_lengths=model_selection.long_kjt_lengths, + pin_memory=model_selection.pin_memory, + embedding_groups=model_selection.embedding_groups, + feature_processor_modules=model_selection.feature_processor_modules, + max_feature_lengths=model_selection.max_feature_lengths, + over_arch_clazz=model_selection.over_arch_clazz, + postproc_module=model_selection.postproc_module, + zch=model_selection.zch, + hidden_layer_size=model_selection.hidden_layer_size, + deep_fm_dimension=model_selection.deep_fm_dimension, + dense_arch_layer_sizes=model_selection.dense_arch_layer_sizes, + over_arch_layer_sizes=model_selection.over_arch_layer_sizes, + ) + # launch trainers run_multi_process_func( func=runner, @@ -155,7 +227,7 @@ def runner( tables: List[EmbeddingBagConfig], weighted_tables: List[EmbeddingBagConfig], run_option: RunOptions, - model_config: ModelConfig, + model_config: BaseModelConfig, pipeline_config: PipelineConfig, ) -> None: # Ensure GPUs are available and we have enough of them