Skip to content

Commit 09f6bd6

Browse files
BoyuanFengDiweiSun
authored andcommitted
[ts-migration] Support RaiseException, prim::Unitialized, prim::Enter, and prim::Exit (pytorch#129416)
- Support raise exception. It's behavior matches non-strict export now, thanks to @ydwu4's [PR](pytorch#128709). - Support prim::Unitialized, prim::Enter, and prim::Exit Pull Request resolved: pytorch#129416 Approved by: https://github.com/angelayi
1 parent e5c6a51 commit 09f6bd6

File tree

2 files changed

+98
-16
lines changed

2 files changed

+98
-16
lines changed

test/export/test_converter.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import unittest
44
from collections import OrderedDict
5-
from typing import Dict, List, Tuple, Union
5+
from typing import Any, Dict, List, Tuple, Union
66

77
import torch
88
import torch.utils._pytree as pytree
@@ -942,6 +942,68 @@ def forward(self, x: torch.Tensor):
942942
inp = (torch.ones(1),)
943943
self._check_equal_ts_ep_converter(M, inp, ["script"], check_persistent=True)
944944

945+
def test_raise_exception(self):
946+
class Module(torch.nn.Module):
947+
def forward(self, x: torch.Tensor, y: int) -> torch.Tensor:
948+
if y > 0:
949+
raise RuntimeError("test")
950+
return x + y
951+
952+
# match non-strict export behavior that errors when the given input leads to
953+
# RaiseException.
954+
with self.assertRaisesRegex(torch.jit.Error, "builtins.RuntimeError"):
955+
inp = (torch.randn(3, 2), 1)
956+
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
957+
958+
# Matching non-strict export behavior that only executes 1 if-branch according
959+
# to the given input.
960+
inp = (torch.randn(3, 2), 0)
961+
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
962+
963+
class Module(torch.nn.Module):
964+
def forward(self, x: torch.Tensor, y: int) -> torch.Tensor:
965+
z = x
966+
if y > 0:
967+
raise RuntimeError("test")
968+
# z = x
969+
else:
970+
z = x + y
971+
return x + y + z
972+
973+
# match non-strict export behavior that errors when the given input leads to
974+
# RaiseException.
975+
with self.assertRaisesRegex(torch.jit.Error, "builtins.RuntimeError"):
976+
inp = (torch.randn(3, 2), 1)
977+
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
978+
979+
# Matching non-strict export behavior that only executes 1 if-branch according
980+
# to the given input.
981+
inp = (torch.randn(3, 2), 0)
982+
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
983+
984+
def test_context_manager(self):
985+
class ContextManager:
986+
def __init__(self):
987+
self.count = 0
988+
return
989+
990+
def __enter__(self):
991+
self.count += 1
992+
return
993+
994+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
995+
self.count -= 1
996+
return
997+
998+
class M(torch.nn.Module):
999+
def forward(self, x, y):
1000+
with ContextManager():
1001+
res = x + y
1002+
return res
1003+
1004+
inp = (torch.ones(3, 3), torch.ones(3, 3))
1005+
self._check_equal_ts_ep_converter(M(), inp)
1006+
9451007
def test_hidden_input_name(self):
9461008
@torch.jit.script
9471009
def func1(x):

torch/_export/converter.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -510,8 +510,12 @@ def convert_call_function_op(self, node: torch._C.Node):
510510
# TODO: covnert sourceRange() into stack_trace
511511
# fx_node.meta["stack_trace"] = node.sourceRange()
512512

513-
output_name = node.output().debugName()
514-
self.name_to_node[output_name] = fx_node
513+
outs = tuple(node.outputs())
514+
if len(outs) == 1:
515+
output_name = node.output().debugName()
516+
self.name_to_node[output_name] = fx_node
517+
elif len(outs) > 1:
518+
raise RuntimeError("Number of outputs > 1 is not supported yet")
515519

516520
def convert_prim_TupleConstruct(self, node: torch._C.Node):
517521
self._convert_prim_iterator(node)
@@ -743,12 +747,27 @@ def _identify_inputs_as_arguments(entry):
743747

744748
cond_node = self.fx_graph.call_function(torch.cond, args, {})
745749

746-
output_name = node.output().debugName()
747-
self.name_to_node[output_name] = cond_node
750+
outs = tuple(node.outputs())
751+
if len(outs) == 1:
752+
output_name = node.output().debugName()
753+
self.name_to_node[output_name] = cond_node
754+
elif len(outs) > 1:
755+
raise RuntimeError("Number of outputs > 1 is not supported yet")
748756

749757
def convert_aten_Bool(self, node: torch._C.Node):
750758
self._convert_as_noop(node)
751759

760+
def convert_prim_Enter(self, node: torch._C.Node):
761+
# export generally treats prim::Enter as noop
762+
# The only context manager export supports is aten::enable_grad.
763+
# Unfortunately, TorchScript does not support aten::enable_grad yet.
764+
# TODO: support aten::enable_grad in both TorchScript and Converter.
765+
return
766+
767+
def convert_prim_Exit(self, node: torch._C.Node):
768+
# export treats prim::Exit as noop
769+
return
770+
752771
def _convert_as_noop(self, node: torch._C.Node):
753772
# Converts the node as a no-op by mapping its output node as arg[0]
754773

@@ -760,13 +779,6 @@ def _convert_as_noop(self, node: torch._C.Node):
760779
output_name = node.output().debugName()
761780
self.name_to_node[output_name] = args[0]
762781

763-
def convert_profiler__record_function_enter_new(self, node: torch._C.Node):
764-
target = torch.ops.profiler._record_function_enter_new
765-
args = tuple(self.get_fx_value(input) for input in node.inputs())
766-
fx_node = self.fx_graph.call_function(target, args)
767-
output_name = node.output().debugName()
768-
self.name_to_node[output_name] = fx_node
769-
770782
def convert_profiler__record_function_exit(self, node: torch._C.Node):
771783
# _record_function_exit has side effect so we keep it in fx.graph
772784
# currently, _record_function_enter_new and _record_function_exit are
@@ -784,6 +796,14 @@ def convert_prim_tolist(self, node: torch._C.Node):
784796
output_name = node.output().debugName()
785797
self.name_to_node[output_name] = fx_node
786798

799+
def convert_prim_Uninitialized(self, node: torch._C.Node):
800+
# `prim::Uninitialized` is inserted by the compiler when it can prove
801+
# the value will never be used. It can be introduced by exceptions,
802+
# breaks, continues, and returns.
803+
# So we add a dummy constant to the graph.
804+
output_name = node.output().debugName()
805+
self.constant_map[output_name] = torch.Tensor()
806+
787807
def _convert_standard_operators(self, node: torch._C.Node):
788808
target = kind_to_standard_operators[node.kind()]
789809
args = tuple(self.get_fx_value(input) for input in node.inputs())
@@ -836,10 +856,10 @@ def convert_graph_outputs(self):
836856
)
837857
else:
838858
raise ValueError(f"Output {output_name} not found")
839-
840-
self.fx_graph.output(
841-
args[0]
842-
) # Get rid of an extra list wrapped around final output.
859+
if args:
860+
self.fx_graph.output(
861+
args[0]
862+
) # Get rid of an extra list wrapped around final output.
843863

844864

845865
class ExplainTS2FXGraphConverter(TS2FXGraphConverter):

0 commit comments

Comments
 (0)