diff --git a/.github/workflows/run_pytorch_tests.yml b/.github/workflows/run_pytorch_tests.yml index 9d2913161..28b3312c2 100644 --- a/.github/workflows/run_pytorch_tests.yml +++ b/.github/workflows/run_pytorch_tests.yml @@ -23,7 +23,7 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements.txt - pip install torch==${{ inputs.torch-version }} torchvision onnx onnxruntime + pip install torch==${{ inputs.torch-version }} torchvision onnx onnxruntime onnxruntime-extensions pip install pytest - name: Run unittests run: | diff --git a/model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py b/model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py index a39a2da50..ed269d152 100644 --- a/model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +++ b/model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py @@ -89,12 +89,12 @@ def export(self) -> None: else: Logger.info(f"Exporting fake-quant onnx model: {self.save_model_path}") - model_input = to_torch_tensor(next(self.repr_dataset())[0]) + model_input = to_torch_tensor(next(self.repr_dataset())) if hasattr(self.model, 'metadata'): onnx_bytes = BytesIO() torch.onnx.export(self.model, - model_input, + tuple(model_input) if isinstance(model_input, list) else model_input, onnx_bytes, opset_version=self._onnx_opset_version, verbose=False, @@ -107,7 +107,7 @@ def export(self) -> None: onnx.save_model(onnx_model, self.save_model_path) else: torch.onnx.export(self.model, - model_input, + tuple(model_input) if isinstance(model_input, list) else model_input, self.save_model_path, opset_version=self._onnx_opset_version, verbose=False, diff --git a/tests/pytorch_tests/exporter_tests/base_pytorch_onnx_export_test.py b/tests/pytorch_tests/exporter_tests/base_pytorch_onnx_export_test.py index 659b3fa25..cdefbea4f 100644 --- a/tests/pytorch_tests/exporter_tests/base_pytorch_onnx_export_test.py +++ b/tests/pytorch_tests/exporter_tests/base_pytorch_onnx_export_test.py @@ -20,6 +20,7 @@ from onnx import numpy_helper import model_compression_toolkit as mct +from mct_quantizers import get_ort_session_options from model_compression_toolkit.exporter.model_exporter.pytorch.pytorch_export_facade import DEFAULT_ONNX_OPSET_VERSION from tests.pytorch_tests.exporter_tests.base_pytorch_export_test import BasePytorchExportTest @@ -40,16 +41,29 @@ def load_exported_model(self, filepath): onnx.checker.check_model(filepath) return onnx.load(filepath) - def infer(self, model, images): - ort_session = onnxruntime.InferenceSession(model.SerializeToString()) + def infer(self, model, inputs): + ort_session = onnxruntime.InferenceSession( + model.SerializeToString(), + get_ort_session_options(), + providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] + ) def to_numpy(tensor): return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() - # compute ONNX Runtime output prediction - ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(images)} - onnx_output = ort_session.run(None, ort_inputs) - return onnx_output[0] + # Prepare inputs + if isinstance(inputs, list): + ort_inputs = {input.name: to_numpy(tensor) for input, tensor in zip(ort_session.get_inputs(), inputs)} + elif isinstance(inputs, dict): + ort_inputs = {name: to_numpy(tensor) for name, tensor in inputs.items()} + else: + raise ValueError("Inputs must be a list or a dictionary") + + output_names = [output.name for output in ort_session.get_outputs()] + onnx_outputs = ort_session.run(output_names, ort_inputs) + output_dict = dict(zip(output_names, onnx_outputs)) + + return output_dict def _get_onnx_node_by_type(self, onnx_model, op_type): return [n for n in onnx_model.graph.node if n.op_type == op_type] diff --git a/tests/pytorch_tests/exporter_tests/test_onnx_multiple_inputs.py b/tests/pytorch_tests/exporter_tests/test_onnx_multiple_inputs.py new file mode 100644 index 000000000..c5f8e1b26 --- /dev/null +++ b/tests/pytorch_tests/exporter_tests/test_onnx_multiple_inputs.py @@ -0,0 +1,38 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from tests.pytorch_tests.exporter_tests.base_pytorch_onnx_export_test import BasePytorchONNXExportTest +from torch import nn + +class TestExportONNXMultipleInputs(BasePytorchONNXExportTest): + def get_model(self): + class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.conv1 = nn.Conv2d(3, 3, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(3, 3, kernel_size=3, padding=1) + + def forward(self, input1, input2): + x1 = self.conv1(input1) + x2 = self.conv2(input2) + return x1 + x2 + + return Model() + + def get_input_shapes(self): + return [(1, 3, 8, 8), (1, 3, 8, 8)] + + def compare(self, loaded_model, quantized_model, quantization_info): + assert len(loaded_model.graph.input)==2, f"Model expected to have two inputs but has {len(loaded_model.graph.input)}" + self.infer(loaded_model, next(self.get_dataset())) \ No newline at end of file diff --git a/tests/pytorch_tests/exporter_tests/test_onnx_multiple_inputs_and_outputs.py b/tests/pytorch_tests/exporter_tests/test_onnx_multiple_inputs_and_outputs.py new file mode 100644 index 000000000..0b3063774 --- /dev/null +++ b/tests/pytorch_tests/exporter_tests/test_onnx_multiple_inputs_and_outputs.py @@ -0,0 +1,39 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from tests.pytorch_tests.exporter_tests.base_pytorch_onnx_export_test import BasePytorchONNXExportTest +from torch import nn + +class TestExportONNXMultipleInputsAndOutputs(BasePytorchONNXExportTest): + def get_model(self): + class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.conv1 = nn.Conv2d(3, 3, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(3, 3, kernel_size=3, padding=1) + + def forward(self, input1, input2): + x1 = self.conv1(input1) + x2 = self.conv2(input2) + return x1, x2 + + return Model() + + def get_input_shapes(self): + return [(1, 3, 8, 8), (1, 3, 8, 8)] + + def compare(self, loaded_model, quantized_model, quantization_info): + assert len(loaded_model.graph.input) == 2, f"Model expected to have two inputs but has {len(loaded_model.graph.input)}" + assert len(loaded_model.graph.output) == 2, f"Model expected to have two outputs but has {len(loaded_model.graph.output)}" + self.infer(loaded_model, next(self.get_dataset())) \ No newline at end of file diff --git a/tests/pytorch_tests/exporter_tests/test_onnx_multiple_outputs.py b/tests/pytorch_tests/exporter_tests/test_onnx_multiple_outputs.py new file mode 100644 index 000000000..086e7cea4 --- /dev/null +++ b/tests/pytorch_tests/exporter_tests/test_onnx_multiple_outputs.py @@ -0,0 +1,38 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from tests.pytorch_tests.exporter_tests.base_pytorch_onnx_export_test import BasePytorchONNXExportTest +from torch import nn + +class TestExportONNXMultipleOutputs(BasePytorchONNXExportTest): + def get_model(self): + class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.conv1 = nn.Conv2d(3, 3, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(3, 3, kernel_size=3, padding=1) + + def forward(self, input1): + x1 = self.conv1(input1) + x2 = self.conv2(input1) + return x1, x2 + + return Model() + + def get_input_shapes(self): + return [(1, 3, 8, 8)] + + def compare(self, loaded_model, quantized_model, quantization_info): + assert len(loaded_model.graph.output)==2, f"Model expected to have two outputs but has {len(loaded_model.graph.output)}" + self.infer(loaded_model, next(self.get_dataset())) \ No newline at end of file diff --git a/tests/pytorch_tests/exporter_tests/test_runner.py b/tests/pytorch_tests/exporter_tests/test_runner.py index 686bb8173..1f4d2a53a 100644 --- a/tests/pytorch_tests/exporter_tests/test_runner.py +++ b/tests/pytorch_tests/exporter_tests/test_runner.py @@ -24,6 +24,10 @@ TestExportONNXWeightSymmetric2BitsQuantizers from tests.pytorch_tests.exporter_tests.custom_ops_tests.test_export_uniform_onnx_quantizers import \ TestExportONNXWeightUniform2BitsQuantizers +from tests.pytorch_tests.exporter_tests.test_onnx_multiple_inputs import TestExportONNXMultipleInputs +from tests.pytorch_tests.exporter_tests.test_onnx_multiple_inputs_and_outputs import \ + TestExportONNXMultipleInputsAndOutputs +from tests.pytorch_tests.exporter_tests.test_onnx_multiple_outputs import TestExportONNXMultipleOutputs class PytorchExporterTestsRunner(unittest.TestCase): @@ -55,4 +59,12 @@ def test_lut_sym2bits_custom_quantizer_onnx(self): TestExportONNXWeightLUTSymmetric2BitsQuantizers().run_test() TestExportONNXWeightLUTSymmetric2BitsQuantizers(onnx_opset_version=16).run_test() + def test_multiple_inputs_onnx(self): + TestExportONNXMultipleInputs().run_test() + + def test_multiple_outputs_onnx(self): + TestExportONNXMultipleOutputs().run_test() + + def test_multiple_inputs_and_outputs_onnx(self): + TestExportONNXMultipleInputsAndOutputs().run_test()