Commit eefc1d4
committed
WIP: Create MLIR functions for ONNX operators that are functions
Resolves #3384.
Many ONNX operators are defined by functions and therefore could be
expanded into simpler ONNX operations during importing, avoiding the
need for tools downstream to support these operators directly.
This commit changes onnx_importer.py to systematically perform this
expansion for all ONNX operators that are not explicitly denylisted.
When importing a node, the schema for the node's operation is retrieved.
If the schema provides a function for the operator, a specialized
version for the node's types and attributes will be created and imported
as an MLIR function with private visibility. An MLIR function call will
then be omitted, instead of a normal operator node. Caching is used to
avoid generating redundant functions within the same module.
Note that previously all MLIR functions generated by the importer had no
visibility specified. This commit changes this: the main function for a
model is now public. This is so that the MLIR inliner pass will
automatically discard the (private) operator functions after inlining.
Some consequences for things downstream of the importer:
- Inlining should now be done before doing any lowering, for example
`torch-mlir-opt --inline --convert-onnx-to-torch`.
- Some lowerings in TorchOnnxToTorch are now redundant and perhaps can
be removed.
Explanations for subtle code changes:
- Looking up the correct schema and function for an operator requires
knowing the opset version. NodeImporter retrieves this from the
opset imports on the ModelProto retained by the GraphInfo. Previously,
the model_proto field on GraphInfo was None when importing a subgraph
in import_regions, but this conflicts with the new need for opset
version info. Since the apparent purpose of setting it to None was to
control how GraphInfo generates its input map, a new flag is added to
GraphInfo (is_subgraph) to control this behavior, so that the actual
ModelProto can now be provided without breaking this. This also turned
out to be useful for getting the Config via ModelInfo via GraphInfo.
- Some operators' functions are context-dependent, which means the
function definition depends on the types of the inputs. Therefore node
importing now needs to look up the types of a node's inputs, not just
its outputs as was the case previously. Consequently the operand to
find_type_proto_for_name() may now be a graph input or initializer in
some cases, so it has to be updated.1 parent 5bb1a65 commit eefc1d4
2 files changed
Lines changed: 368 additions & 29 deletions
File tree
- projects/pt1/python/torch_mlir_e2e_test/configs
- python/torch_mlir/extras
Lines changed: 1 addition & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
92 | 92 | | |
93 | 93 | | |
94 | 94 | | |
95 | | - | |
| 95 | + | |
96 | 96 | | |
97 | 97 | | |
98 | 98 | | |
| |||
0 commit comments