From 98bcc8a0e49646bd585ecb542edfd631ebef1bde Mon Sep 17 00:00:00 2001 From: yernar Date: Wed, 6 Aug 2025 18:30:04 -0700 Subject: [PATCH 1/2] Temporary Commit at 8/2/2025, 10:15:28 PM Differential Revision: D79515598 --- .../benchmark/benchmark_pipeline_utils.py | 214 +------- .../benchmark/benchmark_train_pipeline.py | 6 +- .../distributed/benchmark/benchmark_utils.py | 486 +++++++++++++++++- 3 files changed, 490 insertions(+), 216 deletions(-) diff --git a/torchrec/distributed/benchmark/benchmark_pipeline_utils.py b/torchrec/distributed/benchmark/benchmark_pipeline_utils.py index 9a1fa4647..dae5d8842 100644 --- a/torchrec/distributed/benchmark/benchmark_pipeline_utils.py +++ b/torchrec/distributed/benchmark/benchmark_pipeline_utils.py @@ -16,24 +16,14 @@ 3. Add model-specific params to ModelSelectionConfig and create_model_config's arguments in benchmark_train_pipeline.py """ -import copy from abc import ABC, abstractmethod from dataclasses import dataclass, fields -from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union +from typing import Dict, List, Optional, Type, Union import torch -import torch.distributed as dist -from torch import nn, optim -from torch.optim import Optimizer -from torchrec.distributed import DistributedModelParallel -from torchrec.distributed.embedding_types import EmbeddingComputeKernel -from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology -from torchrec.distributed.planner.constants import NUM_POOLINGS, POOLING_FACTOR -from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner -from torchrec.distributed.planner.types import ParameterConstraints +from torch import nn from torchrec.distributed.test_utils.test_input import ModelInput from torchrec.distributed.test_utils.test_model import ( - TestEBCSharder, TestSparseNN, TestTowerCollectionSparseNN, TestTowerSparseNN, @@ -47,7 +37,6 @@ PrefetchTrainPipelineSparseDist, TrainPipelineSemiSync, ) -from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType from torchrec.models.deepfm import SimpleDeepFMNNWrapper from torchrec.models.dlrm import DLRMWrapper from torchrec.modules.embedding_configs import EmbeddingBagConfig @@ -248,55 +237,6 @@ def create_model_config(model_name: str, **kwargs) -> BaseModelConfig: return model_class(**filtered_kwargs) -def generate_tables( - num_unweighted_features: int, - num_weighted_features: int, - embedding_feature_dim: int, -) -> Tuple[ - List[EmbeddingBagConfig], - List[EmbeddingBagConfig], -]: - """ - Generate embedding bag configurations for both unweighted and weighted features. - - This function creates two lists of EmbeddingBagConfig objects: - 1. Unweighted tables: Named as "table_{i}" with feature names "feature_{i}" - 2. Weighted tables: Named as "weighted_table_{i}" with feature names "weighted_feature_{i}" - - For both types, the number of embeddings scales with the feature index, - calculated as max(i + 1, 100) * 1000. - - Args: - num_unweighted_features (int): Number of unweighted features to generate. - num_weighted_features (int): Number of weighted features to generate. - embedding_feature_dim (int): Dimension of the embedding vectors. - - Returns: - Tuple[List[EmbeddingBagConfig], List[EmbeddingBagConfig]]: A tuple containing - two lists - the first for unweighted embedding tables and the second for - weighted embedding tables. - """ - tables = [ - EmbeddingBagConfig( - num_embeddings=max(i + 1, 100) * 1000, - embedding_dim=embedding_feature_dim, - name="table_" + str(i), - feature_names=["feature_" + str(i)], - ) - for i in range(num_unweighted_features) - ] - weighted_tables = [ - EmbeddingBagConfig( - num_embeddings=max(i + 1, 100) * 1000, - embedding_dim=embedding_feature_dim, - name="weighted_table_" + str(i), - feature_names=["weighted_feature_" + str(i)], - ) - for i in range(num_weighted_features) - ] - return tables, weighted_tables - - def generate_pipeline( pipeline_type: str, emb_lookup_stream: str, @@ -371,156 +311,6 @@ def generate_pipeline( return Pipeline(model=model, optimizer=opt, device=device, apply_jit=apply_jit) -def generate_planner( - planner_type: str, - topology: Topology, - tables: Optional[List[EmbeddingBagConfig]], - weighted_tables: Optional[List[EmbeddingBagConfig]], - sharding_type: ShardingType, - compute_kernel: EmbeddingComputeKernel, - batch_sizes: List[int], - pooling_factors: Optional[List[float]], - num_poolings: Optional[List[float]], -) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]: - """ - Generate an embedding sharding planner based on the specified configuration. - - Args: - planner_type: Type of planner to use ("embedding" or "hetero") - topology: Network topology for distributed training - tables: List of unweighted embedding tables - weighted_tables: List of weighted embedding tables - sharding_type: Strategy for sharding embedding tables - compute_kernel: Compute kernel to use for embedding tables - batch_sizes: Sizes of each batch - pooling_factors: Pooling factors for each feature of the table - num_poolings: Number of poolings for each feature of the table - - Returns: - An instance of EmbeddingShardingPlanner or HeteroEmbeddingShardingPlanner - - Raises: - RuntimeError: If an unknown planner type is specified - """ - # Create parameter constraints for tables - constraints = {} - num_batches = len(batch_sizes) - - if pooling_factors is None: - pooling_factors = [POOLING_FACTOR] * num_batches - - if num_poolings is None: - num_poolings = [NUM_POOLINGS] * num_batches - - assert ( - len(pooling_factors) == num_batches and len(num_poolings) == num_batches - ), "The length of pooling_factors and num_poolings must match the number of batches." - - if tables is not None: - for table in tables: - constraints[table.name] = ParameterConstraints( - sharding_types=[sharding_type.value], - compute_kernels=[compute_kernel.value], - device_group="cuda", - pooling_factors=pooling_factors, - num_poolings=num_poolings, - batch_sizes=batch_sizes, - ) - - if weighted_tables is not None: - for table in weighted_tables: - constraints[table.name] = ParameterConstraints( - sharding_types=[sharding_type.value], - compute_kernels=[compute_kernel.value], - device_group="cuda", - pooling_factors=pooling_factors, - num_poolings=num_poolings, - batch_sizes=batch_sizes, - is_weighted=True, - ) - - if planner_type == "embedding": - return EmbeddingShardingPlanner( - topology=topology, - constraints=constraints if constraints else None, - ) - elif planner_type == "hetero": - topology_groups = {"cuda": topology} - return HeteroEmbeddingShardingPlanner( - topology_groups=topology_groups, - constraints=constraints if constraints else None, - ) - else: - raise RuntimeError(f"Unknown planner type: {planner_type}") - - -def generate_sharded_model_and_optimizer( - model: nn.Module, - sharding_type: str, - kernel_type: str, - pg: dist.ProcessGroup, - device: torch.device, - fused_params: Dict[str, Any], - dense_optimizer: str, - dense_lr: float, - dense_momentum: Optional[float], - dense_weight_decay: Optional[float], - planner: Optional[ - Union[ - EmbeddingShardingPlanner, - HeteroEmbeddingShardingPlanner, - ] - ] = None, -) -> Tuple[nn.Module, Optimizer]: - - sharder = TestEBCSharder( - sharding_type=sharding_type, - kernel_type=kernel_type, - fused_params=fused_params, - ) - sharders = [cast(ModuleSharder[nn.Module], sharder)] - - # Use planner if provided - plan = None - if planner is not None: - if pg is not None: - plan = planner.collective_plan(model, sharders, pg) - else: - plan = planner.plan(model, sharders) - - sharded_model = DistributedModelParallel( - module=copy.deepcopy(model), - env=ShardingEnv.from_process_group(pg), - init_data_parallel=True, - device=device, - sharders=sharders, - plan=plan, - ).to(device) - - # Get dense parameters - dense_params = [ - param - for name, param in sharded_model.named_parameters() - if "sparse" not in name - ] - - # Create optimizer based on the specified type - optimizer_class = getattr(optim, dense_optimizer) - - # Create optimizer with momentum and/or weight_decay if provided - optimizer_kwargs = {"lr": dense_lr} - - if dense_momentum is not None: - optimizer_kwargs["momentum"] = dense_momentum - - if dense_weight_decay is not None: - optimizer_kwargs["weight_decay"] = dense_weight_decay - - optimizer = optimizer_class(dense_params, **optimizer_kwargs) - - return sharded_model, optimizer - - def generate_data( tables: List[EmbeddingBagConfig], weighted_tables: List[EmbeddingBagConfig], diff --git a/torchrec/distributed/benchmark/benchmark_train_pipeline.py b/torchrec/distributed/benchmark/benchmark_train_pipeline.py index 3010ba9d1..2e22ab0a5 100644 --- a/torchrec/distributed/benchmark/benchmark_train_pipeline.py +++ b/torchrec/distributed/benchmark/benchmark_train_pipeline.py @@ -33,9 +33,6 @@ DLRMConfig, generate_data, generate_pipeline, - generate_planner, - generate_sharded_model_and_optimizer, - generate_tables, TestSparseNNConfig, TestTowerCollectionSparseNNConfig, TestTowerSparseNNConfig, @@ -44,6 +41,9 @@ benchmark_func, BenchmarkResult, cmd_conf, + generate_planner, + generate_sharded_model_and_optimizer, + generate_tables, ) from torchrec.distributed.comm import get_local_size from torchrec.distributed.embedding_types import EmbeddingComputeKernel diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index c78ff5b3a..b70f28662 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -11,6 +11,10 @@ #!/usr/bin/env python3 import argparse +import contextlib + +# Additional imports for the new benchmark_module function +import copy import inspect import json import logging @@ -23,6 +27,7 @@ from typing import ( Any, Callable, + cast, Dict, get_args, get_origin, @@ -34,9 +39,23 @@ ) import torch +import torch.distributed as dist import yaml -from torch import multiprocessing as mp +from fbgemm_gpu.split_embedding_configs import EmbOptimType +from torch import multiprocessing as mp, nn, optim from torch.autograd.profiler import record_function +from torch.optim import Optimizer +from torchrec.distributed import DistributedModelParallel +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology +from torchrec.distributed.planner.constants import NUM_POOLINGS, POOLING_FACTOR +from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner +from torchrec.distributed.planner.types import ParameterConstraints +from torchrec.distributed.test_utils.multi_process import MultiProcessContext +from torchrec.distributed.test_utils.test_input import ModelInput +from torchrec.distributed.test_utils.test_model import TestEBCSharder +from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType +from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.test_utils import get_free_port @@ -362,6 +381,471 @@ def set_embedding_config( return embedding_configs, pooling_configs +def generate_tables( + num_unweighted_features: int = 100, + num_weighted_features: int = 100, + embedding_feature_dim: int = 128, +) -> Tuple[ + List[EmbeddingBagConfig], + List[EmbeddingBagConfig], +]: + """ + Generate embedding bag configurations for both unweighted and weighted features. + + This function creates two lists of EmbeddingBagConfig objects: + 1. Unweighted tables: Named as "table_{i}" with feature names "feature_{i}" + 2. Weighted tables: Named as "weighted_table_{i}" with feature names "weighted_feature_{i}" + + For both types, the number of embeddings scales with the feature index, + calculated as max(i + 1, 100) * 1000. + + Args: + num_unweighted_features (int): Number of unweighted features to generate. + num_weighted_features (int): Number of weighted features to generate. + embedding_feature_dim (int): Dimension of the embedding vectors. + + Returns: + Tuple[List[EmbeddingBagConfig], List[EmbeddingBagConfig]]: A tuple containing + two lists - the first for unweighted embedding tables and the second for + weighted embedding tables. + """ + tables = [ + EmbeddingBagConfig( + num_embeddings=max(i + 1, 100) * 1000, + embedding_dim=embedding_feature_dim, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_unweighted_features) + ] + weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=max(i + 1, 100) * 1000, + embedding_dim=embedding_feature_dim, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(num_weighted_features) + ] + return tables, weighted_tables + + +def generate_planner( + planner_type: str, + topology: Topology, + tables: Optional[List[EmbeddingBagConfig]], + weighted_tables: Optional[List[EmbeddingBagConfig]], + sharding_type: ShardingType, + compute_kernel: EmbeddingComputeKernel, + batch_sizes: List[int], + pooling_factors: Optional[List[float]] = None, + num_poolings: Optional[List[float]] = None, +) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]: + """ + Generate an embedding sharding planner based on the specified configuration. + + Args: + planner_type: Type of planner to use ("embedding" or "hetero") + topology: Network topology for distributed training + tables: List of unweighted embedding tables + weighted_tables: List of weighted embedding tables + sharding_type: Strategy for sharding embedding tables + compute_kernel: Compute kernel to use for embedding tables + batch_sizes: Sizes of each batch + pooling_factors: Pooling factors for each feature of the table + num_poolings: Number of poolings for each feature of the table + + Returns: + An instance of EmbeddingShardingPlanner or HeteroEmbeddingShardingPlanner + + Raises: + RuntimeError: If an unknown planner type is specified + """ + # Create parameter constraints for tables + constraints = {} + num_batches = len(batch_sizes) + + if pooling_factors is None: + pooling_factors = [POOLING_FACTOR] * num_batches + + if num_poolings is None: + num_poolings = [NUM_POOLINGS] * num_batches + + assert ( + len(pooling_factors) == num_batches and len(num_poolings) == num_batches + ), "The length of pooling_factors and num_poolings must match the number of batches." + + if tables is not None: + for table in tables: + constraints[table.name] = ParameterConstraints( + sharding_types=[sharding_type.value], + compute_kernels=[compute_kernel.value], + device_group="cuda", + pooling_factors=pooling_factors, + num_poolings=num_poolings, + batch_sizes=batch_sizes, + ) + + if weighted_tables is not None: + for table in weighted_tables: + constraints[table.name] = ParameterConstraints( + sharding_types=[sharding_type.value], + compute_kernels=[compute_kernel.value], + device_group="cuda", + pooling_factors=pooling_factors, + num_poolings=num_poolings, + batch_sizes=batch_sizes, + is_weighted=True, + ) + + if planner_type == "embedding": + return EmbeddingShardingPlanner( + topology=topology, + constraints=constraints if constraints else None, + ) + elif planner_type == "hetero": + topology_groups = {"cuda": topology} + return HeteroEmbeddingShardingPlanner( + topology_groups=topology_groups, + constraints=constraints if constraints else None, + ) + else: + raise RuntimeError(f"Unknown planner type: {planner_type}") + + +def generate_sharded_model_and_optimizer( + model: nn.Module, + sharding_type: str, + kernel_type: str, + pg: dist.ProcessGroup, + device: torch.device, + fused_params: Dict[str, Any], + dense_optimizer: str = "SGD", + dense_lr: float = 0.1, + dense_momentum: Optional[float] = None, + dense_weight_decay: Optional[float] = None, + planner: Optional[ + Union[ + EmbeddingShardingPlanner, + HeteroEmbeddingShardingPlanner, + ] + ] = None, +) -> Tuple[nn.Module, Optimizer]: + """ + Generate a sharded model and optimizer for distributed training. + + Args: + model: The model to be sharded + sharding_type: Type of sharding strategy + kernel_type: Type of compute kernel + pg: Process group for distributed training + device: Device to place the model on + fused_params: Parameters for the fused optimizer + dense_optimizer: Optimizer type for dense parameters + dense_lr: Learning rate for dense parameters + dense_momentum: Momentum for dense parameters (optional) + dense_weight_decay: Weight decay for dense parameters (optional) + planner: Optional planner for sharding strategy + + Returns: + Tuple of sharded model and optimizer + """ + sharder = TestEBCSharder( + sharding_type=sharding_type, + kernel_type=kernel_type, + fused_params=fused_params, + ) + sharders = [cast(ModuleSharder[nn.Module], sharder)] + + # Use planner if provided + plan = None + if planner is not None: + if pg is not None: + plan = planner.collective_plan(model, sharders, pg) + else: + plan = planner.plan(model, sharders) + + sharded_model = DistributedModelParallel( + module=copy.deepcopy(model), + env=ShardingEnv.from_process_group(pg), + init_data_parallel=True, + device=device, + sharders=sharders, + plan=plan, + ).to(device) + + # Get dense parameters + dense_params = [ + param + for name, param in sharded_model.named_parameters() + if "sparse" not in name + ] + + # Create optimizer based on the specified type + optimizer_class = getattr(optim, dense_optimizer) + + # Create optimizer with momentum and/or weight_decay if provided + optimizer_kwargs = {"lr": dense_lr} + + if dense_momentum is not None: + optimizer_kwargs["momentum"] = dense_momentum + + if dense_weight_decay is not None: + optimizer_kwargs["weight_decay"] = dense_weight_decay + + optimizer = optimizer_class(dense_params, **optimizer_kwargs) + + return sharded_model, optimizer + + +def _init_module_and_run_benchmark( + module: torch.nn.Module, + sharding_type: ShardingType, + planner_type: str, + compute_kernel: EmbeddingComputeKernel, + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + batch_size: int, + num_benchmarks: int, + world_size: int, + num_float_features: int = 0, + rank: int = -1, + queue: Optional[mp.Queue] = None, + device_type: str = "cuda", + warmup_iters: int = 20, + bench_iters: int = 100, + prof_iters: int = 20, +) -> None: + """ + Initialize module and run benchmark for a single process. + + This is a simplified version of init_module_and_run_benchmark from benchmark_ebc.py + that doesn't handle compile modes and focuses on the core benchmarking functionality. + """ + from torchrec.distributed.comm import get_local_size + + # Generate input data + num_inputs_to_gen = warmup_iters + bench_iters + prof_iters + + batch_sizes = [batch_size] * num_inputs_to_gen + inputs_batch = [] + + for _ in range(num_inputs_to_gen): + model_input_by_rank = [] + for _ in range(world_size): + model_input_by_rank.append( + ModelInput.generate( + batch_size=batch_size, + num_float_features=num_float_features, + tables=tables, + weighted_tables=weighted_tables, + indices_dtype=torch.int32, + lengths_dtype=torch.int32, + ) + ) + + inputs_batch.append(model_input_by_rank) + + # Transpose to get inputs by rank: [R x B] format + inputs_by_rank = list(zip(*inputs_batch)) + + if rank >= 0: + warmup_inputs_cuda = [ + warmup_input.to(torch.device(f"{device_type}:{rank}")) + for warmup_input in inputs_by_rank[rank][:warmup_iters] + ] + bench_inputs_cuda = [ + bench_input.to(torch.device(f"{device_type}:{rank}")) + for bench_input in inputs_by_rank[rank][ + warmup_iters : warmup_iters + bench_iters + ] + ] + prof_inputs_cuda = [ + prof_input.to(torch.device(f"{device_type}:{rank}")) + for prof_input in inputs_by_rank[rank][-prof_iters:] + ] + else: + warmup_inputs_cuda = [ + warmup_input.to(torch.device(f"{device_type}:0")) + for warmup_input in inputs_by_rank[0][:warmup_iters] + ] + bench_inputs_cuda = [ + bench_input.to(torch.device(f"{device_type}:0")) + for bench_input in inputs_by_rank[0][ + warmup_iters : warmup_iters + bench_iters + ] + ] + prof_inputs_cuda = [ + prof_input.to(torch.device(f"{device_type}:0")) + for prof_input in inputs_by_rank[0][-prof_iters:] + ] + + with ( + MultiProcessContext( + rank, world_size, "nccl", use_deterministic_algorithms=False + ) + if rank != -1 + else contextlib.nullcontext() + ) as ctx: + # Create topology and planner + topology = Topology( + local_world_size=get_local_size(world_size), + world_size=world_size, + compute_device=device_type, + ) + + planner = generate_planner( + planner_type=planner_type, + topology=topology, + tables=tables, + weighted_tables=weighted_tables, + sharding_type=sharding_type, + compute_kernel=compute_kernel, + batch_sizes=batch_sizes[ + :num_benchmarks + ], # Use only benchmark batches for planning + ) + + # Prepare fused_params for sparse optimizer + fused_params = { + "optimizer": EmbOptimType.EXACT_ADAGRAD, + "learning_rate": 0.1, + } + + device = ctx.device if rank != -1 else torch.device(device_type) + pg = ctx.pg if rank != -1 else None + + sharded_model, _ = generate_sharded_model_and_optimizer( + model=module, + sharding_type=sharding_type.value, + kernel_type=compute_kernel.value, + pg=pg, + device=device, + fused_params=fused_params, + planner=planner, + ) + + def _func_to_benchmark( + model: torch.nn.Module, bench_inputs: List[KeyedJaggedTensor] + ) -> None: + for bench_input in bench_inputs: + model(bench_input) + + name = f"{sharding_type.value}-{planner_type}" + + res = benchmark( + name, + sharded_model, + warmup_inputs_cuda, + bench_inputs_cuda, + prof_inputs_cuda, + world_size=world_size, + output_dir="", + num_benchmarks=num_benchmarks, + func_to_benchmark=_func_to_benchmark, + benchmark_func_kwargs=None, + rank=rank, + device_type=device_type, + benchmark_unsharded_module=False, + ) + + if queue is not None: + queue.put(res) + + +def benchmark_module( + module: torch.nn.Module, + tables: List[EmbeddingBagConfig], + weighted_tables: Optional[List[EmbeddingBagConfig]] = None, + num_float_features: int = 0, + sharding_type: ShardingType = ShardingType.TABLE_WISE, + planner_type: str = "embedding", + world_size: int = 2, + num_benchmarks: int = 5, + batch_size: int = 2048, + compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.FUSED, + device_type: str = "cuda", +) -> BenchmarkResult: + """ + Benchmark any PyTorch module with distributed sharding. + + This function provides a simple interface to benchmark arbitrary PyTorch modules + using TorchRec's distributed sharding capabilities. It uses the provided embedding + tables to generate input data, sets up multiprocessing for distributed training, + and returns comprehensive benchmark results. + + Args: + module: PyTorch module to benchmark + tables: List of unweighted embedding table configurations + weighted_tables: Optional list of weighted embedding table configurations + sharding_type: Strategy for sharding embedding tables across devices + planner_type: Type of planner to use ("embedding" or "hetero") + world_size: Number of processes/GPUs to use for distributed training + num_benchmarks: Number of iterations to run for statistical analysis + batch_size: Batch size to use for benchmarking + compute_kernel: Compute kernel to use for embedding tables + device_type: Device type to use ("cuda" or "cpu") + + Returns: + BenchmarkResult containing timing and memory statistics + + Example: + from torchrec.modules.embedding_modules import EmbeddingBagCollection + from torchrec.modules.embedding_configs import EmbeddingBagConfig + + # Create embedding tables + tables = [ + EmbeddingBagConfig( + name="table_0", embedding_dim=128, num_embeddings=100000, + feature_names=["feature_0"] + ) + ] + + # Create a simple EBC module + ebc = EmbeddingBagCollection(tables=tables) + + # Benchmark it + result = benchmark_module( + module=ebc, + tables=tables, + world_size=2, + num_benchmarks=10 + ) + print(result) + """ + logger.info(f"Starting benchmark for module: {type(module).__name__}") + logger.info(f"Sharding type: {sharding_type}") + logger.info(f"Planner type: {planner_type}") + logger.info(f"World size: {world_size}") + logger.info(f"Batch size: {batch_size}") + logger.info(f"Number of benchmarks: {num_benchmarks}") + + assert ( + num_benchmarks > 2 + ), "num_benchmarks needs to be greater than 2 for statistical analysis" + + # Use provided tables or default to empty list for weighted tables + if weighted_tables is None: + weighted_tables = [] + + res = multi_process_benchmark( + callable=_init_module_and_run_benchmark, + module=module, + sharding_type=sharding_type, + planner_type=planner_type, + compute_kernel=compute_kernel, + tables=tables, + weighted_tables=weighted_tables, + batch_size=batch_size, + num_benchmarks=num_benchmarks, + world_size=world_size, + num_float_features=num_float_features, + device_type=device_type, + ) + + return res + + # pyre-ignore [24] def cmd_conf(func: Callable) -> Callable: From 0b785c165b324100c03af26298be0433a0eac7c2 Mon Sep 17 00:00:00 2001 From: Yernar Sadybekov Date: Fri, 8 Aug 2025 09:24:46 -0700 Subject: [PATCH 2/2] Stack pipeline benchmark per-rank-results into a single BenchmarkResult (#3258) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/3258 Stack pipeline benchmark per-rank-results into a single BenchmarkResult to measure overall performance across GPUs. This will cause 'intentional' **2x regression** for allocated memory metrics since we are accumulating the total memory stats across two given GPUs instead of measuring only one. After the change is landed, I will **re-register** the ServiceLab task with the new baseline metrics to avoid future false regression warnings. Reviewed By: aliafzal Differential Revision: D79537357 --- .../benchmark/benchmark_train_pipeline.py | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/torchrec/distributed/benchmark/benchmark_train_pipeline.py b/torchrec/distributed/benchmark/benchmark_train_pipeline.py index 2e22ab0a5..8b990b888 100644 --- a/torchrec/distributed/benchmark/benchmark_train_pipeline.py +++ b/torchrec/distributed/benchmark/benchmark_train_pipeline.py @@ -41,9 +41,11 @@ benchmark_func, BenchmarkResult, cmd_conf, + CPUMemoryStats, generate_planner, generate_sharded_model_and_optimizer, generate_tables, + GPUMemoryStats, ) from torchrec.distributed.comm import get_local_size from torchrec.distributed.embedding_types import EmbeddingComputeKernel @@ -255,7 +257,7 @@ def run_pipeline( table_config: EmbeddingTablesConfig, pipeline_config: PipelineConfig, model_config: BaseModelConfig, -) -> List[BenchmarkResult]: +) -> BenchmarkResult: tables, weighted_tables = generate_tables( num_unweighted_features=table_config.num_unweighted_features, @@ -263,7 +265,7 @@ def run_pipeline( embedding_feature_dim=table_config.embedding_feature_dim, ) - return run_multi_process_func( + benchmark_res_per_rank = run_multi_process_func( func=runner, world_size=run_option.world_size, tables=tables, @@ -273,6 +275,28 @@ def run_pipeline( pipeline_config=pipeline_config, ) + # Combine results from all ranks into a single BenchmarkResult + # Use timing data from rank 0, combine memory stats from all ranks + world_size = run_option.world_size + + total_benchmark_res = BenchmarkResult( + short_name=benchmark_res_per_rank[0].short_name, + gpu_elapsed_time=benchmark_res_per_rank[0].gpu_elapsed_time, + cpu_elapsed_time=benchmark_res_per_rank[0].cpu_elapsed_time, + gpu_mem_stats=[GPUMemoryStats(rank, 0, 0, 0) for rank in range(world_size)], + cpu_mem_stats=[CPUMemoryStats(rank, 0) for rank in range(world_size)], + rank=0, + ) + + for res in benchmark_res_per_rank: + # Each rank's BenchmarkResult contains 1 GPU and 1 CPU memory measurement + if len(res.gpu_mem_stats) > 0: + total_benchmark_res.gpu_mem_stats[res.rank] = res.gpu_mem_stats[0] + if len(res.cpu_mem_stats) > 0: + total_benchmark_res.cpu_mem_stats[res.rank] = res.cpu_mem_stats[0] + + return total_benchmark_res + def runner( rank: int,