diff --git a/.ci/scripts/setup-openvino.sh b/.ci/scripts/setup-openvino.sh new file mode 100755 index 00000000000..ff667619125 --- /dev/null +++ b/.ci/scripts/setup-openvino.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -ex + +# shellcheck source=/dev/null +source "$(dirname "${BASH_SOURCE[0]}")/utils.sh" + +git clone https://github.com/openvinotoolkit/openvino.git +cd openvino && git checkout releases/2025/1 +git submodule update --init --recursive +sudo ./install_build_dependencies.sh +mkdir build && cd build +cmake .. -DCMAKE_BUILD_TYPE=Release -DENABLE_PYTHON=ON +make -j$(nproc) + +cd .. +cmake --install build --prefix dist + +source dist/setupvars.sh +cd ../backends/openvino +pip install -r requirements.txt +cd scripts +./openvino_build.sh --enable_python diff --git a/.ci/scripts/test_openvino.sh b/.ci/scripts/test_openvino.sh new file mode 100755 index 00000000000..85884a6475b --- /dev/null +++ b/.ci/scripts/test_openvino.sh @@ -0,0 +1,16 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -ex + +# shellcheck source=/dev/null +source "$(dirname "${BASH_SOURCE[0]}")/utils.sh" + +source openvino/dist/setupvars.sh +cd backends/openvino/tests +python test_runner.py --test_type ops +python test_runner.py --test_type models diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index c3eafc02c39..c1c145b5acd 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -736,3 +736,25 @@ jobs: conda activate "${CONDA_ENV}" # placeholder for mediatek to add more tests + + test-openvino-linux: + name: test-openvino-linux + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read + strategy: + fail-fast: false + with: + runner: linux.2xlarge + docker-image: executorch-ubuntu-22.04-gcc9 + submodules: 'true' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 90 + script: | + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + conda activate "${CONDA_ENV}" + + PYTHON_EXECUTABLE=python bash .ci/scripts/setup-openvino.sh + PYTHON_EXECUTABLE=python bash .ci/scripts/test_openvino.sh diff --git a/.lintrunner.toml b/.lintrunner.toml index 842b4b1c6cb..c2bbc05ae12 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -299,12 +299,14 @@ include_patterns = [ # TODO(https://github.com/pytorch/executorch/issues/7441): Gradually start enabling all folders. # 'backends/**/*.py', 'backends/arm/**/*.py', + 'backends/openvino/**/*.py', 'build/**/*.py', 'codegen/**/*.py', # 'devtools/**/*.py', 'devtools/visualization/**/*.py', 'docs/**/*.py', # 'examples/**/*.py', + 'examples/openvino/**/*.py', # 'exir/**/*.py', # 'extension/**/*.py', 'kernels/**/*.py', diff --git a/CMakeLists.txt b/CMakeLists.txt index 6509d4adeef..c88e1743b83 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -204,6 +204,8 @@ option(EXECUTORCH_BUILD_MPS "Build the MPS backend" OFF) option(EXECUTORCH_BUILD_NEURON "Build the backends/mediatek directory" OFF) +option(EXECUTORCH_BUILD_OPENVINO "Build the Openvino backend" OFF) + option(EXECUTORCH_BUILD_PYBIND "Build the Python Bindings" OFF) option(EXECUTORCH_BUILD_QNN "Build the Qualcomm backend" OFF) @@ -715,6 +717,10 @@ if(EXECUTORCH_BUILD_NEURON) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/mediatek) endif() +if(EXECUTORCH_BUILD_OPENVINO) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/openvino) +endif() + if(EXECUTORCH_BUILD_QNN) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/qualcomm) endif() @@ -817,6 +823,10 @@ if(EXECUTORCH_BUILD_PYBIND) list(APPEND _dep_libs mpsdelegate) endif() + if(EXECUTORCH_BUILD_OPENVINO) + list(APPEND _dep_libs openvino_backend) + endif() + if(EXECUTORCH_BUILD_XNNPACK) # need to explicitly specify XNNPACK and microkernels-prod # here otherwise uses XNNPACK and microkernel-prod symbols from libtorch_cpu diff --git a/README.md b/README.md index a6b6afe9d62..dd1fafe715b 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ Platform Support: - Arm - Cadence - MediaTek + - OpenVINO - Qualcomm - Vulkan - XNNPACK diff --git a/backends/openvino/CMakeLists.txt b/backends/openvino/CMakeLists.txt new file mode 100644 index 00000000000..8d07cd9a366 --- /dev/null +++ b/backends/openvino/CMakeLists.txt @@ -0,0 +1,75 @@ +# Copyright (c) Intel Corporation +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file found in the +# LICENSE file in the root directory of this source tree. + +# Set minimum required CMake version +cmake_minimum_required(VERSION 3.19) + +# Set project name +project(openvino_backend_project) + +# Set C++ standard +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +# Ensure compile_commands.json is generated +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +# Set up EXECUTORCH_ROOT if not already set +if(NOT EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) +endif() + +# Define common include directories +set(COMMON_INCLUDE_DIRS ${EXECUTORCH_ROOT}/..) + +# Include utility CMake scripts from ExecuteTorch +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) + +# Find OpenVINO libraries +find_package(OpenVINO REQUIRED) + +# Define OpenVINO backend as a static library +add_library(openvino_backend STATIC .) + +# Enable exceptions and RTTI for OpenVINO backend +target_compile_options(openvino_backend PRIVATE -frtti -fexceptions) + +# Include Executorch directories +target_include_directories(openvino_backend PUBLIC ${COMMON_INCLUDE_DIRS}) + +# Link OpenVINO and ExecuteTorch core libraries +target_link_libraries(openvino_backend PRIVATE openvino::runtime executorch_core) + +# Add source files for OpenVINO backend +target_sources(openvino_backend PRIVATE ${CMAKE_CURRENT_LIST_DIR}/runtime/OpenvinoBackend.cpp) + +target_link_options_shared_lib(openvino_backend) + +if(EXECUTORCH_BUILD_OPENVINO_EXECUTOR_RUNNER) + # Build executor runner binary for openvino backend + list(APPEND openvino_executor_runner_libs openvino_backend executorch) + + set(_openvino_executor_runner__srcs + ${EXECUTORCH_ROOT}/examples/portable/executor_runner/executor_runner.cpp + ${EXECUTORCH_ROOT}/extension/data_loader/file_data_loader.cpp + ${EXECUTORCH_ROOT}/extension/evalue_util/print_evalue.cpp + ${EXECUTORCH_ROOT}/extension/runner_util/inputs.cpp + ${EXECUTORCH_ROOT}/extension/runner_util/inputs_portable.cpp + ) + add_executable(openvino_executor_runner ${_openvino_executor_runner__srcs}) + + list(APPEND openvino_executor_runner_libs) + + target_link_libraries( + openvino_executor_runner gflags portable_ops_lib ${openvino_executor_runner_libs} + ) + target_compile_options(openvino_executor_runner PUBLIC ${_common_compile_options}) +endif() + + + +# Install OpenVINO backend library to the lib directory +install(TARGETS openvino_backend DESTINATION lib) diff --git a/backends/openvino/README.md b/backends/openvino/README.md new file mode 100644 index 00000000000..95a5f4c364e --- /dev/null +++ b/backends/openvino/README.md @@ -0,0 +1,89 @@ +# OpenVINO Backend for ExecuTorch +The OpenVINO backend enables optimized execution of deep learning models on Intel hardware, leveraging Intel's [OpenVINO toolkit](https://www.intel.com/content/www/us/en/developer/tools/openvino-toolkit/overview.html) for inference acceleration. + +## Supported Hardware + +OpenVINO backend supports the following hardware: + +- Intel CPUs +- Intel integrated GPUs +- Intel discrete GPUs +- Intel NPUs + +For more information on the supported hardware, please refer to [OpenVINO System Requirements](https://docs.openvino.ai/2025/about-openvino/release-notes-openvino/system-requirements.html) page. + +## Directory Structure + +``` +executorch +├── backends +│ └── openvino +│ ├── runtime +│ ├── OpenvinoBackend.cpp +│ └── OpenvinoBackend.h +│ ├── scripts +│ └── openvino_build.sh +│ ├── tests +│ ├── CMakeLists.txt +│ ├── README.md +│ ├── __init__.py +│ ├── partitioner.py +│ ├── preprocess.py +│ └── requirements.txt +└── examples + └── openvino + ├── aot_optimize_and_infer.py + └── README.md +``` + +## Build Instructions + +### Prerequisites + +Before you begin, ensure you have openvino installed and configured on your system: + +```bash +git clone https://github.com/openvinotoolkit/openvino.git +cd openvino && git checkout releases/2025/1 +git submodule update --init --recursive +sudo ./install_build_dependencies.sh +mkdir build && cd build +cmake .. -DCMAKE_BUILD_TYPE=Release -DENABLE_PYTHON=ON +make -j$(nproc) + +cd .. +cmake --install build --prefix +cd +source setupvars.sh +``` +Note: The OpenVINO backend is not yet supported with the current OpenVINO release packages. It is recommended to build from source. The instructions for using OpenVINO release packages will be added soon. +For more information about OpenVINO build, refer to the [OpenVINO Build Instructions](https://github.com/openvinotoolkit/openvino/blob/master/docs/dev/build_linux.md). + +### Setup + +Follow the steps below to setup your build environment: + +1. **Setup ExecuTorch Environment**: Refer to the [Environment Setup](https://pytorch.org/executorch/stable/getting-started-setup#environment-setup) guide for detailed instructions on setting up the ExecuTorch environment. + +2. **Setup OpenVINO Backend Environment** +- Install the dependent libs. Ensure that you are inside `executorch/backends/openvino/` directory + ```bash + pip install -r requirements.txt + ``` + Note: To achieve optimal performance with NNCF quantization, you should install the latest development version of NNCF (version 2.16.0.dev0+191b53d9 or higher). +3. Navigate to `scripts/` directory. + +4. **Build OpenVINO Backend C++ Libraries and Executor Runner**: Once the prerequisites are in place, run the `openvino_build.sh` script to start the build process. By default, OpenVINO backend will be built under `cmake-out/backends/openvino/` as `libopenvino_backend.a` + + ```bash + ./openvino_build.sh + ``` + **Build OpenVINO Backend Python Package with Pybindings**: To build and install the OpenVINO backend Python package with Python bindings, run the `openvino_build.sh` script with the `--enable_python` argument. This will compile and install the ExecuTorch Python package with the OpenVINO backend into your Python environment. This option will also enable python bindings required to execute OpenVINO backend tests and `export_and_infer_openvino.py` script inside `executorch/examples/openvino` folder. + + ```bash + ./openvino_build.sh --enable_python + ``` + +### Run + +Please refer to [README.md](../../examples/openvino/README.md) for instructions on running examples of various of models with openvino backend. diff --git a/backends/openvino/__init__.py b/backends/openvino/__init__.py new file mode 100644 index 00000000000..05c2ff7c0b9 --- /dev/null +++ b/backends/openvino/__init__.py @@ -0,0 +1,5 @@ +from .partitioner import OpenvinoPartitioner +from .preprocess import OpenvinoBackend +from .quantizer.quantizer import OpenVINOQuantizer + +__all__ = ["OpenvinoBackend", "OpenvinoPartitioner", "OpenVINOQuantizer"] diff --git a/backends/openvino/partitioner.py b/backends/openvino/partitioner.py new file mode 100644 index 00000000000..bc3fde573e2 --- /dev/null +++ b/backends/openvino/partitioner.py @@ -0,0 +1,148 @@ +# Copyright (c) Intel Corporation +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file found in the +# LICENSE file in the root directory of this source tree. + +# mypy: disable-error-code=import-not-found + +from typing import Callable, final, List, Optional, Tuple + +import torch +from executorch.backends.openvino.preprocess import OpenvinoBackend +from executorch.exir.backend.backend_details import CompileSpec +from executorch.exir.backend.partitioner import ( + DelegationSpec, + Partitioner, + PartitionResult, +) +from executorch.exir.backend.utils import tag_constant_data +from openvino.frontend.pytorch.torchdynamo.op_support import ( # type: ignore[import-untyped] + OperatorSupport, +) + +from torch.export.exported_program import ExportedProgram +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch.fx.passes.operator_support import OperatorSupportBase + + +class OpenvinoOperatorsSupport(OperatorSupportBase): + + def __init__( + self, + op_types_to_skip: Optional[set] = None, + op_names_to_skip: Optional[set] = None, + ) -> None: + """ + Initializes the OpenvinoOperatorsSupport class. + + :param op_types_to_skip: A set of operator types to skip during support checking. + :param op_names_to_skip: A set of operator names to skip during support checking. + """ + if op_types_to_skip is None: + op_types_to_skip = set() + if op_names_to_skip is None: + op_names_to_skip = set() + + self._op_types_to_skip = op_types_to_skip + self._op_names_to_skip = op_names_to_skip + + def is_node_supported(self, _, node: torch.fx.Node) -> bool: + """ + Checks if a given node is supported by OpenVINO. + + :param node: The FX graph node representing an operation. + :return: True if the node is supported, otherwise False. + """ + if node.op != "call_function": + return False + + options: list[str] = [] + if not isinstance(node.target, str): + op_type = node.target.__name__ + else: + op_type = str(node.target) + supported_ops = OperatorSupport(options)._support_dict + if op_type == "getitem": + return True + + if "torch.ops." + str(op_type) in supported_ops: + return True + else: + print("Op not supported: ", "torch.ops." + str(op_type)) + + if op_type in self._op_types_to_skip or node.name in self._op_names_to_skip: + print( + f"[OpenVINO Backend] The {op_type} operator with name '{node.name}' is skipped." + ) + return False + + return False + + +@final +class OpenvinoPartitioner(Partitioner): + + def __init__( + self, + compile_spec: List[CompileSpec], + op_types_to_skip: Optional[set] = None, + op_names_to_skip: Optional[set] = None, + ) -> None: + """ + Initializes the OpenvinoPartitioner class. + + :param compile_spec: A list of compile specifications for OpenVINO. + :param op_types_to_skip: A set of operator types to skip during partitioning. + :param op_names_to_skip: A set of operator names to skip during partitioning. + """ + self.delegation_spec = DelegationSpec(OpenvinoBackend.__name__, compile_spec) + self._op_types_to_skip = op_types_to_skip + self._op_names_to_skip = op_names_to_skip + + def ops_to_not_decompose( + self, + ep: ExportedProgram, + ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: + """ + Returns a tuple containing a list of operations that should not be decomposed + and an optional function to filter nodes. + + :param ep: The exported program. + :return: A tuple consisting of a list of ops to keep and an optional filtering function. + """ + ops_not_decompose = [ + torch.ops.aten.pixel_shuffle.default, + torch.ops.aten.upsample_bilinear2d.default, + torch.ops.aten.upsample_bilinear2d.vec, + torch.ops.aten.upsample_nearest2d.default, + torch.ops.aten.upsample_nearest2d.vec, + ] + return (ops_not_decompose, None) + + def partition(self, exported_program: ExportedProgram) -> PartitionResult: + """ + Partitions an exported program into supported and unsupported segments. + + :param exported_program: The exported program. + :return: A PartitionResult containing the partitioned graph and delegation tags. + """ + partitioner = CapabilityBasedPartitioner( + exported_program.graph_module, + OpenvinoOperatorsSupport(self._op_types_to_skip, self._op_names_to_skip), + allows_single_node_partition=True, + ) + partition_list = partitioner.propose_partitions() + + partition_tags = {} + for partition in partition_list: + for node in partition.nodes: + tag = f"tag{partition.id}" + node.meta["delegation_tag"] = tag + partition_tags[tag] = self.delegation_spec + + tag_constant_data(exported_program) + + return PartitionResult( + tagged_exported_program=exported_program, partition_tags=partition_tags + ) diff --git a/backends/openvino/preprocess.py b/backends/openvino/preprocess.py new file mode 100644 index 00000000000..c343f44a8b5 --- /dev/null +++ b/backends/openvino/preprocess.py @@ -0,0 +1,54 @@ +# Copyright (c) Intel Corporation +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file found in the +# LICENSE file in the root directory of this source tree. + +# mypy: disable-error-code=import-not-found + +from typing import final, List + +from executorch.exir.backend.backend_details import ( + BackendDetails, + ExportedProgram, + PreprocessResult, +) +from executorch.exir.backend.compile_spec_schema import CompileSpec +from openvino.frontend.pytorch.torchdynamo.compile import ( # type: ignore[import-untyped] + openvino_compile, +) + + +@final +class OpenvinoBackend(BackendDetails): + + @classmethod + def preprocess( + cls, edge_program: ExportedProgram, module_compile_spec: List[CompileSpec] + ) -> PreprocessResult: + """ + Preprocesses the exported program and compiles it for the OpenVINO backend. + + Args: + edge_program (ExportedProgram): The exported program representing the model. + module_compile_spec (List[CompileSpec]): A list of compile specifications for the OpenVINO backend. + + Returns: + PreprocessResult: The result of preprocessing, including the compiled model bytes. + """ + input_names = edge_program.graph_signature.user_inputs + args = [] + for node in edge_program.graph.nodes: + if node.target in input_names: + args.append(node.meta["val"]) + + compile_options = {} + for spec in module_compile_spec: + compile_options[spec.key] = spec.value.decode() + + compiled = openvino_compile( + edge_program.module(), *args, options=compile_options + ) + model_bytes = compiled.export_model() + + return PreprocessResult(processed_bytes=model_bytes.getvalue()) diff --git a/backends/openvino/quantizer/__init__.py b/backends/openvino/quantizer/__init__.py new file mode 100644 index 00000000000..df038483f2f --- /dev/null +++ b/backends/openvino/quantizer/__init__.py @@ -0,0 +1,3 @@ +from .quantizer import OpenVINOQuantizer, quantize_model + +__all__ = ["OpenVINOQuantizer", "quantize_model"] diff --git a/backends/openvino/quantizer/quantizer.py b/backends/openvino/quantizer/quantizer.py new file mode 100644 index 00000000000..5532235f573 --- /dev/null +++ b/backends/openvino/quantizer/quantizer.py @@ -0,0 +1,414 @@ +# Copyright (c) Intel Corporation +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file found in the +# LICENSE file in the root directory of this source tree. + +# mypy: disable-error-code=import-not-found + +from collections import defaultdict +from enum import Enum +from typing import Any, Callable, DefaultDict, Dict, List, Optional, Tuple, Type + +import nncf # type: ignore[import-untyped] +import nncf.common.quantization as quantization # type: ignore[import-untyped] +import nncf.experimental.torch.fx as nncf_fx # type: ignore[import-untyped] + +import torch.fx + +from nncf.common.graph.graph import NNCFGraph # type: ignore[import-untyped] +from torch.ao.quantization.observer import ( + HistogramObserver, + PerChannelMinMaxObserver, + UniformQuantizationObserverBase, +) +from torch.ao.quantization.quantizer.quantizer import ( + EdgeOrNode, + QuantizationAnnotation, + QuantizationSpec, + QuantizationSpecBase, + Quantizer, + SharedQuantizationSpec, +) + +QUANT_ANNOTATION_KEY = "quantization_annotation" + + +class QuantizationMode(Enum): + """ + Defines special quantization modes. + + - INT8_SYM: INT8 symmetric quantization for both activations and weights. + - INT8_MIXED: INT8 asymmetric quantization for activations, symmetric for weights. + - INT8_TRANSFORMER: Optimized INT8 quantization for transformer-based models + """ + + INT8_SYM = "int8_sym" + INT8_MIXED = "int8_mixed" + INT8_TRANSFORMER = "int8_transformer" + + +class OpenVINOQuantizer(Quantizer): + """ + Implementation of the Torch AO quantizer which annotates models with quantization annotations + optimally for the inference via OpenVINO. + """ + + def __init__( + self, + *, + mode: Optional[QuantizationMode] = QuantizationMode.INT8_SYM, + **kwargs, + ): + """ + :param mode: Defines special quantization modes. + - INT8_SYM: INT8 symmetric quantization for both activations and weights. + - INT8_MIXED: INT8 asymmetric quantization for activations, symmetric for weights. + - INT8_TRANSFORMER: Optimized INT8 quantization for transformer-based models + Default value is INT8_SYM. + :param kwargs: Arguments to pass to the NNCF MinMaxQuantization algorithm. + """ + if mode == QuantizationMode.INT8_SYM: + preset = quantization.structs.QuantizationPreset.PERFORMANCE + model_type = None + elif mode == QuantizationMode.INT8_MIXED: + preset = quantization.structs.QuantizationPreset.MIXED + model_type = None + else: + preset = None + model_type = nncf.parameters.ModelType.TRANSFORMER + self._min_max_algo = ( + nncf.quantization.algorithms.min_max.algorithm.MinMaxQuantization( + preset=preset, model_type=model_type, **kwargs + ) + ) + + def set_ignored_scope( + self, + names: Optional[List[str]] = None, + patterns: Optional[List[str]] = None, + types: Optional[List[str]] = None, + subgraphs: Optional[List[Tuple[List[str], List[str]]]] = None, + validate: bool = True, + ) -> None: + """ + Provides an option to specify portions of model to be excluded from compression. + The ignored scope defines model sub-graphs that should be excluded from the quantization process. + + :param names: List of ignored node names. + :param patterns: List of regular expressions that define patterns for names of ignored nodes. + :param types: List of ignored operation types. + :param subgraphs: List of ignored subgraphs. + :param validate: If set to True, then a RuntimeError will be raised if any ignored scope does not match + in the model graph. + """ + self._min_max_algo.set_ignored_scope( + nncf.IgnoredScope( + names=names or [], + patterns=patterns or [], + types=types or [], + subgraphs=subgraphs or [], + validate=validate, + ) + ) + + def get_nncf_quantization_setup( + self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph + ) -> quantization.quantizer_setup.SingleConfigQuantizerSetup: + self._min_max_algo._set_backend_entity(model) + return self._min_max_algo.find_quantization_setup(model, nncf_graph) + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + nncf_graph = nncf_fx.nncf_graph_builder.GraphConverter.create_nncf_graph(model) + quantization_setup = self.get_nncf_quantization_setup(model, nncf_graph) + + graph = model.graph + node_vs_torch_annotation: DefaultDict[torch.fx.Node, QuantizationAnnotation] = ( + defaultdict(QuantizationAnnotation) + ) + + for qp in quantization_setup.quantization_points.values(): + edge_or_node, annotation = self._get_edge_or_node_and_annotation( + graph, nncf_graph, qp, node_vs_torch_annotation + ) + qspec: QuantizationSpecBase = self._get_torch_ao_qspec_from_qp(qp) + self._fill_torch_ao_annotation(edge_or_node, qspec, annotation) + + for quantizer_ids in quantization_setup.unified_scale_groups.values(): + + root_quantizer_id = self._get_unified_scales_root_quantizer_id( + nncf_graph, quantizer_ids, quantization_setup + ) + root_qp = quantization_setup.quantization_points[root_quantizer_id] + + if any( + root_qp.qconfig != quantization_setup.quantization_points[q_id].qconfig + for q_id in quantizer_ids + ): + qps = [ + quantization_setup.quantization_points[q_id] + for q_id in quantizer_ids + ] + msg = ( + "Different quantization configs are set to one unified scale group:" + f"{[(qp.insertion_point.__dict__, str(qp.qconfig)) for qp in qps]}" + ) + raise nncf.InternalError(msg) + + root_target_node = nncf_fx.node_utils.get_graph_node_by_name( + graph, root_qp.insertion_point.target_node_name + ) + root_edge_or_node = self._get_edge_or_node( + root_target_node, root_qp, nncf_graph + ) + + for quantizer_id in quantizer_ids: + if quantizer_id == root_quantizer_id: + continue + + qspec = SharedQuantizationSpec(root_edge_or_node) + qp = quantization_setup.quantization_points[quantizer_id] + edge_or_node, annotation = self._get_edge_or_node_and_annotation( + graph, nncf_graph, qp, node_vs_torch_annotation + ) + self._fill_torch_ao_annotation(edge_or_node, qspec, annotation) + + for node, annotation in node_vs_torch_annotation.items(): + assert QUANT_ANNOTATION_KEY not in node.meta + node.meta[QUANT_ANNOTATION_KEY] = annotation + return model + + @staticmethod + def _get_unified_scales_root_quantizer_id( + nncf_graph: NNCFGraph, + quantizer_ids: List[int], + quantizer_setup: quantization.quantizer_setup.SingleConfigQuantizerSetup, + ) -> int: + """ + Identifies the earliest quantizer node ID based on the corresponding `nncf_node.node_id` + in the given NNCFGraph. This is required by the `_get_obs_or_fq_map` function. + Refer to: https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/pt2e/prepare.py#L291 + + :param nncf_graph: The NNCFGraph instance. + :param quantizer_ids: The list of quantizer IDs to evaluate. + :param quantizer_setup: The instance of SingleConfigQuantizerSetup. + :return: The ID of the earliest quantizer node in terms of `nncf_node.node_id`. + """ + nncf_node_quantizer_id = None + root_quantizer_id = None + for quantizer_id in quantizer_ids: + target_node_name = quantizer_setup.quantization_points[ + quantizer_id + ].insertion_point.target_node_name + nncf_node = nncf_graph.get_node_by_name(target_node_name) + if ( + nncf_node_quantizer_id is None + or nncf_node.node_id < nncf_node_quantizer_id + ): + root_quantizer_id = quantizer_id + nncf_node_quantizer_id = nncf_node.node_id + if root_quantizer_id is None: + msg = "Root quantizer ids can't be None" + raise nncf.InternalError(msg) + return root_quantizer_id + + @staticmethod + def _get_edge_or_node_and_annotation( + graph: torch.fx.Graph, + nncf_graph: NNCFGraph, + qp: quantization.quantizer_setup.QuantizationPointBase, + node_vs_torch_annotation: Dict[torch.fx.Node, QuantizationAnnotation], + ) -> Tuple[EdgeOrNode, QuantizationAnnotation]: + """ + Retrieves the edge or node and its corresponding QuantizationAnnotation based on the given graph, + quantization point, and node-to-annotation mapping. + + :param graph: torch.fx.Graph instance. + :param nncf_graph: NNCFGraph instance. + :param qp: QuantizationPointBase instance. + :param node_vs_torch_annotation: A dictionary mapping torch.fx.GraphNode objects to their respective + QuantizationAnnotations. + :return: A tuple containing the EdgeOrNode and its associated QuantizationAnnotation. + """ + target_node = nncf_fx.node_utils.get_graph_node_by_name( + graph, qp.insertion_point.target_node_name + ) + annotation = node_vs_torch_annotation[target_node] + edge_or_node = OpenVINOQuantizer._get_edge_or_node(target_node, qp, nncf_graph) + return edge_or_node, annotation + + @staticmethod + def _get_edge_or_node( + target_node: torch.fx.Node, + qp: quantization.quantizer_setup.QuantizationPointBase, + nncf_graph: NNCFGraph, + ) -> EdgeOrNode: + """ + Returns the edge or node based on the given target node and quantization point. + + :param target_node: Target node instance. + :param qp: QuantizationPointBase instance. + :param graph: NNCFGraph instance. + :return: The corresponding EdgeOrNode derived from the target node and quantization point. + """ + ip = qp.insertion_point + if qp.is_weight_quantization_point(): + nncf_node = nncf_graph.get_node_by_name(target_node.name) + weights_ports_ids = ( + nncf.torch.model_graph_manager.get_weight_tensor_port_ids( + nncf_node, nncf_graph + ) + ) + if len(weights_ports_ids) > 1: + # TODO(dlyakhov): support quantization for nodes with several weights + nncf.common.logging.nncf_logger.warning( + f"Quantization of the weighted node {target_node.name}" + " is not yet supported by the OpenVINOQuantizer." + f" Only the weight on port ID {weights_ports_ids[0]} will be quantized." + f" Quantizable weights are located on ports: {weights_ports_ids}." + ) + weight_node = target_node.all_input_nodes[weights_ports_ids[0]] + return (weight_node, target_node) + + if ip.input_port_id is None: + return target_node + + node = target_node.all_input_nodes[ip.input_port_id] + return (node, target_node) + + @staticmethod + def _fill_torch_ao_annotation( + edge_or_node: EdgeOrNode, + qspec: QuantizationSpecBase, + annotation_to_update: QuantizationAnnotation, + ) -> None: + """ + Helper method to update the annotation_to_update based on the specified edge_or_node and qspec. + + :param edge_or_node: The target EdgeOrNode to be used for the update. + :param qspec: An instance of QuantizationSpecBase representing the quantization specification to apply. + :param annotation_to_update: The annotation to update based on the edge_or_node and qspec. + """ + if isinstance(edge_or_node, torch.fx.Node): + annotation_to_update.output_qspec = qspec + else: + annotation_to_update.input_qspec_map[edge_or_node[0]] = qspec + + @staticmethod + def _get_torch_ao_qspec_from_qp( + qp: quantization.quantizer_setup.QuantizationPointBase, + ) -> QuantizationSpec: + """ + Retrieves the quantization configuration from the given quantization point and + converts it into a QuantizationSpec. + + :param qp: An instance of QuantizationPointBase. + :return: A QuantizationSpec retrieved and converted from the quantization point. + """ + # Eps value is copied from nncf/torch/quantization/layers.py + extra_args = {"eps": 1e-16} + qconfig = qp.qconfig + is_weight = qp.is_weight_quantization_point() + + observer: Type[UniformQuantizationObserverBase] + + if qconfig.per_channel: + torch_qscheme = ( + torch.per_channel_symmetric + if qconfig.mode is quantization.structs.QuantizationScheme.SYMMETRIC + else torch.per_channel_affine + ) + else: + torch_qscheme = ( + torch.per_tensor_symmetric + if qconfig.mode is quantization.structs.QuantizationScheme.SYMMETRIC + else torch.per_tensor_affine + ) + if is_weight: + observer = PerChannelMinMaxObserver + quant_min = -128 + quant_max = 127 + dtype = torch.int8 + channel_axis = 0 + else: + observer = ( + HistogramObserver + if torch_qscheme + in [torch.per_tensor_symmetric, torch.per_tensor_affine] + else PerChannelMinMaxObserver + ) + quant_min = 0 + quant_max = 255 + dtype = torch.int8 if qconfig.signedness_to_force else torch.uint8 + channel_axis = 1 # channel dim for activations + return QuantizationSpec( + dtype=dtype, + observer_or_fake_quant_ctr=observer.with_args(**extra_args), + quant_min=quant_min, + quant_max=quant_max, + qscheme=torch_qscheme, + ch_axis=channel_axis, + is_dynamic=False, + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + +def quantize_model( + captured_model: torch.fx.GraphModule, + calibration_dataset: torch.utils.data.DataLoader, + *, + mode: QuantizationMode = QuantizationMode.INT8_SYM, + subset_size: int = 300, + fast_bias_correction: Optional[bool] = True, + smooth_quant: bool = False, + transform_fn: Optional[Callable[[Any], Any]] = None, + extra_quantizer_options: Optional[Dict[str, Any]] = None, + **kwargs, +) -> torch.fx.GraphModule: + """ + Quantizes a model using NNCF quantize_pt2e API. + + :param captured_model: The model to be quantized, represented as a torch.fx.GraphModule. + :param calibration_dataset: A DataLoader containing calibration data for quantization. + :param mode: Defines special quantization modes. + - INT8_SYM: INT8 symmetric quantization for both activations and weights. + - INT8_MIXED: INT8 asymmetric quantization for activations, symmetric for weights. + - INT8_TRANSFORMER: Optimized INT8 quantization for transformer-based models + Default value is INT8_SYM. + :param subset_size: Size of a subset to calculate activations + statistics used for quantization. + :param fast_bias_correction: Setting this option to `False` enables a different + bias correction method which is more accurate, in general, and takes + more time but requires less memory. None disables the bias correction algorithm. + :param smooth_quant: Setting this option to `True` enables the SmoothQuant algorithm. + :param extra_quantizer_options: A dictionary containing additional configuration options + for the OpenVINOQuantizer. + :param kwargs: The keyword arguments for the nncf quantize_pt2e function. + :return: The quantized model as a torch.fx.GraphModule. + """ + extra_quantizer_options = extra_quantizer_options or {} + if "mode" in extra_quantizer_options: + print( + f'Ignoring "mode" from the quantizer_config. Using parameter mode = {mode}' + ) + del extra_quantizer_options["mode"] + + quantizer = OpenVINOQuantizer(mode=mode, **extra_quantizer_options) + + print("PTQ: Quantize the model") + + if "fold_quantize" not in kwargs: + kwargs["fold_quantize"] = False + + quantized_model = nncf_fx.quantize_pt2e( + captured_model, + quantizer, + subset_size=subset_size, + calibration_dataset=nncf.Dataset(calibration_dataset, transform_fn), + fast_bias_correction=fast_bias_correction, + smooth_quant=smooth_quant, + **kwargs, + ) + return quantized_model diff --git a/backends/openvino/requirements.txt b/backends/openvino/requirements.txt new file mode 100644 index 00000000000..316633e9004 --- /dev/null +++ b/backends/openvino/requirements.txt @@ -0,0 +1,2 @@ +transformers +git+https://github.com/openvinotoolkit/nncf@6b0fc1c#egg=nncf diff --git a/backends/openvino/runtime/OpenvinoBackend.cpp b/backends/openvino/runtime/OpenvinoBackend.cpp new file mode 100644 index 00000000000..a3134f72b4b --- /dev/null +++ b/backends/openvino/runtime/OpenvinoBackend.cpp @@ -0,0 +1,189 @@ +/* Copyright (c) Intel Corporation + * + * Licensed under the BSD License (the "License"); you may not use this file + * except in compliance with the License. See the license file found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +#include "OpenvinoBackend.h" + +namespace executorch { +namespace backends { +namespace openvino { + +OpenvinoBackend::OpenvinoBackend() {} + +bool OpenvinoBackend::is_available() const { + try { + // Create an OpenVINO Core object to verify runtime availability + ov::Core core; + + // Check if at least one device is available + auto devices = core.get_available_devices(); + if (!devices.empty()) { + return true; // OpenVINO is available + } + } catch (const std::exception& e) { + // Log the exception if OpenVINO runtime is not available + ET_LOG(Error, "OpenVINO is not available: %s", e.what()); + } catch (...) { + // Handle any unexpected errors + ET_LOG( + Error, "OpenVINO availability check failed due to an unknown error."); + } + + return false; // OpenVINO is not available +} + +exr::Result OpenvinoBackend::init( + exr::BackendInitContext& context, + exr::FreeableBuffer* processed, + exr::ArrayRef compile_specs) const { + ET_LOG(Info, "OpenvinoBackend::init %p", processed->data()); + + ov::Core core; + const char* data_ptr = static_cast(processed->data()); + size_t data_size = processed->size(); + + // Copy data to a string or vector + std::string data_string(data_ptr, data_size); + + // Wrap the data in a stream + std::istringstream compiled_stream(data_string); + + auto device = "CPU"; + // Get the device value, if provided in compile sepcs + for (auto& compile_spec : compile_specs) { + if (std::strcmp(compile_spec.key, "device") == 0) + device = static_cast(compile_spec.value.buffer); + } + + // Import the model + auto compiled_model = core.import_model(compiled_stream, device); + + // The processed data can be freed since the model is compiled + processed->Free(); + + // Allocate an infer request + std::shared_ptr infer_request = + std::make_shared(compiled_model.create_infer_request()); + + // Allocate execution handle + exr::MemoryAllocator* allocator = context.get_runtime_allocator(); + ExecutionHandle* handle = allocator->allocateInstance(); + new (handle) ExecutionHandle; + handle->compiled_model = std::make_shared(compiled_model); + handle->infer_request = infer_request; + + return handle; +} + +exr::Error OpenvinoBackend::execute( + exr::BackendExecutionContext& context, + exr::DelegateHandle* input_handle, + exr::EValue** args) const { + ExecutionHandle* execution_handle = (ExecutionHandle*)input_handle; + + auto infer_request = execution_handle->infer_request; + + size_t num_inputs = infer_request->get_compiled_model().inputs().size(); + size_t num_outputs = infer_request->get_compiled_model().outputs().size(); + + // Set inputs + for (size_t i = 0; i < num_inputs; i++) { + auto input_tensor = args[i]->toTensor(); + ov::Shape input_shape( + input_tensor.sizes().begin(), input_tensor.sizes().end()); + + // Convert input tensor to OpenVINO tensor + ov::element::Type ov_type = + convert_to_openvino_type(input_tensor.scalar_type()); + ov::Tensor ov_input_tensor( + ov_type, input_shape, input_tensor.mutable_data_ptr()); + + infer_request->set_input_tensor(i, ov_input_tensor); + } + + // Set outputs + for (size_t i = 0; i < num_outputs; i++) { + auto output_tensor = args[num_inputs + i]->toTensor(); + ov::Shape output_shape( + output_tensor.sizes().begin(), output_tensor.sizes().end()); + + // Convert input tensor to OpenVINO tensor + ov::element::Type ov_type = + convert_to_openvino_type(output_tensor.scalar_type()); + ov::Tensor ov_output_tensor( + ov_type, output_shape, output_tensor.mutable_data_ptr()); + + infer_request->set_output_tensor(i, ov_output_tensor); + } + + // Execute the inference + infer_request->infer(); + + return exr::Error::Ok; +} + +void OpenvinoBackend::destroy(exr::DelegateHandle* handle) const { + if (!handle) { + ET_LOG(Info, "Attempted to destroy a null handle."); + return; + } + + // Cast the handle to the appropriate type + ExecutionHandle* execution_handle = static_cast(handle); + + // Clean up resources + if (execution_handle->infer_request) { + execution_handle->infer_request.reset(); // Release the infer request + ET_LOG(Info, "Infer request successfully destroyed."); + } + + if (execution_handle->compiled_model) { + execution_handle->compiled_model.reset(); // Release the compiled model + ET_LOG(Info, "Compiled model successfully destroyed."); + } + + ET_LOG(Info, "Delegate handle destroyed successfully."); +} + +ov::element::Type OpenvinoBackend::convert_to_openvino_type( + exa::ScalarType scalar_type) const { + switch (scalar_type) { + case exa::ScalarType::Float: + return ov::element::f32; + case exa::ScalarType::Int: + return ov::element::i32; + case exa::ScalarType::Char: + return ov::element::i8; + case exa::ScalarType::Long: + return ov::element::i64; + case exa::ScalarType::Bool: + return ov::element::boolean; + default: + throw std::runtime_error("Unsupported scalar type"); + } +} + +} // namespace openvino +} // namespace backends +} // namespace executorch + +namespace { +auto backend = executorch::backends::openvino::OpenvinoBackend(); +executorch::runtime::Backend backend_id{"OpenvinoBackend", &backend}; +static auto registered = executorch::runtime::register_backend(backend_id); +} // namespace diff --git a/backends/openvino/runtime/OpenvinoBackend.h b/backends/openvino/runtime/OpenvinoBackend.h new file mode 100644 index 00000000000..069e4659d37 --- /dev/null +++ b/backends/openvino/runtime/OpenvinoBackend.h @@ -0,0 +1,59 @@ +/* Copyright (c) Intel Corporation + * + * Licensed under the BSD License (the "License"); you may not use this file + * except in compliance with the License. See the license file found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef OPENVINO_BACKEND_H +#define OPENVINO_BACKEND_H + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace exr = executorch::runtime; +namespace exa = executorch::aten; + +using namespace std; + +namespace executorch { +namespace backends { +namespace openvino { + +typedef struct { + std::shared_ptr compiled_model; + std::shared_ptr infer_request; +} ExecutionHandle; + +class OpenvinoBackend final : public ::exr::BackendInterface { + public: + OpenvinoBackend(); + ~OpenvinoBackend() = default; + + virtual bool is_available() const override; + exr::Result init( + exr::BackendInitContext& context, + exr::FreeableBuffer* processed, + exr::ArrayRef compile_specs) const override; + exr::Error execute( + exr::BackendExecutionContext& context, + exr::DelegateHandle* input_handle, + exr::EValue** args) const override; + void destroy(exr::DelegateHandle* handle) const override; + + private: + ov::element::Type convert_to_openvino_type(exa::ScalarType scalar_type) const; +}; + +} // namespace openvino +} // namespace backends +} // namespace executorch + +#endif // OPENVINO_BACKEND_H diff --git a/backends/openvino/scripts/openvino_build.sh b/backends/openvino/scripts/openvino_build.sh new file mode 100755 index 00000000000..83ffd7542f3 --- /dev/null +++ b/backends/openvino/scripts/openvino_build.sh @@ -0,0 +1,75 @@ +#!/bin/bash + +# Exit immediately if a command exits with a non-zero status. +set -e + +# Define the directory where CMakeLists.txt is located +EXECUTORCH_ROOT=$(realpath "$(dirname "$0")/../../..") +echo EXECUTORCH_ROOT=${EXECUTORCH_ROOT} + +main() { + build_type=${1:-"--cpp_runtime"} + + # If the first arguments is --cpp_runtime (default), build libraries for C++ runtime + if [[ -z "$build_type" || "$build_type" == "--cpp_runtime" ]]; then + echo "Building C++ Runtime Libraries" + + # Set build directory + local build_dir="cmake-out" + + # Create and enter the build directory + cd "$EXECUTORCH_ROOT" + rm -rf "${build_dir}" + + # Configure the project with CMake + # Note: Add any additional configuration options you need here + cmake -DCMAKE_INSTALL_PREFIX="${build_dir}" \ + -DCMAKE_BUILD_TYPE=Release \ + -DEXECUTORCH_BUILD_OPENVINO=ON \ + -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ + -DEXECUTORCH_BUILD_OPENVINO_EXECUTOR_RUNNER=ON \ + -B"${build_dir}" + + + # Build the project + cmake --build ${build_dir} --target install --config Release -j$(nproc) + + # If the first arguments is --enable_python, build python package with python bindings + elif [[ "$build_type" == "--enable_python" ]]; then + echo "Building Python Package with Pybinding" + + # Create and enter the build directory + cd "$EXECUTORCH_ROOT" + ./install_executorch.sh --clean + + # Set parameters to configure the project with CMake + # Note: Add any additional configuration options you need here + export CMAKE_ARGS="-DEXECUTORCH_BUILD_OPENVINO=ON \ + -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ + -DEXECUTORCH_ENABLE_LOGGING=ON \ + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ + -DEXECUTORCH_BUILD_PYBIND=ON" + export CMAKE_BUILD_ARGS="--target openvino_backend" + + # Build the package + pip install . --no-build-isolation + + else + echo "Error: Argument is not valid: $build_type" + exit 1 # Exit the script with an error code + fi + + # Switch back to the original directory + cd - > /dev/null + + # Print a success message + echo "Build successfully completed." + +} + +main "$@" diff --git a/backends/openvino/tests/README.md b/backends/openvino/tests/README.md new file mode 100644 index 00000000000..0aad14e04a0 --- /dev/null +++ b/backends/openvino/tests/README.md @@ -0,0 +1,60 @@ +# Unit Tests for OpenVINO Backend + +## Directory Structure + +Below is the layout of the `backends/openvino/tests` directory, which includes the necessary files for the example applications: + +``` +backends/openvino/tests +├── ops # Directory with base op test script and individual op tests. + ├── base_openvino_op_test.py # Script which contains the base class for all op tests. + └── test_.py # Individual op tests scripts. +├── models # Directory with model test scripts. + └── test_classification.py # Test script for classification models. +├── README.md # Documentation for unit tests (this file) +└── test_runner.py # Script to execute unit tests. +``` + +## Executing Unit Tests + +### Prerequisites + +Before you begin, refer to instructions provided in [OpenVINO Backend for ExecuTorch](../README.md) to install OpenVINO and ExecuTorch Python package with the OpenVINO backend into your Python environment. + +### Usage + +`test_runner.py` allows to run op or model tests for openvino backend. + +### **Arguments** +- **`--test_type`** (optional): + Type of the tests to run. + Supported values: + - `ops` (default) + - `models` + +- **`--pattern`** (optional): + Pattern to match test files. Provide complete file name to run individual tests. The default value is `test_*.py` + Examples: + - `test_convolution.py` (Assuming `--test_type` parameter is provided as `ops`, this will run only convolution tests) + - `test_add*.py` (Assuming `--test_type` parameter is provided as `ops`, this will run add and addmm op tests) + +- **`--device`** (optional): + Target device to compile and run tests. Default is `CPU`. + Examples: `CPU`, `GPU` + + +## **Examples** + +### Execute Tests for All Ops on CPU +```bash +python test_runner.py --device CPU --test_type ops +``` + +### Execute Convolution Op Tests on CPU +```bash +python test_runner.py --device CPU --test_type ops --pattern test_convolution.py +``` + +### Execute Tests for all Models on GPU +```bash +python test_runner.py --device GPU --test_type models diff --git a/backends/openvino/tests/models/test_classification.py b/backends/openvino/tests/models/test_classification.py new file mode 100644 index 00000000000..78ce6a2777f --- /dev/null +++ b/backends/openvino/tests/models/test_classification.py @@ -0,0 +1,38 @@ +import timm # type: ignore[import-untyped] +import torch +import torchvision.models as torchvision_models # type: ignore[import-untyped] +from executorch.backends.openvino.tests.ops.base_openvino_op_test import ( + BaseOpenvinoOpTest, +) +from transformers import AutoModel # type: ignore[import-untyped] + +classifier_params = [ + {"model": ["torchvision", "resnet50", (1, 3, 224, 224)]}, + {"model": ["torchvision", "mobilenet_v2", (1, 3, 224, 224)]}, +] + + +# Function to load a model based on the selected suite +def load_model(suite: str, model_name: str): + if suite == "timm": + return timm.create_model(model_name, pretrained=True) + elif suite == "torchvision": + if not hasattr(torchvision_models, model_name): + raise ValueError(f"Model {model_name} not found in torchvision.") + return getattr(torchvision_models, model_name)(pretrained=True) + elif suite == "huggingface": + return AutoModel.from_pretrained(model_name) + else: + raise ValueError(f"Unsupported model suite: {suite}") + + +class TestClassifier(BaseOpenvinoOpTest): + + def test_classifier(self): + for params in classifier_params: + with self.subTest(params=params): + module = load_model(params["model"][0], params["model"][1]) + + sample_input = (torch.randn(params["model"][2]),) + + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/base_openvino_op_test.py b/backends/openvino/tests/ops/base_openvino_op_test.py new file mode 100644 index 00000000000..c429845548a --- /dev/null +++ b/backends/openvino/tests/ops/base_openvino_op_test.py @@ -0,0 +1,86 @@ +import unittest + +import executorch +import torch +from executorch.backends.openvino.partitioner import OpenvinoPartitioner +from executorch.backends.openvino.preprocess import OpenvinoBackend +from executorch.exir import EdgeProgramManager, to_edge_transform_and_lower +from executorch.exir.backend.backend_details import CompileSpec +from executorch.runtime import Runtime + +from torch.export import export, ExportedProgram + + +class BaseOpenvinoOpTest(unittest.TestCase): + device = "CPU" + + atol = 1e-3 + rtol = 1e-3 + + def execute_layer_test( + self, + module: torch.nn.Module, + sample_inputs: tuple[torch.Tensor], + expected_partitions: int = 1, + assert_output_equal: bool = True, + ): + + module = module.eval() + # Export to aten dialect using torch.export + aten_dialect: ExportedProgram = export(module, sample_inputs) + + # Convert to edge dialect and lower the module to the backend with a custom partitioner + compile_spec = [CompileSpec("device", self.device.encode())] + lowered_module: EdgeProgramManager = to_edge_transform_and_lower( + aten_dialect, + partitioner=[ + OpenvinoPartitioner(compile_spec), + ], + ) + + # Apply backend-specific passes + exec_prog = lowered_module.to_executorch( + config=executorch.exir.ExecutorchBackendConfig() + ) + + # Check if the number of partitions created matches the expected number of partitions + self.assertEqual( + len(exec_prog.executorch_program.execution_plan[0].delegates), + expected_partitions, + ) + # Check if the individual partitions are assigned to Openvino backend + for i in range(expected_partitions): + self.assertEqual( + exec_prog.executorch_program.execution_plan[0].delegates[i].id, + OpenvinoBackend.__name__, + ) + + # Execute the model and compare the outputs with the reference outputs + if assert_output_equal: + # Execute the module in eager mode to calculate the reference outputs + ref_output = module(*sample_inputs) + if isinstance(ref_output, torch.Tensor): + ref_output = [ + ref_output, + ] + + # Load model from buffer and execute + runtime = Runtime.get() + program = runtime.load_program(exec_prog.buffer) + method = program.load_method("forward") + assert method is not None + outputs = method.execute(sample_inputs) + + # Compare the outputs with the reference outputs + self.assertTrue(len(ref_output) == len(outputs)) + for i in range(len(ref_output)): + self.assertTrue( + torch.allclose( + outputs[i], + ref_output[i], + atol=self.atol, + rtol=self.rtol, + equal_nan=True, + ), + msg=f"ref_output:\n{ref_output[i]}\n\ntest_output:\n{outputs[i]}", + ) diff --git a/backends/openvino/tests/ops/test_add.py b/backends/openvino/tests/ops/test_add.py new file mode 100644 index 00000000000..5b68d0ff149 --- /dev/null +++ b/backends/openvino/tests/ops/test_add.py @@ -0,0 +1,22 @@ +import torch +from executorch.backends.openvino.tests.ops.base_openvino_op_test import ( + BaseOpenvinoOpTest, +) + + +class TestAddOperator(BaseOpenvinoOpTest): + + def create_model(self): + class Add(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.add(x, y) + + return Add() + + def test_add(self): + module = self.create_model() + sample_input = (torch.randn(2, 5, 1, 3), torch.randn(2, 5, 1, 3)) + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/test_addmm.py b/backends/openvino/tests/ops/test_addmm.py new file mode 100644 index 00000000000..51c1314db0d --- /dev/null +++ b/backends/openvino/tests/ops/test_addmm.py @@ -0,0 +1,28 @@ +import torch +from executorch.backends.openvino.tests.ops.base_openvino_op_test import ( + BaseOpenvinoOpTest, +) + + +class TestAddMMOperator(BaseOpenvinoOpTest): + + def create_model(self): + class AddMM(torch.nn.Module): + def __init__(self): + super().__init__() + self.alpha = 1.0 + self.beta = 1.0 + + def forward(self, x, y, z): + # return torch.add(x, y) + return torch.addmm(x, y, z, alpha=self.alpha, beta=self.beta) + + return AddMM() + + def test_addmm(self): + module = self.create_model() + input_x = torch.randn(4, 4, dtype=torch.float32) + input_y = torch.randn(4, 4, dtype=torch.float32) + input_z = torch.randn(4, 4, dtype=torch.float32) + sample_input = (input_x, input_y, input_z) + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/test_arange.py b/backends/openvino/tests/ops/test_arange.py new file mode 100644 index 00000000000..b2aeb9c2100 --- /dev/null +++ b/backends/openvino/tests/ops/test_arange.py @@ -0,0 +1,23 @@ +import torch +from executorch.backends.openvino.tests.ops.base_openvino_op_test import ( + BaseOpenvinoOpTest, +) + + +class TestArangeOperator(BaseOpenvinoOpTest): + + def create_model(self, x): + class Arange(torch.nn.Module): + def __init__(self, x): + super().__init__() + self.x = x + + def forward(self, y): + return torch.arange(self.x, dtype=torch.float32) + y + + return Arange(5) + + def test_arange(self): + module = self.create_model(5) + sample_input = (torch.randn(5),) + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/test_batch_norm.py b/backends/openvino/tests/ops/test_batch_norm.py new file mode 100644 index 00000000000..05d529163f9 --- /dev/null +++ b/backends/openvino/tests/ops/test_batch_norm.py @@ -0,0 +1,61 @@ +import torch +from executorch.backends.openvino.tests.ops.base_openvino_op_test import ( + BaseOpenvinoOpTest, +) + +op_params = [ + {"weights": True, "bias": True, "eps": 1.0}, + {"weights": True, "bias": True, "eps": 0.00005}, + {"weights": True, "bias": True, "eps": 0.5}, + {"weights": True, "bias": True, "eps": 0.042}, + {"weights": True, "bias": False, "eps": 1.0}, + {"weights": True, "bias": False, "eps": 0.00005}, + {"weights": True, "bias": False, "eps": 0.5}, + {"weights": True, "bias": False, "eps": 0.042}, + {"weights": False, "bias": True, "eps": 1.0}, + {"weights": False, "bias": True, "eps": 0.00005}, + {"weights": False, "bias": True, "eps": 0.5}, + {"weights": False, "bias": True, "eps": 0.042}, + {"weights": False, "bias": False, "eps": 1.0}, + {"weights": False, "bias": False, "eps": 0.00005}, + {"weights": False, "bias": False, "eps": 0.5}, + {"weights": False, "bias": False, "eps": 0.042}, +] + + +class TestBatchNormOperator(BaseOpenvinoOpTest): + + def create_model(self, weights, bias, eps): + + class BatchNorm(torch.nn.Module): + def __init__(self, weights=True, bias=True, eps=1e-05): + super(BatchNorm, self).__init__() + self.weight = torch.nn.Parameter(torch.randn(6)) if weights else None + self.bias = torch.nn.Parameter(torch.randn(6)) if bias else None + self.running_mean = torch.randn(6) + self.running_var = torch.randn(6) + self.eps = eps + + def forward(self, x): + return torch.nn.functional.batch_norm( + x, + self.running_mean, + self.running_var, + self.weight, + self.bias, + eps=self.eps, + training=False, + ) + + return BatchNorm(weights, bias, eps) + + def test_batch_norm(self): + for params in op_params: + with self.subTest(params=params): + module = self.create_model( + weights=params["weights"], bias=params["bias"], eps=params["eps"] + ) + + sample_input = (torch.randn(20, 6, 10),) + + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/test_convolution.py b/backends/openvino/tests/ops/test_convolution.py new file mode 100644 index 00000000000..45d785d3612 --- /dev/null +++ b/backends/openvino/tests/ops/test_convolution.py @@ -0,0 +1,296 @@ +import torch +from executorch.backends.openvino.tests.ops.base_openvino_op_test import ( + BaseOpenvinoOpTest, +) + +d2_params = [ + { + "weights_shape": [3, 3, 2, 2], + "strides": [1, 1], + "pads": [0, 0], + "dilations": [1, 1], + "groups": 1, + "output_padding": [0, 0], + "transposed": True, + }, + { + "weights_shape": [3, 3, 2, 2], + "strides": [1, 1], + "pads": [0, 0], + "dilations": [1, 1], + "groups": 1, + "output_padding": [0, 0], + "transposed": False, + }, + { + "weights_shape": [3, 1, 1, 1], + "strides": [1, 1], + "pads": [0, 0], + "dilations": [1, 1], + "groups": 3, + "output_padding": [0, 0], + "transposed": True, + }, + { + "weights_shape": [3, 1, 1, 1], + "strides": [1, 1], + "pads": [0, 0], + "dilations": [1, 1], + "groups": 3, + "output_padding": [0, 0], + "transposed": False, + }, + { + "weights_shape": [3, 1, 1, 1], + "strides": [1, 1], + "bias_shape": [1], + "pads": [1, 1], + "dilations": [1, 1], + "groups": 1, + "output_padding": [0, 0], + "transposed": True, + }, + { + "weights_shape": [3, 3, 1, 1], + "strides": [1, 1], + "pads": [1, 1], + "dilations": [1, 1], + "groups": 1, + "output_padding": [0, 0], + "transposed": False, + }, + { + "weights_shape": [3, 1, 1, 1], + "strides": [1, 1], + "bias_shape": [1], + "pads": [3, 1], + "dilations": [1, 1], + "groups": 1, + "output_padding": [0, 0], + "transposed": True, + }, + { + "weights_shape": [3, 3, 1, 1], + "strides": [1, 1], + "pads": [3, 1], + "dilations": [1, 1], + "groups": 1, + "output_padding": [0, 0], + "transposed": False, + }, + { + "weights_shape": [3, 1, 1, 1], + "strides": [1, 1], + "bias_shape": [1], + "pads": [1, 0], + "dilations": [1, 1], + "groups": 1, + "output_padding": [0, 0], + "transposed": True, + }, + { + "weights_shape": [3, 3, 1, 1], + "strides": [1, 1], + "pads": [0, 1], + "dilations": [1, 1], + "groups": 1, + "output_padding": [0, 0], + "transposed": False, + }, + { + "weights_shape": [3, 1, 1, 1], + "strides": [1, 1], + "pads": [1, 0], + "dilations": [1, 1], + "groups": 3, + "output_padding": [0, 0], + "transposed": True, + }, + { + "weights_shape": [3, 1, 1, 1], + "strides": [1, 1], + "pads": [0, 1], + "dilations": [1, 1], + "groups": 3, + "output_padding": [0, 0], + "transposed": False, + }, + { + "weights_shape": [3, 1, 1, 1], + "strides": [1, 1], + "pads": [1, 0], + "dilations": [2, 2], + "groups": 3, + "output_padding": [0, 0], + "transposed": True, + }, + { + "weights_shape": [3, 1, 1, 1], + "strides": [1, 1], + "pads": [0, 0], + "dilations": [2, 2], + "groups": 3, + "output_padding": [0, 0], + "transposed": False, + }, + { + "weights_shape": [3, 1, 1, 1], + "strides": [2, 1], + "bias_shape": [1], + "pads": [1, 0], + "dilations": [1, 1], + "groups": 1, + "output_padding": [0, 0], + "transposed": True, + }, + { + "weights_shape": [3, 3, 1, 1], + "strides": [2, 1], + "pads": [0, 0], + "dilations": [1, 1], + "groups": 1, + "output_padding": [0, 0], + "transposed": False, + }, + { + "weights_shape": [3, 1, 1, 1], + "strides": [2, 2], + "bias_shape": [1], + "pads": [0, 0], + "dilations": [1, 1], + "groups": 1, + "output_padding": [0, 0], + "transposed": True, + }, + { + "weights_shape": [3, 3, 1, 1], + "strides": [2, 2], + "pads": [0, 0], + "dilations": [1, 1], + "groups": 1, + "output_padding": [0, 0], + "transposed": False, + }, + { + "weights_shape": [3, 3, 1, 1], + "strides": [2, 1], + "pads": [0, 0], + "dilations": [1, 1], + "groups": 1, + "output_padding": [0, 0], + "transposed": False, + }, + { + "weights_shape": [3, 1, 1, 1], + "strides": [2, 2], + "bias_shape": [1], + "pads": [0, 0], + "dilations": [1, 1], + "groups": 1, + "output_padding": [0, 0], + "transposed": True, + }, + { + "weights_shape": [3, 1, 1, 1], + "strides": [2, 2], + "bias_shape": [1], + "pads": [1, 1], + "dilations": [2, 2], + "groups": 1, + "output_padding": [1, 1], + "transposed": True, + }, +] + + +class TestConvolutionOperator(BaseOpenvinoOpTest): + + def create_model( + self, + weights_shape, + strides, + pads, + dilations, + groups, + bias, + transposed, + output_padding=0, + bias_shape=None, + underscore=False, + ): + + bias_dim = 0 + + class Convolution(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(weights_shape)) + self.bias_shape = bias_shape + if self.bias_shape is None: + self.bias_shape = weights_shape[bias_dim] + self.bias = ( + torch.nn.Parameter(torch.randn(self.bias_shape)) if bias else None + ) + self.strides = strides + self.pads = pads + self.dilations = dilations + self.groups = groups + self.transposed = transposed + self.output_padding = output_padding + if underscore: + self.forward = self.forward_ + + def forward(self, x): + return torch.convolution( + x, + self.weight, + self.bias, + self.strides, + self.pads, + self.dilations, + self.transposed, + self.output_padding, + self.groups, + ) + + def forward_(self, x): + return torch._convolution( + x, + self.weight, + self.bias, + self.strides, + self.pads, + self.dilations, + self.transposed, + self.output_padding, + self.groups, + False, + False, + False, + False, + ) + + return Convolution() + + def test_convolution(self): + bias_underscore_config = [(False, False), (True, False)] + for bias, underscore in bias_underscore_config: + for params in d2_params: + with self.subTest(params=params, bias=bias, underscore=underscore): + bias_shape = None + if "bias_shape" in params: + bias_shape = params["bias_shape"] + module = self.create_model( + weights_shape=params["weights_shape"], + strides=params["strides"], + pads=params["pads"], + dilations=params["dilations"], + groups=params["groups"], + output_padding=params["output_padding"], + transposed=params["transposed"], + bias_shape=bias_shape, + bias=bias, + underscore=underscore, + ) + sample_input = (torch.randn(1, 3, 10, 10),) + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/test_mean.py b/backends/openvino/tests/ops/test_mean.py new file mode 100644 index 00000000000..9050ceb90af --- /dev/null +++ b/backends/openvino/tests/ops/test_mean.py @@ -0,0 +1,90 @@ +import torch +from executorch.backends.openvino.tests.ops.base_openvino_op_test import ( + BaseOpenvinoOpTest, +) + +op_params = [ + { + "axes": None, + "keep_dim": None, + "dtype": None, + }, + { + "axes": None, + "keep_dim": None, + "dtype": "float64", + }, + { + "axes": None, + "keep_dim": None, + "dtype": "float32", + }, + { + "axes": None, + "keep_dim": None, + "dtype": "int32", + }, + { + "axes": 0, + "keep_dim": False, + "dtype": None, + }, + { + "axes": 0, + "keep_dim": False, + "dtype": None, + }, +] + +dtypes = { + "float32": torch.float32, + "float64": torch.float64, + "int32": torch.int32, + "int64": torch.int64, + "int8": torch.int8, + "uint8": torch.uint8, +} + + +class TestMeanOperator(BaseOpenvinoOpTest): + + def create_model(self, axes, keep_dims, dtype): + + pt_dtype = dtypes.get(dtype) + + class Mean(torch.nn.Module): + def __init__(self, axes=None, keep_dims=None, dtype=None): + super(Mean, self).__init__() + self.axes = axes + self.keep_dims = keep_dims + self.dtype = dtype + + def forward(self, x): + if self.axes is None and self.keep_dims is None: + if self.dtype is None: + return torch.mean(x, dtype=self.dtype) + return torch.mean(x) + if self.axes is not None and self.keep_dims is None: + if self.dtype is None: + return torch.mean(x, self.axes) + return torch.mean(x, self.axes, dtype=self.dtype) + if self.dtype is None: + return torch.mean(x, self.axes, self.keep_dims) + return torch.mean(x, self.axes, self.keep_dims, dtype=self.dtype) + + return Mean(axes, keep_dims, pt_dtype) + + def test_mean(self): + for params in op_params: + with self.subTest(params=params): + module = self.create_model( + axes=params["axes"], + keep_dims=params["keep_dim"], + dtype=params["dtype"], + ) + + sample_input = ( + torch.randint(-10, 10, (1, 3, 224, 224)).to(dtype=torch.float32), + ) + + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/test_permute.py b/backends/openvino/tests/ops/test_permute.py new file mode 100644 index 00000000000..28ef5ab4369 --- /dev/null +++ b/backends/openvino/tests/ops/test_permute.py @@ -0,0 +1,33 @@ +import torch +from executorch.backends.openvino.tests.ops.base_openvino_op_test import ( + BaseOpenvinoOpTest, +) + +op_params = [ + {"order": [0, 2, 3, 1]}, + {"order": [0, 3, 1, 2]}, +] + + +class TestPermuteOperator(BaseOpenvinoOpTest): + + def create_model(self, order): + + class Permute(torch.nn.Module): + def __init__(self, order): + super(Permute, self).__init__() + self.order = order + + def forward(self, x): + return torch.permute(x, self.order) + + return Permute(order) + + def test_permute(self): + for params in op_params: + with self.subTest(params=params): + module = self.create_model(order=params["order"]) + + sample_input = (torch.randn(1, 3, 224, 224),) + + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/test_pooling.py b/backends/openvino/tests/ops/test_pooling.py new file mode 100644 index 00000000000..bc42b52faaa --- /dev/null +++ b/backends/openvino/tests/ops/test_pooling.py @@ -0,0 +1,91 @@ +import torch +from executorch.backends.openvino.tests.ops.base_openvino_op_test import ( + BaseOpenvinoOpTest, +) + +d2_params = [ + {"kernel_size": [3, 3], "stride": 1, "padding": 0}, + {"kernel_size": [3, 3], "stride": [1, 1], "padding": 1}, + {"kernel_size": [3, 3], "stride": [1, 1], "padding": [0, 1]}, + {"kernel_size": [3, 3], "stride": [1, 1], "padding": [1, 0]}, + {"kernel_size": [3, 3], "stride": [2, 1], "padding": 0}, + {"kernel_size": [2, 1], "stride": [2, 1], "padding": 0}, + {"kernel_size": [2, 1], "stride": None, "padding": 0}, + {"kernel_size": [2, 1], "stride": [], "padding": 0}, + {"kernel_size": [8, 8], "stride": [8, 4], "padding": 1}, +] + + +class TestPoolingOperator(BaseOpenvinoOpTest): + + def create_model( + self, + op_type, + kernel_size, + stride, + padding, + dilation=1, + ceil_mode=True, + count_include_pad=True, + dtype=torch.float32, + ): + + class MaxPoolingBase(torch.nn.Module): + def __init__(self): + super().__init__() + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.ceil_mode = ceil_mode + self.dtype = dtype + + def forward(self, x): + pass + + class MaxPool2D(MaxPoolingBase): + def forward(self, x): + return torch.nn.functional.max_pool2d( + x.to(self.dtype), + self.kernel_size, + self.stride, + self.padding, + self.dilation, + self.ceil_mode, + ) + + class MaxPool2DIndices(MaxPoolingBase): + def forward(self, x): + return torch.nn.functional.max_pool2d( + x, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + self.ceil_mode, + return_indices=True, + ) + + ops = { + "MaxPool2D": MaxPool2D, + "MaxPool2DIndices": MaxPool2DIndices, + } + + aten_pooling = ops[op_type] + + return aten_pooling() + + def test_pooling2d(self): + for params in d2_params: + with self.subTest(params=params): + module = self.create_model( + op_type="MaxPool2D", + kernel_size=params["kernel_size"], + stride=params["stride"], + padding=params["padding"], + dilation=1, + ceil_mode=True, + count_include_pad=True, + ) + sample_input = (torch.randn(1, 3, 15, 15),) + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/test_unary_ops.py b/backends/openvino/tests/ops/test_unary_ops.py new file mode 100644 index 00000000000..99787e587b3 --- /dev/null +++ b/backends/openvino/tests/ops/test_unary_ops.py @@ -0,0 +1,37 @@ +import torch +from executorch.backends.openvino.tests.ops.base_openvino_op_test import ( + BaseOpenvinoOpTest, +) + + +OPS = [ + torch.relu, +] + + +class TestUnaryOperator(BaseOpenvinoOpTest): + + def create_model(self, op, dtype): + + class UnaryOp(torch.nn.Module): + def __init__(self, op, dtype): + super().__init__() + self.dtype = dtype + self.op = op + + def forward(self, x): + x1 = x.to(self.dtype) + y = self.op(x1) + return y, x1 + + return UnaryOp(op, dtype) + + def test_unary_op(self): + for op in OPS: + with self.subTest(op=OPS): + + module = self.create_model(op, dtype=torch.float32) + + sample_input = (torch.rand(2, 10) * 10 + 1,) + + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/test_view.py b/backends/openvino/tests/ops/test_view.py new file mode 100644 index 00000000000..8aef13fffa0 --- /dev/null +++ b/backends/openvino/tests/ops/test_view.py @@ -0,0 +1,35 @@ +import torch +from executorch.backends.openvino.tests.ops.base_openvino_op_test import ( + BaseOpenvinoOpTest, +) + +op_params = [ + {"input_shape": [2, 3, 2], "target_shape": [2, 6]}, + {"input_shape": [4], "target_shape": [2, 2]}, +] + + +class TestViewOperator(BaseOpenvinoOpTest): + + def create_model(self, target_shape): + + class View(torch.nn.Module): + + def __init__(self, target_shape) -> None: + super().__init__() + self.target_shape = target_shape + + def forward(self, input_tensor): + return input_tensor.view(self.target_shape) + + return View(target_shape) + + def test_view(self): + for params in op_params: + with self.subTest(params=params): + + module = self.create_model(params["target_shape"]) + + sample_input = (torch.randn(params["input_shape"]),) + + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/test_runner.py b/backends/openvino/tests/test_runner.py new file mode 100644 index 00000000000..021c372db25 --- /dev/null +++ b/backends/openvino/tests/test_runner.py @@ -0,0 +1,76 @@ +import argparse +import unittest + +import nncf.torch # type: ignore[import-untyped,import-not-found] + + +class OpenvinoTestSuite(unittest.TestSuite): + + test_params: dict[str, str] = {} + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def addTest(self, test): + # Set test parameters if this is an instance of TestOpenvino + from executorch.backends.openvino.tests.ops.base_openvino_op_test import ( + BaseOpenvinoOpTest, + ) + + if isinstance(test, BaseOpenvinoOpTest): + if "device" in self.test_params: + test.device = self.test_params["device"] + # Call the original addTest method to actually add the test to the suite + super().addTest(test) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-s", + "--device", + help="OpenVINO device to execute the model on", + type=str, + default="CPU", + ) + parser.add_argument( + "-p", + "--pattern", + help="Pattern to match test files. Provide complete file name to run individual tests", + type=str, + default="test_*.py", + ) + parser.add_argument( + "-t", + "--test_type", + help="Specify the type of tests ('ops' or 'models')", + type=str, + default="ops", + choices={"ops", "models"}, + ) + + args, ns_args = parser.parse_known_args(namespace=unittest) + test_params: dict[str, str] = {} + test_params["device"] = args.device + test_params["pattern"] = args.pattern + test_params["test_type"] = args.test_type + return test_params + + +if __name__ == "__main__": + loader = unittest.TestLoader() + # Replace the default test suite with a custom test suite to be able to + # pass test parameter to the test cases + loader.suiteClass = OpenvinoTestSuite + test_params = parse_arguments() + loader.suiteClass.test_params = test_params + # Discover all existing op tests in "ops" folder + suite = loader.discover(test_params["test_type"], pattern=test_params["pattern"]) + # Start running tests + with nncf.torch.disable_patching(): + result = unittest.TextTestRunner().run(suite) + if result.wasSuccessful(): + print("OpenVINO backend tests completed successfully") + else: + print("OpenVINO backend tests completed with failures") diff --git a/docs/source/build-run-openvino.md b/docs/source/build-run-openvino.md new file mode 100644 index 00000000000..f9ea5df0862 --- /dev/null +++ b/docs/source/build-run-openvino.md @@ -0,0 +1,114 @@ +# Building and Running ExecuTorch with OpenVINO Backend + +In this tutorial we will walk you through the process of setting up the prerequisites, building OpenVINO backend library, exporting `.pte` models with OpenVINO optimizations, and executing the exported models on Intel hardware. + + +::::{grid} 2 +:::{grid-item-card} What you will learn in this tutorial: +:class-card: card-prerequisites +* In this tutorial you will learn how to lower and deploy a model with OpenVINO. +::: +:::{grid-item-card} Tutorials we recommend you complete before this: +:class-card: card-prerequisites +* [Introduction to ExecuTorch](intro-how-it-works.md) +* [Setting up ExecuTorch](getting-started-setup.md) +* [Building ExecuTorch with CMake](runtime-build-and-cross-compilation.md) +::: +:::: + +## Introduction to OpenVINO + +[OpenVINO](https://www.intel.com/content/www/us/en/developer/tools/openvino-toolkit/overview.html) is an open-source toolkit designed to enhance AI inference on Intel hardware by reducing latency and increasing throughput while preserving accuracy. It optimizes hardware utilization and simplifies AI development and deep learning integration across domains such as computer vision, large language models (LLMs), and generative AI. + +OpenVINO is integrated as an Executorch delegate to accelerate AI applications deployed with Executorch APIs. + +## Supported Hardware + +OpenVINO backend supports the following hardware: + +- Intel CPUs +- Intel integrated GPUs +- Intel discrete GPUs +- Intel NPUs + +For more information on the supported hardware, please refer to [OpenVINO System Requirements](https://docs.openvino.ai/2025/about-openvino/release-notes-openvino/system-requirements.html) page. + +## Instructions for Building OpenVINO Backend + +### Prerequisites + +Before you begin, ensure you have openvino installed and configured on your system: + + +```bash +git clone https://github.com/openvinotoolkit/openvino.git +cd openvino && git checkout releases/2025/1 +git submodule update --init --recursive +sudo ./install_build_dependencies.sh +mkdir build && cd build +cmake .. -DCMAKE_BUILD_TYPE=Release -DENABLE_PYTHON=ON +make -j + +cd .. +cmake --install build --prefix +cd +source setupvars.sh +``` +Note: The OpenVINO backend is not yet supported with the current OpenVINO release packages. It is recommended to build from source. The instructions for using OpenVINO release packages will be added soon. +For more information about OpenVINO build, refer to the [OpenVINO Build Instructions](https://github.com/openvinotoolkit/openvino/blob/master/docs/dev/build_linux.md). + +### Setup + +Follow the steps below to setup your build environment: + +1. **Setup ExecuTorch Environment**: Refer to the [Environment Setup](https://pytorch.org/executorch/stable/getting-started-setup#environment-setup) guide for detailed instructions on setting up the ExecuTorch environment. + +2. **Setup OpenVINO Backend Environment** +- Install the dependent libs. Ensure that you are inside `executorch/backends/openvino/` directory + ```bash + pip install -r requirements.txt + ``` + +3. Navigate to `scripts/` directory. + +4. **Build OpenVINO Backend**: Once the prerequisites are in place, run the `openvino_build.sh` script to start the build process, OpenVINO backend will be built under `cmake-out/backends/openvino/` as `libopenvino_backend.a` + + ```bash + ./openvino_build.sh + ``` + +## Build Instructions for Examples + +### AOT step: +Refer to the [README.md](../../examples/openvino/README.md) in the `executorch/examples/openvino` folder for detailed instructions on exporting deep learning models from various model suites (TIMM, Torchvision, Hugging Face) to openvino backend using Executorch. Users can dynamically specify the model, input shape, and target device. + +Below is an example to export a ResNet50 model from Torchvision model suite for CPU device with an input shape of `[1, 3, 256, 256]` + +```bash +cd executorch/examples/openvino +python aot_optimize_and_infer.py --export --suite torchvision --model resnet50 --input_shape "(1, 3, 256, 256)" --device CPU +``` +The exported model will be saved as 'resnet50.pte' in the current directory. + +### Build C++ OpenVINO Examples + +After building the OpenVINO backend following the [instructions](#setup) above, the executable will be saved in `/cmake-out/backends/openvino/`. + +The executable requires a model file (`.pte` file generated in the aot step) and the number of inference executions. + +#### Example Usage + +Run inference with a given model for 10 executions: + +``` +./openvino_executor_runner \ + --model_path=model.pte \ + --num_executions=10 +``` + + + +## Support + +If you encounter any issues while reproducing the tutorial, please file a github +issue on ExecuTorch repo and tag use `#openvino` tag diff --git a/examples/openvino/README.md b/examples/openvino/README.md new file mode 100644 index 00000000000..8856ccdce4e --- /dev/null +++ b/examples/openvino/README.md @@ -0,0 +1,185 @@ +# OpenVINO Backend Examples + +This guide provides detailed instructions on how to export models for Executorch and execute them on the OpenVINO backend. The examples demonstrate how to export a model, load a model, prepare input tensors, execute inference, and save the output results. + +## Directory Structure + +Below is the layout of the `examples/openvino` directory, which includes the necessary files for the example applications: + +``` +examples/openvino +├── README.md # Documentation for examples (this file) +└── aot_optimize_and_infer.py # Example script to export and execute models +``` + +# Build Instructions for Examples + +## Environment Setup +Follow the [instructions](../../backends/openvino/README.md) of **Prerequisites** and **Setup** in `backends/openvino/README.md` to set up the OpenVINO backend. + +## AOT step: + +The python script called `aot_optimize_and_infer.py` allows users to export deep learning models from various model suites (TIMM, Torchvision, Hugging Face) to a openvino backend using **Executorch**. Users can dynamically specify the model, input shape, and target device. + +### **Usage** + + +#### **Arguments** +- **`--suite`** (required): + Specifies the model suite to use. + Supported values: + - `timm` (e.g., VGG16, ResNet50) + - `torchvision` (e.g., resnet18, mobilenet_v2) + - `huggingface` (e.g., bert-base-uncased). NB: Quantization and validation is not supported yet. + +- **`--model`** (required): + Name of the model to export. + Examples: + - For `timm`: `vgg16`, `resnet50` + - For `torchvision`: `resnet18`, `mobilenet_v2` + - For `huggingface`: `bert-base-uncased`, `distilbert-base-uncased` + +- **`--input_shape`**(optional): + Input shape for the model. Provide this as a **list** or **tuple**. + Examples: + - `[1, 3, 224, 224]` (Zsh users: wrap in quotes) + - `(1, 3, 224, 224)` + +- **`--export`** (optional): + Save the exported model as a `.pte` file. + +- **`--model_file_name`** (optional): + Specify a custom file name to save the exported model. + +- **`--batch_size`** : + Batch size for the validation. Default batch_size == 1. + The dataset length must be evenly divisible by the batch size. + +- **`--quantize`** (optional): + Enable model quantization. --dataset argument is requred for the quantization. `huggingface` suite does not supported yet. + +- **`--validate`** (optional): + Enable model validation. --dataset argument is requred for the validation. `huggingface` suite does not supported yet. + +- **`--dataset`** (optional): + Path to the imagenet-like calibration dataset. + +- **`--infer`** (optional): + Execute inference with the compiled model and report average inference timing. + +- **`--num_iter`** (optional): + Number of iterations to execute inference. Default value for the number of iterations is `1`. + +- **`--warmup_iter`** (optional): + Number of warmup iterations to execute inference before timing begins. Default value for the warmup iterations is `0`. + +- **`--input_tensor_path`** (optional): + Path to the raw tensor file to be used as input for inference. If this argument is not provided, a random input tensor will be generated. + +- **`--output_tensor_path`** (optional): + Path to the raw tensor file which the output of the inference to be saved. + +- **`--device`** (optional) + Target device for the compiled model. Default is `CPU`. + Examples: `CPU`, `GPU` + + +#### **Examples** + +##### Export a TIMM VGG16 model for the CPU +```bash +python aot_optimize_and_infer.py --export --suite timm --model vgg16 --input_shape "[1, 3, 224, 224]" --device CPU +``` + +##### Export a Torchvision ResNet50 model for the GPU +```bash +python aot_optimize_and_infer.py --export --suite torchvision --model resnet50 --input_shape "(1, 3, 256, 256)" --device GPU +``` + +##### Export a Hugging Face BERT model for the CPU +```bash +python aot_optimize_and_infer.py --export --suite huggingface --model bert-base-uncased --input_shape "(1, 512)" --device CPU +``` +##### Export and validate TIMM Resnet50d model for the CPU +```bash +python aot_optimize_and_infer.py --export --suite timm --model vgg16 --input_shape [1, 3, 224, 224] --device CPU --validate --dataset /path/to/dataset +``` + +##### Export, quantize and validate TIMM Resnet50d model for the CPU +```bash +python aot_optimize_and_infer.py --export --suite timm --model vgg16 --input_shape [1, 3, 224, 224] --device CPU --validate --dataset /path/to/dataset --quantize +``` + +##### Execute Inference with Torchvision Inception V3 model for the CPU +```bash +python aot_optimize_and_infer.py --suite torchvision --model inception_v3 --infer --warmup_iter 10 --num_iter 100 --input_shape "(1, 3, 256, 256)" --device CPU +``` + +### **Notes** +1. **Input Shape in Zsh**: + If you are using Zsh, wrap `--input_shape` in quotes or use a tuple: + ```bash + --input_shape '[1, 3, 224, 224]' + --input_shape "(1, 3, 224, 224)" + ``` + +2. **Model Compatibility**: + Ensure the specified `model_name` exists in the selected `suite`. Use the corresponding library's documentation to verify model availability. + +3. **Output File**: + The exported model will be saved as `.pte` in the current directory. + +4. **Dependencies**: + - Python 3.8+ + - PyTorch + - Executorch + - TIMM (`pip install timm`) + - Torchvision + - Transformers (`pip install transformers`) + +### **Error Handling** +- **Model Not Found**: + If the script raises an error such as: + ```bash + ValueError: Model not found + ``` + Verify that the model name is correct for the chosen suite. + +- **Unsupported Input Shape**: + Ensure `--input_shape` is provided as a valid list or tuple. + + +## Build OpenVINO Examples +Build the backend libraries and executor runner by executing the script below in `/backends/openvino/scripts` folder: +```bash +./openvino_build.sh +``` +The executable is saved in `/cmake-out/backends/openvino/` + +### Run the Example with Executor Runner + +Now, run the example using the executable generated in the above step. The executable requires a model file (`.pte` file generated in the aot step), and optional number of inference executions. + +#### Command Syntax: + +``` +cd ../../cmake-out/backends/openvino + +./openvino_executor_runner \ + --model_path= \ + --num_executions= +``` +#### Command-Line Arguments + +- `--model_path`: (Required) Path to the model serialized in `.pte` format. +- `--num_executions`: (Optional) Number of times to run inference (default: 1). + +#### Example Usage + +Run inference with a given model for 10 iterations: + +``` +./openvino_executor_runner \ + --model_path=model.pte \ + --num_executions=10 +``` diff --git a/examples/openvino/aot_optimize_and_infer.py b/examples/openvino/aot_optimize_and_infer.py new file mode 100644 index 00000000000..acd9c896f42 --- /dev/null +++ b/examples/openvino/aot_optimize_and_infer.py @@ -0,0 +1,429 @@ +# Copyright (c) Intel Corporation +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file found in the +# LICENSE file in the root directory of this source tree. + +# mypy: disable-error-code="import-untyped,import-not-found" + +import argparse +import time +from typing import cast, List, Optional + +import executorch + +import nncf.torch +import timm +import torch +import torchvision.models as torchvision_models +from executorch.backends.openvino.partitioner import OpenvinoPartitioner +from executorch.backends.openvino.quantizer import quantize_model +from executorch.exir import ( + EdgeProgramManager, + ExecutorchProgramManager, + to_edge_transform_and_lower, +) +from executorch.exir.backend.backend_details import CompileSpec +from executorch.runtime import Runtime +from sklearn.metrics import accuracy_score +from timm.data import resolve_data_config +from timm.data.transforms_factory import create_transform +from torch.export import export +from torch.export.exported_program import ExportedProgram +from torchvision import datasets +from transformers import AutoModel + + +# Function to load a model based on the selected suite +def load_model(suite: str, model_name: str): + """ + Loads a pre-trained model from the specified model suite. + + :param suite: The suite from which to load the model. Supported values are: + - "timm": Uses `timm.create_model` to load the model. + - "torchvision": Loads a model from `torchvision.models`. Raises an error if the model does not exist. + - "huggingface": Loads a transformer model using `AutoModel.from_pretrained`. + :param model_name: The name of the model to load. + :return: The loaded model instance. + :raises ValueError: If the specified model suite is unsupported or the model is not found. + """ + if suite == "timm": + return timm.create_model(model_name, pretrained=True) + elif suite == "torchvision": + if not hasattr(torchvision_models, model_name): + msg = f"Model {model_name} not found in torchvision." + raise ValueError(msg) + return getattr(torchvision_models, model_name)(pretrained=True) + elif suite == "huggingface": + return AutoModel.from_pretrained(model_name) + else: + msg = f"Unsupported model suite: {suite}" + raise ValueError(msg) + + +def load_calibration_dataset( + dataset_path: str, + batch_size: int, + suite: str, + model: torch.nn.Module, + model_name: str, +): + """ + Loads a calibration dataset for model quantization. + + :param dataset_path: Path to the dataset directory. + :param batch_size: Number of samples per batch. + :param suite: The model suite used for preprocessing transformations. Supported values are: + - "torchvision": Uses predefined transformations for torchvision models. + - "timm": Uses dataset transformations based on the model's pretrained configuration. + :param model: The model instance, required for timm transformation resolution. + :param model_name: The model name, required for torchvision transformations. + :return: A DataLoader instance for the calibration dataset. + :raises ValueError: If the suite is unsupported for validation. + """ + val_dir = f"{dataset_path}/val" + + if suite == "torchvision": + transform = torchvision_models.get_model_weights( + model_name + ).DEFAULT.transforms() + elif suite == "timm": + transform = create_transform( + **resolve_data_config(model.pretrained_cfg, model=model) + ) + else: + msg = f"Validation is not supported yet for the suite {suite}" + raise ValueError(msg) + + val_dataset = datasets.ImageFolder(val_dir, transform=transform) + + calibration_dataset = torch.utils.data.DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=0, + pin_memory=True, + ) + + return calibration_dataset + + +def infer_model( + exec_prog: ExecutorchProgramManager, + inputs, + num_iter: int, + warmup_iter: int, + output_path: str, +) -> float: + """ + Executes inference and reports the average timing. + + :param exec_prog: ExecutorchProgramManager of the lowered model + :param inputs: The inputs for the model. + :param num_iter: The number of iterations to execute inference for timing. + :param warmup_iter: The number of iterations to execute inference for warmup before timing. + :param output_path: Path to the output tensor file to save the output of inference.. + :return: The average inference timing. + """ + # Load model from buffer + runtime = Runtime.get() + program = runtime.load_program(exec_prog.buffer) + method = program.load_method("forward") + if method is None: + raise ValueError("Load method failed") + + # Execute warmup + out = None + for _i in range(warmup_iter): + out = method.execute(inputs) + + # Execute inference and measure timing + time_total = 0.0 + for _i in range(num_iter): + time_start = time.time() + out = method.execute(inputs) + time_end = time.time() + time_total += time_end - time_start + + # Save output tensor as raw tensor file + if output_path: + assert out is not None + torch.save(out, output_path) + + # Return average inference timing + return time_total / float(num_iter) + + +def validate_model( + exec_prog: ExecutorchProgramManager, + calibration_dataset: torch.utils.data.DataLoader, +) -> float: + """ + Validates the model using the calibration dataset. + + :param exec_prog: ExecutorchProgramManager of the lowered model + :param calibration_dataset: A DataLoader containing calibration data. + :return: The accuracy score of the model. + """ + # Load model from buffer + runtime = Runtime.get() + program = runtime.load_program(exec_prog.buffer) + method = program.load_method("forward") + if method is None: + raise ValueError("Load method failed") + + # Iterate over the dataset and run the executor + predictions: List[int] = [] + targets = [] + for _idx, data in enumerate(calibration_dataset): + feature, target = data + targets.extend(target) + out = list(method.execute((feature,))) + predictions.extend(torch.stack(out).reshape(-1, 1000).argmax(-1)) + + # Check accuracy + return accuracy_score(predictions, targets) + + +def main( # noqa: C901 + suite: str, + model_name: str, + input_shape, + save_model: bool, + model_file_name: str, + quantize: bool, + validate: bool, + dataset_path: str, + device: str, + batch_size: int, + infer: bool, + num_iter: int, + warmup_iter: int, + input_path: str, + output_path: str, +): + """ + Main function to load, quantize, and validate a model. + + :param suite: The model suite to use (e.g., "timm", "torchvision", "huggingface"). + :param model_name: The name of the model to load. + :param input_shape: The input shape for the model. + :param save_model: Whether to save the compiled model as a .pte file. + :param model_file_name: Custom file name to save the exported model. + :param quantize: Whether to quantize the model. + :param validate: Whether to validate the model. + :param dataset_path: Path to the dataset for calibration/validation. + :param device: The device to run the model on (e.g., "cpu", "gpu"). + :param batch_size: Batch size for dataset loading. + :param infer: Whether to execute inference and report timing. + :param num_iter: The number of iterations to execute inference for timing. + :param warmup_iter: The number of iterations to execute inference for warmup before timing. + :param input_path: Path to the input tensor file to read the input for inference. + :param output_path: Path to the output tensor file to save the output of inference.. + + """ + + # Load the selected model + model = load_model(suite, model_name) + model = model.eval() + + calibration_dataset: Optional[torch.utils.data.DataLoader] = None + + if dataset_path: + calibration_dataset = load_calibration_dataset( + dataset_path, batch_size, suite, model, model_name + ) + if calibration_dataset is not None: + input_shape = tuple(next(iter(calibration_dataset))[0].shape) + print(f"Input shape retrieved from the model config: {input_shape}") + else: + msg = "Quantization requires a valid calibration dataset" + raise ValueError(msg) + # Ensure input_shape is a tuple + elif isinstance(input_shape, (list, tuple)): + input_shape = tuple(input_shape) + else: + msg = "Input shape must be a list or tuple." + raise ValueError(msg) + # Provide input + if input_path: + example_args = (torch.load(input_path, weights_only=False),) + elif suite == "huggingface": + if hasattr(model, "config") and hasattr(model.config, "vocab_size"): + vocab_size = model.config.vocab_size + else: + vocab_size = 30522 + example_args = (torch.randint(0, vocab_size, input_shape, dtype=torch.int64),) + else: + example_args = (torch.randn(*input_shape),) + + # Export the model to the aten dialect + aten_dialect: ExportedProgram = export(model, example_args) + + if quantize and calibration_dataset: + if suite == "huggingface": + msg = f"Quantization of {suite} models did not support yet." + raise ValueError(msg) + + # Quantize model + if not dataset_path: + msg = "Quantization requires a calibration dataset." + raise ValueError(msg) + + subset_size = 300 + batch_size = calibration_dataset.batch_size or 1 + subset_size = (subset_size // batch_size) + int(subset_size % batch_size > 0) + + def transform_fn(x): + return x[0] + + quantized_model = quantize_model( + cast(torch.fx.GraphModule, aten_dialect.module()), + calibration_dataset, + subset_size=subset_size, + transform_fn=transform_fn, + ) + + aten_dialect = export(quantized_model, example_args) + + # Convert to edge dialect and lower the module to the backend with a custom partitioner + compile_spec = [CompileSpec("device", device.encode())] + lowered_module: EdgeProgramManager = to_edge_transform_and_lower( + aten_dialect, + partitioner=[ + OpenvinoPartitioner(compile_spec), + ], + ) + + # Apply backend-specific passes + exec_prog = lowered_module.to_executorch( + config=executorch.exir.ExecutorchBackendConfig() + ) + + # Serialize and save it to a file + if save_model: + if not model_file_name: + model_file_name = f"{model_name}_{'int8' if quantize else 'fp32'}.pte" + with open(model_file_name, "wb") as file: + exec_prog.write_to_file(file) + print(f"Model exported and saved as {model_file_name} on {device}.") + + if validate and calibration_dataset: + if suite == "huggingface": + msg = f"Validation of {suite} models did not support yet." + raise ValueError(msg) + + if not dataset_path: + msg = "Validation requires a calibration dataset." + raise ValueError(msg) + + print("Start validation of the model:") + acc_top1 = validate_model(exec_prog, calibration_dataset) + print(f"acc@1: {acc_top1}") + + if infer: + print("Start inference of the model:") + avg_time = infer_model( + exec_prog, example_args, num_iter, warmup_iter, output_path + ) + print(f"Average inference time: {avg_time}") + + +if __name__ == "__main__": + # Argument parser for dynamic inputs + parser = argparse.ArgumentParser(description="Export models with executorch.") + parser.add_argument( + "--suite", + type=str, + required=True, + choices=["timm", "torchvision", "huggingface"], + help="Select the model suite (timm, torchvision, huggingface).", + ) + parser.add_argument( + "--model", type=str, required=True, help="Model name to be loaded." + ) + parser.add_argument( + "--input_shape", + type=eval, + help="Input shape for the model as a list or tuple (e.g., [1, 3, 224, 224] or (1, 3, 224, 224)).", + ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="Batch size for the validation. Default batch_size == 1." + " The dataset length must be evenly divisible by the batch size.", + ) + parser.add_argument( + "--export", action="store_true", help="Export the compiled model as .pte file." + ) + parser.add_argument( + "--model_file_name", + type=str, + help="Custom file name to save the exported model.", + ) + parser.add_argument( + "--quantize", action="store_true", help="Enable model quantization." + ) + parser.add_argument( + "--validate", + action="store_true", + help="Enable model validation. --dataset argument is required for the validation.", + ) + parser.add_argument( + "--infer", + action="store_true", + help="Run inference and report timing.", + ) + parser.add_argument( + "--num_iter", + type=int, + default=1, + help="The number of iterations to execute inference for timing.", + ) + parser.add_argument( + "--warmup_iter", + type=int, + default=0, + help="The number of iterations to execute inference for warmup before timing.", + ) + parser.add_argument( + "--input_tensor_path", + type=str, + help="Path to the input tensor file to read the input for inference.", + ) + parser.add_argument( + "--output_tensor_path", + type=str, + help="Path to the output tensor file to save the output of inference.", + ) + parser.add_argument("--dataset", type=str, help="Path to the validation dataset.") + parser.add_argument( + "--device", + type=str, + default="CPU", + help="Target device for compiling the model (e.g., CPU, GPU). Default is CPU.", + ) + + args = parser.parse_args() + + # Run the main function with parsed arguments + # Disable nncf patching as export of the patched model is not supported. + with nncf.torch.disable_patching(): + main( + args.suite, + args.model, + args.input_shape, + args.export, + args.model_file_name, + args.quantize, + args.validate, + args.dataset, + args.device, + args.batch_size, + args.infer, + args.num_iter, + args.warmup_iter, + args.input_tensor_path, + args.output_tensor_path, + ) diff --git a/install_executorch.py b/install_executorch.py index 85703903ffc..6863ed2c0fc 100644 --- a/install_executorch.py +++ b/install_executorch.py @@ -53,7 +53,7 @@ def clean(): # Please keep this insync with `ShouldBuild.pybindings` in setup.py. -VALID_PYBINDS = ["coreml", "mps", "xnnpack", "training"] +VALID_PYBINDS = ["coreml", "mps", "xnnpack", "training", "openvino"] ################################################################################ diff --git a/setup.py b/setup.py index 871fdf329c2..76fbbbd9025 100644 --- a/setup.py +++ b/setup.py @@ -121,6 +121,7 @@ def pybindings(cls) -> bool: [ cls.coreml(), cls.mps(), + cls.openvino(), cls.xnnpack(), cls.training(), ] @@ -135,6 +136,10 @@ def coreml(cls) -> bool: def mps(cls) -> bool: return cls._is_cmake_arg_enabled("EXECUTORCH_BUILD_MPS", default=False) + @classmethod + def openvino(cls) -> bool: + return cls._is_cmake_arg_enabled("EXECUTORCH_BUILD_OPENVINO", default=False) + @classmethod def xnnpack(cls) -> bool: return cls._is_cmake_arg_enabled("EXECUTORCH_BUILD_XNNPACK", default=False) diff --git a/tools/cmake/Utils.cmake b/tools/cmake/Utils.cmake index b66a4eb9cf5..8f3e37d9a9e 100644 --- a/tools/cmake/Utils.cmake +++ b/tools/cmake/Utils.cmake @@ -115,6 +115,10 @@ function(executorch_print_configuration_summary) STATUS " EXECUTORCH_BUILD_NEURON : ${EXECUTORCH_BUILD_NEURON}" ) + message( + STATUS + " EXECUTORCH_BUILD_OPENVINO : ${EXECUTORCH_BUILD_OPENVINO}" + ) message( STATUS " EXECUTORCH_BUILD_PTHREADPOOL : ${EXECUTORCH_BUILD_PTHREADPOOL}"