Skip to content

Commit b0d5787

Browse files
authored
fix: Fix unbacked sym int not found issue (#3617)
1 parent 56e6867 commit b0d5787

File tree

4 files changed

+20
-11
lines changed

4 files changed

+20
-11
lines changed

py/torch_tensorrt/dynamo/_tracer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def trace(
8181
tuple(torch_arg_inputs),
8282
kwargs=torch_kwarg_inputs,
8383
dynamic_shapes=dynamic_shapes,
84+
strict=kwargs.get("strict", False),
8485
)
8586

8687
return exp_program

py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def fake_tensorrt_execute_engine(
6060
output_sym_int = ctx.new_dynamic_size(min=min_val, max=max_val)
6161
# Update var to val (hint)
6262
output_sym_int_shape_env = output_sym_int.node.shape_env
63-
output_sym_int_shape_env.add_var_to_val(
63+
output_sym_int_shape_env.set_unbacked_var_to_val(
6464
output_sym_int.node.expr, opt_val
6565
)
6666
output_shape.append(output_sym_int)
@@ -152,7 +152,7 @@ def __getstate__(self) -> Any:
152152
pass
153153

154154

155-
@torch.library.custom_op(
155+
@torch.library.custom_op( # type: ignore
156156
"tensorrt::no_op_placeholder_for_execute_engine", mutates_args=()
157157
)
158158
def no_op_placeholder_for_execute_engine(

py/torch_tensorrt/dynamo/utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,10 @@ def extract_var_range_info(symbolic_integer: torch.SymInt) -> Dict[str, int]:
375375
# https://pytorch.org/docs/stable/generated/torch.fx.experimental.symbolic_shapes.ShapeEnv.html#torch.fx.experimental.symbolic_shapes.ShapeEnv.bound_sympy
376376
# expr.xreplace replaces the symbolic variables with their current values and computes the expression.
377377
var_range = shape_env.var_to_range.get(expr, None) or shape_env.bound_sympy(expr)
378-
var_val = shape_env.var_to_val.get(expr, None) or expr.xreplace(
379-
shape_env.var_to_val
378+
var_val = (
379+
shape_env.var_to_val.get(expr, None)
380+
or shape_env.unbacked_var_to_val.get(expr, None)
381+
or expr.xreplace(shape_env.var_to_val)
380382
)
381383
assert var_range, var_val
382384
min_val, max_val = int(var_range.lower), int(var_range.upper)
@@ -385,8 +387,9 @@ def extract_var_range_info(symbolic_integer: torch.SymInt) -> Dict[str, int]:
385387
min_max_opt = {}
386388
min_max_opt["min"] = min_val
387389
min_max_opt["max"] = max_val
388-
if isinstance(var_val, sympy.core.numbers.Integer):
390+
if isinstance(var_val, (sympy.core.numbers.Integer, int)):
389391
min_max_opt["opt"] = int(var_val)
392+
390393
return min_max_opt
391394

392395

@@ -447,9 +450,9 @@ def get_graph_io_attrs(
447450
metadata = node.meta["val"]
448451
if isinstance(metadata, (tuple, list)):
449452
for tensor in metadata:
450-
graph_io_attrs.append(attr_fn(tensor)) # type: ignore
453+
graph_io_attrs.append(attr_fn(tensor))
451454
else:
452-
graph_io_attrs.append(attr_fn(metadata)) # type: ignore
455+
graph_io_attrs.append(attr_fn(metadata))
453456

454457
return graph_io_attrs
455458

tests/py/dynamo/models/test_reexport.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ def test_resnet18_dynamic(ir):
458458

459459
dyn_batch = torch.export.Dim("batch", min=1, max=8)
460460
exp_program = torch.export.export(
461-
model, (input_bs2,), dynamic_shapes=({0: dyn_batch},)
461+
model, (input_bs2,), dynamic_shapes=({0: dyn_batch},), strict=False
462462
)
463463
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
464464

@@ -532,8 +532,9 @@ def test_resnet18_dynamic_fallback(ir):
532532
}
533533

534534
dyn_batch = torch.export.Dim("batch", min=1, max=8)
535+
535536
exp_program = torch.export.export(
536-
model, (input_bs2,), dynamic_shapes=({0: dyn_batch},)
537+
model, (input_bs2,), dynamic_shapes=({0: dyn_batch},), strict=False
537538
)
538539
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
539540

@@ -610,6 +611,7 @@ def forward(self, lhs_val, rhs_val):
610611
model,
611612
inputs_4,
612613
dynamic_shapes={"lhs_val": {1: dyn_dim}, "rhs_val": {0: dyn_dim}},
614+
strict=False,
613615
)
614616
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
615617

@@ -699,13 +701,16 @@ def forward(self, x):
699701

700702
dyn_dim = torch.export.Dim("batch", min=1, max=64)
701703
exp_program = torch.export.export(
702-
model, torch_inputs_bs50, dynamic_shapes=({0: dyn_dim},)
704+
model, torch_inputs_bs50, dynamic_shapes=({0: dyn_dim},), strict=False
703705
)
704706
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
705707

706708
# Reexport with dynamic dimensions
707709
trt_exp_program = torch.export.export(
708-
trt_module, torch_inputs_bs50, strict=False, dynamic_shapes=({0: dyn_dim},)
710+
trt_module,
711+
torch_inputs_bs50,
712+
strict=False,
713+
dynamic_shapes=({0: dyn_dim},),
709714
)
710715
torch.export.save(trt_exp_program, trt_ep_path)
711716

0 commit comments

Comments
 (0)