Skip to content

Commit 4f1f62d

Browse files
TroyGardenmeta-codesync[bot]
authored andcommitted
fix missing function in blocking_copy (meta-pytorch#3492)
Summary: Pull Request resolved: meta-pytorch#3492 # context * somehow the blocking_copy function is missed from previous diff: D85508674 * added it back Reviewed By: aporialiao Differential Revision: D85774050 fbshipit-source-id: 58e798d611eb0000bf2fde815745f834e3c5cd46
1 parent 0d3c048 commit 4f1f62d

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

torchrec/distributed/benchmark/benchmark_comms.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,11 +483,14 @@ def non_blocking_copy(
483483
num_concat: int,
484484
ctx: MultiProcessContext,
485485
preallocated: bool = False,
486+
use_data_copy_stream: bool = True,
486487
**_kwargs: Dict[str, Any],
487488
) -> None:
488489
with record_function("## setup ##"):
489490
main_stream = torch.cuda.current_stream()
490-
data_copy_stream = torch.cuda.Stream()
491+
data_copy_stream = (
492+
torch.cuda.Stream() if use_data_copy_stream else nullcontext()
493+
)
491494
irrelevant_data = torch.rand(dim, dim, device=ctx.device) - 0.5
492495

493496
# the host to device data transfer will block cuda execution without the `pin_memory()`
@@ -519,7 +522,8 @@ def non_blocking_copy(
519522

520523
with record_function("## pre-comms compute ##"):
521524
# make sure the data copy is done before the pre-comms compute
522-
main_stream.wait_stream(data_copy_stream)
525+
if use_data_copy_stream:
526+
main_stream.wait_stream(data_copy_stream)
523527
pre_comms = _compute(
524528
dim=dim, num_mul=num_mul, num_concat=1, ctx=ctx, x=device_data
525529
)
@@ -543,6 +547,24 @@ def preallocated_non_blocking_copy(
543547
)
544548

545549

550+
def blocking_copy(
551+
_batch_inputs: List[Dict[str, Any]],
552+
dim: int,
553+
num_mul: int,
554+
num_concat: int,
555+
ctx: MultiProcessContext,
556+
**_kwargs: Dict[str, Any],
557+
) -> None:
558+
return non_blocking_copy(
559+
_batch_inputs=_batch_inputs,
560+
dim=dim,
561+
num_mul=num_mul,
562+
num_concat=num_concat,
563+
ctx=ctx,
564+
use_data_copy_stream=False,
565+
)
566+
567+
546568
# single-rank runner
547569
def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig) -> None:
548570
# Ensure GPUs are available and we have enough of them
@@ -576,6 +598,8 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig)
576598
func = non_blocking_copy
577599
case "preallocated_non_blocking_copy":
578600
func = preallocated_non_blocking_copy
601+
case "blocking_copy":
602+
func = blocking_copy
579603
case _:
580604
raise ValueError(f"Unknown benchmark name: {arg.name}")
581605

0 commit comments

Comments
 (0)