From 978c40d85310a1d9b8a6069be853bc3dbec44e18 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Tue, 12 Mar 2024 10:55:49 +1000 Subject: [PATCH] Make partitioning utils QDQ aware so it does not break up QDQ node units (#19723) ### Description If the EP handles QDQ node units, we need to make sure we do not split those into different partitions. Update the partitioning utils to be QDQ aware. If there are node units we process the logical nodes they represent instead of individual nodes. This ensure we process all nodes in a QDQ node unit at the same time so that they are always in the same partition. ### Motivation and Context Fix one of the issues in #19590 --------- Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> --- .../external/onnxruntime_external_deps.cmake | 9 +- cmake/onnxruntime_providers_coreml.cmake | 4 +- cmake/onnxruntime_providers_nnapi.cmake | 6 +- cmake/onnxruntime_providers_qnn.cmake | 8 +- cmake/onnxruntime_providers_xnnpack.cmake | 3 - onnxruntime/core/framework/node_unit.cc | 351 ++++++++++++++++++ .../node_unit => framework}/node_unit.h | 51 ++- .../selectors_actions/qdq_selectors.cc | 12 +- .../selectors_actions/qdq_selectors.h | 8 +- .../selectors_actions/shared/utils.cc | 62 +++- .../selectors_actions/shared/utils.h | 16 +- .../providers/js/js_execution_provider.cc | 1 - .../nnapi/nnapi_builtin/builders/helper.cc | 2 +- .../builders/impl/base_op_builder.h | 2 +- .../nnapi_builtin/builders/model_builder.cc | 12 +- .../builders/op_builder_helpers.cc | 1 - .../builders/op_builder_helpers.h | 2 +- .../nnapi_builtin/nnapi_execution_provider.cc | 8 +- .../core/providers/partitioning_utils.cc | 105 ++++-- .../core/providers/partitioning_utils.h | 14 +- .../core/providers/qnn/builder/op_builder.h | 2 +- .../core/providers/qnn/builder/qnn_model.cc | 4 +- .../core/providers/qnn/builder/qnn_model.h | 2 +- .../providers/qnn/builder/qnn_model_wrapper.h | 2 +- .../providers/qnn/qnn_execution_provider.cc | 69 ++-- .../providers/shared/node_unit/node_unit.cc | 319 ---------------- .../core/providers/shared/utils/utils.cc | 10 +- onnxruntime/core/providers/utils.cc | 3 +- .../xnnpack/detail/node_support_checker.cc | 2 +- .../core/providers/xnnpack/detail/utils.cc | 6 +- .../core/providers/xnnpack/detail/utils.h | 2 +- .../xnnpack/xnnpack_execution_provider.cc | 18 +- .../mlas/unittest/test_fp16_activation.cpp | 1 + .../test/optimizer/qdq_transformer_test.cc | 14 +- .../internal_testing_execution_provider.cc | 1 + .../test/providers/partitioning_utils_test.cc | 174 +++++++++ .../test/testdata/ort_github_issue_19590.onnx | Bin 0 -> 599 bytes .../test/testdata/ort_github_issue_19590.py | 77 ++++ 38 files changed, 890 insertions(+), 493 deletions(-) create mode 100644 onnxruntime/core/framework/node_unit.cc rename onnxruntime/core/{providers/shared/node_unit => framework}/node_unit.h (54%) delete mode 100644 onnxruntime/core/providers/shared/node_unit/node_unit.cc create mode 100644 onnxruntime/test/providers/partitioning_utils_test.cc create mode 100644 onnxruntime/test/testdata/ort_github_issue_19590.onnx create mode 100644 onnxruntime/test/testdata/ort_github_issue_19590.py diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index cb75b0b8751bb..e4fefdbf86369 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -37,8 +37,13 @@ if (onnxruntime_BUILD_UNIT_TESTS) set(gtest_disable_pthreads ON) endif() set(INSTALL_GTEST OFF CACHE BOOL "" FORCE) - if (CMAKE_SYSTEM_NAME STREQUAL "iOS") - # Needs to update onnxruntime/test/xctest/xcgtest.mm + if (IOS OR ANDROID) + # on mobile platforms the absl flags class dumps the flag names (assumably for binary size), which breaks passing + # any args to gtest executables, such as using --gtest_filter to debug a specific test. + # Processing of compile definitions: + # https://github.com/abseil/abseil-cpp/blob/8dc90ff07402cd027daec520bb77f46e51855889/absl/flags/config.h#L21 + # If set, this code throws away the flag and does nothing on registration, which results in no flags being known: + # https://github.com/abseil/abseil-cpp/blob/8dc90ff07402cd027daec520bb77f46e51855889/absl/flags/flag.h#L205-L217 set(GTEST_HAS_ABSL OFF CACHE BOOL "" FORCE) else() set(GTEST_HAS_ABSL ON CACHE BOOL "" FORCE) diff --git a/cmake/onnxruntime_providers_coreml.cmake b/cmake/onnxruntime_providers_coreml.cmake index 8f3b1828e1c61..b8ebc4ca53239 100644 --- a/cmake/onnxruntime_providers_coreml.cmake +++ b/cmake/onnxruntime_providers_coreml.cmake @@ -70,8 +70,8 @@ list(FILTER coreml_proto_generated_srcs INCLUDE REGEX "\.pb\.(h|cc)$") source_group(TREE ${CMAKE_CURRENT_BINARY_DIR} PREFIX coreml_proto_generated FILES ${coreml_proto_generated_srcs}) # These are shared utils, -# TODO, move this to a separated lib when used by EPs other than NNAPI and CoreML -file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS +# TODO, move this to a separate lib when used by EPs other than NNAPI and CoreML +file(GLOB onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" ) diff --git a/cmake/onnxruntime_providers_nnapi.cmake b/cmake/onnxruntime_providers_nnapi.cmake index 5ac25a3b76efb..b718a976eb26f 100644 --- a/cmake/onnxruntime_providers_nnapi.cmake +++ b/cmake/onnxruntime_providers_nnapi.cmake @@ -49,12 +49,10 @@ endif() # These are shared utils, - # TODO, move this to a separated lib when used by EPs other than NNAPI and CoreML + # TODO, move this to a separate lib when used by EPs other than NNAPI and CoreML list(APPEND onnxruntime_provider_nnapi_cc_src_patterns "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc" ) file(GLOB onnxruntime_providers_nnapi_cc_srcs CONFIGURE_DEPENDS ${onnxruntime_provider_nnapi_cc_src_patterns}) @@ -81,4 +79,4 @@ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() \ No newline at end of file + endif() diff --git a/cmake/onnxruntime_providers_qnn.cmake b/cmake/onnxruntime_providers_qnn.cmake index a93a06e960c81..b68d84c23bb32 100644 --- a/cmake/onnxruntime_providers_qnn.cmake +++ b/cmake/onnxruntime_providers_qnn.cmake @@ -4,12 +4,10 @@ add_compile_definitions(USE_QNN=1) # These are shared utils, - # TODO, move this to a separated lib when used by EPs other than QNN, NNAPI and CoreML - file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS + # TODO, move to a separate lib when used by EPs other than QNN, NNAPI and CoreML + file(GLOB onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc" ) file(GLOB_RECURSE @@ -42,4 +40,4 @@ # ignore the warning unknown-pragmas on "pragma region" if(NOT MSVC) target_compile_options(onnxruntime_providers_qnn PRIVATE "-Wno-unknown-pragmas") - endif() \ No newline at end of file + endif() diff --git a/cmake/onnxruntime_providers_xnnpack.cmake b/cmake/onnxruntime_providers_xnnpack.cmake index 6342c24b2917e..796536ac9d12b 100644 --- a/cmake/onnxruntime_providers_xnnpack.cmake +++ b/cmake/onnxruntime_providers_xnnpack.cmake @@ -7,9 +7,6 @@ "${ONNXRUNTIME_INCLUDE_DIR}/core/providers/xnnpack/*.h" "${ONNXRUNTIME_ROOT}/core/providers/xnnpack/*.h" "${ONNXRUNTIME_ROOT}/core/providers/xnnpack/*.cc" - # utils for handling QDQ models - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc" ) source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_xnnpack_cc_srcs}) diff --git a/onnxruntime/core/framework/node_unit.cc b/onnxruntime/core/framework/node_unit.cc new file mode 100644 index 0000000000000..4dee1c14b3761 --- /dev/null +++ b/onnxruntime/core/framework/node_unit.cc @@ -0,0 +1,351 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + +#include "node_unit.h" +#include "core/graph/graph_viewer.h" + +namespace onnxruntime { + +namespace { + +enum class QLinearOpType : uint8_t { + Unknown, // Unknown or not a linear quantized op + DequantizeLinear, + QuantizeLinear, + QLinearConv, + QLinearMatMul, + QLinearAdd, + QLinearSigmoid, + QLinearAveragePool, + QLinearMul, + QLinearReduceMean, + QLinearConcat, + QLinearGlobalAveragePool, + QLinearLeakyRelu, +}; + +QLinearOpType GetQLinearOpType(const onnxruntime::Node& node) { + const auto& op_type = node.OpType(); + if (op_type == "DequantizeLinear") + return QLinearOpType::DequantizeLinear; + else if (op_type == "QuantizeLinear") + return QLinearOpType::QuantizeLinear; + else if (op_type == "QLinearConv") + return QLinearOpType::QLinearConv; + else if (op_type == "QLinearMatMul") + return QLinearOpType::QLinearMatMul; + else if (op_type == "QLinearAdd") + return QLinearOpType::QLinearAdd; + else if (op_type == "QLinearSigmoid") + return QLinearOpType::QLinearSigmoid; + else if (op_type == "QLinearAveragePool") + return QLinearOpType::QLinearAveragePool; + else if (op_type == "QLinearMul") + return QLinearOpType::QLinearMul; + else if (op_type == "QLinearReduceMean") + return QLinearOpType::QLinearReduceMean; + else if (op_type == "QLinearConcat") + return QLinearOpType::QLinearConcat; + else if (op_type == "QLinearGlobalAveragePool") + return QLinearOpType::QLinearGlobalAveragePool; + else if (op_type == "QLinearLeakyRelu") + return QLinearOpType::QLinearLeakyRelu; + + return QLinearOpType::Unknown; +} + +// Ops have 1 input +bool IsUnaryQLinearOp(QLinearOpType type) { + return type == QLinearOpType::QLinearSigmoid || + type == QLinearOpType::QLinearAveragePool || + type == QLinearOpType::QLinearGlobalAveragePool || + type == QLinearOpType::QLinearLeakyRelu || + type == QLinearOpType::QLinearReduceMean; +} + +// Ops have 2 inputs +bool IsBinaryQLinearOp(QLinearOpType type) { + return type == QLinearOpType::QLinearConv || + type == QLinearOpType::QLinearMatMul || + type == QLinearOpType::QLinearAdd || + type == QLinearOpType::QLinearMul; +} + +// Ops have 1 or more inputs +bool IsVariadicQLinearOp(QLinearOpType type) { + return type == QLinearOpType::QLinearConcat; +} + +const std::vector GetQDQIONodes(const GraphViewer& graph_viewer, + const QDQ::NodeGroup& node_group, bool is_input) { + std::vector io_nodes; + const auto& src_nodes = is_input ? node_group.dq_nodes : node_group.q_nodes; + io_nodes.reserve(src_nodes.size()); + for (const auto& node_idx : src_nodes) { + io_nodes.push_back(graph_viewer.GetNode(node_idx)); + } + + return io_nodes; +} + +// Get the input or output NodeUnitIODef(s) for the given QDQ NodeGroup +std::vector GetQDQIODefs(const Node& target_node, const QDQ::NodeGroup& node_group, bool is_input) { + const auto& dq_or_q_nodes = is_input ? node_group.dq_nodes : node_group.q_nodes; + const auto target_node_io_defs = is_input ? target_node.InputDefs() : target_node.OutputDefs(); + const size_t target_node_io_defs_size = target_node_io_defs.size(); + + // Find all the quantized IO defs and indices (for the input/output of the target node) + std::unordered_map quantized_io_defs; + quantized_io_defs.reserve(target_node_io_defs_size); + + auto cur = is_input ? target_node.InputEdgesBegin() : target_node.OutputEdgesBegin(); + auto end = is_input ? target_node.InputEdgesEnd() : target_node.OutputEdgesEnd(); + + for (; cur != end; ++cur) { + const Node& node = cur->GetNode(); + + // If we can find the node index in the dq or q nodes this is a quantized input/output + if (std::find(dq_or_q_nodes.cbegin(), dq_or_q_nodes.cend(), node.Index()) != dq_or_q_nodes.cend()) { + const auto node_inputs = node.InputDefs(); + // quantization scale and zp are always the input[1, 2] + NodeUnitIODef::QuantParam quant_param{*node_inputs[1], node_inputs.size() == 3 ? node_inputs[2] : nullptr}; + + if (is_input) { + // DQ is input to the target node, use the DstArgIndex + auto idx = cur->GetDstArgIndex(); + // This is a DQ node, we are using x, x_scale, x_zp (input[0, 1, 2]) + quantized_io_defs.insert({idx, NodeUnitIODef{*node_inputs[0], quant_param}}); + } else { + // Q is output of the target node, use the SrcArgIndex + auto idx = cur->GetSrcArgIndex(); + // This is a Q node, we are using y (output[0]), y_scale, y_zp (input[1, 2]) + const auto node_outputs = node.OutputDefs(); + quantized_io_defs.insert({idx, NodeUnitIODef{*node_outputs[0], quant_param}}); + } + } + } + + // Construct the IODefs for this QDQ NodeGroup + std::vector io_defs; + io_defs.reserve(target_node_io_defs_size); + for (size_t i = 0; i < target_node_io_defs_size; i++) { + // If we can find the NodeUnitIODef for this index, this is a quantized input/output + if (quantized_io_defs.find(i) != quantized_io_defs.cend()) { + io_defs.push_back(std::move(quantized_io_defs.at(i))); + } else { + // This is a regular input + io_defs.push_back({*target_node_io_defs[i], std::nullopt}); + } + } + + return io_defs; +} + +} // namespace + +Status QDQ::NodeGroup::CanCreateNodeGroup(const GraphViewer& graph_viewer, + const Node& target_node, + gsl::span dq_nodes, + gsl::span q_nodes) { + // Within a QDQ node group, a target node input is the only consumer of each DQ. + // This should have been ensured by the EnsureUniqueDQForNodeUnit graph transformer, but other graph modifications + // may have happened since. Verify that this is still true. + for (const auto* dq_node : dq_nodes) { + const bool dq_produces_graph_output = graph_viewer.NodeProducesGraphOutput(*dq_node); + ORT_RETURN_IF(dq_produces_graph_output, + "QDQ node group cannot have DQ node that produces a graph output. DQ node: ", dq_node->Name(), + ", target node: ", target_node.Name()); + + const bool dq_has_single_output_edge_to_target = + dq_node->GetOutputEdgesCount() == 1 && + dq_node->OutputEdgesBegin()->GetNode().Index() == target_node.Index(); + ORT_RETURN_IF_NOT(dq_has_single_output_edge_to_target, + "QDQ node group cannot have DQ that doesn't have a single output edge to the target node. " + "DQ node: ", + dq_node->Name(), ", target node: ", target_node.Name()); + } + + // an output from the target node can have either Q consumers or direct consumers. it cannot have both. + // this must be checked on a per output basis. + // e.g. TopK produces values and indices. The indices output won't be quantized, so even if we replace the TopK QDQ + // node group with a quantized TopK, an int64_t indices value will be produced and can provide a graph output. + if (!q_nodes.empty()) { + auto cur_edge = target_node.OutputEdgesBegin(); + auto end_edge = target_node.OutputEdgesEnd(); + std::vector output_consumers(target_node.OutputDefs().size(), nullptr); + + for (; cur_edge != end_edge; ++cur_edge) { + auto output_idx = cur_edge->GetSrcArgIndex(); + const Node& this_consumer = cur_edge->GetNode(); + const Node* existing_consumer = output_consumers[output_idx]; + + if (existing_consumer != nullptr) { + // another edge for this output. either both are Q or both are not. + bool valid = true; + if (existing_consumer->OpType() == "QuantizeLinear") { + valid = this_consumer.OpType() == "QuantizeLinear"; + } else { + valid = this_consumer.OpType() != "QuantizeLinear"; + } + + ORT_RETURN_IF_NOT(valid, + "QDQ node group cannot have an output from the target node being consumed by a Q node and " + "a non-Q node. target node: ", + target_node.Name()); + } else { + output_consumers[output_idx] = &this_consumer; + } + } + + const auto& graph_outputs = graph_viewer.GetOutputs(); + for (size_t idx = 0, end = output_consumers.size(); idx < end; ++idx) { + // any output with a Q cannot be a graph output as it will disappear if the QDQ node unit is converted to + // a quantized op. + if (output_consumers[idx] != nullptr && output_consumers[idx]->OpType() == "QuantizeLinear") { + const auto& output_name = target_node.OutputDefs()[idx]->Name(); + bool is_graph_output = std::any_of(graph_outputs.begin(), graph_outputs.end(), + [&output_name](const NodeArg* node_arg) { + return node_arg->Name() == output_name; + }); + ORT_RETURN_IF(is_graph_output, + "QDQ node group cannot have an output from the target node that is consumed by a Q node and " + "a graph output. target node: ", + target_node.Name(), " output idx:", idx); + } + } + } + + return Status::OK(); +} +NodeUnit::NodeUnit(const Node& node) + : target_node_(node), + type_(Type::SingleNode), + input_edge_count_(node.GetInputEdgesCount()) { + InitForSingleNode(); +} + +NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group) + : dq_nodes_{GetQDQIONodes(graph_viewer, node_group, true /* is_input */)}, + target_node_(*graph_viewer.GetNode(node_group.target_node)), + q_nodes_{GetQDQIONodes(graph_viewer, node_group, false /* is_input */)}, + type_(Type::QDQGroup), + inputs_{GetQDQIODefs(target_node_, node_group, true /* is_input */)}, + outputs_{GetQDQIODefs(target_node_, node_group, false /* is_input */)} { + ORT_THROW_IF_ERROR(QDQ::NodeGroup::CanCreateNodeGroup(graph_viewer, target_node_, dq_nodes_, q_nodes_)); + + input_edge_count_ = std::accumulate(dq_nodes_.cbegin(), dq_nodes_.cend(), size_t(0), + [](size_t acc, const Node* node) { return acc + node->GetInputEdgesCount(); }); + + // add edges for inputs that are not from DQ nodes. there is one edge to each DQ node. + // other inputs could come from initializers or graph inputs (no edges) or other nodes (edge). + input_edge_count_ += target_node_.GetInputEdgesCount() - dq_nodes_.size(); + + // create output edges. each target node output either goes to Q node/s or non-Q node/s. + // ValidateNodeGroupQDQNodes ensures this. + auto cur_edge = target_node_.OutputEdgesBegin(); + auto end_edge = target_node_.OutputEdgesEnd(); + for (; cur_edge != end_edge; ++cur_edge) { + const Node& node = cur_edge->GetNode(); + + // if node is in q_nodes we hide the Q node. + if (std::find(q_nodes_.cbegin(), q_nodes_.cend(), &node) != q_nodes_.cend()) { + auto src_idx = cur_edge->GetSrcArgIndex(); + auto q_cur_edge = node.OutputEdgesBegin(); + auto q_end_edge = node.OutputEdgesEnd(); + for (; q_cur_edge != q_end_edge; ++q_cur_edge) { + output_edges_.insert(Node::EdgeEnd{q_cur_edge->GetNode(), src_idx, q_cur_edge->GetDstArgIndex()}); + } + } else { + // non-Q node, or Q node that isn't in the QDQ node group (unexpected but may be possible). add as-is. + output_edges_.insert(*cur_edge); + } + } +} + +const std::string& NodeUnit::Domain() const noexcept { return target_node_.Domain(); } +const std::string& NodeUnit::OpType() const noexcept { return target_node_.OpType(); } +const std::string& NodeUnit::Name() const noexcept { return target_node_.Name(); } +int NodeUnit::SinceVersion() const noexcept { return target_node_.SinceVersion(); } +NodeIndex NodeUnit::Index() const noexcept { return target_node_.Index(); } +const Path& NodeUnit::ModelPath() const noexcept { return target_node_.ModelPath(); } +ProviderType NodeUnit::GetExecutionProviderType() const noexcept { return target_node_.GetExecutionProviderType(); } + +void NodeUnit::InitForSingleNode() { + const auto& input_defs = target_node_.InputDefs(); + const auto& output_defs = target_node_.OutputDefs(); + auto qlinear_type = GetQLinearOpType(target_node_); + if (qlinear_type == QLinearOpType::Unknown || IsVariadicQLinearOp(qlinear_type)) { // TODO, add variadic support + // Not a Qlinear op, add all inputs / outputs + auto add_all_io = [](std::vector& defs, + const ConstPointerContainer>& node_defs) { + defs.reserve(node_defs.size()); + + for (const auto def : node_defs) { + defs.push_back(NodeUnitIODef{*def, std::nullopt}); + } + }; + + add_all_io(inputs_, input_defs); + add_all_io(outputs_, output_defs); + } else if (IsUnaryQLinearOp(qlinear_type)) { + // Unary QLinear Op has 5 inputs + // x, x_scale, x_zp, y_scale, y_zp (optional) + inputs_.push_back(NodeUnitIODef{*input_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}}); + outputs_.push_back(NodeUnitIODef{*output_defs[0], + NodeUnitIODef::QuantParam{*input_defs[3], + input_defs.size() > 4 ? input_defs[4] : nullptr}}); + + } else if (IsBinaryQLinearOp(qlinear_type)) { + // Binary QLinear Op has 9 inputs + // x1, x1_scale, x1_zp, x2/w, x2_scale, x2_zp, y_scale , y_zp, B + inputs_.push_back(NodeUnitIODef{*input_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}}); + inputs_.push_back(NodeUnitIODef{*input_defs[3], NodeUnitIODef::QuantParam{*input_defs[4], input_defs[5]}}); + + if (input_defs.size() == 9) { // has Bias + inputs_.push_back(NodeUnitIODef{*input_defs[8], std::nullopt}); // for Bias the scale and zp are optional + } + + outputs_.push_back(NodeUnitIODef{*output_defs[0], NodeUnitIODef::QuantParam{*input_defs[6], input_defs[7]}}); + + } else if (qlinear_type == QLinearOpType::DequantizeLinear) { + // DequantizeLinear has 3 inputs + // x, x_scale, x_zp + // output is not quantized + inputs_.push_back(NodeUnitIODef{*input_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs.size() == 3 + ? input_defs[2] + : nullptr}}); + outputs_.push_back(NodeUnitIODef{*output_defs[0], std::nullopt}); + + } else if (qlinear_type == QLinearOpType::QuantizeLinear) { + // QuantizeLinear the input is not quantized and has 3 inputs + // x, y_scale, y_zp (optional) + // The output is quantized + inputs_.push_back(NodeUnitIODef{*input_defs[0], std::nullopt}); + outputs_.push_back(NodeUnitIODef{*output_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs.size() == 3 + ? input_defs[2] + : nullptr}}); + } else { + ORT_THROW("The QLinear op [", static_cast(qlinear_type), "] is not supported"); + } +} + +Node::EdgeConstIterator NodeUnit::OutputEdgesBegin() const { + return (type_ == Type::SingleNode) ? target_node_.OutputEdgesBegin() : output_edges_.begin(); +} + +Node::EdgeConstIterator NodeUnit::OutputEdgesEnd() const { + return (type_ == Type::SingleNode) ? target_node_.OutputEdgesEnd() : output_edges_.end(); +} + +std::vector NodeUnit::GetAllNodesInGroup() const noexcept { + std::vector all_nodes = dq_nodes_; + all_nodes.push_back(&target_node_); + all_nodes.insert(all_nodes.end(), q_nodes_.begin(), q_nodes_.end()); + return all_nodes; +} + +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/providers/shared/node_unit/node_unit.h b/onnxruntime/core/framework/node_unit.h similarity index 54% rename from onnxruntime/core/providers/shared/node_unit/node_unit.h rename to onnxruntime/core/framework/node_unit.h index b47204ca3c42d..66afaec8ee1e2 100644 --- a/onnxruntime/core/providers/shared/node_unit/node_unit.h +++ b/onnxruntime/core/framework/node_unit.h @@ -3,6 +3,9 @@ #pragma once +// QDQ models require graph modification at runtime, so we know this infrastructure is not used in a minimal build +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + #include #include #include @@ -18,8 +21,21 @@ class NodeArg; class Path; namespace QDQ { -struct NodeGroup; -} +// Struct to represent a DequantizeLinear -> Op -> QuantizeLinear node group +struct NodeGroup { + std::vector dq_nodes; + std::vector q_nodes; + NodeIndex target_node; + + // Validator to check if the set of nodes can form a valid QDQ NodeGroup. + // Checks target node is only consumer of each DQ, and that the outputs remain valid if the QDQ node group was to + // be converted into a single node with a quantized operator. + static Status CanCreateNodeGroup(const GraphViewer& graph_viewer, + const Node& target_node, + gsl::span dq_nodes, + gsl::span q_nodes); +}; +} // namespace QDQ // Definition of one input or output // If the optional quant_param is present, then this is a quantized input, @@ -69,26 +85,33 @@ class NodeUnit { const std::vector& GetQNodes() const noexcept { return q_nodes_; } std::vector GetAllNodesInGroup() const noexcept; - Node::EdgeConstIterator OutputEdgesBegin(size_t index) const; - Node::EdgeConstIterator OutputEdgesEnd(size_t index) const; + /// Number of input edges to the logical node. For a QDQ node this is the count of input edges to the DQ nodes + /// plus any other edges to the target node for inputs that are not via a DQ node. + size_t InputEdgeCount() const { return input_edge_count_; } + + // output edges. src index is for outputs of the target node. dest index and node is for consumer of node unit + // output. any Q nodes are hidden. + Node::EdgeConstIterator OutputEdgesBegin() const; + Node::EdgeConstIterator OutputEdgesEnd() const; private: - const std::vector q_nodes_; // q-nodes for this NodeUnit - const std::vector dq_nodes_; // dq nodes for this NodeUnit, not all inputs + // Initialization for a NodeUnit that contains a single node + void InitForSingleNode(); + + const std::vector dq_nodes_; // dq nodes for this NodeUnit, not necessarily all inputs const Node& target_node_; + const std::vector q_nodes_; // q-nodes for this NodeUnit. not necessarily all outputs const Type type_; std::vector inputs_; std::vector outputs_; - // Initializing for a single Node - void InitForSingleNode(); -}; + size_t input_edge_count_; // total number of input edges -// Get all the nodes in the given graph_viewer as NodeUnits (SingleNode or QDQGroup) -// And return a map to quick query the NodeUnit which contains the given Node, -// Note, the value of the map is owned by the vector of std::unique_ptr -std::pair>, std::unordered_map> -GetAllNodeUnits(const GraphViewer& graph_viewer); + // output edges, hiding any Q nodes involved. src_idx will be value from target node. only used for QDQ node group. + Node::EdgeSet output_edges_; +}; } // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 8535b8c9a944a..6b4f62ae1343d 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -58,8 +58,8 @@ bool NodeGroupSelector::CheckQDQNodes(const GraphViewer& graph_viewer, const Nod return false; } - if (const auto dq_validation_status = QDQ::ValidateNodeGroupDQNodes(graph_viewer, node, dq_nodes); - !dq_validation_status.IsOK()) { + if (const auto qdq_validation_status = NodeGroup::CanCreateNodeGroup(graph_viewer, node, dq_nodes, q_nodes); + !qdq_validation_status.IsOK()) { return false; } @@ -153,8 +153,8 @@ bool DropDQNodeGroupSelector::Check(const GraphViewer& graph_viewer, return false; } - if (const auto dq_validation_status = QDQ::ValidateNodeGroupDQNodes(graph_viewer, node, dq_nodes); - !dq_validation_status.IsOK()) { + if (const auto qdq_validation_status = NodeGroup::CanCreateNodeGroup(graph_viewer, node, dq_nodes, q_nodes); + !qdq_validation_status.IsOK()) { return false; } @@ -544,8 +544,8 @@ bool TopKNodeGroupSelector::Check(const GraphViewer& graph_viewer, return false; } - if (const auto dq_validation_status = QDQ::ValidateNodeGroupDQNodes(graph_viewer, node, dq_nodes); - !dq_validation_status.IsOK()) { + if (const auto qdq_validation_status = QDQ::NodeGroup::CanCreateNodeGroup(graph_viewer, node, dq_nodes, q_nodes); + !qdq_validation_status.IsOK()) { return false; } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index deee6e7f25f1a..c90a42a36483d 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -5,6 +5,7 @@ #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) +#include "core/framework/node_unit.h" #include "core/optimizer/selectors_actions/selector_action_transformer.h" namespace onnxruntime { @@ -13,13 +14,6 @@ class Node; namespace QDQ { -// Struct to represent a DQ->Op->Q node group -struct NodeGroup { - std::vector dq_nodes; - std::vector q_nodes; - NodeIndex target_node; -}; - class NodeGroupSelector { public: // This is a QDQ Selectors only function, will return QDQ::NodeGroup instead of NodesToOptimizeIndices diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index 544fe82a268c8..1876f7826c968 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -13,6 +13,7 @@ #include #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" namespace onnxruntime { namespace QDQ { @@ -43,6 +44,7 @@ static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() { {"Tile", {}}}; } +// These produce int64 indices output, which can't be quantized, so there's no downstream Q node. static const OpVersionsAndSelector::OpVersionsMap GetDropDQOpVersionsMap() { return {{"ArgMax", {}}, {"ArgMin", {}}}; @@ -324,28 +326,48 @@ std::vector SelectorManager::GetQDQSelections(const GraphViewer& grap return qdq_selections; } -Status ValidateNodeGroupDQNodes(const GraphViewer& graph_viewer, - const Node& target_node, - gsl::span dq_nodes) { - // Within a QDQ node group, a target node input is the only consumer of each DQ. - // This should have been ensured by the EnsureUniqueDQForNodeUnit graph transformer, but other graph modifications - // may have happened since. Verify that this is still true. - for (const auto* dq_node : dq_nodes) { - const bool dq_produces_graph_output = graph_viewer.NodeProducesGraphOutput(*dq_node); - ORT_RETURN_IF(dq_produces_graph_output, - "QDQ node group cannot have DQ node that produces a graph output. DQ node: ", dq_node->Name(), - ", target node: ", target_node.Name()); - - const bool dq_has_single_output_edge_to_target = - dq_node->GetOutputEdgesCount() == 1 && - dq_node->OutputEdgesBegin()->GetNode().Index() == target_node.Index(); - ORT_RETURN_IF_NOT(dq_has_single_output_edge_to_target, - "QDQ node group cannot have DQ that doesn't have a single output edge to the target node. " - "DQ node: ", - dq_node->Name(), ", target node: ", target_node.Name()); +std::pair>, std::unordered_map> +GetAllNodeUnits(const GraphViewer& graph_viewer) { + std::vector> node_unit_holder; + std::unordered_map node_unit_map; + + const auto add_node_unit_to_map = [&](const std::vector& node_indices, const NodeUnit* node_unit) { + for (const auto& node_idx : node_indices) { + const auto* node = graph_viewer.GetNode(node_idx); + node_unit_map.insert({node, node_unit}); + } + }; + + // Get QDQ NodeUnits first + QDQ::SelectorManager selector_mgr; + const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer); + + for (const auto& qdq_selection : qdq_selections) { + auto qdq_unit = std::make_unique(graph_viewer, qdq_selection); + + // Fill the node to node_unit map for all nodes in the QDQ Group + add_node_unit_to_map(qdq_selection.dq_nodes, qdq_unit.get()); + add_node_unit_to_map(qdq_selection.q_nodes, qdq_unit.get()); + add_node_unit_to_map({qdq_selection.target_node}, qdq_unit.get()); + + node_unit_holder.push_back(std::move(qdq_unit)); + } + + // Get the left over SingleNode NodeUnits + const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); + for (const auto node_idx : node_indices) { + const auto* node(graph_viewer.GetNode(node_idx)); + + // This is already part of a QDQ NodeUnit + if (node_unit_map.find(node) != node_unit_map.cend()) + continue; + + auto node_unit = std::make_unique(*node); + node_unit_map[node] = node_unit.get(); + node_unit_holder.push_back(std::move(node_unit)); } - return Status::OK(); + return std::make_pair(std::move(node_unit_holder), std::move(node_unit_map)); } } // namespace QDQ diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h index 246f26c1760ec..de36202afff29 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h @@ -7,6 +7,7 @@ #include "core/common/common.h" #include "core/common/gsl.h" #include "core/common/inlined_containers.h" +#include "core/framework/node_unit.h" #include "core/graph/basic_types.h" #if !defined(ORT_MINIMAL_BUILD) @@ -78,11 +79,16 @@ class SelectorManager { ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SelectorManager); }; -// Checks whether the provided DQ nodes are valid for forming a QDQ node group with the provided target node. -// Returns successful status if so, failed status with reason otherwise. -Status ValidateNodeGroupDQNodes(const GraphViewer& graph_viewer, - const Node& target_node, - gsl::span dq_nodes); +// Get all the nodes in the given graph_viewer as NodeUnits (SingleNode or QDQGroup) +// And return a map to quick query the NodeUnit which contains the given Node, +// Note, the value of the map is owned by the vector of std::unique_ptr +// +// TODO: The overall QDQ setup needs refactoring to separate out generic functionality from optimizer specific +// functionality. +// We currently have a bit of a mess with generic things like this to get all the node units being in the optimizer +// library whereas it should be able to be used by an EP with no dependency on optimizers. +std::pair>, std::unordered_map> +GetAllNodeUnits(const GraphViewer& graph_viewer); } // namespace QDQ } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 2d2c89f36f1a7..038423104d92e 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -21,7 +21,6 @@ #include "core/framework/kernel_registry.h" #include "core/graph/function_utils.h" #include "core/graph/indexed_sub_graph.h" -#include "core/providers/shared/node_unit/node_unit.h" #include "data_transfer.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc index 0b32508a5bb38..745504ca04941 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc @@ -11,6 +11,7 @@ #include "core/common/logging/logging.h" #include "core/common/safeint.h" +#include "core/framework/node_unit.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/graph_viewer.h" #include "core/graph/graph.h" @@ -18,7 +19,6 @@ #include "core/providers/common.h" #include "core/providers/nnapi/nnapi_builtin/builders/op_builder.h" #include "core/providers/nnapi/nnapi_builtin/builders/op_builder_factory.h" -#include "core/providers/shared/node_unit/node_unit.h" #include "core/providers/shared/utils/utils.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h index 6a54bf7bdb938..0c0bc7b2e4674 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h @@ -4,7 +4,7 @@ #pragma once #include "core/common/common.h" -#include "core/providers/shared/node_unit/node_unit.h" +#include "core/framework/node_unit.h" #include "core/providers/nnapi/nnapi_builtin/builders/model_builder.h" #include "core/providers/nnapi/nnapi_builtin/builders/op_builder.h" #include "core/providers/nnapi/nnapi_builtin/builders/op_builder_factory.h" diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc index 6962a7be94bb6..d0ae32378379d 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc @@ -11,17 +11,19 @@ #include "core/common/safeint.h" #include "core/common/status.h" #include "core/framework/execution_provider.h" +#include "core/framework/node_unit.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/graph_viewer.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/providers/common.h" #include "core/providers/nnapi/nnapi_builtin/nnapi_api_helper.h" -#include "core/providers/shared/node_unit/node_unit.h" -#include "core/providers/shared/utils/utils.h" #include "core/providers/nnapi/nnapi_builtin/builders/helper.h" #include "core/providers/nnapi/nnapi_builtin/builders/op_builder.h" #include "core/providers/nnapi/nnapi_builtin/builders/op_builder_factory.h" #include "core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation.h" -#include "core/optimizer/initializer.h" +#include "core/providers/shared/utils/utils.h" using namespace android::nn::wrapper; @@ -119,7 +121,7 @@ const NodeUnit& ModelBuilder::GetNodeUnit(const Node* node) const { } void ModelBuilder::PreprocessNodeUnits() { - std::tie(node_unit_holder_, node_unit_map_) = GetAllNodeUnits(graph_viewer_); + std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_); } // Help to get all quantized operators' input and the NodeUnit(s) using the input @@ -664,7 +666,7 @@ int32_t ModelBuilder::FindActivation(const NodeUnit& node_unit) { int32_t fuse_code = ANEURALNETWORKS_FUSED_NONE; bool fuse_code_assigned_from_activation = false; - for (auto it = node_unit.OutputEdgesBegin(0), end = node_unit.OutputEdgesEnd(0); it != end; ++it) { + for (auto it = node_unit.OutputEdgesBegin(), end = node_unit.OutputEdgesEnd(); it != end; ++it) { const auto& dst_node = it->GetNode(); const auto* dst_input = dst_node.InputDefs()[it->GetDstArgIndex()]; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc index 466865f23f49a..dab7bccf43396 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc @@ -21,7 +21,6 @@ #include "core/optimizer/initializer.h" #include "core/providers/common.h" #include "core/providers/shared/utils/utils.h" -#include "core/providers/shared/node_unit/node_unit.h" #include "core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h" namespace onnxruntime::nnapi::op_builder_helpers { diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h index 61a16ceff752f..0844857a06d61 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h @@ -7,12 +7,12 @@ #include #include "core/common/common.h" +#include "core/framework/node_unit.h" #include "core/providers/common.h" #include "core/providers/nnapi/nnapi_builtin/builders/helper.h" #include "core/providers/nnapi/nnapi_builtin/builders/model_builder.h" #include "core/providers/nnapi/nnapi_builtin/builders/op_builder.h" #include "core/providers/nnapi/nnapi_builtin/nnapi_lib/NeuralNetworksWrapper.h" -#include "core/providers/shared/node_unit/node_unit.h" namespace onnxruntime::nnapi::op_builder_helpers { diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc index b04703d7611ee..4d2888222ff0f 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc @@ -7,7 +7,10 @@ #include "core/common/logging/logging.h" #include "core/common/string_utils.h" #include "core/framework/compute_capability.h" +#include "core/framework/node_unit.h" #include "core/graph/graph_viewer.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" #include "core/platform/env.h" #include "core/providers/common.h" #include "core/providers/nnapi/nnapi_builtin/builders/helper.h" @@ -17,7 +20,6 @@ #include "core/providers/nnapi/nnapi_builtin/nnapi_api_helper.h" #include "core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation.h" #include "core/providers/partitioning_utils.h" -#include "core/providers/shared/node_unit/node_unit.h" #include "core/session/onnxruntime_cxx_api.h" namespace onnxruntime { @@ -119,7 +121,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); // This holds the result of whether a NodeUnit is supported or not, // to prevent nodes in a NodeUnit to be checked for multiple times @@ -181,7 +183,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view }; result = utils::CreateSupportedPartitions(graph_viewer, is_node_supported, on_group_closed, - gen_metadef_name, NNAPI, kNnapiExecutionProvider); + gen_metadef_name, NNAPI, kNnapiExecutionProvider, &node_unit_map); // Generally, NNAPI supports sub-graphs with at least one non-constant initializer input and one output. // So far, we have a few cases that sub-graph has zero valid inputs, like `CastLike` diff --git a/onnxruntime/core/providers/partitioning_utils.cc b/onnxruntime/core/providers/partitioning_utils.cc index d537a4cf58b2d..c45f5cd0848dd 100644 --- a/onnxruntime/core/providers/partitioning_utils.cc +++ b/onnxruntime/core/providers/partitioning_utils.cc @@ -1,6 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// QDQ models require graph modification at runtime, so we know this infrastructure is not used in a minimal build +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + #include "core/providers/partitioning_utils.h" #include @@ -10,6 +13,7 @@ #include "core/framework/compute_capability.h" #include "core/framework/execution_provider.h" +#include "core/framework/node_unit.h" #include "core/graph/graph_viewer.h" #include "core/providers/common.h" @@ -76,6 +80,11 @@ When selecting the next node to process, we first take: The remaining unsupported nodes mark the border of the current group so they will be processed later when we consider the next group. +If node_unit_map is provided, we process NodeUnit instances (a logical 'Node' that can be a single node or a +QDQ node group) instead of individual Node instances. As an EP must take complete NodeUnit instances (i.e. it +must not break up a QDQ node group by taking a subset of nodes in it), this granularity of processing is valid. +It is required to ensure we do not break up a QDQ node unit during partitioning. + @param graph_viewer GraphViewer that IExecutionProvider::GetCapability is called with. @param is_node_supported_fn Callback to check whether a node is supported. @param on_group_closed_fn Callback to indicate a completed partition node group. @@ -88,6 +97,7 @@ std::vector> CreateSupportedPartitionNodeGroups( const IsNodeSupportedFn& is_node_supported_fn, const OnGroupClosedFn& on_group_closed_fn, const std::string& execution_provider_type, + const std::unordered_map* node_unit_map, bool debug_output) { #ifdef NDEBUG ORT_UNUSED_PARAMETER(debug_output); @@ -111,7 +121,18 @@ std::vector> CreateSupportedPartitionNodeGroups( // initialize in-degrees and find root nodes for (const auto& node_index : graph_viewer.GetNodesInTopologicalOrder()) { const auto& node = *graph_viewer.GetNode(node_index); - const auto node_input_edge_count = node.GetInputEdgesCount(); + auto node_input_edge_count = node.GetInputEdgesCount(); + + if (node_unit_map != nullptr) { + const auto& node_unit = node_unit_map->at(&node); + if (&node_unit->GetNode() != &node) { + // only process the target node + continue; + } + + node_input_edge_count = node_unit->InputEdgeCount(); + } + in_degree.insert({node.Index(), node_input_edge_count}); if (node_input_edge_count == 0) { nodes_to_process.push_back(&node); @@ -151,6 +172,8 @@ std::vector> CreateSupportedPartitionNodeGroups( } }; + size_t num_nodes_processed = 0; + while (!nodes_to_process.empty() || !nodes_to_process_with_next_group.empty()) { if (nodes_to_process.empty()) { // we have processed all the nodes that we can while building this partition node group, start a new one @@ -162,9 +185,13 @@ std::vector> CreateSupportedPartitionNodeGroups( const Node& node = *nodes_to_process.front(); nodes_to_process.pop_front(); + const NodeUnit* node_unit = node_unit_map ? node_unit_map->at(&node) : nullptr; + const bool is_qdq_node_unit = node_unit && node_unit->UnitType() == NodeUnit::Type::QDQGroup; + // a node that is already assigned to an EP other than current EP is unsupported - const bool is_node_supported = - (node.GetExecutionProviderType().empty() || node.GetExecutionProviderType() == execution_provider_type) && is_node_supported_fn(node); + const bool is_node_supported = (node.GetExecutionProviderType().empty() || + node.GetExecutionProviderType() == execution_provider_type) && + is_node_supported_fn(node); if (!is_node_supported && Contains(supported_group_border, &node)) { // an unsupported node on the border will be processed after the current partition node group @@ -173,34 +200,62 @@ std::vector> CreateSupportedPartitionNodeGroups( } if (is_node_supported) { - // add node to the partition node group - supported_group.push_back(&node); + if (is_qdq_node_unit) { + // add DQ -> node -> Q for the node unit. must be in topological order + for (const auto& dq : node_unit->GetDQNodes()) { + supported_group.push_back(dq); + } - // remove node from the border and add its outputs to the border + supported_group.push_back(&node); + + for (const auto& q : node_unit->GetQNodes()) { + supported_group.push_back(q); + } + } else { + supported_group.push_back(&node); + } + + // remove node from the border supported_group_border.erase(&node); + } - std::for_each( - node.OutputNodesBegin(), node.OutputNodesEnd(), - [&supported_group_border](const Node& output) { - supported_group_border.insert(&output); - }); + // For each downstream node: + // 1: add the downstream node to the border if the current node is supported + // 2: adjust in-degrees of the nodes consuming the current node's outputs, and add any new nodes to process + const auto process_downstream_node = [&](const Node& downstream_node) { + if (is_node_supported) { + supported_group_border.insert(&downstream_node); + } + + auto& downstream_node_in_degree = in_degree[downstream_node.Index()]; + --downstream_node_in_degree; + + if (downstream_node_in_degree == 0) { + nodes_to_process.push_back(&downstream_node); + } + }; + + if (node_unit_map) { + std::for_each(node_unit->OutputEdgesBegin(), node_unit->OutputEdgesEnd(), + [&](const Node::EdgeEnd& edge_end) { + const Node& n = edge_end.GetNode(); + const NodeUnit& downstream_node_unit = *node_unit_map->at(&n); + const Node& output = downstream_node_unit.GetNode(); + + process_downstream_node(output); + }); + } else { + std::for_each(node.OutputNodesBegin(), node.OutputNodesEnd(), process_downstream_node); } - // adjust in-degrees of the node outputs and add any new nodes to process - std::for_each( - node.OutputNodesBegin(), node.OutputNodesEnd(), - [&](const Node& output) { - auto& output_node_in_degree = in_degree[output.Index()]; - --output_node_in_degree; - - if (output_node_in_degree == 0) { - nodes_to_process.push_back(&output); - } - }); + ++num_nodes_processed; } close_group(); + ORT_ENFORCE(num_nodes_processed == in_degree.size(), + "Processed ", num_nodes_processed, " nodes. Expected to process ", in_degree.size()); + return supported_groups; } } // namespace @@ -318,11 +373,13 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer, const GenerateMetadefNameFn& generate_metadef_name_fn, const std::string& execution_provider_name, const std::string& execution_provider_type, + const std::unordered_map* node_unit_map, bool debug_output) { const auto groups = CreateSupportedPartitionNodeGroups(graph_viewer, is_node_supported_fn, on_partition_closed_fn, execution_provider_type, + node_unit_map, debug_output); std::vector> partitions{}; @@ -346,6 +403,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer, const GenerateMetadefNameFn& generate_metadef_name_fn, const std::string& execution_provider_name, const std::string& execution_provider_type, + const std::unordered_map* node_unit_map, bool debug_output) { const auto excluded_nodes = CreateExcludedNodeSet(graph_viewer, stop_ops); const bool check_excluded_nodes = !excluded_nodes.empty(); @@ -360,8 +418,11 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer, generate_metadef_name_fn, execution_provider_name, execution_provider_type, + node_unit_map, debug_output); } } // namespace utils } // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/providers/partitioning_utils.h b/onnxruntime/core/providers/partitioning_utils.h index 136725c2f7250..c3f6b104e3f6a 100644 --- a/onnxruntime/core/providers/partitioning_utils.h +++ b/onnxruntime/core/providers/partitioning_utils.h @@ -3,6 +3,9 @@ #pragma once +// QDQ models require graph modification at runtime, so we know this infrastructure is not used in a minimal build +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + #include #include #include @@ -14,8 +17,9 @@ namespace onnxruntime { struct ComputeCapability; class GraphViewer; -class NodeArg; class Node; +class NodeArg; +class NodeUnit; namespace utils { @@ -56,6 +60,8 @@ Create the supported partitions for the execution provider. @param generate_metadef_name_fn Callback to create the name for the MetaDef. @param execution_provider_name Name of execution provider creating the ComputeCapability instance. @param execution_provider_type ExecutionProviderType of the EP creating this ComputeCapability instance. +@param node_unit_map Map of each Node in the graph_viewer to its NodeUnit. Provide if EP handles QDQ format models. + Should be created by EP calling GetAllNodeUnits. @param debug_output Print diagnostic output about the partitions and reasons for partition breaks. No-op in a release build. @@ -68,6 +74,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer, const GenerateMetadefNameFn& generate_metadef_name_fn, const std::string& execution_provider_name, const std::string& execution_provider_type, + const std::unordered_map* node_unit_map = nullptr, bool debug_output = false); /** @@ -79,6 +86,8 @@ Create the supported partitions for the execution provider. @param generate_metadef_name Functor to create the name for the MetaDef. @param execution_provider_name Name of execution provider creating the ComputeCapability instance. @param execution_provider_type ExecutionProviderType of the EP creating this ComputeCapability instance. +@param node_unit_map Map of each Node in the graph_viewer to its NodeUnit. Provide if EP handles QDQ format models. + Should be created by EP calling GetAllNodeUnits. @param debug_output Print diagnostic output about the partitions and reasons for partition breaks. No-op in a release build. @@ -91,6 +100,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer, const GenerateMetadefNameFn& generate_metadef_name, const std::string& execution_provider_name, const std::string& execution_provider_type, + const std::unordered_map* node_unit_map = nullptr, bool debug_output = false); /** @@ -125,3 +135,5 @@ InlinedHashSet CreateExcludedNodeSet(const GraphViewer& graph_viewe const std::unordered_set& stop_ops); } // namespace utils } // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/providers/qnn/builder/op_builder.h b/onnxruntime/core/providers/qnn/builder/op_builder.h index 018d9a2797a66..05398c3f22ea2 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder.h +++ b/onnxruntime/core/providers/qnn/builder/op_builder.h @@ -4,7 +4,7 @@ #pragma once #include "core/graph/graph_viewer.h" -#include "core/providers/shared/node_unit/node_unit.h" +#include "core/framework/node_unit.h" #include "core/providers/shared/utils/utils.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index dc91b9dfa199e..b3501dfec1ba8 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -9,6 +9,8 @@ #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" #include "core/framework/utils.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" #include "core/providers/qnn/builder/qnn_utils.h" namespace onnxruntime { @@ -95,7 +97,7 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, // valid throughout the lifetime of the ModelBuilder std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); // This name must be same with the EPContext node name const auto& graph_name = fused_node.Name(); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.h b/onnxruntime/core/providers/qnn/builder/qnn_model.h index d0dd091cb1688..8fed2f364ba5a 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.h @@ -6,13 +6,13 @@ #include #include "core/common/status.h" +#include "core/framework/node_unit.h" #include "core/graph/graph_viewer.h" #include "core/platform/ort_mutex.h" #include "core/providers/qnn/builder/qnn_def.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/qnn_backend_manager.h" #include "core/session/onnxruntime_cxx_api.h" -#include "core/providers/shared/node_unit/node_unit.h" namespace onnxruntime { namespace qnn { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h index 8ae489c749f31..1e2993f246ae4 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h @@ -11,8 +11,8 @@ #include "QnnInterface.h" #include "qnn_def.h" #include "core/common/logging/logging.h" +#include "core/framework/node_unit.h" #include "core/graph/graph_viewer.h" -#include "core/providers/shared/node_unit/node_unit.h" #include "core/providers/shared/utils/utils.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 3d9cfd92b7922..5c4fa3e0fb88b 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -10,6 +10,8 @@ #include "core/session/onnxruntime_run_options_config_keys.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/kernel_registry.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" #include "core/platform/env.h" #include "core/providers/common.h" #include "core/providers/partitioning_utils.h" @@ -494,7 +496,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); const auto supported_nodes = GetSupportedNodes(graph_viewer, node_unit_map, node_unit_holder.size(), is_qnn_ctx_model, logger); @@ -534,44 +536,39 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer size_t num_of_supported_nodes = 0; // Create partitions from supported nodes. - { - std::vector> partitions = utils::CreateSupportedPartitions(graph_viewer, - supported_nodes, {}, - gen_metadef_name, QNN, - kQnnExecutionProvider, - true); - - // Filter out partitions that consist of a single QuantizeLinear or DequantizeLinear node. - // We also count the number of supported nodes in all valid partitions. - for (auto& partition : partitions) { - bool is_valid_partition = true; - size_t nodes_in_partition = 0; - - if (partition && partition->sub_graph) { - nodes_in_partition = partition->sub_graph->nodes.size(); - - if (nodes_in_partition == 1 && !is_qnn_ctx_model) { - const Node* node = graph_viewer.GetNode(partition->sub_graph->nodes[0]); - - if (!node) { - LOGS(logger, ERROR) << "QNN EP: Invalid node in partition of one node."; - is_valid_partition = false; - } else if (node->OpType() == "QuantizeLinear" || node->OpType() == "DequantizeLinear") { - LOGS(logger, WARNING) << "QNN EP does not support a single Quantize/Dequantize node in a partition."; - is_valid_partition = false; - } + std::vector> partitions = utils::CreateSupportedPartitions( + graph_viewer, supported_nodes, {}, gen_metadef_name, QNN, kQnnExecutionProvider, &node_unit_map, true); + + // Filter out partitions that consist of a single QuantizeLinear or DequantizeLinear node. + // We also count the number of supported nodes in all valid partitions. + for (auto& partition : partitions) { + bool is_valid_partition = true; + size_t nodes_in_partition = 0; + + if (partition && partition->sub_graph) { + nodes_in_partition = partition->sub_graph->nodes.size(); + + if (nodes_in_partition == 1 && !is_qnn_ctx_model) { + const Node* node = graph_viewer.GetNode(partition->sub_graph->nodes[0]); + + if (!node) { + LOGS(logger, ERROR) << "QNN EP: Invalid node in partition of one node."; + is_valid_partition = false; + } else if (node->OpType() == "QuantizeLinear" || node->OpType() == "DequantizeLinear") { + LOGS(logger, WARNING) << "QNN EP does not support a single Quantize/Dequantize node in a partition."; + is_valid_partition = false; } - } else { - LOGS(logger, ERROR) << "QNN EP: Invalid partition."; - is_valid_partition = false; } + } else { + LOGS(logger, ERROR) << "QNN EP: Invalid partition."; + is_valid_partition = false; + } - if (is_valid_partition) { - result.push_back(std::move(partition)); - num_of_supported_nodes += nodes_in_partition; - } - } // for - } + if (is_valid_partition) { + result.push_back(std::move(partition)); + num_of_supported_nodes += nodes_in_partition; + } + } // for const size_t num_of_partitions = result.size(); const auto summary_msg = MakeString("Number of partitions supported by QNN EP: ", num_of_partitions, diff --git a/onnxruntime/core/providers/shared/node_unit/node_unit.cc b/onnxruntime/core/providers/shared/node_unit/node_unit.cc deleted file mode 100644 index 10dd58ba28375..0000000000000 --- a/onnxruntime/core/providers/shared/node_unit/node_unit.cc +++ /dev/null @@ -1,319 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "node_unit.h" -#include "core/graph/graph_viewer.h" -#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" -#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" - -namespace onnxruntime { - -namespace { - -enum class QLinearOpType : uint8_t { - Unknown, // Unknown or not a linear quantized op - DequantizeLinear, - QuantizeLinear, - QLinearConv, - QLinearMatMul, - QLinearAdd, - QLinearSigmoid, - QLinearAveragePool, - QLinearMul, - QLinearReduceMean, - QLinearConcat, - QLinearGlobalAveragePool, - QLinearLeakyRelu, -}; - -QLinearOpType GetQLinearOpType(const onnxruntime::Node& node) { - const auto& op_type = node.OpType(); - if (op_type == "DequantizeLinear") - return QLinearOpType::DequantizeLinear; - else if (op_type == "QuantizeLinear") - return QLinearOpType::QuantizeLinear; - else if (op_type == "QLinearConv") - return QLinearOpType::QLinearConv; - else if (op_type == "QLinearMatMul") - return QLinearOpType::QLinearMatMul; - else if (op_type == "QLinearAdd") - return QLinearOpType::QLinearAdd; - else if (op_type == "QLinearSigmoid") - return QLinearOpType::QLinearSigmoid; - else if (op_type == "QLinearAveragePool") - return QLinearOpType::QLinearAveragePool; - else if (op_type == "QLinearMul") - return QLinearOpType::QLinearMul; - else if (op_type == "QLinearReduceMean") - return QLinearOpType::QLinearReduceMean; - else if (op_type == "QLinearConcat") - return QLinearOpType::QLinearConcat; - else if (op_type == "QLinearGlobalAveragePool") - return QLinearOpType::QLinearGlobalAveragePool; - else if (op_type == "QLinearLeakyRelu") - return QLinearOpType::QLinearLeakyRelu; - - return QLinearOpType::Unknown; -} - -// Ops have 1 input -bool IsUnaryQLinearOp(QLinearOpType type) { - return type == QLinearOpType::QLinearSigmoid || - type == QLinearOpType::QLinearAveragePool || - type == QLinearOpType::QLinearGlobalAveragePool || - type == QLinearOpType::QLinearLeakyRelu || - type == QLinearOpType::QLinearReduceMean; -} - -// Ops have 2 inputs -bool IsBinaryQLinearOp(QLinearOpType type) { - return type == QLinearOpType::QLinearConv || - type == QLinearOpType::QLinearMatMul || - type == QLinearOpType::QLinearAdd || - type == QLinearOpType::QLinearMul; -} - -// Ops have 1 or more inputs -bool IsVariadicQLinearOp(QLinearOpType type) { - return type == QLinearOpType::QLinearConcat; -} - -const std::vector GetQDQIONodes(const GraphViewer& graph_viewer, - const QDQ::NodeGroup& node_group, bool is_input) { - std::vector io_nodes; - const auto& src_nodes = is_input ? node_group.dq_nodes : node_group.q_nodes; - io_nodes.reserve(src_nodes.size()); - for (const auto& node_idx : src_nodes) { - io_nodes.push_back(graph_viewer.GetNode(node_idx)); - } - return io_nodes; -} - -// Get the input or output NodeUnitIODef(s) for the given QDQ NodeGroup -std::vector GetQDQIODefs(const Node& target_node, const QDQ::NodeGroup& node_group, - bool is_input) { - const auto& dq_or_q_nodes = is_input ? node_group.dq_nodes : node_group.q_nodes; - const auto target_node_io_defs = is_input ? target_node.InputDefs() : target_node.OutputDefs(); - const size_t target_node_io_defs_size = target_node_io_defs.size(); - - // Find all the quantized IO defs and indices (for the input to the target node) - std::unordered_map quantized_io_defs; - quantized_io_defs.reserve(target_node_io_defs_size); - - auto cur = is_input ? target_node.InputEdgesBegin() : target_node.OutputEdgesBegin(); - auto end = is_input ? target_node.InputEdgesEnd() : target_node.OutputEdgesEnd(); - for (; cur != end; ++cur) { - const Node& node = cur->GetNode(); - - // If we can find the node index in the dq or q nodes, then this is a quantize node (can be DQ or Q depends on is_input) - if (std::find(dq_or_q_nodes.cbegin(), dq_or_q_nodes.cend(), node.Index()) != dq_or_q_nodes.cend()) { - const auto node_inputs = node.InputDefs(); - // quantization scale and zp are always the input[1, 2] - NodeUnitIODef::QuantParam quant_param{ - *node_inputs[1], - node_inputs.size() == 3 ? node_inputs[2] : nullptr}; - if (is_input) { - // DQ is input to the target node, use the DstArgIndex - auto idx = cur->GetDstArgIndex(); - // This is a DQ node, we are using x, x_scale, x_zp (input[0, 1, 2]) - quantized_io_defs.insert({idx, NodeUnitIODef{*node_inputs[0], quant_param}}); - } else { - // Q is output of the target node, use the SrcArgIndex - auto idx = cur->GetSrcArgIndex(); - // This is a Q node, we are using y (output[0]), y_scale, y_zp (input[1, 2]) - const auto node_outputs = node.OutputDefs(); - quantized_io_defs.insert({idx, NodeUnitIODef{*node_outputs[0], quant_param}}); - } - } - } - - // Construct the IODefs for this QDQ NodeGroup - std::vector io_defs; - io_defs.reserve(target_node_io_defs_size); - for (size_t i = 0; i < target_node_io_defs_size; i++) { - // If we can find the NodeUnitIODef for this index, this is a quantized input - if (quantized_io_defs.find(i) != quantized_io_defs.cend()) { - io_defs.push_back(std::move(quantized_io_defs.at(i))); - } else { - // This is a regular input - io_defs.push_back({*target_node_io_defs[i], std::nullopt}); - } - } - - return io_defs; -} - -} // namespace - -NodeUnit::NodeUnit(const Node& node) - : target_node_(node), - type_(Type::SingleNode) { - InitForSingleNode(); -} - -NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group) - : q_nodes_{GetQDQIONodes(graph_viewer, node_group, false /* is_input */)}, - dq_nodes_{GetQDQIONodes(graph_viewer, node_group, true /* is_input */)}, - target_node_(*graph_viewer.GetNode(node_group.target_node)), - type_(Type::QDQGroup), - inputs_{GetQDQIODefs(target_node_, node_group, true /* is_input */)}, - outputs_{GetQDQIODefs(target_node_, node_group, false /* is_input */)} { - ORT_THROW_IF_ERROR(QDQ::ValidateNodeGroupDQNodes(graph_viewer, target_node_, dq_nodes_)); -} - -const std::string& NodeUnit::Domain() const noexcept { return target_node_.Domain(); } -const std::string& NodeUnit::OpType() const noexcept { return target_node_.OpType(); } -const std::string& NodeUnit::Name() const noexcept { return target_node_.Name(); } -int NodeUnit::SinceVersion() const noexcept { return target_node_.SinceVersion(); } -NodeIndex NodeUnit::Index() const noexcept { return target_node_.Index(); } -const Path& NodeUnit::ModelPath() const noexcept { return target_node_.ModelPath(); } -ProviderType NodeUnit::GetExecutionProviderType() const noexcept { return target_node_.GetExecutionProviderType(); } - -void NodeUnit::InitForSingleNode() { - const auto& input_defs = target_node_.InputDefs(); - const auto& output_defs = target_node_.OutputDefs(); - auto qlinear_type = GetQLinearOpType(target_node_); - if (qlinear_type == QLinearOpType::Unknown || - IsVariadicQLinearOp(qlinear_type)) { // TODO, add variadic support - // Not a Qlinear op, add all inputs / outputs - auto add_all_io = [](std::vector& defs, - const ConstPointerContainer>& node_defs) { - defs.reserve(node_defs.size()); - - for (const auto def : node_defs) { - defs.push_back(NodeUnitIODef{*def, std::nullopt}); - } - }; - add_all_io(inputs_, input_defs); - add_all_io(outputs_, output_defs); - } else if (IsUnaryQLinearOp(qlinear_type)) { - // Unary QLinear Op has 5 inputs - // x, x_scale, x_zp, y_scale, y_zp (optional) - inputs_.push_back(NodeUnitIODef{ - *input_defs[0], - NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}}); - - outputs_.push_back(NodeUnitIODef{ - *output_defs[0], - NodeUnitIODef::QuantParam{*input_defs[3], - input_defs.size() > 4 - ? input_defs[4] - : nullptr}}); - } else if (IsBinaryQLinearOp(qlinear_type)) { - // Binary QLinear Op has 9 inputs - // x1, x1_scale, x1_zp, x2/w, x2_scale, x2_zp, y_scale , y_zp, B - inputs_.push_back(NodeUnitIODef{ - *input_defs[0], - NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}}); - inputs_.push_back(NodeUnitIODef{ - *input_defs[3], - NodeUnitIODef::QuantParam{*input_defs[4], input_defs[5]}}); - - if (input_defs.size() == 9) { // has Bias - inputs_.push_back(NodeUnitIODef{ - *input_defs[8], - std::nullopt}); // for Bias the scale and zp are optional - } - - outputs_.push_back(NodeUnitIODef{ - *output_defs[0], - NodeUnitIODef::QuantParam{*input_defs[6], input_defs[7]}}); - } else if (qlinear_type == QLinearOpType::DequantizeLinear) { - // DequantizeLinear has 3 inputs - // x, x_scale, x_zp - // output is not quantized - inputs_.push_back(NodeUnitIODef{ - *input_defs[0], - NodeUnitIODef::QuantParam{*input_defs[1], - input_defs.size() == 3 - ? input_defs[2] - : nullptr}}); - outputs_.push_back(NodeUnitIODef{*output_defs[0], std::nullopt}); - } else if (qlinear_type == QLinearOpType::QuantizeLinear) { - // QuantizeLinear the input is not quantized and has 3 inputs - // x, y_scale, y_zp (optional) - // The output is quantized - inputs_.push_back(NodeUnitIODef{*input_defs[0], std::nullopt}); - outputs_.push_back(NodeUnitIODef{ - *output_defs[0], - NodeUnitIODef::QuantParam{*input_defs[1], - input_defs.size() == 3 - ? input_defs[2] - : nullptr}}); - } else { - ORT_THROW("The QLinear op [", static_cast(qlinear_type), "] is not supported"); - } -} - -Node::EdgeConstIterator NodeUnit::OutputEdgesBegin(size_t index) const { - if (type_ == Type::SingleNode) { - ORT_ENFORCE(index == 0, "invalid output node index"); - return target_node_.OutputEdgesBegin(); - } else { - ORT_ENFORCE(index < q_nodes_.size(), "invalid output node index"); - return q_nodes_[index]->OutputEdgesBegin(); - } -} - -Node::EdgeConstIterator NodeUnit::OutputEdgesEnd(size_t index) const { - if (type_ == Type::SingleNode) { - ORT_ENFORCE(index == 0, "invalid output node index"); - return target_node_.OutputEdgesEnd(); - } else { - ORT_ENFORCE(index < q_nodes_.size(), "invalid output node index"); - return q_nodes_[index]->OutputEdgesEnd(); - } -} - -std::vector NodeUnit::GetAllNodesInGroup() const noexcept { - std::vector all_nodes = dq_nodes_; - all_nodes.push_back(&target_node_); - all_nodes.insert(all_nodes.end(), q_nodes_.begin(), q_nodes_.end()); - return all_nodes; -} - -std::pair>, std::unordered_map> -GetAllNodeUnits(const GraphViewer& graph_viewer) { - std::vector> node_unit_holder; - std::unordered_map node_unit_map; - - const auto add_node_unit_to_map = [&](const std::vector& node_indices, const NodeUnit* node_unit) { - for (const auto& node_idx : node_indices) { - const auto* node = graph_viewer.GetNode(node_idx); - node_unit_map.insert({node, node_unit}); - } - }; - - // Get QDQ NodeUnits first - QDQ::SelectorManager selector_mgr; - const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer); - - for (const auto& qdq_selection : qdq_selections) { - auto qdq_unit = std::make_unique(graph_viewer, qdq_selection); - - // Fill the node to node_unit map for all nodes in the QDQ Group - add_node_unit_to_map(qdq_selection.dq_nodes, qdq_unit.get()); - add_node_unit_to_map(qdq_selection.q_nodes, qdq_unit.get()); - add_node_unit_to_map({qdq_selection.target_node}, qdq_unit.get()); - - node_unit_holder.push_back(std::move(qdq_unit)); - } - - // Get the left over SingleNode NodeUnits - const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); - for (const auto node_idx : node_indices) { - const auto* node(graph_viewer.GetNode(node_idx)); - - // This is already part of a QDQ NodeUnit - if (node_unit_map.find(node) != node_unit_map.cend()) - continue; - - auto node_unit = std::make_unique(*node); - node_unit_map[node] = node_unit.get(); - node_unit_holder.push_back(std::move(node_unit)); - } - - return std::make_pair(std::move(node_unit_holder), std::move(node_unit_map)); -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/shared/utils/utils.cc b/onnxruntime/core/providers/shared/utils/utils.cc index c07a0929353b1..2088618538de5 100644 --- a/onnxruntime/core/providers/shared/utils/utils.cc +++ b/onnxruntime/core/providers/shared/utils/utils.cc @@ -4,12 +4,12 @@ #include "utils.h" -#include -#include -#include -#include -#include "core/providers/shared/node_unit/node_unit.h" +#include "core/common/safeint.h" +#include "core/framework/node_unit.h" +#include "core/framework/tensorprotoutils.h" +#include "core/graph/graph.h" #include "core/optimizer/initializer.h" +#include "core/providers/common.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/utils.cc b/onnxruntime/core/providers/utils.cc index ca3fc4fc1972b..b2f9d265ca053 100644 --- a/onnxruntime/core/providers/utils.cc +++ b/onnxruntime/core/providers/utils.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "core/framework/tensorprotoutils.h" -#include "utils.h" +#include "core/providers/utils.h" namespace onnxruntime { namespace utils { @@ -23,6 +23,5 @@ common::Status OutputOptionalWithoutDataHelper(const ONNX_NAMESPACE::TypeProto& return Status::OK(); } #endif - } // namespace utils } // namespace onnxruntime diff --git a/onnxruntime/core/providers/xnnpack/detail/node_support_checker.cc b/onnxruntime/core/providers/xnnpack/detail/node_support_checker.cc index 8e7e228f974e6..e2d71cda68ec4 100644 --- a/onnxruntime/core/providers/xnnpack/detail/node_support_checker.cc +++ b/onnxruntime/core/providers/xnnpack/detail/node_support_checker.cc @@ -6,12 +6,12 @@ #include #include "core/common/common.h" +#include "core/framework/node_unit.h" #include "core/framework/op_node_proto_helper.h" #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" #include "core/providers/common.h" #include "core/providers/cpu/nn/pool_attributes.h" -#include "core/providers/shared/node_unit/node_unit.h" #include "core/providers/xnnpack/detail/utils.h" // each operator provides a helper to check if supported diff --git a/onnxruntime/core/providers/xnnpack/detail/utils.cc b/onnxruntime/core/providers/xnnpack/detail/utils.cc index 1a32612981120..f9cb45ebc8abc 100644 --- a/onnxruntime/core/providers/xnnpack/detail/utils.cc +++ b/onnxruntime/core/providers/xnnpack/detail/utils.cc @@ -6,14 +6,14 @@ #include #include "core/common/common.h" +#include "core/common/safeint.h" +#include "core/framework/node_unit.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/indexed_sub_graph.h" #include "core/graph/node_attr_utils.h" +#include "core/optimizer/initializer.h" -#include "core/providers/shared/node_unit/node_unit.h" #include "onnx/defs/attr_proto_util.h" -#include "core/common/safeint.h" -#include "core/optimizer/initializer.h" namespace onnxruntime { namespace xnnpack { diff --git a/onnxruntime/core/providers/xnnpack/detail/utils.h b/onnxruntime/core/providers/xnnpack/detail/utils.h index 2bbf3ac8c2cb5..d555ee2286b84 100644 --- a/onnxruntime/core/providers/xnnpack/detail/utils.h +++ b/onnxruntime/core/providers/xnnpack/detail/utils.h @@ -10,10 +10,10 @@ #include #include +#include "core/framework/node_unit.h" #include "core/framework/op_kernel.h" #include "core/graph/indexed_sub_graph.h" #include "core/providers/common.h" -#include "core/providers/shared/node_unit/node_unit.h" #include "xnnpack.h" diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc index eafbfae6f01e1..12e567e7080b3 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc @@ -6,17 +6,17 @@ #include #include -#include "core/graph/function_utils.h" -#include "xnnpack_execution_provider.h" -#include "detail/utils.h" -#include "detail/node_support_checker.h" - #include "core/framework/compute_capability.h" #include "core/framework/kernel_registry.h" -#include "core/providers/shared/node_unit/node_unit.h" +#include "core/framework/node_unit.h" +#include "core/graph/function_utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" - -#include "xnnpack_init.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" +#include "core/providers/xnnpack/xnnpack_execution_provider.h" +#include "core/providers/xnnpack/detail/utils.h" +#include "core/providers/xnnpack/detail/node_support_checker.h" +#include "core/providers/xnnpack/xnnpack_init.h" namespace onnxruntime { @@ -268,7 +268,7 @@ std::vector> XnnpackExecutionProvider::GetCap // Get all the NodeUnits in the GraphViewer so we can check if something is in a QDQ node group std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph); // This holds the result of whether a NodeUnit is supported or not, // to prevent nodes in a NodeUnit being checked for multiple times diff --git a/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp b/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp index 484a9a22429d5..969997d2b84ec 100644 --- a/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp +++ b/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "test_fp16.h" +#include #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 13333f1558cc6..fbd5c9b5a137b 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/framework/compute_capability.h" +#include "core/framework/node_unit.h" #include "core/graph/model.h" #include "core/graph/onnx_protobuf.h" #include "core/mlas/inc/mlas.h" @@ -9,7 +10,6 @@ #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" #include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" -#include "core/optimizer/utils.h" #include "core/providers/partitioning_utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/environment.h" @@ -30,10 +30,6 @@ #pragma warning(disable : 4127) #endif // #if defined(_MSC_VER) -#ifdef USE_NNAPI -#include "core/providers/shared/node_unit/node_unit.h" -#endif // #ifdef USE_NNAPI - struct QDQOpKeys { const char* quantize_linear; const char* dequantize_linear; @@ -3243,14 +3239,14 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) { ASSERT_EQ(std::vector({4}), qdq_group.q_nodes); } -// The function GetAllNodeUnits is enabled for NNAPI EP only for now -#ifdef USE_NNAPI +// The function GetAllNodeUnits is used by NNAPI, XNNPACK and QNN +#if defined(USE_NNAPI) || defined(USE_QNN) || defined(USE_XNNPACK) { // Get all the NodeUnits in the graph_viewer std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(whole_graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(whole_graph_viewer); // We should get a single QDQ Node unit in the result ASSERT_EQ(1, node_unit_holder.size()); @@ -3288,7 +3284,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) { verify_io_def(qdq_node_unit.Inputs()[2], *whole_graph_viewer.GetNode(2)); // DQ_bias verify_io_def(qdq_node_unit.Outputs()[0], *whole_graph_viewer.GetNode(4)); // Q_output } -#endif // #ifdef USE_NNAPI +#endif // defined(USE_NNAPI) || defined(USE_QNN) || defined(USE_XNNPACK) // Create a graph viewer covers part of the graph // Make sure the qdq conv selector will fail for the partial graph diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc index 0167f7a7718b1..2e073def5d643 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc @@ -220,6 +220,7 @@ InternalTestingExecutionProvider::GetCapability(const onnxruntime::GraphViewer& auto compile_capabilities = utils::CreateSupportedPartitions(graph_viewer, supported_compiled_nodes, stop_ops_, generate_metadef_name, ep_name_, onnxruntime::utils::kInternalTestingExecutionProvider, + /*QDQ NodeUnit map*/ nullptr, debug_output_); if (!static_capabilities.empty()) { diff --git a/onnxruntime/test/providers/partitioning_utils_test.cc b/onnxruntime/test/providers/partitioning_utils_test.cc new file mode 100644 index 0000000000000..5db69489afaef --- /dev/null +++ b/onnxruntime/test/providers/partitioning_utils_test.cc @@ -0,0 +1,174 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "core/common/common.h" +#include "core/graph/graph_viewer.h" +#include "core/graph/model.h" +#include "core/framework/node_unit.h" +#include "core/framework/compute_capability.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" +#include "core/providers/partitioning_utils.h" + +#include "test/optimizer/graph_transform_test_builder.h" +#include "test/optimizer/qdq_test_utils.h" +#include "test/util/include/asserts.h" +#include "test/util/include/test_utils.h" +#include "test/util/include/test/test_environment.h" + +namespace onnxruntime { +namespace test { + +// Test handling of a DQ node that is connected to an initializer at the start of the graph, but not used +// in a QDQ node group until after an unsupported node in the graph. If we do not process QDQ node units +// correctly this DQ will incorrectly be in the first partition, with the rest of the QDQ node group in +// the second partition. +TEST(PartitioningUtilsTest, TestQDQHandling) { + constexpr const ORTCHAR_T* model_uri = ORT_TSTR("testdata/ort_github_issue_19590.onnx"); + auto& logger = DefaultLoggingManager().DefaultLogger(); + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, logger)); + Graph& graph = p_model->MainGraph(); + GraphViewer graph_viewer = GraphViewer(graph); + + // we want everything but the Cast in the test model to be supported + const auto is_node_supported = [&](const Node& node) -> bool { + return node.OpType() != "Cast"; + }; + + const auto on_group_closed = [&](const std::vector& /*group*/) -> bool { + return true; + }; + + const auto gen_metadef_name = [&]() { + static int metadef_id = 0; + return "TestMetaDef_" + std::to_string(metadef_id++); + }; + + std::vector> node_unit_holder; + std::unordered_map node_unit_map; + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + + auto result = utils::CreateSupportedPartitions(graph_viewer, is_node_supported, on_group_closed, + gen_metadef_name, "TEST", kCpuExecutionProvider, &node_unit_map, + true); + + // we should have 2 supported partitions, split by the Cast node. + // the first should have the Mul and NOT the DQ for the initializer if everything worked correctly. + ASSERT_EQ(result.size(), size_t(2)) << "Expected 2 partitions"; + ASSERT_EQ(result[0]->sub_graph->nodes.size(), size_t(1)) << "First partition should only have the Mul and not a DQ"; + ASSERT_EQ(result[1]->sub_graph->nodes.size(), size_t(5)); // everything else except the unsupported Cast +} + +/// Check that CreateSupportedPartitions processes all nodes without error. +static void CheckAllNodesProcessed(const std::function& build_model) { + auto& logger = DefaultLoggingManager().DefaultLogger(); + const std::unordered_map domain_to_version = {{"", 15}}; + + Model model("PartitioningUtils_TestModel", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, {}, logger); + + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + build_model(helper); + helper.SetGraphOutputs(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + + GraphViewer graph_viewer = GraphViewer(graph); + + std::vector> node_unit_holder; + std::unordered_map node_unit_map; + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + + const auto is_node_supported = [&](const Node& /*node*/) -> bool { + return true; + }; + + const auto on_group_closed = [&](const std::vector& /*group*/) -> bool { + return true; + }; + + const auto gen_metadef_name = [&]() { + static int metadef_id = 0; + return "TestMetaDef_" + std::to_string(metadef_id++); + }; + + auto result = utils::CreateSupportedPartitions(graph_viewer, is_node_supported, on_group_closed, + gen_metadef_name, "TEST", kCpuExecutionProvider, &node_unit_map, + true); + + // the 'real' test is that CreateSupportedPartitions doesn't throw due to a mismatch with expected vs processed nodes + // as all ops are supported there should only ever be 1 partition + ASSERT_EQ(result.size(), size_t(1)) << "Expected 1 partition"; +} + +TEST(PartitioningUtilsTest, TestHandlingQDQNodeUnitWithNoQNodes) { + // build graph with QDQ node unit for logical operator (Equal) that has no Q node and a downstream node (Cast). + auto build_model = [](ModelTestBuilder& builder) { + constexpr uint8_t zero_point = 0; + constexpr float qdq_scale = 0.0038f; + const std::vector input_shape = {1, 3, 8, 8}; + + auto* input0 = builder.MakeInput(input_shape, -1.0f, 1.0f); + auto* input1 = builder.MakeInput(input_shape, -1.0f, 1.0f); + auto* output = builder.MakeOutput(); + + // input -> Q -> DQ -> Op + auto* qdq0_output = AddQDQNodePair(builder, input0, qdq_scale, zero_point); + auto* qdq1_output = AddQDQNodePair(builder, input1, qdq_scale, zero_point); + + // Equal -> + auto* equal_output = builder.MakeIntermediate(); + builder.AddNode("Equal", {qdq0_output, qdq1_output}, {equal_output}); + + // -> Cast -> output + Node& cast_node = builder.AddNode("Cast", {equal_output}, {output}); + cast_node.AddAttribute("to", + static_cast(ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); + }; + + CheckAllNodesProcessed(build_model); +} + +// TopK produces 2 outputs, one of which is used in a QDQ node group (Q of values output) +// and the other (indices output) is not. A downstream node consuming the indices output has an edge from the target +// node and not a Q node. +// To process this correctly, the QDQ NodeUnit must return output edges for both the Q node/s of the values output, +// and the downstream node (Cast in this case) of the indices output. +TEST(PartitioningUtilsTest, TestQDQNodeGroupWithOutputFromTargetNode) { + const auto build_model = [](ModelTestBuilder& builder) { + constexpr uint8_t zero_point = 0; + constexpr float qdq_scale = 0.0038f; + const std::vector input_shape = {1, 3, 8, 8}; + + auto* input0 = builder.MakeInput(input_shape, -1.0f, 1.0f); + + // input -> Q -> DQ -> + auto* qdq0_output = AddQDQNodePair(builder, input0, qdq_scale, zero_point); + + // K input + NodeArg* k_input = builder.MakeInput({1}, {10}); + + // TopK op + NodeArg* values_output = builder.MakeIntermediate(); + NodeArg* indices_output = builder.MakeIntermediate(); + builder.AddNode("TopK", {qdq0_output, k_input}, {values_output, indices_output}); + + // values -> Q -> DQ -> graph output + AddQDQNodePairWithOutputAsGraphOutput(builder, values_output, qdq_scale, zero_point); + + // indices -> Cast -> graph output + auto* i_output = builder.MakeOutput(); + Node& cast_node = builder.AddNode("Cast", {indices_output}, {i_output}); + const auto dst_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32; + cast_node.AddAttribute("to", static_cast(dst_type)); + }; + + CheckAllNodesProcessed(build_model); +} +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/testdata/ort_github_issue_19590.onnx b/onnxruntime/test/testdata/ort_github_issue_19590.onnx new file mode 100644 index 0000000000000000000000000000000000000000..fa07b624780bb0244d04e0f8ec17f10c8d9d6d2c GIT binary patch literal 599 zcmZuuyH3L}6iu9@I5((777`E`kSCNW?NmiV4WT0$SXm;cL=Y)WU(UkBr}T^X7Cr&z zZ3=^BpL?!-j?cxW|E0qM#w*5GWgaJSnH78nqy3WQsYk5WZQN0g45+V~IJgP-y4lXn`Vpf{9qGK%Co_kb(6q{=T;_FLv zP!Y^wXliXuWLr$O#s0B11Iag&K_c{rbs!GN$P7D_Il`N=U6BK!OY DAeX8E literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/ort_github_issue_19590.py b/onnxruntime/test/testdata/ort_github_issue_19590.py new file mode 100644 index 0000000000000..9be07134fd8ad --- /dev/null +++ b/onnxruntime/test/testdata/ort_github_issue_19590.py @@ -0,0 +1,77 @@ +import onnx +from onnx import TensorProto, helper + +# graph with a QDQ MatMul node unit where one input is and initializer -> DQ and the other is on a path that +# contains a supported node followed by an unsupported node followed by the DQ -> MatMul. +# The DQ of the initializer is prior to the unsupported node. If the partitioning utils do not process the QDQ node +# unit together, the DQ for the initializer and the first supported node will be in the first partition, which +# incorrectly breaks up the QDQ node unit. +graph_proto = helper.make_graph( + [ + # DQ of initializer for MatMul B input + helper.make_node( + "DequantizeLinear", + inputs=["matmul_b_uint8", "scale0"], + outputs=["dq_matmul_b"], + name="dq_matmul_b", + ), + # Treat as supported + helper.make_node( + "Mul", + inputs=["input:0", "scale_input"], + outputs=["mul:0"], + name="mul0", + ), + # Treat as unsupported + helper.make_node("Cast", inputs=["mul:0"], outputs=["mul_uint8"], name="cast0", to=2), + # DQ of MatMul A input + helper.make_node( + "DequantizeLinear", + inputs=["mul_uint8", "scale1"], + outputs=["dq_matmul_a"], + name="dq_matmul_a", + ), + # MatMul + helper.make_node( + "MatMul", + inputs=[ + "dq_matmul_a", + "dq_matmul_b", + ], + outputs=["matmul_ab"], + name="matmul_ab", + ), + # Q + helper.make_node( + "QuantizeLinear", + inputs=["matmul_ab", "scale2"], + outputs=["q_matmul_ab"], + name="q_matmul_ab", + ), + # DQ for model output + helper.make_node( + "DequantizeLinear", + inputs=["q_matmul_ab", "scale2"], + outputs=["out:0"], + name="dq_graph_output", + ), + ], + "Main_graph", + [ + helper.make_tensor_value_info("input:0", TensorProto.FLOAT, [3, 2]), + ], + [ + helper.make_tensor_value_info("out:0", TensorProto.FLOAT, [3, 2]), + ], + [ + helper.make_tensor("scale0", TensorProto.FLOAT, [1], [20.0]), + helper.make_tensor("scale1", TensorProto.FLOAT, [1], [30.0]), + helper.make_tensor("scale2", TensorProto.FLOAT, [1], [40.0]), + helper.make_tensor("matmul_b_uint8", TensorProto.UINT8, [2, 2], [1, 2, 3, 4]), + helper.make_tensor("scale_input", TensorProto.FLOAT, [2], [3.0, 4.0]), + ], +) + +model = helper.make_model(graph_proto) +onnx.checker.check_model(model, True) +onnx.save(model, "ort_github_issue_19590.onnx")