Skip to content

Stack pipeline benchmark per-rank-results into a single BenchmarkResult #3258

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 2 additions & 212 deletions torchrec/distributed/benchmark/benchmark_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,14 @@
3. Add model-specific params to ModelSelectionConfig and create_model_config's arguments in benchmark_train_pipeline.py
"""

import copy
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union
from typing import Dict, List, Optional, Type, Union

import torch
import torch.distributed as dist
from torch import nn, optim
from torch.optim import Optimizer
from torchrec.distributed import DistributedModelParallel
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner.constants import NUM_POOLINGS, POOLING_FACTOR
from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner
from torchrec.distributed.planner.types import ParameterConstraints
from torch import nn
from torchrec.distributed.test_utils.test_input import ModelInput
from torchrec.distributed.test_utils.test_model import (
TestEBCSharder,
TestSparseNN,
TestTowerCollectionSparseNN,
TestTowerSparseNN,
Expand All @@ -47,7 +37,6 @@
PrefetchTrainPipelineSparseDist,
TrainPipelineSemiSync,
)
from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType
from torchrec.models.deepfm import SimpleDeepFMNNWrapper
from torchrec.models.dlrm import DLRMWrapper
from torchrec.modules.embedding_configs import EmbeddingBagConfig
Expand Down Expand Up @@ -248,55 +237,6 @@ def create_model_config(model_name: str, **kwargs) -> BaseModelConfig:
return model_class(**filtered_kwargs)


def generate_tables(
num_unweighted_features: int,
num_weighted_features: int,
embedding_feature_dim: int,
) -> Tuple[
List[EmbeddingBagConfig],
List[EmbeddingBagConfig],
]:
"""
Generate embedding bag configurations for both unweighted and weighted features.

This function creates two lists of EmbeddingBagConfig objects:
1. Unweighted tables: Named as "table_{i}" with feature names "feature_{i}"
2. Weighted tables: Named as "weighted_table_{i}" with feature names "weighted_feature_{i}"

For both types, the number of embeddings scales with the feature index,
calculated as max(i + 1, 100) * 1000.

Args:
num_unweighted_features (int): Number of unweighted features to generate.
num_weighted_features (int): Number of weighted features to generate.
embedding_feature_dim (int): Dimension of the embedding vectors.

Returns:
Tuple[List[EmbeddingBagConfig], List[EmbeddingBagConfig]]: A tuple containing
two lists - the first for unweighted embedding tables and the second for
weighted embedding tables.
"""
tables = [
EmbeddingBagConfig(
num_embeddings=max(i + 1, 100) * 1000,
embedding_dim=embedding_feature_dim,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
)
for i in range(num_unweighted_features)
]
weighted_tables = [
EmbeddingBagConfig(
num_embeddings=max(i + 1, 100) * 1000,
embedding_dim=embedding_feature_dim,
name="weighted_table_" + str(i),
feature_names=["weighted_feature_" + str(i)],
)
for i in range(num_weighted_features)
]
return tables, weighted_tables


def generate_pipeline(
pipeline_type: str,
emb_lookup_stream: str,
Expand Down Expand Up @@ -371,156 +311,6 @@ def generate_pipeline(
return Pipeline(model=model, optimizer=opt, device=device, apply_jit=apply_jit)


def generate_planner(
planner_type: str,
topology: Topology,
tables: Optional[List[EmbeddingBagConfig]],
weighted_tables: Optional[List[EmbeddingBagConfig]],
sharding_type: ShardingType,
compute_kernel: EmbeddingComputeKernel,
batch_sizes: List[int],
pooling_factors: Optional[List[float]],
num_poolings: Optional[List[float]],
) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]:
"""
Generate an embedding sharding planner based on the specified configuration.

Args:
planner_type: Type of planner to use ("embedding" or "hetero")
topology: Network topology for distributed training
tables: List of unweighted embedding tables
weighted_tables: List of weighted embedding tables
sharding_type: Strategy for sharding embedding tables
compute_kernel: Compute kernel to use for embedding tables
batch_sizes: Sizes of each batch
pooling_factors: Pooling factors for each feature of the table
num_poolings: Number of poolings for each feature of the table

Returns:
An instance of EmbeddingShardingPlanner or HeteroEmbeddingShardingPlanner

Raises:
RuntimeError: If an unknown planner type is specified
"""
# Create parameter constraints for tables
constraints = {}
num_batches = len(batch_sizes)

if pooling_factors is None:
pooling_factors = [POOLING_FACTOR] * num_batches

if num_poolings is None:
num_poolings = [NUM_POOLINGS] * num_batches

assert (
len(pooling_factors) == num_batches and len(num_poolings) == num_batches
), "The length of pooling_factors and num_poolings must match the number of batches."

if tables is not None:
for table in tables:
constraints[table.name] = ParameterConstraints(
sharding_types=[sharding_type.value],
compute_kernels=[compute_kernel.value],
device_group="cuda",
pooling_factors=pooling_factors,
num_poolings=num_poolings,
batch_sizes=batch_sizes,
)

if weighted_tables is not None:
for table in weighted_tables:
constraints[table.name] = ParameterConstraints(
sharding_types=[sharding_type.value],
compute_kernels=[compute_kernel.value],
device_group="cuda",
pooling_factors=pooling_factors,
num_poolings=num_poolings,
batch_sizes=batch_sizes,
is_weighted=True,
)

if planner_type == "embedding":
return EmbeddingShardingPlanner(
topology=topology,
constraints=constraints if constraints else None,
)
elif planner_type == "hetero":
topology_groups = {"cuda": topology}
return HeteroEmbeddingShardingPlanner(
topology_groups=topology_groups,
constraints=constraints if constraints else None,
)
else:
raise RuntimeError(f"Unknown planner type: {planner_type}")


def generate_sharded_model_and_optimizer(
model: nn.Module,
sharding_type: str,
kernel_type: str,
pg: dist.ProcessGroup,
device: torch.device,
fused_params: Dict[str, Any],
dense_optimizer: str,
dense_lr: float,
dense_momentum: Optional[float],
dense_weight_decay: Optional[float],
planner: Optional[
Union[
EmbeddingShardingPlanner,
HeteroEmbeddingShardingPlanner,
]
] = None,
) -> Tuple[nn.Module, Optimizer]:

sharder = TestEBCSharder(
sharding_type=sharding_type,
kernel_type=kernel_type,
fused_params=fused_params,
)
sharders = [cast(ModuleSharder[nn.Module], sharder)]

# Use planner if provided
plan = None
if planner is not None:
if pg is not None:
plan = planner.collective_plan(model, sharders, pg)
else:
plan = planner.plan(model, sharders)

sharded_model = DistributedModelParallel(
module=copy.deepcopy(model),
env=ShardingEnv.from_process_group(pg),
init_data_parallel=True,
device=device,
sharders=sharders,
plan=plan,
).to(device)

# Get dense parameters
dense_params = [
param
for name, param in sharded_model.named_parameters()
if "sparse" not in name
]

# Create optimizer based on the specified type
optimizer_class = getattr(optim, dense_optimizer)

# Create optimizer with momentum and/or weight_decay if provided
optimizer_kwargs = {"lr": dense_lr}

if dense_momentum is not None:
optimizer_kwargs["momentum"] = dense_momentum

if dense_weight_decay is not None:
optimizer_kwargs["weight_decay"] = dense_weight_decay

optimizer = optimizer_class(dense_params, **optimizer_kwargs)

return sharded_model, optimizer


def generate_data(
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
Expand Down
34 changes: 29 additions & 5 deletions torchrec/distributed/benchmark/benchmark_train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@
DLRMConfig,
generate_data,
generate_pipeline,
generate_planner,
generate_sharded_model_and_optimizer,
generate_tables,
TestSparseNNConfig,
TestTowerCollectionSparseNNConfig,
TestTowerSparseNNConfig,
Expand All @@ -44,6 +41,11 @@
benchmark_func,
BenchmarkResult,
cmd_conf,
CPUMemoryStats,
generate_planner,
generate_sharded_model_and_optimizer,
generate_tables,
GPUMemoryStats,
)
from torchrec.distributed.comm import get_local_size
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
Expand Down Expand Up @@ -255,15 +257,15 @@ def run_pipeline(
table_config: EmbeddingTablesConfig,
pipeline_config: PipelineConfig,
model_config: BaseModelConfig,
) -> List[BenchmarkResult]:
) -> BenchmarkResult:

tables, weighted_tables = generate_tables(
num_unweighted_features=table_config.num_unweighted_features,
num_weighted_features=table_config.num_weighted_features,
embedding_feature_dim=table_config.embedding_feature_dim,
)

return run_multi_process_func(
benchmark_res_per_rank = run_multi_process_func(
func=runner,
world_size=run_option.world_size,
tables=tables,
Expand All @@ -273,6 +275,28 @@ def run_pipeline(
pipeline_config=pipeline_config,
)

# Combine results from all ranks into a single BenchmarkResult
# Use timing data from rank 0, combine memory stats from all ranks
world_size = run_option.world_size

total_benchmark_res = BenchmarkResult(
short_name=benchmark_res_per_rank[0].short_name,
gpu_elapsed_time=benchmark_res_per_rank[0].gpu_elapsed_time,
cpu_elapsed_time=benchmark_res_per_rank[0].cpu_elapsed_time,
gpu_mem_stats=[GPUMemoryStats(rank, 0, 0, 0) for rank in range(world_size)],
cpu_mem_stats=[CPUMemoryStats(rank, 0) for rank in range(world_size)],
rank=0,
)

for res in benchmark_res_per_rank:
# Each rank's BenchmarkResult contains 1 GPU and 1 CPU memory measurement
if len(res.gpu_mem_stats) > 0:
total_benchmark_res.gpu_mem_stats[res.rank] = res.gpu_mem_stats[0]
if len(res.cpu_mem_stats) > 0:
total_benchmark_res.cpu_mem_stats[res.rank] = res.cpu_mem_stats[0]

return total_benchmark_res


def runner(
rank: int,
Expand Down
Loading
Loading