@@ -510,8 +510,12 @@ def convert_call_function_op(self, node: torch._C.Node):
510
510
# TODO: covnert sourceRange() into stack_trace
511
511
# fx_node.meta["stack_trace"] = node.sourceRange()
512
512
513
- output_name = node .output ().debugName ()
514
- self .name_to_node [output_name ] = fx_node
513
+ outs = tuple (node .outputs ())
514
+ if len (outs ) == 1 :
515
+ output_name = node .output ().debugName ()
516
+ self .name_to_node [output_name ] = fx_node
517
+ elif len (outs ) > 1 :
518
+ raise RuntimeError ("Number of outputs > 1 is not supported yet" )
515
519
516
520
def convert_prim_TupleConstruct (self , node : torch ._C .Node ):
517
521
self ._convert_prim_iterator (node )
@@ -743,12 +747,27 @@ def _identify_inputs_as_arguments(entry):
743
747
744
748
cond_node = self .fx_graph .call_function (torch .cond , args , {})
745
749
746
- output_name = node .output ().debugName ()
747
- self .name_to_node [output_name ] = cond_node
750
+ outs = tuple (node .outputs ())
751
+ if len (outs ) == 1 :
752
+ output_name = node .output ().debugName ()
753
+ self .name_to_node [output_name ] = cond_node
754
+ elif len (outs ) > 1 :
755
+ raise RuntimeError ("Number of outputs > 1 is not supported yet" )
748
756
749
757
def convert_aten_Bool (self , node : torch ._C .Node ):
750
758
self ._convert_as_noop (node )
751
759
760
+ def convert_prim_Enter (self , node : torch ._C .Node ):
761
+ # export generally treats prim::Enter as noop
762
+ # The only context manager export supports is aten::enable_grad.
763
+ # Unfortunately, TorchScript does not support aten::enable_grad yet.
764
+ # TODO: support aten::enable_grad in both TorchScript and Converter.
765
+ return
766
+
767
+ def convert_prim_Exit (self , node : torch ._C .Node ):
768
+ # export treats prim::Exit as noop
769
+ return
770
+
752
771
def _convert_as_noop (self , node : torch ._C .Node ):
753
772
# Converts the node as a no-op by mapping its output node as arg[0]
754
773
@@ -760,13 +779,6 @@ def _convert_as_noop(self, node: torch._C.Node):
760
779
output_name = node .output ().debugName ()
761
780
self .name_to_node [output_name ] = args [0 ]
762
781
763
- def convert_profiler__record_function_enter_new (self , node : torch ._C .Node ):
764
- target = torch .ops .profiler ._record_function_enter_new
765
- args = tuple (self .get_fx_value (input ) for input in node .inputs ())
766
- fx_node = self .fx_graph .call_function (target , args )
767
- output_name = node .output ().debugName ()
768
- self .name_to_node [output_name ] = fx_node
769
-
770
782
def convert_profiler__record_function_exit (self , node : torch ._C .Node ):
771
783
# _record_function_exit has side effect so we keep it in fx.graph
772
784
# currently, _record_function_enter_new and _record_function_exit are
@@ -784,6 +796,14 @@ def convert_prim_tolist(self, node: torch._C.Node):
784
796
output_name = node .output ().debugName ()
785
797
self .name_to_node [output_name ] = fx_node
786
798
799
+ def convert_prim_Uninitialized (self , node : torch ._C .Node ):
800
+ # `prim::Uninitialized` is inserted by the compiler when it can prove
801
+ # the value will never be used. It can be introduced by exceptions,
802
+ # breaks, continues, and returns.
803
+ # So we add a dummy constant to the graph.
804
+ output_name = node .output ().debugName ()
805
+ self .constant_map [output_name ] = torch .Tensor ()
806
+
787
807
def _convert_standard_operators (self , node : torch ._C .Node ):
788
808
target = kind_to_standard_operators [node .kind ()]
789
809
args = tuple (self .get_fx_value (input ) for input in node .inputs ())
@@ -836,10 +856,10 @@ def convert_graph_outputs(self):
836
856
)
837
857
else :
838
858
raise ValueError (f"Output { output_name } not found" )
839
-
840
- self .fx_graph .output (
841
- args [0 ]
842
- ) # Get rid of an extra list wrapped around final output.
859
+ if args :
860
+ self .fx_graph .output (
861
+ args [0 ]
862
+ ) # Get rid of an extra list wrapped around final output.
843
863
844
864
845
865
class ExplainTS2FXGraphConverter (TS2FXGraphConverter ):
0 commit comments