From 992764c9e161b6506f3fbefabe455006ff258817 Mon Sep 17 00:00:00 2001 From: quic-tirupath Date: Mon, 19 May 2025 14:00:26 -0700 Subject: [PATCH] [QNN EP] Add ONNX ScatterElements support - Translate ONNX ScatterElements as QNN's ScatterElements Op - Handle unsupported reduction value i.e., "min" - Add unit tests to verify ScatterElements Op support on HTP --- .../qnn/builder/op_builder_factory.cc | 2 + .../qnn/builder/opbuilder/base_op_builder.h | 2 + .../builder/opbuilder/simple_op_builder.cc | 49 +++++++++++ .../test/providers/qnn/simple_op_htp_test.cc | 88 +++++++++++++++++++ 4 files changed, 141 insertions(+) diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc index efb4afcb88c85..b3a464ba7204d 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc @@ -65,6 +65,8 @@ OpBuilderRegistrations::OpBuilderRegistrations() { CreateSimpleOpBuilder("GridSample", *this); CreateSimpleOpBuilder("LpNormalization", *this); + + CreateSimpleOpBuilder("ScatterElements", *this); } { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h index 5474db0590f92..214132f5c820d 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h @@ -215,6 +215,8 @@ class BaseOpBuilder : public IOpBuilder { {"Pad", QNN_OP_PAD}, + {"ScatterElements", QNN_OP_SCATTER_ELEMENTS}, + {"Expand", QNN_OP_ELEMENT_WISE_MULTIPLY}}; auto it = onnx_op_type_to_qnn_op_type.find(onnx_op_type); ORT_ENFORCE(it != onnx_op_type_to_qnn_op_type.end()); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index ab022df063c96..6549542238753 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -40,6 +40,7 @@ class SimpleOpBuilder : public BaseOpBuilder { static constexpr std::array gridsample_supported_modes = {"bilinear", "nearest"}; static constexpr std::array gridsample_supported_padding_modes = {"zeros", "border", "reflection"}; + static constexpr std::array scatterelements_supported_reduction = {"none", "add", "mul", "max"}; }; Status SimpleOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper, @@ -101,6 +102,14 @@ Status SimpleOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper, } } + // QNN ScatterElements doesn't support MIN reduction + if (op_type == "ScatterElements") { + NodeAttrHelper node_helper(node_unit); + std::string reduction = node_helper.Get("reduction", "none"); + ORT_RETURN_IF_NOT(utils::ArrayHasString(scatterelements_supported_reduction, reduction), "ScatterElements does not support reduction ", + reduction.c_str()); + } + return Status::OK(); } @@ -254,6 +263,33 @@ Status ProcessGridSampleAttributes(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } +// Process Reduction attribute of ScatterElements op +Status ProcessReductionAttribute(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector& param_tensor_names) { + NodeAttrHelper node_helper(node_unit); + std::string reduction = node_helper.Get("reduction", "none"); + Qnn_Scalar_t reduction_qnn_scalar = QNN_SCALAR_INIT; + reduction_qnn_scalar.dataType = QNN_DATATYPE_UINT_32; + if ("none" == reduction) { + reduction_qnn_scalar.uint32Value = QNN_OP_SCATTER_ELEMENTS_REDUCTION_NONE; + } else if ("add" == reduction) { + reduction_qnn_scalar.uint32Value = QNN_OP_SCATTER_ELEMENTS_REDUCTION_ADD; + } else if ("mul" == reduction) { + reduction_qnn_scalar.uint32Value = QNN_OP_SCATTER_ELEMENTS_REDUCTION_MUL; + } else if ("max" == reduction) { + reduction_qnn_scalar.uint32Value = QNN_OP_SCATTER_ELEMENTS_REDUCTION_MAX; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ScatterElements support only reduction:{none, add, mul, max}."); + } + QnnParamWrapper reduction_param(node_unit.Index(), node_unit.Name(), QNN_OP_SCATTER_ELEMENTS_PARAM_REDUCTION, + reduction_qnn_scalar); + param_tensor_names.push_back(reduction_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(reduction_param)); + + return Status::OK(); +} + Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, std::vector&& input_names, @@ -358,6 +394,19 @@ Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w ORT_RETURN_IF_ERROR(ProcessGridSampleAttributes(qnn_model_wrapper, node_unit, param_tensor_names)); } + if (op_type == "ScatterElements") { + // Process axis attribute + int32_t default_axis = 0; + Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT; + ORT_RETURN_IF_ERROR(ProcessAxisAttribute(qnn_model_wrapper, node_unit, axis_qnn_scalar, default_axis)); + QnnParamWrapper axis_param(node_unit.Index(), node_unit.Name(), QNN_OP_SCATTER_ELEMENTS_PARAM_AXIS, axis_qnn_scalar); + param_tensor_names.push_back(axis_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(axis_param)); + + // Process reduction attribute + ORT_RETURN_IF_ERROR(ProcessReductionAttribute(qnn_model_wrapper, node_unit, param_tensor_names)); + } + return ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), std::move(param_tensor_names), diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index bfdb1a1a6afdd..1addd1d7de9b8 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -1017,6 +1017,94 @@ TEST_F(QnnHTPBackendTests, BinaryOp_HTP_Or_Unsupported) { ExpectedEPNodeAssignment::All); } +// Test ScatterElements with default attributes on HTP +TEST_F(QnnHTPBackendTests, ScatterElements_int64_int64) { + std::vector data = {0, 1, 2, 3}; + std::vector indices = {1}; + std::vector updates = {10}; + RunOpTest("ScatterElements", + { + TestInputDef({4}, false, std::move(data)), + TestInputDef({1, 1}, false, std::move(indices)), + TestInputDef({1}, false, std::move(updates)), + }, + {}, + 17, + ExpectedEPNodeAssignment::All); +} + +// Test ScatterElements with reduction ADD on HTP +TEST_F(QnnHTPBackendTests, ScatterElements_int64_int64_reduction_add) { + std::vector data = {0, 1, 2, 3}; + std::vector indices = {1}; + std::vector updates = {10}; + RunOpTest("ScatterElements", + { + TestInputDef({4}, false, std::move(data)), + TestInputDef({1, 1}, false, std::move(indices)), + TestInputDef({1}, false, std::move(updates)), + }, + { + utils::MakeAttribute("reduction", "add"), + }, + 17, + ExpectedEPNodeAssignment::All); +} + +// Test ScatterElements with reduction Mul on HTP +TEST_F(QnnHTPBackendTests, ScatterElements_int64_int64_reduction_mul) { + std::vector data = {0, 1, 2, 3}; + std::vector indices = {1}; + std::vector updates = {10}; + RunOpTest("ScatterElements", + { + TestInputDef({4}, false, std::move(data)), + TestInputDef({1, 1}, false, std::move(indices)), + TestInputDef({1}, false, std::move(updates)), + }, + { + utils::MakeAttribute("reduction", "mul"), + }, + 17, + ExpectedEPNodeAssignment::All); +} + +// Test ScatterElements with reduction Max on HTP +TEST_F(QnnHTPBackendTests, ScatterElements_int64_int64_reduction_max) { + std::vector data = {0, 1, 2, 3}; + std::vector indices = {1}; + std::vector updates = {10}; + RunOpTest("ScatterElements", + { + TestInputDef({4}, false, std::move(data)), + TestInputDef({1, 1}, false, std::move(indices)), + TestInputDef({1}, false, std::move(updates)), + }, + { + utils::MakeAttribute("reduction", "max"), + }, + 17, + ExpectedEPNodeAssignment::All); +} + +// Test ScatterElements with reduction Min on CPU Fallback +TEST_F(QnnHTPBackendTests, ScatterElements_int64_int64_reduction_min) { + std::vector data = {0, 1, 2, 3}; + std::vector indices = {1}; + std::vector updates = {10}; + RunOpTest("ScatterElements", + { + TestInputDef({4}, false, std::move(data)), + TestInputDef({1, 1}, false, std::move(indices)), + TestInputDef({1}, false, std::move(updates)), + }, + { + utils::MakeAttribute("reduction", "min"), + }, + 17, + ExpectedEPNodeAssignment::None); +} + // Test 8-bit QDQ GridSample with bilinear TEST_F(QnnHTPBackendTests, GridSample_Bilinear) { RunQDQOpTest("GridSample",