Skip to content

Added model config for supporting SparseNN, TowerSparseNN, TowerCollectionSparseNN models. #3104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 166 additions & 14 deletions torchrec/distributed/benchmark/benchmark_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
# pyre-strict

import copy
from dataclasses import dataclass
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union

import torch
Expand All @@ -24,8 +25,9 @@
from torchrec.distributed.test_utils.test_input import ModelInput
from torchrec.distributed.test_utils.test_model import (
TestEBCSharder,
TestOverArchLarge,
TestSparseNN,
TestTowerCollectionSparseNN,
TestTowerSparseNN,
)
from torchrec.distributed.train_pipeline import (
TrainPipelineBase,
Expand All @@ -41,16 +43,56 @@


@dataclass
class ModelConfig:
batch_size: int = 8192
num_float_features: int = 10
feature_pooling_avg: int = 10
use_offsets: bool = False
dev_str: str = ""
long_kjt_indices: bool = True
long_kjt_offsets: bool = True
long_kjt_lengths: bool = True
pin_memory: bool = True
class BaseModelConfig(ABC):
"""
Abstract base class for model configurations.

This class defines the common parameters shared across all model types
and requires each concrete implementation to provide its own generate_model method.
"""

# Common parameters for all model types
batch_size: int
num_float_features: int
feature_pooling_avg: int
use_offsets: bool
dev_str: str
long_kjt_indices: bool
long_kjt_offsets: bool
long_kjt_lengths: bool
pin_memory: bool

@abstractmethod
def generate_model(
self,
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
dense_device: torch.device,
) -> nn.Module:
"""
Generate a model instance based on the configuration.

Args:
tables: List of unweighted embedding tables
weighted_tables: List of weighted embedding tables
dense_device: Device to place dense layers on

Returns:
A neural network module instance
"""
pass


@dataclass
class TestSparseNNConfig(BaseModelConfig):
"""Configuration for TestSparseNN model."""

embedding_groups: Optional[Dict[str, List[str]]]
feature_processor_modules: Optional[Dict[str, torch.nn.Module]]
max_feature_lengths: Optional[Dict[str, int]]
over_arch_clazz: Type[nn.Module]
postproc_module: Optional[nn.Module]
zch: bool

def generate_model(
self,
Expand All @@ -60,13 +102,123 @@ def generate_model(
) -> nn.Module:
return TestSparseNN(
tables=tables,
num_float_features=self.num_float_features,
weighted_tables=weighted_tables,
dense_device=dense_device,
sparse_device=torch.device("meta"),
over_arch_clazz=TestOverArchLarge,
max_feature_lengths=self.max_feature_lengths,
feature_processor_modules=self.feature_processor_modules,
over_arch_clazz=self.over_arch_clazz,
postproc_module=self.postproc_module,
embedding_groups=self.embedding_groups,
zch=self.zch,
)


@dataclass
class TestTowerSparseNNConfig(BaseModelConfig):
"""Configuration for TestTowerSparseNN model."""

embedding_groups: Optional[Dict[str, List[str]]] = None
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None

def generate_model(
self,
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
dense_device: torch.device,
) -> nn.Module:
return TestTowerSparseNN(
num_float_features=self.num_float_features,
tables=tables,
weighted_tables=weighted_tables,
dense_device=dense_device,
sparse_device=torch.device("meta"),
embedding_groups=self.embedding_groups,
feature_processor_modules=self.feature_processor_modules,
)


@dataclass
class TestTowerCollectionSparseNNConfig(BaseModelConfig):
"""Configuration for TestTowerCollectionSparseNN model."""

embedding_groups: Optional[Dict[str, List[str]]] = None
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None

def generate_model(
self,
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
dense_device: torch.device,
) -> nn.Module:
return TestTowerCollectionSparseNN(
tables=tables,
weighted_tables=weighted_tables,
dense_device=dense_device,
sparse_device=torch.device("meta"),
num_float_features=self.num_float_features,
embedding_groups=self.embedding_groups,
feature_processor_modules=self.feature_processor_modules,
)


@dataclass
class DeepFMConfig(BaseModelConfig):
"""Configuration for DeepFM model."""

hidden_layer_size: int
deep_fm_dimension: int

def generate_model(
self,
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
dense_device: torch.device,
) -> nn.Module:
# TODO: Implement DeepFM model generation
raise NotImplementedError("DeepFM model generation not yet implemented")


@dataclass
class DLRMConfig(BaseModelConfig):
"""Configuration for DLRM model."""

dense_arch_layer_sizes: List[int]
over_arch_layer_sizes: List[int]

def generate_model(
self,
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
dense_device: torch.device,
) -> nn.Module:
# TODO: Implement DLRM model generation
raise NotImplementedError("DLRM model generation not yet implemented")


# pyre-ignore[2]: Missing parameter annotation
def create_model_config(model_name: str, **kwargs) -> BaseModelConfig:

model_configs = {
"test_sparse_nn": TestSparseNNConfig,
"test_tower_sparse_nn": TestTowerSparseNNConfig,
"test_tower_collection_sparse_nn": TestTowerCollectionSparseNNConfig,
"deepfm": DeepFMConfig,
"dlrm": DLRMConfig,
}

if model_name not in model_configs:
raise ValueError(f"Unknown model name: {model_name}")

# Filter kwargs to only include valid parameters for the specific model config class
model_class = model_configs[model_name]
valid_field_names = {field.name for field in fields(model_class)}
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_field_names}

return model_class(**filtered_kwargs)


def generate_tables(
num_unweighted_features: int,
num_weighted_features: int,
Expand Down Expand Up @@ -319,7 +471,7 @@ def generate_sharded_model_and_optimizer(
def generate_data(
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
model_config: ModelConfig,
model_config: BaseModelConfig,
num_batches: int,
) -> List[ModelInput]:
"""
Expand Down
82 changes: 77 additions & 5 deletions torchrec/distributed/benchmark/benchmark_train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,25 @@

#!/usr/bin/env python3

from dataclasses import dataclass
from typing import List, Optional
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Type, Union

import torch
from fbgemm_gpu.split_embedding_configs import EmbOptimType
from torch import nn
from torchrec.distributed.benchmark.benchmark_pipeline_utils import (
BaseModelConfig,
create_model_config,
DeepFMConfig,
DLRMConfig,
generate_data,
generate_pipeline,
generate_planner,
generate_sharded_model_and_optimizer,
generate_tables,
ModelConfig,
TestSparseNNConfig,
TestTowerCollectionSparseNNConfig,
TestTowerSparseNNConfig,
)
from torchrec.distributed.benchmark.benchmark_utils import benchmark_func, cmd_conf
from torchrec.distributed.comm import get_local_size
Expand All @@ -33,6 +39,7 @@
run_multi_process_func,
)
from torchrec.distributed.test_utils.test_input import ModelInput
from torchrec.distributed.test_utils.test_model import TestOverArchLarge
from torchrec.distributed.train_pipeline import TrainPipeline
from torchrec.distributed.types import ShardingType
from torchrec.modules.embedding_configs import EmbeddingBagConfig
Expand Down Expand Up @@ -124,19 +131,84 @@ class PipelineConfig:
emb_lookup_stream: str = "data_dist"


@dataclass
class ModelSelectionConfig:
model_name: str = "test_sparse_nn"

# Common config for all model types
batch_size: int = 8192
num_float_features: int = 10
feature_pooling_avg: int = 10
use_offsets: bool = False
dev_str: str = ""
long_kjt_indices: bool = True
long_kjt_offsets: bool = True
long_kjt_lengths: bool = True
pin_memory: bool = True

# TestSparseNN specific config
embedding_groups: Optional[Dict[str, List[str]]] = None
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None
max_feature_lengths: Optional[Dict[str, int]] = None
over_arch_clazz: Type[nn.Module] = TestOverArchLarge
postproc_module: Optional[nn.Module] = None
zch: bool = False

# DeepFM specific config
hidden_layer_size: int = 20
deep_fm_dimension: int = 5

# DLRM specific config
dense_arch_layer_sizes: List[int] = field(default_factory=lambda: [20, 10])
over_arch_layer_sizes: List[int] = field(default_factory=lambda: [5, 3])


@cmd_conf
def main(
run_option: RunOptions,
table_config: EmbeddingTablesConfig,
model_config: ModelConfig,
model_selection: ModelSelectionConfig,
pipeline_config: PipelineConfig,
model_config: Optional[
Union[
TestSparseNNConfig,
TestTowerCollectionSparseNNConfig,
TestTowerSparseNNConfig,
DeepFMConfig,
DLRMConfig,
]
] = None,
) -> None:
tables, weighted_tables = generate_tables(
num_unweighted_features=table_config.num_unweighted_features,
num_weighted_features=table_config.num_weighted_features,
embedding_feature_dim=table_config.embedding_feature_dim,
)

if model_config is None:
model_config = create_model_config(
model_name=model_selection.model_name,
batch_size=model_selection.batch_size,
num_float_features=model_selection.num_float_features,
feature_pooling_avg=model_selection.feature_pooling_avg,
use_offsets=model_selection.use_offsets,
dev_str=model_selection.dev_str,
long_kjt_indices=model_selection.long_kjt_indices,
long_kjt_offsets=model_selection.long_kjt_offsets,
long_kjt_lengths=model_selection.long_kjt_lengths,
pin_memory=model_selection.pin_memory,
embedding_groups=model_selection.embedding_groups,
feature_processor_modules=model_selection.feature_processor_modules,
max_feature_lengths=model_selection.max_feature_lengths,
over_arch_clazz=model_selection.over_arch_clazz,
postproc_module=model_selection.postproc_module,
zch=model_selection.zch,
hidden_layer_size=model_selection.hidden_layer_size,
deep_fm_dimension=model_selection.deep_fm_dimension,
dense_arch_layer_sizes=model_selection.dense_arch_layer_sizes,
over_arch_layer_sizes=model_selection.over_arch_layer_sizes,
)

# launch trainers
run_multi_process_func(
func=runner,
Expand All @@ -155,7 +227,7 @@ def runner(
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
run_option: RunOptions,
model_config: ModelConfig,
model_config: BaseModelConfig,
pipeline_config: PipelineConfig,
) -> None:
# Ensure GPUs are available and we have enough of them
Expand Down
Loading