Skip to content

Commit

Permalink
[QNN] Implement 'qnn.softmax' (apache#14536)
Browse files Browse the repository at this point in the history
* [QNN] Implement 'qnn.softmax'

Co-authored-by: Toshiki Maekawa <[email protected]>

* Disable fq2i for nn.softmax by default

Co-authored-by: Toshiki Maekawa <[email protected]>

* Add assertion for input scale

Co-authored-by: Toshiki Maekawa <[email protected]>

* Use clip to prevent too large bitshift

Co-authored-by: Toshiki Maekawa <[email protected]>

* Test multiple input scales

Co-authored-by: Toshiki Maekawa <[email protected]>

* Follow linter

Co-authored-by: Toshiki Maekawa <[email protected]>

* Add comment

Co-authored-by: Toshiki Maekawa <[email protected]>

---------

Co-authored-by: Toshiki Maekawa <[email protected]>
  • Loading branch information
maekawatoshiki and Toshiki Maekawa authored May 15, 2023
1 parent 0274930 commit eb1ea97
Show file tree
Hide file tree
Showing 9 changed files with 311 additions and 23 deletions.
1 change: 1 addition & 0 deletions python/tvm/relay/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
debug,
register_external_compiler,
register_fake_quantization_to_integer,
register_optional_fake_quantization_to_integer,
register_mixed_precision_conversion,
)
from . import strategy
Expand Down
21 changes: 21 additions & 0 deletions python/tvm/relay/op/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,27 @@ def register_fake_quantization_to_integer(op_name, func=None, level=10):
return tvm.ir.register_op_attr(op_name, "FTVMFakeQuantizationToInteger", func, level)


def register_optional_fake_quantization_to_integer(op_name, func=None, level=10):
"""Register optional quantize function for an op
Given an op and Affine Types on it's inputs, this function should return the op
in affine space/integer operators and the new type of the output, where affine
denotes the transformation x_real = (x_affine - zero_point) * scale
Parameters
----------
op_name : str
The name of the operator
func: function (expr: Expr, map: Map<Expr, AffineType>) -> new_expr: Expr
The function for translating the op into affine space and integer operators
level : int
The priority level
"""
return tvm.ir.register_op_attr(op_name, "FTVMOptionalFakeQuantizationToInteger", func, level)


def register_mixed_precision_conversion(op_name, func=None, level=10):
"""Register mixed precision conversion function for an op
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,3 +1245,7 @@ def leaky_relu(x, alpha, input_scale, input_zero_point, output_scale, output_zer
return _make.leaky_relu(
x, alpha, input_scale, input_zero_point, output_scale, output_zero_point
)


def softmax(x, scale, zero_point, output_scale, output_zero_point, axis=-1):
return _make.softmax(x, axis, scale, zero_point, output_scale, output_zero_point)
18 changes: 17 additions & 1 deletion python/tvm/relay/transform/fake_quantization_to_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
from tvm.relay.qnn.op import canonicalizations
from tvm.tir import bijective_layout

from ..op import register_fake_quantization_to_integer
from ..op import (
register_fake_quantization_to_integer,
register_optional_fake_quantization_to_integer,
)


def fold_constant(expr):
Expand Down Expand Up @@ -619,3 +622,16 @@ def take(expr, type_map):

out = relay.op.take(arg, indices, **expr.attrs)
return [out, t]


@register_optional_fake_quantization_to_integer("nn.softmax")
def softmax(expr, type_map):
"""Rewrite a softmax op"""
arg = expr.args[0]
arg_t = type_map[arg]
out_t = type_map[expr]

out = relay.qnn.op.softmax(
arg, arg_t.scale, arg_t.zero_point, out_t.scale, out_t.zero_point, **expr.attrs
)
return [out, out_t]
11 changes: 9 additions & 2 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,7 +1251,7 @@ def AnnotateSpans():
return _ffi_api.AnnotateSpans()


def FakeQuantizationToInteger(hard_fail=False, use_qat=False):
def FakeQuantizationToInteger(hard_fail=False, use_qat=False, optional_qnn_ops=None):
# pylint: disable=anomalous-backslash-in-string
"""
Find regions of the graph of the form
Expand Down Expand Up @@ -1298,12 +1298,19 @@ def FakeQuantizationToInteger(hard_fail=False, use_qat=False):
|
q
optional_qnn_ops : List[str]
Specify a list of operator names to explicitly enable conversion for
specific ops disabled by default.
Example: ['nn.softmax']
Returns
-------
ret : tvm.transform.Pass
The registered FakeQuantizationToInteger pass.
"""
return _ffi_api.FakeQuantizationToInteger(hard_fail, use_qat)
if optional_qnn_ops is None:
optional_qnn_ops = []
return _ffi_api.FakeQuantizationToInteger(hard_fail, use_qat, optional_qnn_ops)


def FlattenAtrousConv():
Expand Down
154 changes: 154 additions & 0 deletions src/relay/qnn/op/softmax.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relay/qnn/op/softmax.cc
* \brief QNN softmax operator.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>

#include "op_common.h"
#include "tvm/ir/expr.h"
#include "tvm/relay/attrs/nn.h"
#include "tvm/relay/type.h"
#include "tvm/runtime/data_type.h"
#include "tvm/runtime/logging.h"
#include "tvm/topi/reduction.h"

namespace tvm {
namespace relay {
namespace qnn {

bool QnnSoftmaxRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// Expected Types: input, scale, zero_point, output_scale, output_zero_point, output
ICHECK_EQ(types.size(), 6);
const auto* x = types[0].as<TensorTypeNode>();
if (x == nullptr) return false;
ICHECK(x->dtype == DataType::Int(8))
<< "Expected quantized softmax type(int8) for input but was " << x->dtype;

// Check the types of scale and zero points.
for (size_t i = 1; i < 5; ++i) {
if (types[i].as<IncompleteTypeNode>()) {
return false;
}
}

ICHECK(IsScalarType(types[1], DataType::Float(32))); // scale
ICHECK(IsScalarType(types[2], DataType::Int(32))); // zero_point
ICHECK(IsScalarType(types[3], DataType::Float(32))); // scale
ICHECK(IsScalarType(types[4], DataType::Int(32))); // zero_point

// Assign types for scale and zero points.
reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // scale
reporter->Assign(types[2], TensorType({}, DataType::Int(32))); // zero_point
reporter->Assign(types[3], TensorType({}, DataType::Float(32))); // scale
reporter->Assign(types[4], TensorType({}, DataType::Int(32))); // zero_point

// Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
// IdentityRel infer type function.
Array<Type> tensor_types = {types[0], types[5]};
return IdentityRel(tensor_types, 2, attrs, reporter);
}

// Positional relay function to create quantized softmax operator used by frontend FFI.
Expr MakeQuantizedSoftmax(Expr x, int axis, Expr scale, Expr zero_point, Expr output_scale,
Expr output_zero_point) {
auto attrs = make_object<SoftmaxAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("qnn.softmax");
return Call(op, {x, scale, zero_point, output_scale, output_zero_point}, Attrs(attrs), {});
}

/*
* \brief Canonicalizes the QNN softmax op.
* \param attrs The Softmax attrs.
* \param new_args The new mutated args to the call node.
* \param arg_types The types of input and output.
* \return The sequence of Relay ops for softmax op.
* \note This op is highly experimental and sometimes lacks accuracy.
* Be aware that the input scale must be in the range of 0 to 1.
*/
Expr QnnSoftmaxCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) {
// Expected: input, scale, zero_point, output_scale, output_zero_point
ICHECK_EQ(new_args.size(), 5);

const auto const_i32 = [&](int32_t val) { return MakeConstantScalar(DataType::Int(32), val); };
const auto const_f32 = [&](float val) { return MakeConstantScalar(DataType::Float(32), val); };

const auto const_input_scale = new_args[1].as<ConstantNode>();
ICHECK(const_input_scale) << "Input scale should be constant.";
ICHECK(const_input_scale->is_scalar()) << "Input scale should be scalar.";
const float input_scale = static_cast<float*>(const_input_scale->data->data)[0];
ICHECK(input_scale <= 1.f) << "Input scale should be less than or equal to 1.";

const Expr input_zero_point = new_args[2];
const Expr output_scale = new_args[3];
const Expr output_zero_point = new_args[4];
const int axis = attrs.as<SoftmaxAttrs>()->axis;

// Refer to the Algorithm 1 in https://arxiv.org/pdf/2207.01405.pdf

const Expr quantized_data = Subtract(Cast(new_args[0], DataType::Int(32)), input_zero_point);

const Expr x_0 = ConvertDtype(const_f32(std::round(1.f / input_scale)), DataType::Int(32));
const Expr max = Max(quantized_data, {axis}, true, false);
const Expr x = Subtract(quantized_data, max);

const int m = 30;
const int bits = 8;
const Expr x_p = Subtract(Add(x, RightShift(x, const_i32(1))), RightShift(x, const_i32(4)));
const Expr q = Clip(Divide(x_p, Negative(x_0)), 0, 20);
const Expr max_q = Max(q, {axis}, true, false);
const Expr r = Subtract(x_p, Multiply(q, Negative(x_0)));
const Expr x_b = Add(RightShift(r, const_i32(1)), x_0);
const Expr exps = LeftShift(x_b, Subtract(max_q, q));
const Expr sums = Sum(exps, {axis}, true, false);
const Expr output =
RightShift(Multiply(Divide(const_i32(1 << m), sums), exps), const_i32(m - (bits - 1)));
const Expr requantized = Requantize(output, arg_types[0].as<TensorTypeNode>()->shape,
const_f32(1.f / (1 << (bits - 1))), const_i32(0),
output_scale, output_zero_point, DataType::Int(bits), 0);

return requantized;
}

RELAY_REGISTER_OP("qnn.softmax")
.describe("Softmax for quantized tensors.")
.set_attrs_type<SoftmaxAttrs>()
.set_num_inputs(5)
.add_argument("data", "Quantized Tensor", "The input data.")
.add_argument("scale", "Tensor", "The quantization scale of the input tensor.")
.add_argument("zero_point", "Tensor", "The quantization zero_point of the input tensor.")
.add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.")
.add_argument("output_zero_point", "Tensor",
"The quantization zero_point of the output tensor.")
.set_support_level(11)
.add_type_rel("QSoftmax", QnnSoftmaxRel)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnSoftmaxCanonicalize);

TVM_REGISTER_GLOBAL("relay.qnn.op._make.softmax").set_body_typed(MakeQuantizedSoftmax);

} // namespace qnn
} // namespace relay
} // namespace tvm
Loading

0 comments on commit eb1ea97

Please sign in to comment.