Skip to content

Commit 53378e9

Browse files
committed
feat: Data Structure update for Dynamo Registry
- Add custom class overriding default Dictionary class to access converters from various registries - Add new dictionary type `Dict[Target, Sequence[ConverterSupport]]` as well as ConverterSupport class which stores a converter and its validation implementation - Add unified `DYNAMO_CONVERTERS` dictionary which coalesces both the FX and Dynamo converter dictionaries and acts as a single unified dictionary - Streamline dictionary accesses via get/contains accessors - Add priority converter decorator enum to prioritize user-provided converters and name argument checking "capability validation" to clarify utility - Add boilerplate `no_dynamic` converter capability validator for easy use in specifying converters as not-able to handle dynamic shapes
1 parent 4239a7b commit 53378e9

File tree

4 files changed

+352
-16
lines changed

4 files changed

+352
-16
lines changed

py/torch_tensorrt/dynamo/backend/lowering/_partition.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Dict, List, Optional, Sequence, Set
2+
from typing import Callable, Dict, List, Optional, Sequence, Set
33

44
import torch
55

@@ -55,6 +55,10 @@ def __init__(
5555
)
5656

5757
self.min_block_size = min_block_size
58+
logger.debug(
59+
"Initialized Capability-Based Partitioner with available Converters:\n"
60+
+ f"{CONVERTERS.display_all_available_converters()}"
61+
)
5862

5963
def propose_partitions(self) -> List[Partition]:
6064
# Propose partitions using the default, then refine the results
@@ -123,10 +127,7 @@ def is_node_supported(
123127
else node.target
124128
)
125129

126-
if (
127-
node.target in CONVERTERS.keys()
128-
and node_name not in self.torch_executed_ops
129-
):
130+
if node in CONVERTERS and node_name not in self.torch_executed_ops:
130131
# If node is a proper, supported computational node, store the operator
131132
if not node.is_impure():
132133
self.supported_operators.add(node_name)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import torch
2+
3+
4+
def dynamic_unsupported(node: torch.fx.Node) -> bool:
5+
# Validate that none of the inputs to the node have Dynamic shapes
6+
assert isinstance(
7+
node, torch.fx.Node
8+
), "Inputs to validator functions must be FX Nodes"
9+
10+
# Check node value itself
11+
if node.meta["val"]._has_symbolic_sizes_strides:
12+
return False
13+
14+
# Check node arguments individually
15+
if any(
16+
arg.meta["val"]._has_symbolic_sizes_strides
17+
for arg in node.args
18+
if isinstance(arg, torch.fx.Node)
19+
):
20+
return False
21+
22+
# Check node keyword arguments individually
23+
if any(
24+
kwarg.meta["val"]._has_symbolic_sizes_strides
25+
for kwarg in node.kwargs.values()
26+
if isinstance(kwarg, torch.fx.Node)
27+
):
28+
return False
29+
30+
return True
+308-4
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,327 @@
1-
from typing import Any, Callable, Dict
1+
from dataclasses import dataclass, field
2+
from typing import Any, Callable, Dict, Optional, Sequence, Union
3+
from enum import Enum, auto
24

3-
from torch.fx.node import Target
5+
from torch.fx.node import Target, Node, _get_qualified_name
46
from torch_tensorrt.fx.converter_registry import CONVERTERS
57

6-
DYNAMO_CONVERTERS: Dict[Target, Any] = dict(CONVERTERS)
8+
9+
class ConverterPriority(Enum):
10+
"""Enum to set a converter's priority in the registry"""
11+
12+
STANDARD = auto()
13+
HIGH = auto()
14+
15+
16+
@dataclass(frozen=True)
17+
class ConverterSupport:
18+
"""Class representing a converter implementation and support function
19+
20+
Args:
21+
converter_implementation: Function which converts said node to a TRT equivalent
22+
capability_validator: Function which takes in a Node and returns a bool indicating
23+
whether that node can be supported by its companion converter. Note that
24+
this function must not modify the node or its graph
25+
"""
26+
27+
converter_implementation: Callable
28+
capability_validator: Callable[[Node], bool] = field(default=lambda node: True)
29+
30+
31+
# Dictionary representing Dynamo aten-only converters
32+
# Each converter maps to a sequence of at least one ConverterSupport object(s)
33+
DYNAMO_ATEN_CONVERTERS: Dict[Target, Sequence[ConverterSupport]] = {}
734

835

936
def dynamo_tensorrt_converter(
1037
key: Target,
1138
enabled: bool = True,
39+
capability_validator: Optional[Callable[[Node], bool]] = None,
40+
priority: ConverterPriority = ConverterPriority.STANDARD,
1241
) -> Callable[[Any], Any]:
42+
"""Decorator for Dynamo TensorRT Converter
43+
44+
Registers the decorated function in the DYNAMO_ATEN_CONVERTERS registry
45+
46+
Args:
47+
key: Node target for which the converter is implemented for
48+
(for example, torch.ops.add.Tensor)
49+
enabled: Whether the converter should be enabled/cached or not
50+
capability_validator: Function which evaluates whether a node is valid for conversion
51+
by the decorated converter. See ConverterSupport for more details.
52+
Defaults to None, implying the capability_validator function is always true -
53+
this means all nodes of "key" kind can be supported by this converter
54+
priority: Converter's level of priority relative to other converters with the
55+
same target
56+
Returns:
57+
The converter being decorated
58+
"""
59+
1360
def register_converter(converter):
14-
DYNAMO_CONVERTERS[key] = converter
61+
"""Helper function to register the converter, then return it"""
62+
assert callable(converter), "Converter function must be callable"
63+
64+
# If no capability_validator function is specified, use the default function - always return true
65+
if capability_validator is None:
66+
converter_support = ConverterSupport(converter_implementation=converter)
67+
else:
68+
assert callable(
69+
capability_validator
70+
), "Argument checking function must be callable"
71+
converter_support = ConverterSupport(
72+
converter_implementation=converter,
73+
capability_validator=capability_validator,
74+
)
75+
76+
# If a converter for this operator already exists, append the new converter to the list
77+
# Otherwise, start a new list
78+
if key in DYNAMO_ATEN_CONVERTERS:
79+
# High priority converters are inserted at the front of the list,
80+
# so they can be checked first by the registry
81+
if priority is ConverterPriority.HIGH:
82+
DYNAMO_ATEN_CONVERTERS[key].insert(0, converter_support)
83+
else:
84+
DYNAMO_ATEN_CONVERTERS[key].append(converter_support)
85+
else:
86+
DYNAMO_ATEN_CONVERTERS[key] = [converter_support]
87+
1588
return converter
1689

1790
def disable_converter(converter):
1891
return converter
1992

93+
# Select whether to cache/enable the converter
2094
if enabled:
2195
return register_converter
2296
else:
2397
return disable_converter
98+
99+
100+
class ConverterRegistry:
101+
"""Registry for storing multiple converter dictionaries
102+
103+
Capable of storing dictionaries with the following signature:
104+
Dict[Target, Union[Callable, Sequence[ConverterSupport]]]
105+
106+
Also able to validate converter implementations against user-provided
107+
argument-checking functions
108+
109+
Args:
110+
registries: List of dictionaries representing converter registries.
111+
The order of the provided dictionaries is the order in which they
112+
will be traversed. This is only significant when using non-validated
113+
methods.
114+
"""
115+
116+
def __init__(
117+
self,
118+
registries: Sequence[Dict[Target, Union[Callable, Sequence[ConverterSupport]]]],
119+
registry_names: Optional[Sequence[str]] = None,
120+
):
121+
# Copy reference to each dictionary object into attribute list
122+
self.registries = [registry for registry in registries]
123+
124+
if registry_names is not None:
125+
assert len(self.registries) == len(registry_names)
126+
self.registry_names = [name for name in registry_names]
127+
else:
128+
self.registry_names = [
129+
f"Registry {i + 1}" for i in range(len(self.registries))
130+
]
131+
132+
self.validate_invariants()
133+
134+
def validate_invariants(self):
135+
"""Validates the invariants required of the dictionaries in the registries
136+
137+
Raises AssertionError if any invariants have been violated
138+
"""
139+
# All registries must be dictionaries
140+
assert all(isinstance(elt, dict) for elt in self.registries)
141+
142+
# Every dictionary in the registry must have one of two signatures:
143+
# Dict[Target, Callable] or Dict[Target, Sequence[ConverterSupport]]
144+
# Where, for the latter, the sequence must be non-empty
145+
for registry in self.registries:
146+
for converters in registry.values():
147+
if isinstance(converters, (list, tuple)):
148+
assert (
149+
all(isinstance(c, ConverterSupport) for c in converters)
150+
and len(converters) > 0
151+
)
152+
else:
153+
assert callable(converters), "Converter function must be callable"
154+
155+
def __getitem_without_validation__(self, key: Target):
156+
"""Get the first-found converter in any registry
157+
158+
Searches all registries in order and returns the first converter encountered
159+
"""
160+
if isinstance(key, Node):
161+
raise KeyError(
162+
"Unvalidated accesses to the Converter registry can only be "
163+
+ "made with node targets. Try accessing the registry with node.target"
164+
)
165+
166+
self.validate_invariants()
167+
168+
# Iterate over all registries and return the first converter found
169+
for registry in self.registries:
170+
if key in registry:
171+
converters = registry[key]
172+
173+
if isinstance(converters, (list, tuple)):
174+
return converters[0].converter_implementation
175+
else:
176+
return converters
177+
178+
raise KeyError(f"None of the converter registries have an entry for {key}")
179+
180+
def __getitem__(self, node: Node):
181+
"""Get the first-found validated converter in any registry
182+
183+
Searches all registries in order and returns the first converter
184+
which passes validation on the input node
185+
"""
186+
if not isinstance(node, Node):
187+
raise KeyError(
188+
"Validated accesses to the Converter registry can only be "
189+
+ "made with node inputs. Try accessing the registry with a node "
190+
+ "or use get_unvalidated to access without node validation."
191+
)
192+
193+
self.validate_invariants()
194+
key = node.target
195+
196+
# Iterate over all registries, validating the converter on the input node
197+
# If no capability_validator function is found, assume full coverage
198+
for registry in self.registries:
199+
if key in registry:
200+
converters = registry[key]
201+
202+
if isinstance(converters, (list, tuple)):
203+
for candidate in converters:
204+
if candidate.capability_validator(node):
205+
return candidate.converter_implementation
206+
else:
207+
return converters
208+
209+
raise KeyError(
210+
f"None of the converter registries have a validated entry for {key}, with node {node}"
211+
)
212+
213+
def keys(self):
214+
"""Get all unique targets across all dictionaries"""
215+
return self.unique_targets()
216+
217+
def get_unvalidated(self, key: Target, value=None):
218+
"""Get unvalidated converter for input target with a default return"""
219+
try:
220+
return self.__getitem_without_validation__(key)
221+
except KeyError:
222+
return value
223+
224+
def get(self, node: Node, value=None):
225+
"""Get validated converter for input node with a default return"""
226+
try:
227+
return self.__getitem__(node)
228+
except KeyError:
229+
return value
230+
231+
def __contains__(self, key: Union[Target, Node]):
232+
"""Check whether a converter for an input node or target exists"""
233+
try:
234+
# Attempt to access the item in the registry
235+
if isinstance(key, Node):
236+
self.__getitem__(key)
237+
else:
238+
self.__getitem_without_validation__(key)
239+
240+
return True
241+
except KeyError:
242+
return False
243+
244+
def get_all_converters_with_target(
245+
self, key: Target, return_registry_info: bool = False
246+
):
247+
"""Get all converters across all registries for the target
248+
249+
Returns a list of all converterts having the specified target
250+
"""
251+
self.validate_invariants()
252+
converters_with_target = []
253+
254+
# Store count of number of registered converters per registry
255+
if return_registry_info:
256+
registry_data = {name: 0 for name in self.registry_names}
257+
258+
for index, registry in enumerate(self.registries):
259+
if key in registry:
260+
converters = registry[key]
261+
262+
if isinstance(converters, (list, tuple)):
263+
converters_with_target.extend(
264+
[c.converter_implementation for c in converters]
265+
)
266+
# Add converter count to registry name storage
267+
if return_registry_info:
268+
registry_data[self.registry_names[index]] += len(converters)
269+
else:
270+
converters_with_target.append(converters)
271+
# Add converter count to registry name storage
272+
if return_registry_info:
273+
registry_data[self.registry_names[index]] += 1
274+
275+
if return_registry_info:
276+
return converters_with_target, registry_data
277+
else:
278+
return converters_with_target
279+
280+
def __setitem__(self, key, value):
281+
raise AssertionError(
282+
f"Do not set registry members directly through the ConverterRegistry object. "
283+
+ f"Attempted to set {key}: {value} via direct assignment to ConverterRegistry."
284+
)
285+
286+
def __delitem__(self, key):
287+
raise AssertionError(
288+
f"Do not delete registry members directly through the ConverterRegistry object. "
289+
+ f"Attempted to delete {key} via direct del on ConverterRegistry."
290+
)
291+
292+
def __len__(self):
293+
"""Returns the sum of lengths of all registries stored"""
294+
return sum(len(registry) for registry in self.registries)
295+
296+
def unique_targets(self):
297+
"""Returns the set of unique converter targets stored across all registries"""
298+
return set.union(*[set(registry.keys()) for registry in self.registries])
299+
300+
def qualified_name_or_str(self, target: Target) -> str:
301+
"""Returns string representation of an FX Node target"""
302+
if isinstance(target, str):
303+
return target
304+
else:
305+
return _get_qualified_name(target)
306+
307+
def display_all_available_converters(self) -> str:
308+
"""Returns a string with all converters and their source, separated by newlines"""
309+
available_converters = "Available converters in ATen registries with counts:\n"
310+
311+
for target in sorted(
312+
self.unique_targets(), key=lambda target: self.qualified_name_or_str(target)
313+
):
314+
_, registry_data = self.get_all_converters_with_target(
315+
target, return_registry_info=True
316+
)
317+
available_converters += f"Node: {self.qualified_name_or_str(target)} - Registry Presence Counts: {registry_data}\n"
318+
319+
return available_converters
320+
321+
322+
# Initialize dynamo converter registry with the FX and Dynamo aten registries
323+
# Note the Dynamo registry is listed first, for precedence
324+
DYNAMO_CONVERTERS: ConverterRegistry = ConverterRegistry(
325+
[DYNAMO_ATEN_CONVERTERS, CONVERTERS],
326+
["Dynamo ATen Converters Registry", "FX ATen Converters Registry"],
327+
)

0 commit comments

Comments
 (0)