Skip to content

Commit

Permalink
Implement dequantizeLinear
Browse files Browse the repository at this point in the history
  • Loading branch information
lisa0314 committed Feb 11, 2025
1 parent a97dfb2 commit 5f0efb1
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 3 deletions.
8 changes: 6 additions & 2 deletions services/webnn/ort/context_impl_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ ContextProperties ContextImplOrt::GetContextProperties() {
static constexpr SupportedRanks kNonScalarMaxRank =
SupportedRanks::NonScalarUpTo(8);

static constexpr SupportedDataTypes kDequantizeLinearInputSupportedDataTypes{
OperandDataType::kInt4, OperandDataType::kUint4, OperandDataType::kUint8,
OperandDataType::kInt8, OperandDataType::kInt32};

return ContextProperties(
InputOperandLayout::kNchw, Resample2DAxes::kChannelsFirst,
/*tensor_byte_length_limit=*/kTensorByteLengthLimit,
Expand All @@ -74,8 +78,8 @@ ContextProperties ContextImplOrt::GetContextProperties() {
/*conv2d_input=*/DataTypeConstraint::kFloat16To32,
/*conv_transpose2d_input=*/DataTypeConstraint::kFloat16To32,
/*cumulative_sum_input=*/{},
/*dequantize_linear_input=*/{},
/*dequantize_linear_scale=*/{},
/*dequantize_linear_input=*/kDequantizeLinearInputSupportedDataTypes,
/*dequantize_linear_scale=*/DataTypeConstraint::kFloat16To32,
/*add_input=*/
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
/*sub_input=*/
Expand Down
157 changes: 156 additions & 1 deletion services/webnn/ort/graph_builder_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ constexpr char kOpTypeClamp[] = "Clip";
constexpr char kOpTypeConcat[] = "Concat";
constexpr char kOpTypeConv2d[] = "Conv";
constexpr char kOpTypeConvTranspose2d[] = "ConvTranspose";
constexpr char kOpTypeDequantizeLinear[] = "DequantizeLinear";
constexpr char kOpTypeExpand[] = "Expand";
constexpr char kOpTypeGather[] = "Gather";
constexpr char kOpTypeGelu[] = "Gelu";
Expand Down Expand Up @@ -300,6 +301,39 @@ void GraphBuilderOrt::AppendCast(std::string_view input_name,
ADD_CAST_NODE(node_name, input_name, output_name, to_data_type);
}

std::string GraphBuilderOrt::PrependTranspose(
std::string_view input_name,
base::span<const uint32_t> permutation) {
const std::string node_name = GenerateNextOperationName("inserted_transpose");
const std::string output_name = GenerateNextOperandName();

std::array<const char*, 1> input_names = {input_name.data()};
std::array<const char*, 1> output_names = {output_name.data()};

std::vector<int64_t> perm(permutation.begin(), permutation.end());
std::array<OrtOpAttr*, 1> attributes = {
model_builder_.CreateAttribute(/*name=*/"perm", perm).Release()};

model_builder_.AddNode(kOpTypeTranspose, node_name, input_names, output_names,
attributes);
return output_name;
}

void GraphBuilderOrt::AppendTranspose(std::string_view input_name,
std::string_view output_name,
base::span<const uint32_t> permutation) {
const std::string node_name = GenerateNextOperationName("inserted_transpose");
std::array<const char*, 1> input_names = {input_name.data()};
std::array<const char*, 1> output_names = {output_name.data()};

std::vector<int64_t> perm(permutation.begin(), permutation.end());
std::array<OrtOpAttr*, 1> attributes = {
model_builder_.CreateAttribute(/*name=*/"perm", perm).Release()};

model_builder_.AddNode(kOpTypeTranspose, node_name, input_names, output_names,
attributes);
}

void GraphBuilderOrt::AddInput(uint64_t input_id) {
const mojom::Operand& operand = GetOperand(input_id);
std::string name = GetOperandNameById(input_id);
Expand Down Expand Up @@ -952,6 +986,123 @@ GraphBuilderOrt::AddExpandOperation(const mojom::Expand& expand) {
return base::ok();
}

[[nodiscard]] base::expected<void, mojom::ErrorPtr>
GraphBuilderOrt::AddDequantizeLinearOperation(
const mojom::DequantizeLinear& dequantize_linear) {
const std::string node_name =
GenerateNextOperationName(dequantize_linear.label);
std::string input_name =
GetOperandNameById(dequantize_linear.input_operand_id);
std::string scale_name =
GetOperandNameById(dequantize_linear.scale_operand_id);
std::string zero_point_name =
GetOperandNameById(dequantize_linear.zero_point_operand_id);
std::string output_name =
GetOperandNameById(dequantize_linear.output_operand_id);

const OperandDescriptor& input_descriptor =
GetOperand(dequantize_linear.input_operand_id).descriptor;
std::vector<uint32_t> input_shape = input_descriptor.shape();

const OperandDescriptor& scale_descriptor =
GetOperand(dequantize_linear.scale_operand_id).descriptor;
std::vector<uint32_t> scale_shape = scale_descriptor.shape();

int64_t axis = 1;
int64_t block_size = 0;
bool need_transpose = false;

// https://github.com/openvinotoolkit/openvino/blob/master/src/frontends/onnx/frontend/src/op/dequantize_linear.cpp#L220
if (scale_shape.size() > 2) {
return NewNotSupportedError(
"OpenVINO dequantizeLinear cannot operate with more than 2D scales");
}

if (scale_shape.empty()) {
// For per-tensor/layer dequantization the scale is a scalar.
axis = 0;
} else if (scale_shape.size() == 1) {
bool is_valid = false;
// for per per-axis dequantization it is a 1-D Tensor
for (size_t i = 0; i < input_shape.size(); i++) {
if (scale_shape[0] == input_shape[i]) {
axis = i;
is_valid = true;
}
}
if (!is_valid) {
return NewNotSupportedError(
"For 1D scale, the size of scale must be the same as the size of the "
"input dim specified by the axis.");
}
} else {
CHECK_EQ(scale_shape.size(), 2u);
// For blocked dequantization it has the same shape as the input, except for
// one dimension in which blocking is performed.
if (scale_shape.size() == input_shape.size()) {
uint32_t diff_count = 0;
for (size_t i = 0; i < input_shape.size(); i++) {
if (scale_shape[i] != input_shape[i]) {
// https://github.com/openvinotoolkit/openvino/blob/master/src/frontends/onnx/frontend/src/op/dequantize_linear.cpp#L230
if (input_shape[i] % scale_shape[i] != 0) {
return NewNotSupportedError(
"For blocked dequantization, OpenVINO DequantizeLinear doesn't "
"support case when input cannot be divided by scale.");
}
block_size = input_shape[i] / scale_shape[i];
axis = i;
diff_count++;
if (diff_count > 1) {
return NewNotSupportedError(
"For blocked dequantization it has the same shape as the "
"input, except for one dimension in which blocking is "
"performed");
}
}
}
// The shape of scale is the same as the shape of input.
if (diff_count == 0) {
axis = 0;
block_size = 1;
}

// Currently, OpenVINO only supports axis == 0 when scale.size == 2.
// https://github.com/openvinotoolkit/openvino/blob/master/src/frontends/onnx/frontend/src/op/dequantize_linear.cpp#L228.
if (axis != 0) {
input_name = PrependTranspose(input_name, {1, 0});
scale_name = PrependTranspose(scale_name, {1, 0});
zero_point_name = PrependTranspose(zero_point_name, {1, 0});
axis = 0;
need_transpose = true;
}
}
}

const std::string transposed_output_name =
need_transpose ? GenerateNextOperandName() : output_name;

base::FixedArray<const char*> input_names = {
input_name.c_str(), scale_name.c_str(), zero_point_name.c_str()};
base::FixedArray<const char*> output_names = {transposed_output_name.c_str()};

std::array<OrtOpAttr*, 2> attributes = {
model_builder_
.CreateAttribute(/*name=*/"axis", base::checked_cast<int64_t>(axis))
.Release(),
model_builder_
.CreateAttribute(/*name=*/"block_size",
base::checked_cast<int64_t>(block_size))
.Release()};

model_builder_.AddNode(kOpTypeDequantizeLinear, node_name, input_names,
output_names, attributes);

if (need_transpose) {
AppendTranspose(transposed_output_name, output_name, {1, 0});
}
return base::ok();
}

void GraphBuilderOrt::AddGatherOperation(const mojom::Gather& gather) {
const std::string node_name = GenerateNextOperationName(gather.label);
const std::string input_name = GetOperandNameById(gather.input_operand_id);
Expand Down Expand Up @@ -1780,6 +1931,11 @@ GraphBuilderOrt::BuildModel() {
RETURN_IF_ERROR(AddConv2dOperation(*operation->get_conv2d()));
break;
}
case mojom::Operation::Tag::kDequantizeLinear: {
RETURN_IF_ERROR(
AddDequantizeLinearOperation(*operation->get_dequantize_linear()));
break;
}
case mojom::Operation::Tag::kExpand: {
RETURN_IF_ERROR(AddExpandOperation(*operation->get_expand()));
break;
Expand Down Expand Up @@ -1863,7 +2019,6 @@ GraphBuilderOrt::BuildModel() {
break;
}
case mojom::Operation::Tag::kCumulativeSum:
case mojom::Operation::Tag::kDequantizeLinear:
case mojom::Operation::Tag::kElu:
case mojom::Operation::Tag::kGatherElements:
case mojom::Operation::Tag::kGatherNd:
Expand Down
10 changes: 10 additions & 0 deletions services/webnn/ort/graph_builder_ort.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ class GraphBuilderOrt {
std::string PrependCast(std::string_view input_name,
ONNXTensorElementDataType to_data_type);

std::string PrependTranspose(std::string_view input_name,
base::span<const uint32_t> permutation);

// Insert a cast operation after an operation to convert its output to the
// target `to_data_type`. The `input_name` specifies the cast operation's
// input (the output of the operation to be casted), and the `output_name`
Expand All @@ -139,6 +142,10 @@ class GraphBuilderOrt {
std::string_view output_name,
ONNXTensorElementDataType to_data_type);

void AppendTranspose(std::string_view input_name,
std::string_view output_name,
base::span<const uint32_t> permutation);

void AddInput(uint64_t input_id);
void AddOutput(uint64_t output_id);

Expand Down Expand Up @@ -175,6 +182,9 @@ class GraphBuilderOrt {
const mojom::Conv2d& conv2d);
[[nodiscard]] base::expected<void, mojom::ErrorPtr> AddExpandOperation(
const mojom::Expand& expand);
[[nodiscard]] base::expected<void, mojom::ErrorPtr>
AddDequantizeLinearOperation(
const mojom::DequantizeLinear& dequantize_linear);
void AddGatherOperation(const mojom::Gather& gather);
void AddGemmOperation(const mojom::Gemm& gemm);
[[nodiscard]] base::expected<void, mojom::ErrorPtr>
Expand Down

0 comments on commit 5f0efb1

Please sign in to comment.