diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 072273a137557..f124e90580353 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -9,8 +9,8 @@ #include "core/common/inlined_containers.h" #include #include "core/optimizer/initializer.h" -#include "core/providers/common.h" #include "core/providers/shared/utils/utils.h" +#include "map_info.h" #include #include @@ -201,183 +201,27 @@ std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewe const emscripten::val& wnn_limits, const logging::Logger& logger); -// Some ONNX ops are supported by decomposed WebNN ops. -const std::map> decomposed_op_map = { - {"ConvInteger", {"cast", "conv2d", "dequantizeLinear"}}, - {"GroupQueryAttention", - {"add", "cast", "concat", "constant", "cumulativeSum", "div", "expand", "lesser", "matmul", "reshape", "scatterND", - "softmax", "transpose", "where"}}, - {"LRN", {"add", "averagePool2d", "div", "mul", "pad", "pow", "transpose"}}, - {"MatMulInteger", {"cast", "dequantizeLinear", "matmul"}}, - {"MatMulNBits", {"add", "dequantizeLinear", "matmul", "reshape", "transpose"}}, - {"MultiHeadAttention", {"add", "cast", "concat", "constant", "div", "matmul", "reshape", "softmax", "transpose"}}, - {"RotaryEmbedding", {"add", "concat", "gather", "mul", "reshape", "slice", "split"}}, - {"SimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}}, - {"SkipSimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}}, -}; -// ONNX op type to WebNN op type mapping. -const std::map op_map = { - {"Abs", "abs"}, - {"Add", "add"}, - {"And", "logicalAnd"}, - {"ArgMax", "argMax"}, - {"ArgMin", "argMin"}, - {"AveragePool", "averagePool2d"}, - {"BatchNormalization", "batchNormalization"}, - {"Cast", "cast"}, - {"Ceil", "ceil"}, - {"Clip", "clamp"}, - {"Concat", "concat"}, - {"Conv", "conv2d"}, - {"ConvTranspose", "convTranspose2d"}, - {"Cos", "cos"}, - {"CumSum", "cumulativeSum"}, - {"Div", "div"}, - {"DequantizeLinear", "dequantizeLinear"}, - {"Dropout", "identity"}, - {"DynamicQuantizeLinear", "dynamicQuantizeLinear"}, - {"Einsum", "matmul"}, - {"Elu", "elu"}, - {"Equal", "equal"}, - {"Erf", "erf"}, - {"Exp", "exp"}, - {"Expand", "expand"}, - {"Flatten", "reshape"}, - {"Floor", "floor"}, - {"Gather", "gather"}, - {"GatherElements", "gatherElements"}, - {"GatherND", "gatherND"}, - {"Gelu", "gelu"}, - {"Gemm", "gemm"}, - {"GlobalAveragePool", "averagePool2d"}, - {"GlobalMaxPool", "maxPool2d"}, - {"GlobalLpPool", "l2Pool2d"}, - {"Greater", "greater"}, - {"GreaterOrEqual", "greaterOrEqual"}, - {"GRU", "gru"}, - {"HardSigmoid", "hardSigmoid"}, - {"HardSwish", "hardSwish"}, - {"Identity", "identity"}, - {"InstanceNormalization", "instanceNormalization"}, - {"LayerNormalization", "layerNormalization"}, - {"LeakyRelu", "leakyRelu"}, - {"Less", "lesser"}, - {"LessOrEqual", "lesserOrEqual"}, - {"Log", "log"}, - {"LpPool", "l2Pool2d"}, - {"LSTM", "lstm"}, - {"MatMul", "matmul"}, - {"Max", "max"}, - {"MaxPool", "maxPool2d"}, - {"Min", "min"}, - {"Mul", "mul"}, - {"Neg", "neg"}, - {"Not", "logicalNot"}, - {"Or", "logicalOr"}, - {"Pad", "pad"}, - {"Pow", "pow"}, - {"PRelu", "prelu"}, - {"QuantizeLinear", "quantizeLinear"}, - {"Reciprocal", "reciprocal"}, - {"ReduceL1", "reduceL1"}, - {"ReduceL2", "reduceL2"}, - {"ReduceLogSum", "reduceLogSum"}, - {"ReduceLogSumExp", "reduceLogSumExp"}, - {"ReduceMax", "reduceMax"}, - {"ReduceMean", "reduceMean"}, - {"ReduceMin", "reduceMin"}, - {"ReduceProd", "reduceProduct"}, - {"ReduceSum", "reduceSum"}, - {"ReduceSumSquare", "reduceSumSquare"}, - {"Relu", "relu"}, - {"Reshape", "reshape"}, - {"Resize", "resample2d"}, - {"ScatterElements", "scatterElements"}, - {"ScatterND", "scatterND"}, - {"Shape", "slice"}, - {"Sigmoid", "sigmoid"}, - {"Sign", "sign"}, - {"Softplus", "softplus"}, - {"Softsign", "softsign"}, - {"Sin", "sin"}, - {"Slice", "slice"}, - {"Softmax", "softmax"}, - {"Split", "split"}, - {"Sqrt", "sqrt"}, - {"Squeeze", "reshape"}, - {"Sub", "sub"}, - {"Tan", "tan"}, - {"Tanh", "tanh"}, - {"Tile", "tile"}, - {"Transpose", "transpose"}, - {"Trilu", "triangular"}, - {"Unsqueeze", "reshape"}, - {"Where", "where"}, - {"Xor", "logicalXor"}, -}; - -// WebNN op name to its first input name mapping, only record the name that is different from "input". -// This map is used to determine the first input name of a WebNN op and is utilized by OpSupportLimits. -const std::map webnn_op_first_input_name_map = { - {"add", "a"}, - {"concat", "inputs"}, - {"div", "a"}, - {"equal", "a"}, - {"gemm", "a"}, - {"greater", "a"}, - {"greaterOrEqual", "a"}, - {"lesser", "a"}, - {"lesserOrEqual", "a"}, - {"logicalAnd", "a"}, - {"logicalNot", "a"}, - {"logicalOr", "a"}, - {"logicalXor", "a"}, - {"matmul", "a"}, - {"max", "a"}, - {"min", "a"}, - {"mul", "a"}, - {"pow", "a"}, - {"sub", "a"}, - {"where", "condition"}, -}; - // Retrieve the first input name of a WebNN op used for validating supported input data types. // WebNN ops have various first input names such as 'a', 'input', 'inputs', etc. -// Special names other than 'input' are recorded in the webnn_op_first_input_name_map. +// All WebNN op inputs are recorded in op_inputs_map. inline std::string_view GetWebNNOpFirstInputName(const std::string_view webnn_op_type) { - auto it = webnn_op_first_input_name_map.find(webnn_op_type); - return (it != webnn_op_first_input_name_map.end()) ? it->second : "input"; + auto it = op_inputs_map.find(webnn_op_type); + if (it != op_inputs_map.end()) { + for (const auto& input : it->second.inputs) { + if (input.index == 0) { + return input.name; + } + } + } + return "input"; } inline std::string_view GetWebNNOpType(const std::string_view op_type) { - auto it = op_map.find(op_type); - // Return an empty string if the op_type is not listed in the op_map. - return (it != op_map.end()) ? it->second : ""; + auto it = op_inputs_map.find(op_type); + // Return an empty string if the op_type is not listed in the op_inputs_map. + return (it != op_inputs_map.end()) ? it->second.opType : ""; } -const std::map onnx_to_webnn_data_type_map = { - {ONNX_NAMESPACE::TensorProto_DataType_INT4, "int4"}, - {ONNX_NAMESPACE::TensorProto_DataType_UINT4, "uint4"}, - {ONNX_NAMESPACE::TensorProto_DataType_BOOL, "uint8"}, - {ONNX_NAMESPACE::TensorProto_DataType_INT8, "int8"}, - {ONNX_NAMESPACE::TensorProto_DataType_UINT8, "uint8"}, - {ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, "float16"}, - {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, "float32"}, - {ONNX_NAMESPACE::TensorProto_DataType_INT32, "int32"}, - {ONNX_NAMESPACE::TensorProto_DataType_INT64, "int64"}, - {ONNX_NAMESPACE::TensorProto_DataType_UINT32, "uint32"}, - {ONNX_NAMESPACE::TensorProto_DataType_UINT64, "uint64"}, -}; - -// This array contains the input/output data types of a WebNN graph that are allowed to be fallback to int32. -constexpr std::array supported_fallback_integer_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_BOOL, - ONNX_NAMESPACE::TensorProto_DataType_INT8, - ONNX_NAMESPACE::TensorProto_DataType_UINT8, - ONNX_NAMESPACE::TensorProto_DataType_UINT32, - ONNX_NAMESPACE::TensorProto_DataType_INT64, -}; - bool AreDataTypesSame(const std::string_view op_type, gsl::span input_types, const logging::Logger& logger); diff --git a/onnxruntime/core/providers/webnn/builders/map_info.h b/onnxruntime/core/providers/webnn/builders/map_info.h new file mode 100644 index 0000000000000..59408ba244842 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/map_info.h @@ -0,0 +1,205 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/providers/common.h" + +/** + * This file defines mappings and structures to facilitate the translation of ONNX operations + * and data types to their corresponding WebNN representations. + * + * It includes: + * - Data type mappings between ONNX and WebNN. + * - Lists of supported fallback integer types for WebNN. + * - Decomposition of certain ONNX operations into sequences of WebNN operations. + * - Structures and maps for input index-to-name translation for ONNX to WebNN ops. + */ +namespace onnxruntime { +namespace webnn { +const std::map onnx_to_webnn_data_type_map = { + {ONNX_NAMESPACE::TensorProto_DataType_INT4, "int4"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT4, "uint4"}, + {ONNX_NAMESPACE::TensorProto_DataType_BOOL, "uint8"}, + {ONNX_NAMESPACE::TensorProto_DataType_INT8, "int8"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT8, "uint8"}, + {ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, "float16"}, + {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, "float32"}, + {ONNX_NAMESPACE::TensorProto_DataType_INT32, "int32"}, + {ONNX_NAMESPACE::TensorProto_DataType_INT64, "int64"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT32, "uint32"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT64, "uint64"}, +}; + +// This array contains the input/output data types of a WebNN graph that are allowed to be fallback to int32. +constexpr std::array supported_fallback_integer_data_types = { + ONNX_NAMESPACE::TensorProto_DataType_BOOL, + ONNX_NAMESPACE::TensorProto_DataType_INT8, + ONNX_NAMESPACE::TensorProto_DataType_UINT8, + ONNX_NAMESPACE::TensorProto_DataType_UINT32, + ONNX_NAMESPACE::TensorProto_DataType_INT64, +}; + +// Some ONNX ops are supported by decomposed WebNN ops. +const std::map> decomposed_op_map = { + {"ConvInteger", {"cast", "conv2d", "dequantizeLinear"}}, + {"GroupQueryAttention", + {"add", "cast", "concat", "constant", "cumulativeSum", "div", "expand", "lesser", "matmul", "reshape", "scatterND", + "softmax", "transpose", "where"}}, + {"LRN", {"add", "averagePool2d", "div", "mul", "pad", "pow", "transpose"}}, + {"MatMulInteger", {"cast", "dequantizeLinear", "matmul"}}, + {"MatMulNBits", {"add", "dequantizeLinear", "matmul", "reshape", "transpose"}}, + {"MultiHeadAttention", {"add", "cast", "concat", "constant", "div", "matmul", "reshape", "softmax", "transpose"}}, + {"RotaryEmbedding", {"add", "concat", "gather", "mul", "reshape", "slice", "split"}}, + {"SimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}}, + {"SkipSimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}}, +}; + +/** + * Represents information about an input to a WebNN operation. + * + * This structure is used to map ONNX operation inputs to their corresponding + * WebNN operation inputs. It contains the index of the input as specified + * in the ONNX operation and the name of the input in the WebNN operation. + * + * InputInfo::index + * The index of this input as defined in the ONNX operation specification. + * + * InputInfo::name + * The name of this input in the WebNN operation. + */ +struct InputInfo { + int index; + std::string_view name; +}; + +struct WebnnOpInfo { + std::string_view opType; + std::vector inputs; + WebnnOpInfo(std::string_view op, std::initializer_list in) + : opType(op), inputs(in) {} +}; + +/** + * Maps ONNX operation type to their corresponding WebNN operation type and input mappings. + * + * This unordered map provides a mapping between ONNX operation names (keys) and their corresponding + * WebNN operation information (values). Each value is a `WebnnOpInfo` structure that contains: + * - The WebNN operation name (`opType`). + * - A vector of `InputInfo` structures, where each `InputInfo` specifies: + * - The index of the input in the ONNX operation (`index`). + * - The corresponding input name in the WebNN operation (`name`). + * + * For the ONNX operation "Abs", it has only one "input", which is at index 0 in the "Node.InputDefs" array. + * The corresponding WebNN operation is "abs", and the input name is "input". + * + * This mapping is used to translate ONNX operations and their inputs into WebNN operations + * and their respective input names. + * + * Order: + * The sorting rule is based on character length in ascending order (for better formatting), + * and for items with the same length, they are sorted alphabetically. + */ +const std::unordered_map op_inputs_map = { + {"Cos", {"cos", {{0, "input"}}}}, + {"Abs", {"abs", {{0, "input"}}}}, + {"Elu", {"elu", {{0, "input"}}}}, + {"Erf", {"erf", {{0, "input"}}}}, + {"Exp", {"exp", {{0, "input"}}}}, + {"Log", {"log", {{0, "input"}}}}, + {"Neg", {"neg", {{0, "input"}}}}, + {"Pad", {"pad", {{0, "input"}}}}, + {"Sin", {"sin", {{0, "input"}}}}, + {"Tan", {"tan", {{0, "input"}}}}, + {"Cast", {"cast", {{0, "input"}}}}, + {"Ceil", {"ceil", {{0, "input"}}}}, + {"Gelu", {"gelu", {{0, "input"}}}}, + {"Relu", {"relu", {{0, "input"}}}}, + {"Sign", {"sign", {{0, "input"}}}}, + {"Sqrt", {"sqrt", {{0, "input"}}}}, + {"Tanh", {"tanh", {{0, "input"}}}}, + {"Tile", {"tile", {{0, "input"}}}}, + {"Clip", {"clamp", {{0, "input"}}}}, + {"Floor", {"floor", {{0, "input"}}}}, + {"Shape", {"slice", {{0, "input"}}}}, + {"Slice", {"slice", {{0, "input"}}}}, + {"Split", {"split", {{0, "input"}}}}, + {"Sub", {"sub", {{0, "a"}, {1, "b"}}}}, + {"Add", {"add", {{0, "a"}, {1, "b"}}}}, + {"ArgMax", {"argMax", {{0, "input"}}}}, + {"ArgMin", {"argMin", {{0, "input"}}}}, + {"Div", {"div", {{0, "a"}, {1, "b"}}}}, + {"Expand", {"expand", {{0, "input"}}}}, + {"Max", {"max", {{0, "a"}, {1, "b"}}}}, + {"Min", {"min", {{0, "a"}, {1, "b"}}}}, + {"Mul", {"mul", {{0, "a"}, {1, "b"}}}}, + {"Pow", {"pow", {{0, "a"}, {1, "b"}}}}, + {"Concat", {"concat", {{0, "inputs"}}}}, + {"Not", {"logicalNot", {{0, "input"}}}}, + {"Flatten", {"reshape", {{0, "input"}}}}, + {"LpPool", {"l2Pool2d", {{0, "input"}}}}, + {"Reshape", {"reshape", {{0, "input"}}}}, + {"Sigmoid", {"sigmoid", {{0, "input"}}}}, + {"Softmax", {"softmax", {{0, "input"}}}}, + {"Squeeze", {"reshape", {{0, "input"}}}}, + {"Dropout", {"identity", {{0, "input"}}}}, + {"Trilu", {"triangular", {{0, "input"}}}}, + {"Equal", {"equal", {{0, "a"}, {1, "b"}}}}, + {"Identity", {"identity", {{0, "input"}}}}, + {"Less", {"lesser", {{0, "a"}, {1, "b"}}}}, + {"MaxPool", {"maxPool2d", {{0, "input"}}}}, + {"ReduceL1", {"reduceL1", {{0, "input"}}}}, + {"ReduceL2", {"reduceL2", {{0, "input"}}}}, + {"Resize", {"resample2d", {{0, "input"}}}}, + {"Softplus", {"softplus", {{0, "input"}}}}, + {"Softsign", {"softsign", {{0, "input"}}}}, + {"Unsqueeze", {"reshape", {{0, "input"}}}}, + {"Or", {"logicalOr", {{0, "a"}, {1, "b"}}}}, + {"Einsum", {"matmul", {{0, "a"}, {1, "b"}}}}, + {"HardSwish", {"hardSwish", {{0, "input"}}}}, + {"LeakyRelu", {"leakyRelu", {{0, "input"}}}}, + {"MatMul", {"matmul", {{0, "a"}, {1, "b"}}}}, + {"ReduceMax", {"reduceMax", {{0, "input"}}}}, + {"ReduceMin", {"reduceMin", {{0, "input"}}}}, + {"ReduceSum", {"reduceSum", {{0, "input"}}}}, + {"Transpose", {"transpose", {{0, "input"}}}}, + {"And", {"logicalAnd", {{0, "a"}, {1, "b"}}}}, + {"CumSum", {"cumulativeSum", {{0, "input"}}}}, + {"Xor", {"logicalXor", {{0, "a"}, {1, "b"}}}}, + {"GlobalLpPool", {"l2Pool2d", {{0, "input"}}}}, + {"Greater", {"greater", {{0, "a"}, {1, "b"}}}}, + {"Reciprocal", {"reciprocal", {{0, "input"}}}}, + {"ReduceMean", {"reduceMean", {{0, "input"}}}}, + {"GlobalMaxPool", {"maxPool2d", {{0, "input"}}}}, + {"HardSigmoid", {"hardSigmoid", {{0, "input"}}}}, + {"ReduceProd", {"reduceProduct", {{0, "input"}}}}, + {"AveragePool", {"averagePool2d", {{0, "input"}}}}, + {"Gemm", {"gemm", {{0, "a"}, {1, "b"}, {2, "c"}}}}, + {"PRelu", {"prelu", {{0, "input"}, {1, "slope"}}}}, + {"ReduceLogSum", {"reduceLogSum", {{0, "input"}}}}, + {"Gather", {"gather", {{0, "input"}, {1, "indices"}}}}, + {"LessOrEqual", {"lesserOrEqual", {{0, "a"}, {1, "b"}}}}, + {"GlobalAveragePool", {"averagePool2d", {{0, "input"}}}}, + {"ReduceLogSumExp", {"reduceLogSumExp", {{0, "input"}}}}, + {"ReduceSumSquare", {"reduceSumSquare", {{0, "input"}}}}, + {"GatherND", {"gatherND", {{0, "input"}, {1, "indices"}}}}, + {"GreaterOrEqual", {"greaterOrEqual", {{0, "a"}, {1, "b"}}}}, + {"Conv", {"conv2d", {{0, "input"}, {1, "filter"}, {2, "bias"}}}}, + {"DynamicQuantizeLinear", {"dynamicQuantizeLinear", {{0, "input"}}}}, + {"GatherElements", {"gatherElements", {{0, "input"}, {1, "indices"}}}}, + {"ScatterND", {"scatterND", {{0, "input"}, {1, "indices"}, {2, "updates"}}}}, + {"Where", {"where", {{0, "condition"}, {1, "trueValue"}, {2, "falseValue"}}}}, + {"ConvTranspose", {"convTranspose2d", {{0, "input"}, {1, "filter"}, {2, "bias"}}}}, + {"QuantizeLinear", {"quantizeLinear", {{0, "input"}, {1, "scale"}, {2, "zeroPoint"}}}}, + {"ScatterElements", {"scatterElements", {{0, "input"}, {1, "indices"}, {2, "updates"}}}}, + {"LayerNormalization", {"layerNormalization", {{0, "input"}, {1, "scale"}, {2, "bias"}}}}, + {"DequantizeLinear", {"dequantizeLinear", {{0, "input"}, {1, "scale"}, {2, "zeroPoint"}}}}, + {"InstanceNormalization", {"instanceNormalization", {{0, "input"}, {1, "scale"}, {2, "bias"}}}}, + {"GRU", {"gru", {{0, "input"}, {1, "weight"}, {2, "recurrentWeight"}, {3, "bias"}, {5, "initialHiddenState"}}}}, + {"BatchNormalization", {"batchNormalization", {{0, "input"}, {1, "scale"}, {2, "bias"}, {3, "input_mean"}, {4, "input_var"}}}}, + {"LSTM", {"lstm", {{0, "input"}, {1, "weight"}, {2, "recurrentWeight"}, {3, "bias"}, {5, "initialHiddenState"}, {6, "initialCellState"}, {7, "peepholeWeight"}}}}, +}; +} // namespace webnn +} // namespace onnxruntime