Skip to content

Commit cdfc7e1

Browse files
committed
Create MLIR functions for ONNX operators that are functions
Resolves llvm#3384. Many ONNX operators are defined by functions and therefore could be expanded into simpler ONNX operations during importing, avoiding the need for tools downstream to support these operators directly. This commit adds this capability to onnx_importer.py. When importing a node, the schema for the node's operation is retrieved. If the schema provides a function for the operator, a specialized version for the node's types and attributes will be created and imported as an MLIR function with private visibility. An MLIR function call will then be emitted, instead of a normal operator node. Caching is used to avoid generating redundant functions within the same module. In order to avoid a disruptive change to the importer output for a large number of operations that already have TorchOnnxToTorch support, an allowlist strategy is used by default. With this commit, only two operations are allowlisted for expansion: MeanVarianceNormalization and NegativeLogLikelihoodLoss. Hopefully this list can be gradually expanded. It is possible to disable the allowlist in the configuration, in which case all functions are expanded (useful for testing). Tools downstream of the importer may now need to do inlining when consuming the output of the importer, e.g.: cat imported.mlir | torch-mlir-opt --inline --convert-onnx-to-torch Explanations for subtle code changes: - Looking up the correct schema and function for an operator requires knowing the opset version. NodeImporter retrieves this from the opset imports on the ModelProto retained by the GraphInfo. Previously, the model_proto field on GraphInfo was None when importing a subgraph in import_regions, but this conflicts with the new need for opset version info. Since the apparent purpose of setting it to None was to control how GraphInfo generates its input map, a new flag is added to GraphInfo (is_subgraph) to control this behavior, so that the actual ModelProto can now be provided without breaking this. This also turned out to be useful for getting the Config via ModelInfo via GraphInfo. - Some operators' functions are context-dependent, which means the function definition depends on the types of the inputs. Therefore node importing now needs to look up the types of a node's inputs, not just its outputs as was the case previously. Consequently the operand to find_type_proto_for_name() may now be a graph input or initializer in some cases, so it has to be updated.
1 parent a02e14e commit cdfc7e1

File tree

9 files changed

+547
-29
lines changed

9 files changed

+547
-29
lines changed

projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,10 @@ def _module_lowering(
9797
# Lower from ONNX to Torch
9898
run_pipeline_with_repro_report(
9999
torch_mod,
100-
f"builtin.module(func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))",
100+
# The importer may produce additional MLIR functions corresponding to
101+
# ONNX operators that are functions. In some cases they need to be
102+
# inlined to avoid the backend choking on them.
103+
f"builtin.module(inline, func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))",
101104
"Lowering Onnx backend contract to Linalg-on-Tensors backend contract",
102105
)
103106

python/torch_mlir/extras/onnx_importer.py

Lines changed: 469 additions & 27 deletions
Large diffs are not rendered by default.

python/torch_mlir/tools/import_onnx/__main__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,14 @@
3131

3232

3333
def main(args: argparse.Namespace):
34+
config = onnx_importer.Config()
35+
if args.disable_function_expansion_allowlist:
36+
config.function_expansion_allowlists_by_domain = None
37+
3438
model_proto = load_onnx_model(args)
3539
context = Context()
3640
torch_d.register_dialect(context)
37-
model_info = onnx_importer.ModelInfo(model_proto)
41+
model_info = onnx_importer.ModelInfo(model_proto, config=config)
3842
m = model_info.create_module(context=context).operation
3943
imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m)
4044
imp.import_all()
@@ -195,6 +199,12 @@ def parse_arguments(argv=None) -> argparse.Namespace:
195199
" to before importing to MLIR. This can sometime assist with shape inference.",
196200
type=int,
197201
)
202+
parser.add_argument(
203+
"--disable-function-expansion-allowlist",
204+
action="store_true",
205+
help="Disable the allowlist for ONNX function expansion,"
206+
" allowing non-allowlisted functions to be expanded.",
207+
)
198208
args = parser.parse_args(argv)
199209
return args
200210

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Test that expansion of ONNX operators that are functions works for a simple
2+
# example. The exact name mangling scheme used is not matched against, all that
3+
# matters is that it has the name of the operator (GreaterOrEqual here) in it.
4+
# Attributes are also not checked here. What we are interested in is the types
5+
# and operations.
6+
#
7+
# The model comes from an upstream ONNX test: backend/test/data/node/test_greater_equal/model.onnx
8+
9+
# RUN: %PYTHON -m torch_mlir.tools.import_onnx --disable-function-expansion-allowlist %s.onnx | FileCheck %s
10+
11+
# CHECK-LABEL: func.func @test_greater_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1>
12+
# CHECK: %0 = call @"{{.*}}GreaterOrEqual{{.*}}"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1>
13+
14+
# CHECK-LABEL: func.func private @"{{.*}}GreaterOrEqual{{.*}}"(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1>
15+
# CHECK: %0 = torch.operator "onnx.Greater"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1>
16+
# CHECK: %1 = torch.operator "onnx.Equal"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1>
17+
# CHECK: %2 = torch.operator "onnx.Or"(%0, %1) : (!torch.vtensor<[3,4,5],i1>, !torch.vtensor<[3,4,5],i1>) -> !torch.vtensor<[3,4,5],i1>
18+
# CHECK: return %2 : !torch.vtensor<[3,4,5],i1>
Binary file not shown.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Test the expansion of ONNX operators that are functions, specifically the
2+
# propagation of attribute values from the call-site to nodes within the
3+
# expanded function.
4+
#
5+
# In this case, the model has a ReduceSumSquare node with the attribute
6+
# 'keepdims' set to 0, and the definition of this version of ReduceSumSquare
7+
# contains a ReduceSum node that references the value of 'keepdims', so we
8+
# expect to see this value propagated to the ReduceSum node in the expansion.
9+
#
10+
# This also tests that the absence of 'axes' (as an optional attribute with no
11+
# default value) is propagated in the same way.
12+
#
13+
# The model comes from an upstream ONNX test: backend/test/data/node/test_reduce_sum_square_do_not_keepdims_example/model.onnx
14+
15+
# RUN: %PYTHON -m torch_mlir.tools.import_onnx --disable-function-expansion-allowlist %s.onnx | FileCheck %s
16+
#
17+
# CHECK-LABEL: func.func @test_reduce_sum_square_do_not_keepdims_example
18+
# CHECK: %0 = call @"{{.*}}ReduceSumSquare{{.*}}"
19+
#
20+
# CHECK-LABEL: func.func private @"{{.*}}ReduceSumSquare{{.*}}"
21+
# CHECK: %0 = torch.operator "onnx.Mul"
22+
# CHECK: %1 = torch.operator "onnx.ReduceSum"{{.*}}{torch.onnx.keepdims = 0 : si64}
Binary file not shown.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Test the expansion of ONNX operators that are functions, specifically the
2+
# propagation of attribute values from the call-site to nodes within the
3+
# expanded function.
4+
#
5+
# In this case, the model has a ReduceSumSquare node with no attributes, but the
6+
# definition of this version of ReduceSumSquare contains a ReduceSum node that
7+
# references the value of 'keepdims', and the definition says its default value
8+
# is 1, so we expect to see this value propagated to the ReduceSum node in the
9+
# expansion.
10+
#
11+
# This also tests that the absence of 'axes' (as an optional attribute with no
12+
# default value) is propagated in the same way.
13+
#
14+
# The model comes from an upstream ONNX test: backend/test/data/node/test_reduce_sum_square_empty_set/model.onnx
15+
16+
# RUN: %PYTHON -m torch_mlir.tools.import_onnx --disable-function-expansion-allowlist %s.onnx | FileCheck %s
17+
#
18+
# CHECK-LABEL: func.func @test_reduce_sum_square_empty_set
19+
# CHECK: %0 = call @"{{.*}}ReduceSumSquare{{.*}}"
20+
#
21+
# CHECK-LABEL: func.func private @"{{.*}}ReduceSumSquare{{.*}}"
22+
# CHECK: %0 = torch.operator "onnx.Mul"
23+
# CHECK: %1 = torch.operator "onnx.ReduceSum"{{.*}}{torch.onnx.keepdims = 1 : si64}
Binary file not shown.

0 commit comments

Comments
 (0)