Skip to content

[WebNN] Refactor op mappings and add input name mapping between ONNX and WebNN #24830

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 14 additions & 170 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
#include "core/common/inlined_containers.h"
#include <core/graph/basic_types.h>
#include "core/optimizer/initializer.h"
#include "core/providers/common.h"
#include "core/providers/shared/utils/utils.h"
#include "map_info.h"

Check warning on line 13 in onnxruntime/core/providers/webnn/builders/helper.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: onnxruntime/core/providers/webnn/builders/helper.h:13: Include the directory when naming header files [build/include_subdir] [4]

#include <emscripten.h>
#include <emscripten/val.h>
Expand Down Expand Up @@ -201,183 +201,27 @@
const emscripten::val& wnn_limits,
const logging::Logger& logger);

// Some ONNX ops are supported by decomposed WebNN ops.
const std::map<std::string_view, std::vector<std::string_view>> decomposed_op_map = {
{"ConvInteger", {"cast", "conv2d", "dequantizeLinear"}},
{"GroupQueryAttention",
{"add", "cast", "concat", "constant", "cumulativeSum", "div", "expand", "lesser", "matmul", "reshape", "scatterND",
"softmax", "transpose", "where"}},
{"LRN", {"add", "averagePool2d", "div", "mul", "pad", "pow", "transpose"}},
{"MatMulInteger", {"cast", "dequantizeLinear", "matmul"}},
{"MatMulNBits", {"add", "dequantizeLinear", "matmul", "reshape", "transpose"}},
{"MultiHeadAttention", {"add", "cast", "concat", "constant", "div", "matmul", "reshape", "softmax", "transpose"}},
{"RotaryEmbedding", {"add", "concat", "gather", "mul", "reshape", "slice", "split"}},
{"SimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}},
{"SkipSimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}},
};
// ONNX op type to WebNN op type mapping.
const std::map<std::string_view, std::string_view> op_map = {
{"Abs", "abs"},
{"Add", "add"},
{"And", "logicalAnd"},
{"ArgMax", "argMax"},
{"ArgMin", "argMin"},
{"AveragePool", "averagePool2d"},
{"BatchNormalization", "batchNormalization"},
{"Cast", "cast"},
{"Ceil", "ceil"},
{"Clip", "clamp"},
{"Concat", "concat"},
{"Conv", "conv2d"},
{"ConvTranspose", "convTranspose2d"},
{"Cos", "cos"},
{"CumSum", "cumulativeSum"},
{"Div", "div"},
{"DequantizeLinear", "dequantizeLinear"},
{"Dropout", "identity"},
{"DynamicQuantizeLinear", "dynamicQuantizeLinear"},
{"Einsum", "matmul"},
{"Elu", "elu"},
{"Equal", "equal"},
{"Erf", "erf"},
{"Exp", "exp"},
{"Expand", "expand"},
{"Flatten", "reshape"},
{"Floor", "floor"},
{"Gather", "gather"},
{"GatherElements", "gatherElements"},
{"GatherND", "gatherND"},
{"Gelu", "gelu"},
{"Gemm", "gemm"},
{"GlobalAveragePool", "averagePool2d"},
{"GlobalMaxPool", "maxPool2d"},
{"GlobalLpPool", "l2Pool2d"},
{"Greater", "greater"},
{"GreaterOrEqual", "greaterOrEqual"},
{"GRU", "gru"},
{"HardSigmoid", "hardSigmoid"},
{"HardSwish", "hardSwish"},
{"Identity", "identity"},
{"InstanceNormalization", "instanceNormalization"},
{"LayerNormalization", "layerNormalization"},
{"LeakyRelu", "leakyRelu"},
{"Less", "lesser"},
{"LessOrEqual", "lesserOrEqual"},
{"Log", "log"},
{"LpPool", "l2Pool2d"},
{"LSTM", "lstm"},
{"MatMul", "matmul"},
{"Max", "max"},
{"MaxPool", "maxPool2d"},
{"Min", "min"},
{"Mul", "mul"},
{"Neg", "neg"},
{"Not", "logicalNot"},
{"Or", "logicalOr"},
{"Pad", "pad"},
{"Pow", "pow"},
{"PRelu", "prelu"},
{"QuantizeLinear", "quantizeLinear"},
{"Reciprocal", "reciprocal"},
{"ReduceL1", "reduceL1"},
{"ReduceL2", "reduceL2"},
{"ReduceLogSum", "reduceLogSum"},
{"ReduceLogSumExp", "reduceLogSumExp"},
{"ReduceMax", "reduceMax"},
{"ReduceMean", "reduceMean"},
{"ReduceMin", "reduceMin"},
{"ReduceProd", "reduceProduct"},
{"ReduceSum", "reduceSum"},
{"ReduceSumSquare", "reduceSumSquare"},
{"Relu", "relu"},
{"Reshape", "reshape"},
{"Resize", "resample2d"},
{"ScatterElements", "scatterElements"},
{"ScatterND", "scatterND"},
{"Shape", "slice"},
{"Sigmoid", "sigmoid"},
{"Sign", "sign"},
{"Softplus", "softplus"},
{"Softsign", "softsign"},
{"Sin", "sin"},
{"Slice", "slice"},
{"Softmax", "softmax"},
{"Split", "split"},
{"Sqrt", "sqrt"},
{"Squeeze", "reshape"},
{"Sub", "sub"},
{"Tan", "tan"},
{"Tanh", "tanh"},
{"Tile", "tile"},
{"Transpose", "transpose"},
{"Trilu", "triangular"},
{"Unsqueeze", "reshape"},
{"Where", "where"},
{"Xor", "logicalXor"},
};

// WebNN op name to its first input name mapping, only record the name that is different from "input".
// This map is used to determine the first input name of a WebNN op and is utilized by OpSupportLimits.
const std::map<std::string_view, std::string_view> webnn_op_first_input_name_map = {
{"add", "a"},
{"concat", "inputs"},
{"div", "a"},
{"equal", "a"},
{"gemm", "a"},
{"greater", "a"},
{"greaterOrEqual", "a"},
{"lesser", "a"},
{"lesserOrEqual", "a"},
{"logicalAnd", "a"},
{"logicalNot", "a"},
{"logicalOr", "a"},
{"logicalXor", "a"},
{"matmul", "a"},
{"max", "a"},
{"min", "a"},
{"mul", "a"},
{"pow", "a"},
{"sub", "a"},
{"where", "condition"},
};

// Retrieve the first input name of a WebNN op used for validating supported input data types.
// WebNN ops have various first input names such as 'a', 'input', 'inputs', etc.
// Special names other than 'input' are recorded in the webnn_op_first_input_name_map.
// All WebNN op inputs are recorded in op_inputs_map.
inline std::string_view GetWebNNOpFirstInputName(const std::string_view webnn_op_type) {
auto it = webnn_op_first_input_name_map.find(webnn_op_type);
return (it != webnn_op_first_input_name_map.end()) ? it->second : "input";
auto it = op_inputs_map.find(webnn_op_type);
if (it != op_inputs_map.end()) {
for (const auto& input : it->second.inputs) {
if (input.index == 0) {
return input.name;
}
}
}
return "input";
}

inline std::string_view GetWebNNOpType(const std::string_view op_type) {
auto it = op_map.find(op_type);
// Return an empty string if the op_type is not listed in the op_map.
return (it != op_map.end()) ? it->second : "";
auto it = op_inputs_map.find(op_type);
// Return an empty string if the op_type is not listed in the op_inputs_map.
return (it != op_inputs_map.end()) ? it->second.opType : "";
}

const std::map<ONNX_NAMESPACE::TensorProto_DataType, std::string_view> onnx_to_webnn_data_type_map = {
{ONNX_NAMESPACE::TensorProto_DataType_INT4, "int4"},
{ONNX_NAMESPACE::TensorProto_DataType_UINT4, "uint4"},
{ONNX_NAMESPACE::TensorProto_DataType_BOOL, "uint8"},
{ONNX_NAMESPACE::TensorProto_DataType_INT8, "int8"},
{ONNX_NAMESPACE::TensorProto_DataType_UINT8, "uint8"},
{ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, "float16"},
{ONNX_NAMESPACE::TensorProto_DataType_FLOAT, "float32"},
{ONNX_NAMESPACE::TensorProto_DataType_INT32, "int32"},
{ONNX_NAMESPACE::TensorProto_DataType_INT64, "int64"},
{ONNX_NAMESPACE::TensorProto_DataType_UINT32, "uint32"},
{ONNX_NAMESPACE::TensorProto_DataType_UINT64, "uint64"},
};

// This array contains the input/output data types of a WebNN graph that are allowed to be fallback to int32.
constexpr std::array<ONNX_NAMESPACE::TensorProto_DataType, 5> supported_fallback_integer_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
ONNX_NAMESPACE::TensorProto_DataType_INT8,
ONNX_NAMESPACE::TensorProto_DataType_UINT8,
ONNX_NAMESPACE::TensorProto_DataType_UINT32,
ONNX_NAMESPACE::TensorProto_DataType_INT64,
};

bool AreDataTypesSame(const std::string_view op_type,
gsl::span<const int32_t> input_types,
const logging::Logger& logger);
Expand Down
Loading
Loading