1010
1111import functools
1212import logging
13+ import os
1314import random
15+ from contextlib import nullcontext
1416from dataclasses import dataclass
17+ from typing import Callable
1518
1619import click
1720import fbgemm_gpu
@@ -542,6 +545,17 @@ def ref(
542545@click .option ("--has-weights" , is_flag = True , default = False )
543546@click .option ("--weight-type" , type = str , default = "float" )
544547@click .option ("--use-selected-lengths-sum" , is_flag = True , default = False )
548+ @click .option (
549+ "--export-trace" ,
550+ is_flag = True ,
551+ default = False ,
552+ help = "Enable export of trace for profiling. Default is False." ,
553+ )
554+ @click .option (
555+ "--trace-url" ,
556+ type = str ,
557+ default = "keyed_jagged_index_select_dim1_{phase}_trace_{ospid}.json" ,
558+ )
545559def keyed_jagged_index_select_dim1 (
546560 num_batches : int ,
547561 max_seq_length : int ,
@@ -551,6 +565,8 @@ def keyed_jagged_index_select_dim1(
551565 has_weights : bool ,
552566 weight_type : str ,
553567 use_selected_lengths_sum : bool ,
568+ export_trace : bool ,
569+ trace_url : str ,
554570) -> None :
555571 jagged_tensor_types = {
556572 "float" : torch .float ,
@@ -622,20 +638,28 @@ def keyed_jagged_index_select_dim1(
622638 if is_float :
623639 values .requires_grad = True
624640
625- time , output = benchmark_torch_function (
626- torch .ops .fbgemm .keyed_jagged_index_select_dim1 ,
627- (
628- values ,
629- lengths ,
630- offsets ,
631- indices ,
632- input_batch_size ,
633- weights ,
634- selected_lengths_sum ,
635- ),
636- iters = 1000 ,
637- )
638- output = output [0 ]
641+ def _kineto_trace_handler (p : profile , phase : str ) -> None :
642+ p .export_chrome_trace (trace_url .format (phase = phase , ospid = os .getpid ()))
643+
644+ # pyre-ignore[3]
645+ def context_factory (on_trace_ready : Callable [[profile ], None ]):
646+ return profile (on_trace_ready = on_trace_ready ) if export_trace else nullcontext ()
647+
648+ with context_factory (lambda p : _kineto_trace_handler (p , "fwd" )):
649+ time , output = benchmark_torch_function (
650+ torch .ops .fbgemm .keyed_jagged_index_select_dim1 ,
651+ (
652+ values ,
653+ lengths ,
654+ offsets ,
655+ indices ,
656+ input_batch_size ,
657+ weights ,
658+ selected_lengths_sum ,
659+ ),
660+ iters = 1000 ,
661+ )
662+ output = output [0 ]
639663
640664 # Prepare inputs for the reference run
641665 ref_inputs = []
@@ -687,9 +711,12 @@ def keyed_jagged_index_select_dim1_ref(
687711 return
688712
689713 grad = torch .rand_like (output )
690- time , _ = benchmark_torch_function (
691- functools .partial (output .backward , retain_graph = True ), (grad ,), iters = 1000
692- )
714+
715+ with context_factory (lambda p : _kineto_trace_handler (p , "bwd" )):
716+ time , _ = benchmark_torch_function (
717+ functools .partial (output .backward , retain_graph = True ), (grad ,), iters = 1000
718+ )
719+
693720 time_ref , _ = benchmark_torch_function (
694721 functools .partial (output_ref .backward , retain_graph = True ), (grad ,), iters = 1000
695722 )
0 commit comments