Skip to content

Commit 9a05d32

Browse files
gchalumpmeta-codesync[bot]
authored andcommitted
Add kineto tracing to bench:jagged_tensor (#5061)
Summary: Pull Request resolved: #5061 X-link: https://github.com/facebookresearch/FBGEMM/pull/2063 Pull Request resolved: #5039 X-link: https://github.com/facebookresearch/FBGEMM/pull/2048 Add kineto tracing to bench:jagged_tensor - only to keyed-jagged-index-select-dim1 for now Reviewed By: spcyppt Differential Revision: D85169086 fbshipit-source-id: 309a28eb93553196949e98d864ecc4b683b12e1f
1 parent b0dffd3 commit 9a05d32

File tree

1 file changed

+44
-17
lines changed

1 file changed

+44
-17
lines changed

fbgemm_gpu/bench/jagged_tensor_benchmark.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010

1111
import functools
1212
import logging
13+
import os
1314
import random
15+
from contextlib import nullcontext
1416
from dataclasses import dataclass
17+
from typing import Callable
1518

1619
import click
1720
import fbgemm_gpu
@@ -542,6 +545,17 @@ def ref(
542545
@click.option("--has-weights", is_flag=True, default=False)
543546
@click.option("--weight-type", type=str, default="float")
544547
@click.option("--use-selected-lengths-sum", is_flag=True, default=False)
548+
@click.option(
549+
"--export-trace",
550+
is_flag=True,
551+
default=False,
552+
help="Enable export of trace for profiling. Default is False.",
553+
)
554+
@click.option(
555+
"--trace-url",
556+
type=str,
557+
default="keyed_jagged_index_select_dim1_{phase}_trace_{ospid}.json",
558+
)
545559
def keyed_jagged_index_select_dim1(
546560
num_batches: int,
547561
max_seq_length: int,
@@ -551,6 +565,8 @@ def keyed_jagged_index_select_dim1(
551565
has_weights: bool,
552566
weight_type: str,
553567
use_selected_lengths_sum: bool,
568+
export_trace: bool,
569+
trace_url: str,
554570
) -> None:
555571
jagged_tensor_types = {
556572
"float": torch.float,
@@ -622,20 +638,28 @@ def keyed_jagged_index_select_dim1(
622638
if is_float:
623639
values.requires_grad = True
624640

625-
time, output = benchmark_torch_function(
626-
torch.ops.fbgemm.keyed_jagged_index_select_dim1,
627-
(
628-
values,
629-
lengths,
630-
offsets,
631-
indices,
632-
input_batch_size,
633-
weights,
634-
selected_lengths_sum,
635-
),
636-
iters=1000,
637-
)
638-
output = output[0]
641+
def _kineto_trace_handler(p: profile, phase: str) -> None:
642+
p.export_chrome_trace(trace_url.format(phase=phase, ospid=os.getpid()))
643+
644+
# pyre-ignore[3]
645+
def context_factory(on_trace_ready: Callable[[profile], None]):
646+
return profile(on_trace_ready=on_trace_ready) if export_trace else nullcontext()
647+
648+
with context_factory(lambda p: _kineto_trace_handler(p, "fwd")):
649+
time, output = benchmark_torch_function(
650+
torch.ops.fbgemm.keyed_jagged_index_select_dim1,
651+
(
652+
values,
653+
lengths,
654+
offsets,
655+
indices,
656+
input_batch_size,
657+
weights,
658+
selected_lengths_sum,
659+
),
660+
iters=1000,
661+
)
662+
output = output[0]
639663

640664
# Prepare inputs for the reference run
641665
ref_inputs = []
@@ -687,9 +711,12 @@ def keyed_jagged_index_select_dim1_ref(
687711
return
688712

689713
grad = torch.rand_like(output)
690-
time, _ = benchmark_torch_function(
691-
functools.partial(output.backward, retain_graph=True), (grad,), iters=1000
692-
)
714+
715+
with context_factory(lambda p: _kineto_trace_handler(p, "bwd")):
716+
time, _ = benchmark_torch_function(
717+
functools.partial(output.backward, retain_graph=True), (grad,), iters=1000
718+
)
719+
693720
time_ref, _ = benchmark_torch_function(
694721
functools.partial(output_ref.backward, retain_graph=True), (grad,), iters=1000
695722
)

0 commit comments

Comments
 (0)