Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for onnx export to models with multiple inputs/outputs #1223

Merged
merged 3 commits into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run_pytorch_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install torch==${{ inputs.torch-version }} torchvision onnx onnxruntime
pip install torch==${{ inputs.torch-version }} torchvision onnx onnxruntime onnxruntime-extensions
pip install pytest
- name: Run unittests
run: |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@ def export(self) -> None:
else:
Logger.info(f"Exporting fake-quant onnx model: {self.save_model_path}")

model_input = to_torch_tensor(next(self.repr_dataset())[0])
model_input = to_torch_tensor(next(self.repr_dataset()))

if hasattr(self.model, 'metadata'):
onnx_bytes = BytesIO()
torch.onnx.export(self.model,
model_input,
tuple(model_input) if isinstance(model_input, list) else model_input,
onnx_bytes,
opset_version=self._onnx_opset_version,
verbose=False,
Expand All @@ -107,7 +107,7 @@ def export(self) -> None:
onnx.save_model(onnx_model, self.save_model_path)
else:
torch.onnx.export(self.model,
model_input,
tuple(model_input) if isinstance(model_input, list) else model_input,
self.save_model_path,
opset_version=self._onnx_opset_version,
verbose=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from onnx import numpy_helper

import model_compression_toolkit as mct
from mct_quantizers import get_ort_session_options
from model_compression_toolkit.exporter.model_exporter.pytorch.pytorch_export_facade import DEFAULT_ONNX_OPSET_VERSION

from tests.pytorch_tests.exporter_tests.base_pytorch_export_test import BasePytorchExportTest
Expand All @@ -40,16 +41,29 @@ def load_exported_model(self, filepath):
onnx.checker.check_model(filepath)
return onnx.load(filepath)

def infer(self, model, images):
ort_session = onnxruntime.InferenceSession(model.SerializeToString())
def infer(self, model, inputs):
ort_session = onnxruntime.InferenceSession(
model.SerializeToString(),
get_ort_session_options(),
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)

def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(images)}
onnx_output = ort_session.run(None, ort_inputs)
return onnx_output[0]
# Prepare inputs
if isinstance(inputs, list):
ort_inputs = {input.name: to_numpy(tensor) for input, tensor in zip(ort_session.get_inputs(), inputs)}
elif isinstance(inputs, dict):
ort_inputs = {name: to_numpy(tensor) for name, tensor in inputs.items()}
else:
raise ValueError("Inputs must be a list or a dictionary")

output_names = [output.name for output in ort_session.get_outputs()]
onnx_outputs = ort_session.run(output_names, ort_inputs)
output_dict = dict(zip(output_names, onnx_outputs))

return output_dict

def _get_onnx_node_by_type(self, onnx_model, op_type):
return [n for n in onnx_model.graph.node if n.op_type == op_type]
Expand Down
38 changes: 38 additions & 0 deletions tests/pytorch_tests/exporter_tests/test_onnx_multiple_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from tests.pytorch_tests.exporter_tests.base_pytorch_onnx_export_test import BasePytorchONNXExportTest
from torch import nn

class TestExportONNXMultipleInputs(BasePytorchONNXExportTest):
def get_model(self):
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(3, 3, kernel_size=3, padding=1)

def forward(self, input1, input2):
x1 = self.conv1(input1)
x2 = self.conv2(input2)
return x1 + x2

return Model()

def get_input_shapes(self):
return [(1, 3, 8, 8), (1, 3, 8, 8)]

def compare(self, loaded_model, quantized_model, quantization_info):
assert len(loaded_model.graph.input)==2, f"Model expected to have two inputs but has {len(loaded_model.graph.input)}"
self.infer(loaded_model, next(self.get_dataset()))
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from tests.pytorch_tests.exporter_tests.base_pytorch_onnx_export_test import BasePytorchONNXExportTest
from torch import nn

class TestExportONNXMultipleInputsAndOutputs(BasePytorchONNXExportTest):
def get_model(self):
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(3, 3, kernel_size=3, padding=1)

def forward(self, input1, input2):
x1 = self.conv1(input1)
x2 = self.conv2(input2)
return x1, x2

return Model()

def get_input_shapes(self):
return [(1, 3, 8, 8), (1, 3, 8, 8)]

def compare(self, loaded_model, quantized_model, quantization_info):
assert len(loaded_model.graph.input) == 2, f"Model expected to have two inputs but has {len(loaded_model.graph.input)}"
assert len(loaded_model.graph.output) == 2, f"Model expected to have two outputs but has {len(loaded_model.graph.output)}"
self.infer(loaded_model, next(self.get_dataset()))
38 changes: 38 additions & 0 deletions tests/pytorch_tests/exporter_tests/test_onnx_multiple_outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from tests.pytorch_tests.exporter_tests.base_pytorch_onnx_export_test import BasePytorchONNXExportTest
from torch import nn

class TestExportONNXMultipleOutputs(BasePytorchONNXExportTest):
def get_model(self):
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(3, 3, kernel_size=3, padding=1)

def forward(self, input1):
x1 = self.conv1(input1)
x2 = self.conv2(input1)
return x1, x2

return Model()

def get_input_shapes(self):
return [(1, 3, 8, 8)]

def compare(self, loaded_model, quantized_model, quantization_info):
assert len(loaded_model.graph.output)==2, f"Model expected to have two outputs but has {len(loaded_model.graph.output)}"
self.infer(loaded_model, next(self.get_dataset()))
12 changes: 12 additions & 0 deletions tests/pytorch_tests/exporter_tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
TestExportONNXWeightSymmetric2BitsQuantizers
from tests.pytorch_tests.exporter_tests.custom_ops_tests.test_export_uniform_onnx_quantizers import \
TestExportONNXWeightUniform2BitsQuantizers
from tests.pytorch_tests.exporter_tests.test_onnx_multiple_inputs import TestExportONNXMultipleInputs
from tests.pytorch_tests.exporter_tests.test_onnx_multiple_inputs_and_outputs import \
TestExportONNXMultipleInputsAndOutputs
from tests.pytorch_tests.exporter_tests.test_onnx_multiple_outputs import TestExportONNXMultipleOutputs


class PytorchExporterTestsRunner(unittest.TestCase):
Expand Down Expand Up @@ -55,4 +59,12 @@ def test_lut_sym2bits_custom_quantizer_onnx(self):
TestExportONNXWeightLUTSymmetric2BitsQuantizers().run_test()
TestExportONNXWeightLUTSymmetric2BitsQuantizers(onnx_opset_version=16).run_test()

def test_multiple_inputs_onnx(self):
TestExportONNXMultipleInputs().run_test()

def test_multiple_outputs_onnx(self):
TestExportONNXMultipleOutputs().run_test()

def test_multiple_inputs_and_outputs_onnx(self):
TestExportONNXMultipleInputsAndOutputs().run_test()

Loading