Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement dequantizeLinear #134

Merged
merged 7 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions services/webnn/ort/context_impl_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ ContextProperties ContextImplOrt::GetContextProperties() {
static constexpr SupportedRanks kNonScalarMaxRank =
SupportedRanks::NonScalarUpTo(8);

static constexpr SupportedDataTypes kDequantizeLinearInputSupportedDataTypes{
OperandDataType::kInt4, OperandDataType::kUint4, OperandDataType::kUint8,
OperandDataType::kInt8, OperandDataType::kInt32};

return ContextProperties(
InputOperandLayout::kNchw, Resample2DAxes::kChannelsFirst,
/*tensor_byte_length_limit=*/kTensorByteLengthLimit,
Expand All @@ -74,8 +78,8 @@ ContextProperties ContextImplOrt::GetContextProperties() {
/*conv2d_input=*/DataTypeConstraint::kFloat16To32,
/*conv_transpose2d_input=*/DataTypeConstraint::kFloat16To32,
/*cumulative_sum_input=*/{},
/*dequantize_linear_input=*/{},
/*dequantize_linear_scale=*/{},
/*dequantize_linear_input=*/kDequantizeLinearInputSupportedDataTypes,
/*dequantize_linear_scale=*/DataTypeConstraint::kFloat16To32,
/*add_input=*/
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
/*sub_input=*/
Expand Down
198 changes: 197 additions & 1 deletion services/webnn/ort/graph_builder_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ constexpr char kOpTypeClamp[] = "Clip";
constexpr char kOpTypeConcat[] = "Concat";
constexpr char kOpTypeConv2d[] = "Conv";
constexpr char kOpTypeConvTranspose2d[] = "ConvTranspose";
constexpr char kOpTypeDequantizeLinear[] = "DequantizeLinear";
constexpr char kOpTypeExpand[] = "Expand";
constexpr char kOpTypeGather[] = "Gather";
constexpr char kOpTypeGelu[] = "Gelu";
Expand Down Expand Up @@ -300,6 +301,60 @@ void GraphBuilderOrt::AppendCast(std::string_view input_name,
ADD_CAST_NODE(node_name, input_name, output_name, to_data_type);
}

std::string GraphBuilderOrt::PrependTranspose(
std::string_view input_name,
base::span<const uint32_t> permutation) {
const std::string node_name = GenerateNextOperationName("inserted_transpose");
const std::string output_name = GenerateNextOperandName();

std::array<const char*, 1> input_names = {input_name.data()};
std::array<const char*, 1> output_names = {output_name.data()};

std::vector<int64_t> perm(permutation.begin(), permutation.end());
std::array<OrtOpAttr*, 1> attributes = {
model_builder_.CreateAttribute(/*name=*/"perm", perm).Release()};

model_builder_.AddNode(kOpTypeTranspose, node_name, input_names, output_names,
attributes);
return output_name;
}

void GraphBuilderOrt::AppendTranspose(std::string_view input_name,
std::string_view output_name,
base::span<const uint32_t> permutation) {
const std::string node_name = GenerateNextOperationName("inserted_transpose");
std::array<const char*, 1> input_names = {input_name.data()};
std::array<const char*, 1> output_names = {output_name.data()};

std::vector<int64_t> perm(permutation.begin(), permutation.end());
std::array<OrtOpAttr*, 1> attributes = {
model_builder_.CreateAttribute(/*name=*/"perm", perm).Release()};

model_builder_.AddNode(kOpTypeTranspose, node_name, input_names, output_names,
attributes);
}

[[nodiscard]] base::expected<std::string, mojom::ErrorPtr>
GraphBuilderOrt::PrependReshape(std::string_view input_name,
base::span<const int64_t> new_shape) {
const std::string node_name = GenerateNextOperationName("inserted_reshape");
const std::string output_name = GenerateNextOperandName();

// Shape is an operand with data type int64, not an attribute.
std::vector<uint32_t> new_shape_dims = {
base::checked_cast<uint32_t>(new_shape.size())};
ASSIGN_OR_RETURN(const std::string shape_name,
CreateInitializer<int64_t>(new_shape_dims, new_shape));

std::array<const char*, 2> input_names = {input_name.data(),
shape_name.c_str()};
std::array<const char*, 1> output_names = {output_name.c_str()};

model_builder_.AddNode(kOpTypeReshape, node_name, input_names, output_names);

return output_name;
}

void GraphBuilderOrt::AddInput(uint64_t input_id) {
const mojom::Operand& operand = GetOperand(input_id);
std::string name = GetOperandNameById(input_id);
Expand Down Expand Up @@ -952,6 +1007,143 @@ GraphBuilderOrt::AddExpandOperation(const mojom::Expand& expand) {
return base::ok();
}

[[nodiscard]] base::expected<void, mojom::ErrorPtr>
GraphBuilderOrt::AddDequantizeLinearOperation(
const mojom::DequantizeLinear& dequantize_linear) {
const std::string node_name =
GenerateNextOperationName(dequantize_linear.label);
std::string input_name =
GetOperandNameById(dequantize_linear.input_operand_id);
std::string scale_name =
GetOperandNameById(dequantize_linear.scale_operand_id);
std::string zero_point_name =
GetOperandNameById(dequantize_linear.zero_point_operand_id);
std::string output_name =
GetOperandNameById(dequantize_linear.output_operand_id);

const OperandDescriptor& input_descriptor =
GetOperand(dequantize_linear.input_operand_id).descriptor;
std::vector<uint32_t> input_shape = input_descriptor.shape();

const OperandDescriptor& scale_descriptor =
GetOperand(dequantize_linear.scale_operand_id).descriptor;
// ZeroPoint has the same shape as the scale.
std::vector<uint32_t> scale_shape = scale_descriptor.shape();

int64_t axis = 0;
int64_t block_size;
bool need_transpose = false;

uint32_t not_one_value_count = 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
uint32_t not_one_value_count = 0;
uint32_t not_one_value_dim_count = 0;

bool is_per_axis = false;
CHECK_LE(scale_shape.size(), input_shape.size());
for (size_t i = 0; i < scale_shape.size(); i++) {
if (scale_shape[scale_shape.size() - i - 1] != 1) {
not_one_value_count++;
if (not_one_value_count == 1 &&
scale_shape[scale_shape.size() - i - 1] ==
input_shape[input_shape.size() - i - 1]) {
axis = input_shape.size() - i - 1;
is_per_axis = true;
}
}
}

// Each value of scale_shape is 1.
if (not_one_value_count == 0) {
auto it = std::find(input_shape.begin(), input_shape.end(), 1);
if (it != input_shape.end()) {
axis = std::distance(input_shape.begin(), it);
is_per_axis = true;
}
}

if (scale_shape.empty()) {
// For per-tensor/layer dequantization the scale is a scalar.
axis = 0;
// block_size must be 0 for per-tensor quantization.
block_size = 0;
} else if (not_one_value_count <= 1 && is_per_axis) {
// for per per-axis dequantization, scale and zeroPoint must be a 1-D
// Tensor.
ASSIGN_OR_RETURN(scale_name,
PrependReshape(scale_name, {input_shape[axis]}));
ASSIGN_OR_RETURN(zero_point_name,
PrependReshape(zero_point_name, {input_shape[axis]}));
// block_size must be 0 for per-axis quantization.
block_size = 0;
} else {
// For blocked dequantization it has the same shape as the input, except for
// one dimension in which blocking is performed.
if (scale_shape.size() == input_shape.size()) {
uint32_t blocked_axis_count = 0;
for (size_t i = 0; i < input_shape.size(); i++) {
if (scale_shape[i] != input_shape[i]) {
block_size = input_shape[i] / scale_shape[i];
axis = i;
blocked_axis_count++;
// TODO(https://github.com/shiyi9801/chromium/issues/135): Consider to
// add emulation to support multi-dimensions blockwise.
if (blocked_axis_count > 1) {
return NewNotSupportedError(
"For blocked dequantization it has the same shape as the "
"input, except for one dimension in which blocking is "
"performed");
}
}
}
// The shape of scale is the same as the shape of input.
if (blocked_axis_count == 0) {
axis = 0;
block_size = 1;
}

// Currently, OpenVINO only supports axis == 0 when scale.size == 2.
// https://github.com/openvinotoolkit/openvino/blob/master/src/frontends/onnx/frontend/src/op/dequantize_linear.cpp#L228.
if (!base::CommandLine::ForCurrentProcess()->HasSwitch(
switches::kWebNNOrtUseOpenvino) &&
axis != 0) {
input_name = PrependTranspose(input_name, {1, 0});
scale_name = PrependTranspose(scale_name, {1, 0});
zero_point_name = PrependTranspose(zero_point_name, {1, 0});
axis = 0;
need_transpose = true;
}
} else {
// The proposal of requiring scale and zeroPoint to be the same rank as
// the input is under discussion-
// https://github.com/webmachinelearning/webnn/pull/805#discussion_r1919498405
return NewNotSupportedError(
"Currently, ONNX only supports per-tensor, per-axis and block-wise "
"dequantizeLinear");
}
}

const std::string transposed_output_name =
need_transpose ? GenerateNextOperandName() : output_name;

base::FixedArray<const char*> input_names = {
input_name.c_str(), scale_name.c_str(), zero_point_name.c_str()};
base::FixedArray<const char*> output_names = {transposed_output_name.c_str()};

std::array<OrtOpAttr*, 2> attributes = {
model_builder_
.CreateAttribute(/*name=*/"axis", base::checked_cast<int64_t>(axis))
.Release(),
model_builder_
.CreateAttribute(/*name=*/"block_size",
base::checked_cast<int64_t>(block_size))
.Release()};

model_builder_.AddNode(kOpTypeDequantizeLinear, node_name, input_names,
output_names, attributes);

if (need_transpose) {
AppendTranspose(transposed_output_name, output_name, {1, 0});
}
return base::ok();
}

void GraphBuilderOrt::AddGatherOperation(const mojom::Gather& gather) {
const std::string node_name = GenerateNextOperationName(gather.label);
const std::string input_name = GetOperandNameById(gather.input_operand_id);
Expand Down Expand Up @@ -1780,6 +1972,11 @@ GraphBuilderOrt::BuildModel() {
RETURN_IF_ERROR(AddConv2dOperation(*operation->get_conv2d()));
break;
}
case mojom::Operation::Tag::kDequantizeLinear: {
RETURN_IF_ERROR(
AddDequantizeLinearOperation(*operation->get_dequantize_linear()));
break;
}
case mojom::Operation::Tag::kExpand: {
RETURN_IF_ERROR(AddExpandOperation(*operation->get_expand()));
break;
Expand Down Expand Up @@ -1863,7 +2060,6 @@ GraphBuilderOrt::BuildModel() {
break;
}
case mojom::Operation::Tag::kCumulativeSum:
case mojom::Operation::Tag::kDequantizeLinear:
case mojom::Operation::Tag::kElu:
case mojom::Operation::Tag::kGatherElements:
case mojom::Operation::Tag::kGatherNd:
Expand Down
14 changes: 14 additions & 0 deletions services/webnn/ort/graph_builder_ort.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ class GraphBuilderOrt {
std::string PrependCast(std::string_view input_name,
ONNXTensorElementDataType to_data_type);

[[nodiscard]] base::expected<std::string, mojom::ErrorPtr> PrependReshape(
std::string_view input_name,
base::span<const int64_t> new_shape);

std::string PrependTranspose(std::string_view input_name,
base::span<const uint32_t> permutation);

// Insert a cast operation after an operation to convert its output to the
// target `to_data_type`. The `input_name` specifies the cast operation's
// input (the output of the operation to be casted), and the `output_name`
Expand All @@ -139,6 +146,10 @@ class GraphBuilderOrt {
std::string_view output_name,
ONNXTensorElementDataType to_data_type);

void AppendTranspose(std::string_view input_name,
std::string_view output_name,
base::span<const uint32_t> permutation);

void AddInput(uint64_t input_id);
void AddOutput(uint64_t output_id);

Expand Down Expand Up @@ -175,6 +186,9 @@ class GraphBuilderOrt {
const mojom::Conv2d& conv2d);
[[nodiscard]] base::expected<void, mojom::ErrorPtr> AddExpandOperation(
const mojom::Expand& expand);
[[nodiscard]] base::expected<void, mojom::ErrorPtr>
AddDequantizeLinearOperation(
const mojom::DequantizeLinear& dequantize_linear);
void AddGatherOperation(const mojom::Gather& gather);
void AddGemmOperation(const mojom::Gemm& gemm);
[[nodiscard]] base::expected<void, mojom::ErrorPtr>
Expand Down