diff --git a/torchrec/distributed/benchmark/benchmark_train_pipeline.py b/torchrec/distributed/benchmark/benchmark_train_pipeline.py index c93bbba59..3010ba9d1 100644 --- a/torchrec/distributed/benchmark/benchmark_train_pipeline.py +++ b/torchrec/distributed/benchmark/benchmark_train_pipeline.py @@ -40,7 +40,11 @@ TestTowerCollectionSparseNNConfig, TestTowerSparseNNConfig, ) -from torchrec.distributed.benchmark.benchmark_utils import benchmark_func, cmd_conf +from torchrec.distributed.benchmark.benchmark_utils import ( + benchmark_func, + BenchmarkResult, + cmd_conf, +) from torchrec.distributed.comm import get_local_size from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.planner import Topology @@ -201,15 +205,7 @@ def main( table_config: EmbeddingTablesConfig, model_selection: ModelSelectionConfig, pipeline_config: PipelineConfig, - model_config: Optional[ - Union[ - TestSparseNNConfig, - TestTowerCollectionSparseNNConfig, - TestTowerSparseNNConfig, - DeepFMConfig, - DLRMConfig, - ] - ] = None, + model_config: Optional[BaseModelConfig] = None, ) -> None: tables, weighted_tables = generate_tables( num_unweighted_features=table_config.num_unweighted_features, @@ -254,6 +250,30 @@ def main( ) +def run_pipeline( + run_option: RunOptions, + table_config: EmbeddingTablesConfig, + pipeline_config: PipelineConfig, + model_config: BaseModelConfig, +) -> List[BenchmarkResult]: + + tables, weighted_tables = generate_tables( + num_unweighted_features=table_config.num_unweighted_features, + num_weighted_features=table_config.num_weighted_features, + embedding_feature_dim=table_config.embedding_feature_dim, + ) + + return run_multi_process_func( + func=runner, + world_size=run_option.world_size, + tables=tables, + weighted_tables=weighted_tables, + run_option=run_option, + model_config=model_config, + pipeline_config=pipeline_config, + ) + + def runner( rank: int, world_size: int, @@ -262,7 +282,7 @@ def runner( run_option: RunOptions, model_config: BaseModelConfig, pipeline_config: PipelineConfig, -) -> None: +) -> BenchmarkResult: # Ensure GPUs are available and we have enough of them assert ( torch.cuda.is_available() and torch.cuda.device_count() >= world_size @@ -356,48 +376,34 @@ def _func_to_benchmark( except StopIteration: break - # Run comparison if apply_jit is True, otherwise run single benchmark - jit_configs = ( - [(True, "WithJIT"), (False, "WithoutJIT")] - if pipeline_config.apply_jit - else [(False, "")] + pipeline = generate_pipeline( + pipeline_type=pipeline_config.pipeline, + emb_lookup_stream=pipeline_config.emb_lookup_stream, + model=sharded_model, + opt=optimizer, + device=ctx.device, + apply_jit=pipeline_config.apply_jit, + ) + pipeline.progress(iter(bench_inputs)) + + result = benchmark_func( + name=type(pipeline).__name__, + bench_inputs=bench_inputs, # pyre-ignore + prof_inputs=bench_inputs, # pyre-ignore + num_benchmarks=5, + num_profiles=2, + profile_dir=run_option.profile, + world_size=run_option.world_size, + func_to_benchmark=_func_to_benchmark, + benchmark_func_kwargs={"model": sharded_model, "pipeline": pipeline}, + rank=rank, + export_stacks=run_option.export_stacks, ) - results = [] - - for apply_jit, jit_suffix in jit_configs: - pipeline = generate_pipeline( - pipeline_type=pipeline_config.pipeline, - emb_lookup_stream=pipeline_config.emb_lookup_stream, - model=sharded_model, - opt=optimizer, - device=ctx.device, - apply_jit=apply_jit, - ) - pipeline.progress(iter(bench_inputs)) - - name = ( - f"{type(pipeline).__name__}{jit_suffix}" - if jit_suffix - else type(pipeline).__name__ - ) - result = benchmark_func( - name=name, - bench_inputs=bench_inputs, # pyre-ignore - prof_inputs=bench_inputs, # pyre-ignore - num_benchmarks=5, - num_profiles=2, - profile_dir=run_option.profile, - world_size=run_option.world_size, - func_to_benchmark=_func_to_benchmark, - benchmark_func_kwargs={"model": sharded_model, "pipeline": pipeline}, - rank=rank, - export_stacks=run_option.export_stacks, - ) - results.append(result) if rank == 0: - for result in results: - print(result) + print(result) + + return result if __name__ == "__main__": diff --git a/torchrec/distributed/test_utils/multi_process.py b/torchrec/distributed/test_utils/multi_process.py index 8742227ed..cb2faf1df 100644 --- a/torchrec/distributed/test_utils/multi_process.py +++ b/torchrec/distributed/test_utils/multi_process.py @@ -189,17 +189,27 @@ def _run_multi_process_test_per_rank( self.assertEqual(0, p.exitcode) +def _wrapper_func_for_multiprocessing(args): # pyre-ignore[2, 3] + """Wrapper function that unpacks arguments and calls the original func""" + func, rank, world_size, kwargs = args + kwargs["rank"] = rank + kwargs["world_size"] = world_size + return func(**kwargs) + + +# pyre-ignore[3] def run_multi_process_func( + # pyre-ignore[2] func: Callable[ [int, int, ...], # rank, world_size, ... - None, + Any, # Changed from None to Any to allow return values ], multiprocessing_method: str = "spawn", use_deterministic_algorithms: bool = True, world_size: int = 2, # pyre-ignore **kwargs, -) -> None: +) -> List[Any]: """ """ os.environ["MASTER_ADDR"] = str("localhost") os.environ["MASTER_PORT"] = str(get_free_port()) @@ -215,22 +225,16 @@ def run_multi_process_func( if world_size == 1: kwargs["world_size"] = 1 kwargs["rank"] = 0 - func(**kwargs) - return + result = func(**kwargs) + return [result] + ctx = multiprocessing.get_context(multiprocessing_method) - processes = [] - for rank in range(world_size): - kwargs["rank"] = rank - kwargs["world_size"] = world_size - p = ctx.Process( - target=func, - name=f"rank{rank}", - kwargs=kwargs, - ) - p.start() - processes.append(p) - for p in processes: - p.join() - if p.exitcode != 0: - print(p) + # Prepare arguments for each process + args_list = [(func, rank, world_size, kwargs.copy()) for rank in range(world_size)] + + # Create a pool of worker processes for each rank + with ctx.Pool(processes=world_size) as pool: + results = pool.map(_wrapper_func_for_multiprocessing, args_list) + + return results