diff --git a/services/webnn/ort/context_impl_ort.cc b/services/webnn/ort/context_impl_ort.cc index bb9c1868d4cace..0f75afadeeeb8b 100644 --- a/services/webnn/ort/context_impl_ort.cc +++ b/services/webnn/ort/context_impl_ort.cc @@ -59,6 +59,10 @@ ContextProperties ContextImplOrt::GetContextProperties() { static constexpr SupportedRanks kNonScalarMaxRank = SupportedRanks::NonScalarUpTo(8); + static constexpr SupportedDataTypes kDequantizeLinearInputSupportedDataTypes{ + OperandDataType::kInt4, OperandDataType::kUint4, OperandDataType::kUint8, + OperandDataType::kInt8, OperandDataType::kInt32}; + return ContextProperties( InputOperandLayout::kNchw, Resample2DAxes::kChannelsFirst, /*tensor_byte_length_limit=*/kTensorByteLengthLimit, @@ -74,8 +78,8 @@ ContextProperties ContextImplOrt::GetContextProperties() { /*conv2d_input=*/DataTypeConstraint::kFloat16To32, /*conv_transpose2d_input=*/DataTypeConstraint::kFloat16To32, /*cumulative_sum_input=*/{}, - /*dequantize_linear_input=*/{}, - /*dequantize_linear_scale=*/{}, + /*dequantize_linear_input=*/kDequantizeLinearInputSupportedDataTypes, + /*dequantize_linear_scale=*/DataTypeConstraint::kFloat16To32, /*add_input=*/ {DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank}, /*sub_input=*/ diff --git a/services/webnn/ort/graph_builder_ort.cc b/services/webnn/ort/graph_builder_ort.cc index eb4bc1cc5ed452..974124d2dc0fc9 100644 --- a/services/webnn/ort/graph_builder_ort.cc +++ b/services/webnn/ort/graph_builder_ort.cc @@ -69,6 +69,7 @@ constexpr char kOpTypeClamp[] = "Clip"; constexpr char kOpTypeConcat[] = "Concat"; constexpr char kOpTypeConv2d[] = "Conv"; constexpr char kOpTypeConvTranspose2d[] = "ConvTranspose"; +constexpr char kOpTypeDequantizeLinear[] = "DequantizeLinear"; constexpr char kOpTypeExpand[] = "Expand"; constexpr char kOpTypeGather[] = "Gather"; constexpr char kOpTypeGelu[] = "Gelu"; @@ -300,6 +301,60 @@ void GraphBuilderOrt::AppendCast(std::string_view input_name, ADD_CAST_NODE(node_name, input_name, output_name, to_data_type); } +std::string GraphBuilderOrt::PrependTranspose( + std::string_view input_name, + base::span permutation) { + const std::string node_name = GenerateNextOperationName("inserted_transpose"); + const std::string output_name = GenerateNextOperandName(); + + std::array input_names = {input_name.data()}; + std::array output_names = {output_name.data()}; + + std::vector perm(permutation.begin(), permutation.end()); + std::array attributes = { + model_builder_.CreateAttribute(/*name=*/"perm", perm).Release()}; + + model_builder_.AddNode(kOpTypeTranspose, node_name, input_names, output_names, + attributes); + return output_name; +} + +void GraphBuilderOrt::AppendTranspose(std::string_view input_name, + std::string_view output_name, + base::span permutation) { + const std::string node_name = GenerateNextOperationName("inserted_transpose"); + std::array input_names = {input_name.data()}; + std::array output_names = {output_name.data()}; + + std::vector perm(permutation.begin(), permutation.end()); + std::array attributes = { + model_builder_.CreateAttribute(/*name=*/"perm", perm).Release()}; + + model_builder_.AddNode(kOpTypeTranspose, node_name, input_names, output_names, + attributes); +} + +[[nodiscard]] base::expected +GraphBuilderOrt::PrependReshape(std::string_view input_name, + base::span new_shape) { + const std::string node_name = GenerateNextOperationName("inserted_reshape"); + const std::string output_name = GenerateNextOperandName(); + + // Shape is an operand with data type int64, not an attribute. + std::vector new_shape_dims = { + base::checked_cast(new_shape.size())}; + ASSIGN_OR_RETURN(const std::string shape_name, + CreateInitializer(new_shape_dims, new_shape)); + + std::array input_names = {input_name.data(), + shape_name.c_str()}; + std::array output_names = {output_name.c_str()}; + + model_builder_.AddNode(kOpTypeReshape, node_name, input_names, output_names); + + return output_name; +} + void GraphBuilderOrt::AddInput(uint64_t input_id) { const mojom::Operand& operand = GetOperand(input_id); std::string name = GetOperandNameById(input_id); @@ -952,6 +1007,148 @@ GraphBuilderOrt::AddExpandOperation(const mojom::Expand& expand) { return base::ok(); } +[[nodiscard]] base::expected +GraphBuilderOrt::AddDequantizeLinearOperation( + const mojom::DequantizeLinear& dequantize_linear) { + const std::string node_name = + GenerateNextOperationName(dequantize_linear.label); + std::string input_name = + GetOperandNameById(dequantize_linear.input_operand_id); + std::string scale_name = + GetOperandNameById(dequantize_linear.scale_operand_id); + std::string zero_point_name = + GetOperandNameById(dequantize_linear.zero_point_operand_id); + std::string output_name = + GetOperandNameById(dequantize_linear.output_operand_id); + + const OperandDescriptor& input_descriptor = + GetOperand(dequantize_linear.input_operand_id).descriptor; + std::vector input_shape = input_descriptor.shape(); + + const OperandDescriptor& scale_descriptor = + GetOperand(dequantize_linear.scale_operand_id).descriptor; + // ZeroPoint has the same shape as the scale. + std::vector scale_shape = scale_descriptor.shape(); + + std::optional axis; + uint32_t not_one_value_dim_count = 0; + bool found_same_size = false; + CHECK_LE(scale_shape.size(), input_shape.size()); + for (size_t i = 0; i < scale_shape.size(); i++) { + if (scale_shape[scale_shape.size() - i - 1] != 1) { + not_one_value_dim_count++; + if (scale_shape[scale_shape.size() - i - 1] == + input_shape[input_shape.size() - i - 1]) { + axis = input_shape.size() - i - 1; + found_same_size = true; + } + } + } + // TODO(https://github.com/shiyi9801/chromium/issues/139): Consider to add + // emulation to support multiple axes case, e.g. input shape is [2, 3, 4, 5] + // and scale shape is [1, 3, 4, 1]. + bool is_per_axis = found_same_size && not_one_value_dim_count == 1; + + std::optional block_size; + bool need_transpose = false; + if (scale_shape.empty()) { + // For per-tensor/layer dequantization the scale is a scalar. + } else if (not_one_value_dim_count == 0) { + // The numbers in scale shape are all 1., scale and zeroPoint should be + // reshaped to a scalar. + ASSIGN_OR_RETURN(scale_name, PrependReshape(scale_name, {})); + ASSIGN_OR_RETURN(zero_point_name, PrependReshape(zero_point_name, {})); + } else if (is_per_axis) { + // For per-axis dequantization, scale and zeroPoint must be a 1-D + // Tensor. + CHECK(axis.has_value()); + ASSIGN_OR_RETURN(scale_name, + PrependReshape(scale_name, {input_shape[axis.value()]})); + ASSIGN_OR_RETURN( + zero_point_name, + PrependReshape(zero_point_name, {input_shape[axis.value()]})); + } else if (scale_shape.size() == input_shape.size()) { + // For blocked dequantization it has the same shape as the input, except for + // one dimension in which blocking is performed. + uint32_t blocked_axis_count = 0; + axis = 0; + block_size = 1; + for (size_t i = 0; i < input_shape.size(); i++) { + if (scale_shape[i] != input_shape[i]) { + CHECK_EQ(input_shape[i] % scale_shape[i], 0u); + block_size = input_shape[i] / scale_shape[i]; + axis = i; + blocked_axis_count++; + // TODO(https://github.com/shiyi9801/chromium/issues/135): Consider to + // add emulation to support multi-dimensions blockwise. + if (blocked_axis_count > 1) { + return NewNotSupportedError( + "For blocked dequantization scale has the same shape as the " + "input or except for one dimension in which blocking is " + "performed"); + } + } + } + + // Currently, OpenVINO only supports axis == 0 when scale.size == 2. + // https://github.com/openvinotoolkit/openvino/blob/master/src/frontends/onnx/frontend/src/op/dequantize_linear.cpp#L228. + if (base::CommandLine::ForCurrentProcess()->HasSwitch( + switches::kWebNNOrtUseOpenvino)) { + if (scale_shape.size() != 2) { + // https://github.com/openvinotoolkit/openvino/blob/master/src/frontends/onnx/frontend/src/op/dequantize_linear.cpp#L220 + return NewNotSupportedError( + "Currently ORT OpenVINO only support 2D scale for block_wise " + "dequantizeLinear."); + } else if (axis == 1) { + input_name = PrependTranspose(input_name, {1, 0}); + scale_name = PrependTranspose(scale_name, {1, 0}); + zero_point_name = PrependTranspose(zero_point_name, {1, 0}); + axis = 0; + need_transpose = true; + } + } + } else { + // The proposal of requiring scale and zeroPoint to be the same rank as + // the input is under discussion- + // https://github.com/webmachinelearning/webnn/pull/805#discussion_r1919498405 + return NewNotSupportedError( + "Currently, ONNX only supports per-tensor, per-axis and block-wise " + "dequantizeLinear"); + } + + const std::string transposed_output_name = + need_transpose ? GenerateNextOperandName() : output_name; + + base::FixedArray input_names = { + input_name.c_str(), scale_name.c_str(), zero_point_name.c_str()}; + base::FixedArray output_names = {transposed_output_name.c_str()}; + + std::vector attributes; + if (axis.has_value()) { + attributes.push_back( + model_builder_ + .CreateAttribute(/*name=*/"axis", + base::checked_cast(axis.value())) + .Release()); + } + + if (block_size.has_value()) { + attributes.push_back( + model_builder_ + .CreateAttribute(/*name=*/"block_size", + base::checked_cast(block_size.value())) + .Release()); + } + + model_builder_.AddNode(kOpTypeDequantizeLinear, node_name, input_names, + output_names, attributes); + + if (need_transpose) { + AppendTranspose(transposed_output_name, output_name, {1, 0}); + } + return base::ok(); +} + void GraphBuilderOrt::AddGatherOperation(const mojom::Gather& gather) { const std::string node_name = GenerateNextOperationName(gather.label); const std::string input_name = GetOperandNameById(gather.input_operand_id); @@ -1780,6 +1977,11 @@ GraphBuilderOrt::BuildModel() { RETURN_IF_ERROR(AddConv2dOperation(*operation->get_conv2d())); break; } + case mojom::Operation::Tag::kDequantizeLinear: { + RETURN_IF_ERROR( + AddDequantizeLinearOperation(*operation->get_dequantize_linear())); + break; + } case mojom::Operation::Tag::kExpand: { RETURN_IF_ERROR(AddExpandOperation(*operation->get_expand())); break; @@ -1863,7 +2065,6 @@ GraphBuilderOrt::BuildModel() { break; } case mojom::Operation::Tag::kCumulativeSum: - case mojom::Operation::Tag::kDequantizeLinear: case mojom::Operation::Tag::kElu: case mojom::Operation::Tag::kGatherElements: case mojom::Operation::Tag::kGatherNd: diff --git a/services/webnn/ort/graph_builder_ort.h b/services/webnn/ort/graph_builder_ort.h index ca9e2a5953a93c..9b1fb00620b91f 100644 --- a/services/webnn/ort/graph_builder_ort.h +++ b/services/webnn/ort/graph_builder_ort.h @@ -131,6 +131,13 @@ class GraphBuilderOrt { std::string PrependCast(std::string_view input_name, ONNXTensorElementDataType to_data_type); + [[nodiscard]] base::expected PrependReshape( + std::string_view input_name, + base::span new_shape); + + std::string PrependTranspose(std::string_view input_name, + base::span permutation); + // Insert a cast operation after an operation to convert its output to the // target `to_data_type`. The `input_name` specifies the cast operation's // input (the output of the operation to be casted), and the `output_name` @@ -139,6 +146,10 @@ class GraphBuilderOrt { std::string_view output_name, ONNXTensorElementDataType to_data_type); + void AppendTranspose(std::string_view input_name, + std::string_view output_name, + base::span permutation); + void AddInput(uint64_t input_id); void AddOutput(uint64_t output_id); @@ -175,6 +186,9 @@ class GraphBuilderOrt { const mojom::Conv2d& conv2d); [[nodiscard]] base::expected AddExpandOperation( const mojom::Expand& expand); + [[nodiscard]] base::expected + AddDequantizeLinearOperation( + const mojom::DequantizeLinear& dequantize_linear); void AddGatherOperation(const mojom::Gather& gather); void AddGemmOperation(const mojom::Gemm& gemm); [[nodiscard]] base::expected diff --git a/services/webnn/ort/ort_model_builder.cc b/services/webnn/ort/ort_model_builder.cc index be4818b816658d..090858998e13bf 100644 --- a/services/webnn/ort/ort_model_builder.cc +++ b/services/webnn/ort/ort_model_builder.cc @@ -101,7 +101,7 @@ void OrtModelBuilder::AddOutput(std::string_view name, void* ort_tensor_raw_data = nullptr; RETURN_STATUS_IF_FAILED( GetOrtApi()->GetTensorMutableData(initializer, &ort_tensor_raw_data)); - CHECK(ort_tensor_raw_data); + // ort_tensor_raw_data can be nullprt when data is empty. UNSAFE_BUFFERS( base::span(static_cast(ort_tensor_raw_data), data.size())) .copy_from(data);