Skip to content

Created run_pipeline API function to get benchmark results #3237

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 56 additions & 50 deletions torchrec/distributed/benchmark/benchmark_train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@
TestTowerCollectionSparseNNConfig,
TestTowerSparseNNConfig,
)
from torchrec.distributed.benchmark.benchmark_utils import benchmark_func, cmd_conf
from torchrec.distributed.benchmark.benchmark_utils import (
benchmark_func,
BenchmarkResult,
cmd_conf,
)
from torchrec.distributed.comm import get_local_size
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.planner import Topology
Expand Down Expand Up @@ -201,15 +205,7 @@ def main(
table_config: EmbeddingTablesConfig,
model_selection: ModelSelectionConfig,
pipeline_config: PipelineConfig,
model_config: Optional[
Union[
TestSparseNNConfig,
TestTowerCollectionSparseNNConfig,
TestTowerSparseNNConfig,
DeepFMConfig,
DLRMConfig,
]
] = None,
model_config: Optional[BaseModelConfig] = None,
) -> None:
tables, weighted_tables = generate_tables(
num_unweighted_features=table_config.num_unweighted_features,
Expand Down Expand Up @@ -254,6 +250,30 @@ def main(
)


def run_pipeline(
run_option: RunOptions,
table_config: EmbeddingTablesConfig,
pipeline_config: PipelineConfig,
model_config: BaseModelConfig,
) -> List[BenchmarkResult]:

tables, weighted_tables = generate_tables(
num_unweighted_features=table_config.num_unweighted_features,
num_weighted_features=table_config.num_weighted_features,
embedding_feature_dim=table_config.embedding_feature_dim,
)

return run_multi_process_func(
func=runner,
world_size=run_option.world_size,
tables=tables,
weighted_tables=weighted_tables,
run_option=run_option,
model_config=model_config,
pipeline_config=pipeline_config,
)


def runner(
rank: int,
world_size: int,
Expand All @@ -262,7 +282,7 @@ def runner(
run_option: RunOptions,
model_config: BaseModelConfig,
pipeline_config: PipelineConfig,
) -> None:
) -> BenchmarkResult:
# Ensure GPUs are available and we have enough of them
assert (
torch.cuda.is_available() and torch.cuda.device_count() >= world_size
Expand Down Expand Up @@ -356,48 +376,34 @@ def _func_to_benchmark(
except StopIteration:
break

# Run comparison if apply_jit is True, otherwise run single benchmark
jit_configs = (
[(True, "WithJIT"), (False, "WithoutJIT")]
if pipeline_config.apply_jit
else [(False, "")]
pipeline = generate_pipeline(
pipeline_type=pipeline_config.pipeline,
emb_lookup_stream=pipeline_config.emb_lookup_stream,
model=sharded_model,
opt=optimizer,
device=ctx.device,
apply_jit=pipeline_config.apply_jit,
)
pipeline.progress(iter(bench_inputs))

result = benchmark_func(
name=type(pipeline).__name__,
bench_inputs=bench_inputs, # pyre-ignore
prof_inputs=bench_inputs, # pyre-ignore
num_benchmarks=5,
num_profiles=2,
profile_dir=run_option.profile,
world_size=run_option.world_size,
func_to_benchmark=_func_to_benchmark,
benchmark_func_kwargs={"model": sharded_model, "pipeline": pipeline},
rank=rank,
export_stacks=run_option.export_stacks,
)
results = []

for apply_jit, jit_suffix in jit_configs:
pipeline = generate_pipeline(
pipeline_type=pipeline_config.pipeline,
emb_lookup_stream=pipeline_config.emb_lookup_stream,
model=sharded_model,
opt=optimizer,
device=ctx.device,
apply_jit=apply_jit,
)
pipeline.progress(iter(bench_inputs))

name = (
f"{type(pipeline).__name__}{jit_suffix}"
if jit_suffix
else type(pipeline).__name__
)
result = benchmark_func(
name=name,
bench_inputs=bench_inputs, # pyre-ignore
prof_inputs=bench_inputs, # pyre-ignore
num_benchmarks=5,
num_profiles=2,
profile_dir=run_option.profile,
world_size=run_option.world_size,
func_to_benchmark=_func_to_benchmark,
benchmark_func_kwargs={"model": sharded_model, "pipeline": pipeline},
rank=rank,
export_stacks=run_option.export_stacks,
)
results.append(result)

if rank == 0:
for result in results:
print(result)
print(result)

return result


if __name__ == "__main__":
Expand Down
42 changes: 23 additions & 19 deletions torchrec/distributed/test_utils/multi_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,17 +189,27 @@ def _run_multi_process_test_per_rank(
self.assertEqual(0, p.exitcode)


def _wrapper_func_for_multiprocessing(args): # pyre-ignore[2, 3]
"""Wrapper function that unpacks arguments and calls the original func"""
func, rank, world_size, kwargs = args
kwargs["rank"] = rank
kwargs["world_size"] = world_size
return func(**kwargs)


# pyre-ignore[3]
def run_multi_process_func(
# pyre-ignore[2]
func: Callable[
[int, int, ...], # rank, world_size, ...
None,
Any, # Changed from None to Any to allow return values
],
multiprocessing_method: str = "spawn",
use_deterministic_algorithms: bool = True,
world_size: int = 2,
# pyre-ignore
**kwargs,
) -> None:
) -> List[Any]:
""" """
os.environ["MASTER_ADDR"] = str("localhost")
os.environ["MASTER_PORT"] = str(get_free_port())
Expand All @@ -215,22 +225,16 @@ def run_multi_process_func(
if world_size == 1:
kwargs["world_size"] = 1
kwargs["rank"] = 0
func(**kwargs)
return
result = func(**kwargs)
return [result]

ctx = multiprocessing.get_context(multiprocessing_method)
processes = []
for rank in range(world_size):
kwargs["rank"] = rank
kwargs["world_size"] = world_size
p = ctx.Process(
target=func,
name=f"rank{rank}",
kwargs=kwargs,
)
p.start()
processes.append(p)

for p in processes:
p.join()
if p.exitcode != 0:
print(p)
# Prepare arguments for each process
args_list = [(func, rank, world_size, kwargs.copy()) for rank in range(world_size)]

# Create a pool of worker processes for each rank
with ctx.Pool(processes=world_size) as pool:
results = pool.map(_wrapper_func_for_multiprocessing, args_list)

return results
Loading