Skip to content

Commit 4239a7b

Browse files
committed
fix/feat: Add Dynamo-only converter registry
- Add Dynamo converter registry which functions as a superset of the standard FX converter registry - For use with new + experimental converters - Uses custom decorator `dynamo_tensorrt_converter` - Update references within Dynamo functions to use the converter registry `DYNAMO_CONVERTERS`
1 parent 44e4ffa commit 4239a7b

File tree

4 files changed

+30
-2
lines changed

4 files changed

+30
-2
lines changed

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,7 @@
1+
from .converter_registry import (
2+
DYNAMO_CONVERTERS,
3+
dynamo_tensorrt_converter,
4+
)
5+
16
from torch_tensorrt.dynamo import fx_ts_compat
27
from .backend import compile

py/torch_tensorrt/dynamo/backend/lowering/_partition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch.fx.node import _get_qualified_name
1111
from torch.fx.passes.operator_support import OperatorSupport
1212

13-
from torch_tensorrt.fx.converter_registry import CONVERTERS
13+
from torch_tensorrt.dynamo import DYNAMO_CONVERTERS as CONVERTERS
1414

1515

1616
logger = logging.getLogger(__name__)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from typing import Any, Callable, Dict
2+
3+
from torch.fx.node import Target
4+
from torch_tensorrt.fx.converter_registry import CONVERTERS
5+
6+
DYNAMO_CONVERTERS: Dict[Target, Any] = dict(CONVERTERS)
7+
8+
9+
def dynamo_tensorrt_converter(
10+
key: Target,
11+
enabled: bool = True,
12+
) -> Callable[[Any], Any]:
13+
def register_converter(converter):
14+
DYNAMO_CONVERTERS[key] = converter
15+
return converter
16+
17+
def disable_converter(converter):
18+
return converter
19+
20+
if enabled:
21+
return register_converter
22+
else:
23+
return disable_converter

py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from torch.fx.node import _get_qualified_name
1515
from torch.fx.passes.shape_prop import TensorMetadata
1616

17-
from torch_tensorrt.dynamo.fx_ts_compat import CONVERTERS
17+
from torch_tensorrt.dynamo import DYNAMO_CONVERTERS as CONVERTERS
1818
from .input_tensor_spec import InputTensorSpec
1919
from torch_tensorrt.fx.observer import Observer
2020
from torch_tensorrt.fx.utils import (

0 commit comments

Comments
 (0)