From eb1ea972b3cc2f7199cd161e0e0decd8d0ed074a Mon Sep 17 00:00:00 2001 From: uint256_t Date: Tue, 16 May 2023 08:01:11 +0900 Subject: [PATCH] [QNN] Implement 'qnn.softmax' (#14536) * [QNN] Implement 'qnn.softmax' Co-authored-by: Toshiki Maekawa * Disable fq2i for nn.softmax by default Co-authored-by: Toshiki Maekawa * Add assertion for input scale Co-authored-by: Toshiki Maekawa * Use clip to prevent too large bitshift Co-authored-by: Toshiki Maekawa * Test multiple input scales Co-authored-by: Toshiki Maekawa * Follow linter Co-authored-by: Toshiki Maekawa * Add comment Co-authored-by: Toshiki Maekawa --------- Co-authored-by: Toshiki Maekawa --- python/tvm/relay/op/__init__.py | 1 + python/tvm/relay/op/op.py | 21 +++ python/tvm/relay/qnn/op/qnn.py | 4 + .../transform/fake_quantization_to_integer.py | 18 +- python/tvm/relay/transform/transform.py | 11 +- src/relay/qnn/op/softmax.cc | 154 ++++++++++++++++++ .../fake_quantization_to_integer.cc | 73 ++++++--- src/relay/transforms/pattern_utils.h | 4 + .../test_pass_fake_quantization_to_integer.py | 48 ++++++ 9 files changed, 311 insertions(+), 23 deletions(-) create mode 100644 src/relay/qnn/op/softmax.cc diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 825bd1f627ca..9a996838c46e 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -29,6 +29,7 @@ debug, register_external_compiler, register_fake_quantization_to_integer, + register_optional_fake_quantization_to_integer, register_mixed_precision_conversion, ) from . import strategy diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 5f37845cebf0..d897a68f2056 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -475,6 +475,27 @@ def register_fake_quantization_to_integer(op_name, func=None, level=10): return tvm.ir.register_op_attr(op_name, "FTVMFakeQuantizationToInteger", func, level) +def register_optional_fake_quantization_to_integer(op_name, func=None, level=10): + """Register optional quantize function for an op + + Given an op and Affine Types on it's inputs, this function should return the op + in affine space/integer operators and the new type of the output, where affine + denotes the transformation x_real = (x_affine - zero_point) * scale + + Parameters + ---------- + op_name : str + The name of the operator + + func: function (expr: Expr, map: Map) -> new_expr: Expr + The function for translating the op into affine space and integer operators + + level : int + The priority level + """ + return tvm.ir.register_op_attr(op_name, "FTVMOptionalFakeQuantizationToInteger", func, level) + + def register_mixed_precision_conversion(op_name, func=None, level=10): """Register mixed precision conversion function for an op diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index e2c251ec7850..0e73a6889fcd 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -1245,3 +1245,7 @@ def leaky_relu(x, alpha, input_scale, input_zero_point, output_scale, output_zer return _make.leaky_relu( x, alpha, input_scale, input_zero_point, output_scale, output_zero_point ) + + +def softmax(x, scale, zero_point, output_scale, output_zero_point, axis=-1): + return _make.softmax(x, axis, scale, zero_point, output_scale, output_zero_point) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 82255c5663be..4c9a3f7cd09c 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -25,7 +25,10 @@ from tvm.relay.qnn.op import canonicalizations from tvm.tir import bijective_layout -from ..op import register_fake_quantization_to_integer +from ..op import ( + register_fake_quantization_to_integer, + register_optional_fake_quantization_to_integer, +) def fold_constant(expr): @@ -619,3 +622,16 @@ def take(expr, type_map): out = relay.op.take(arg, indices, **expr.attrs) return [out, t] + + +@register_optional_fake_quantization_to_integer("nn.softmax") +def softmax(expr, type_map): + """Rewrite a softmax op""" + arg = expr.args[0] + arg_t = type_map[arg] + out_t = type_map[expr] + + out = relay.qnn.op.softmax( + arg, arg_t.scale, arg_t.zero_point, out_t.scale, out_t.zero_point, **expr.attrs + ) + return [out, out_t] diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 4c609620cbb7..b8af0518b29c 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1251,7 +1251,7 @@ def AnnotateSpans(): return _ffi_api.AnnotateSpans() -def FakeQuantizationToInteger(hard_fail=False, use_qat=False): +def FakeQuantizationToInteger(hard_fail=False, use_qat=False, optional_qnn_ops=None): # pylint: disable=anomalous-backslash-in-string """ Find regions of the graph of the form @@ -1298,12 +1298,19 @@ def FakeQuantizationToInteger(hard_fail=False, use_qat=False): | q + optional_qnn_ops : List[str] + Specify a list of operator names to explicitly enable conversion for + specific ops disabled by default. + Example: ['nn.softmax'] + Returns ------- ret : tvm.transform.Pass The registered FakeQuantizationToInteger pass. """ - return _ffi_api.FakeQuantizationToInteger(hard_fail, use_qat) + if optional_qnn_ops is None: + optional_qnn_ops = [] + return _ffi_api.FakeQuantizationToInteger(hard_fail, use_qat, optional_qnn_ops) def FlattenAtrousConv(): diff --git a/src/relay/qnn/op/softmax.cc b/src/relay/qnn/op/softmax.cc new file mode 100644 index 000000000000..f848ba9384e3 --- /dev/null +++ b/src/relay/qnn/op/softmax.cc @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/qnn/op/softmax.cc + * \brief QNN softmax operator. + */ +#include +#include + +#include "op_common.h" +#include "tvm/ir/expr.h" +#include "tvm/relay/attrs/nn.h" +#include "tvm/relay/type.h" +#include "tvm/runtime/data_type.h" +#include "tvm/runtime/logging.h" +#include "tvm/topi/reduction.h" + +namespace tvm { +namespace relay { +namespace qnn { + +bool QnnSoftmaxRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // Expected Types: input, scale, zero_point, output_scale, output_zero_point, output + ICHECK_EQ(types.size(), 6); + const auto* x = types[0].as(); + if (x == nullptr) return false; + ICHECK(x->dtype == DataType::Int(8)) + << "Expected quantized softmax type(int8) for input but was " << x->dtype; + + // Check the types of scale and zero points. + for (size_t i = 1; i < 5; ++i) { + if (types[i].as()) { + return false; + } + } + + ICHECK(IsScalarType(types[1], DataType::Float(32))); // scale + ICHECK(IsScalarType(types[2], DataType::Int(32))); // zero_point + ICHECK(IsScalarType(types[3], DataType::Float(32))); // scale + ICHECK(IsScalarType(types[4], DataType::Int(32))); // zero_point + + // Assign types for scale and zero points. + reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // scale + reporter->Assign(types[2], TensorType({}, DataType::Int(32))); // zero_point + reporter->Assign(types[3], TensorType({}, DataType::Float(32))); // scale + reporter->Assign(types[4], TensorType({}, DataType::Int(32))); // zero_point + + // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay + // IdentityRel infer type function. + Array tensor_types = {types[0], types[5]}; + return IdentityRel(tensor_types, 2, attrs, reporter); +} + +// Positional relay function to create quantized softmax operator used by frontend FFI. +Expr MakeQuantizedSoftmax(Expr x, int axis, Expr scale, Expr zero_point, Expr output_scale, + Expr output_zero_point) { + auto attrs = make_object(); + attrs->axis = axis; + static const Op& op = Op::Get("qnn.softmax"); + return Call(op, {x, scale, zero_point, output_scale, output_zero_point}, Attrs(attrs), {}); +} + +/* + * \brief Canonicalizes the QNN softmax op. + * \param attrs The Softmax attrs. + * \param new_args The new mutated args to the call node. + * \param arg_types The types of input and output. + * \return The sequence of Relay ops for softmax op. + * \note This op is highly experimental and sometimes lacks accuracy. + * Be aware that the input scale must be in the range of 0 to 1. + */ +Expr QnnSoftmaxCanonicalize(const Attrs& attrs, const Array& new_args, + const Array& arg_types) { + // Expected: input, scale, zero_point, output_scale, output_zero_point + ICHECK_EQ(new_args.size(), 5); + + const auto const_i32 = [&](int32_t val) { return MakeConstantScalar(DataType::Int(32), val); }; + const auto const_f32 = [&](float val) { return MakeConstantScalar(DataType::Float(32), val); }; + + const auto const_input_scale = new_args[1].as(); + ICHECK(const_input_scale) << "Input scale should be constant."; + ICHECK(const_input_scale->is_scalar()) << "Input scale should be scalar."; + const float input_scale = static_cast(const_input_scale->data->data)[0]; + ICHECK(input_scale <= 1.f) << "Input scale should be less than or equal to 1."; + + const Expr input_zero_point = new_args[2]; + const Expr output_scale = new_args[3]; + const Expr output_zero_point = new_args[4]; + const int axis = attrs.as()->axis; + + // Refer to the Algorithm 1 in https://arxiv.org/pdf/2207.01405.pdf + + const Expr quantized_data = Subtract(Cast(new_args[0], DataType::Int(32)), input_zero_point); + + const Expr x_0 = ConvertDtype(const_f32(std::round(1.f / input_scale)), DataType::Int(32)); + const Expr max = Max(quantized_data, {axis}, true, false); + const Expr x = Subtract(quantized_data, max); + + const int m = 30; + const int bits = 8; + const Expr x_p = Subtract(Add(x, RightShift(x, const_i32(1))), RightShift(x, const_i32(4))); + const Expr q = Clip(Divide(x_p, Negative(x_0)), 0, 20); + const Expr max_q = Max(q, {axis}, true, false); + const Expr r = Subtract(x_p, Multiply(q, Negative(x_0))); + const Expr x_b = Add(RightShift(r, const_i32(1)), x_0); + const Expr exps = LeftShift(x_b, Subtract(max_q, q)); + const Expr sums = Sum(exps, {axis}, true, false); + const Expr output = + RightShift(Multiply(Divide(const_i32(1 << m), sums), exps), const_i32(m - (bits - 1))); + const Expr requantized = Requantize(output, arg_types[0].as()->shape, + const_f32(1.f / (1 << (bits - 1))), const_i32(0), + output_scale, output_zero_point, DataType::Int(bits), 0); + + return requantized; +} + +RELAY_REGISTER_OP("qnn.softmax") + .describe("Softmax for quantized tensors.") + .set_attrs_type() + .set_num_inputs(5) + .add_argument("data", "Quantized Tensor", "The input data.") + .add_argument("scale", "Tensor", "The quantization scale of the input tensor.") + .add_argument("zero_point", "Tensor", "The quantization zero_point of the input tensor.") + .add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.") + .add_argument("output_zero_point", "Tensor", + "The quantization zero_point of the output tensor.") + .set_support_level(11) + .add_type_rel("QSoftmax", QnnSoftmaxRel) + .set_attr("TNonComputational", true) + .set_attr("FTVMQnnCanonicalize", QnnSoftmaxCanonicalize); + +TVM_REGISTER_GLOBAL("relay.qnn.op._make.softmax").set_body_typed(MakeQuantizedSoftmax); + +} // namespace qnn +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/fake_quantization_to_integer.cc b/src/relay/transforms/fake_quantization_to_integer.cc index 31353d5aa25e..b767924ecd1f 100644 --- a/src/relay/transforms/fake_quantization_to_integer.cc +++ b/src/relay/transforms/fake_quantization_to_integer.cc @@ -163,8 +163,12 @@ void SubgraphExtractor::VisitExpr_(const CallNode* call_node) { class SubgraphMutator : public ExprMutator { public: - SubgraphMutator(ExprSet subgraph, AffineTypeMap affine_types, bool hard_fail) - : subgraph_(subgraph), affine_types_(affine_types), hard_fail_(hard_fail) {} + SubgraphMutator(ExprSet subgraph, AffineTypeMap affine_types, bool hard_fail, + const std::unordered_set& optional_qnn_ops) + : subgraph_(subgraph), + affine_types_(affine_types), + hard_fail_(hard_fail), + optional_qnn_ops_(optional_qnn_ops) {} Expr MutateSubgraph(const Expr& expr) { if (subgraph_.size() == 0) { @@ -176,9 +180,14 @@ class SubgraphMutator : public ExprMutator { out_type_ = affine_types_[expr]; static auto fqfq = Op::GetAttrMap("FTVMFakeQuantizationToInteger"); + static auto opt_fqfq = + Op::HasAttrMap("FTVMOptionalFakeQuantizationToInteger") + ? Op::GetAttrMap("FTVMOptionalFakeQuantizationToInteger") + : fqfq; for (auto node : subgraph_) { const Op op = Downcast(node.as()->op); - if (!fqfq.count(Downcast(op))) { + if (!fqfq.count(Downcast(op)) && + !(optional_qnn_ops_.count(op->name) && opt_fqfq.count(Downcast(op)))) { // Only modify the subgraph if we have translation // rules for every op if (hard_fail_) { @@ -207,8 +216,12 @@ class SubgraphMutator : public ExprMutator { static auto fqfq = Op::GetAttrMap("FTVMFakeQuantizationToInteger"); + static auto opt_fqfq = + Op::HasAttrMap("FTVMOptionalFakeQuantizationToInteger") + ? Op::GetAttrMap("FTVMOptionalFakeQuantizationToInteger") + : fqfq; Op op = Downcast(call_node->op); - if (fqfq.count(op)) { + if (fqfq.count(op) || (optional_qnn_ops_.count(op->name) && opt_fqfq.count(op))) { Expr expr; if (op == dequantize_op_) { expr = GetRef(call_node); @@ -219,7 +232,7 @@ class SubgraphMutator : public ExprMutator { affine_types_.Set(expr, out_type_); } // Call the rewrite - Array vals = fqfq[op](expr, affine_types_); + Array vals = (fqfq.count(op) ? fqfq : opt_fqfq)[op](expr, affine_types_); // Save the outputs of the rewrite ICHECK(vals.size() == 2) << "got the wrong number of returned arguments from FTVMFakeQuantizationToInteger for " @@ -256,13 +269,16 @@ class SubgraphMutator : public ExprMutator { AffineTypeMap affine_types_; AffineType out_type_; const bool hard_fail_; + const std::unordered_set& optional_qnn_ops_; const Op quantize_op_ = Op::Get("qnn.quantize"); const Op dequantize_op_ = Op::Get("qnn.dequantize"); }; class FakeQuantizationRewriter : public MixedModeMutator { public: - explicit FakeQuantizationRewriter(bool hard_fail) : hard_fail_(hard_fail) {} + explicit FakeQuantizationRewriter(bool hard_fail, + const std::unordered_set& optional_qnn_ops) + : hard_fail_(hard_fail), optional_qnn_ops_(optional_qnn_ops) {} protected: Expr Rewrite_(const CallNode* pre, const Expr& post) override { @@ -286,8 +302,8 @@ class FakeQuantizationRewriter : public MixedModeMutator { for (auto expr : subgraph) { post_subgraph.insert(memo_[expr]); } - Expr out = - SubgraphMutator(post_subgraph, post_affine_types, hard_fail_).MutateSubgraph(post); + Expr out = SubgraphMutator(post_subgraph, post_affine_types, hard_fail_, optional_qnn_ops_) + .MutateSubgraph(post); return out; } } @@ -295,6 +311,7 @@ class FakeQuantizationRewriter : public MixedModeMutator { } const Op quantize_op_ = Op::Get("qnn.quantize"); const bool hard_fail_; + const std::unordered_set& optional_qnn_ops_; }; /* Checks if the operation to convert QAT pass is enabled. @@ -404,8 +421,12 @@ class QATSubgraphExtractor : public ExprVisitor { class QATSubgraphMutator : public ExprMutator { public: - QATSubgraphMutator(ExprSet subgraph, AffineTypeMap affine_types, bool hard_fail) - : subgraph_(subgraph), affine_types_(affine_types), hard_fail_(hard_fail) {} + QATSubgraphMutator(ExprSet subgraph, AffineTypeMap affine_types, bool hard_fail, + const std::unordered_set& optional_qnn_ops) + : subgraph_(subgraph), + affine_types_(affine_types), + hard_fail_(hard_fail), + optional_qnn_ops_(optional_qnn_ops) {} Expr MutateSubgraph(const Expr& expr) { if (subgraph_.size() == 0) { @@ -447,9 +468,13 @@ class QATSubgraphMutator : public ExprMutator { Expr out; static auto fqfq = Op::GetAttrMap("FTVMFakeQuantizationToInteger"); + static auto opt_fqfq = + Op::HasAttrMap("FTVMOptionalFakeQuantizationToInteger") + ? Op::GetAttrMap("FTVMOptionalFakeQuantizationToInteger") + : fqfq; Op op = Downcast(call_node->op); - if (fqfq.count(op)) { + if (fqfq.count(op) || (optional_qnn_ops_.count(op->name) && opt_fqfq.count(op))) { Expr expr; if (op == dequantize_op_) { expr = GetRef(call_node); @@ -457,7 +482,7 @@ class QATSubgraphMutator : public ExprMutator { expr = ExprMutator::VisitExpr_(call_node); } // Call the rewrite - Array vals = fqfq[op](expr, affine_types_); + Array vals = (fqfq.count(op) ? fqfq : opt_fqfq)[op](expr, affine_types_); // Save the outputs of the rewrite ICHECK(vals.size() == 2) << "got the wrong number of returned arguments from FTVMFakeQuantizationToInteger for " @@ -500,13 +525,15 @@ class QATSubgraphMutator : public ExprMutator { ExprSet subgraph_; AffineTypeMap affine_types_; const bool hard_fail_; + const std::unordered_set& optional_qnn_ops_; const Op dequantize_op_ = Op::Get("qnn.dequantize"); const CallNode* quantize_node_ = nullptr; }; class QATRewriter : public MixedModeMutator { public: - explicit QATRewriter(bool hard_fail) : hard_fail_(hard_fail) {} + explicit QATRewriter(bool hard_fail, const std::unordered_set& optional_qnn_ops) + : hard_fail_(hard_fail), optional_qnn_ops_(optional_qnn_ops) {} protected: Expr Rewrite_(const CallNode* pre, const Expr& post) override { @@ -516,31 +543,37 @@ class QATRewriter : public MixedModeMutator { QATSubgraphExtractor extractor; ExprSet subgraph = extractor.GetSubgraph(post); AffineTypeMap affine_types = extractor.GetAffineTypes(); - Expr out = QATSubgraphMutator(subgraph, affine_types, hard_fail_).MutateSubgraph(post); + Expr out = QATSubgraphMutator(subgraph, affine_types, hard_fail_, optional_qnn_ops_) + .MutateSubgraph(post); return out; } } return post; } const bool hard_fail_; + const std::unordered_set& optional_qnn_ops_; }; -Expr FakeQuantizationToInteger(const Expr& expr, const IRModule& mod, bool hard_fail, - bool use_qat) { - auto fq_expr = FakeQuantizationRewriter(hard_fail).Mutate(expr); +Expr FakeQuantizationToInteger(const Expr& expr, const IRModule& mod, bool hard_fail, bool use_qat, + const Array& optional_qnn_ops) { + const std::unordered_set optional_qnn_ops_(optional_qnn_ops.begin(), + optional_qnn_ops.end()); + auto fq_expr = FakeQuantizationRewriter(hard_fail, optional_qnn_ops_).Mutate(expr); if (use_qat) { fq_expr = tvm::relay::InferType(fq_expr); - fq_expr = QATRewriter(hard_fail).Mutate(fq_expr); + fq_expr = QATRewriter(hard_fail, optional_qnn_ops_).Mutate(fq_expr); } return fq_expr; } namespace transform { -Pass FakeQuantizationToInteger(bool hard_fail, bool use_qat) { +Pass FakeQuantizationToInteger(bool hard_fail, bool use_qat, + const Array& optional_qnn_ops) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return Downcast(FakeQuantizationToInteger(f, m, hard_fail, use_qat)); + return Downcast( + FakeQuantizationToInteger(f, m, hard_fail, use_qat, optional_qnn_ops)); }; return CreateFunctionPass(pass_func, 0, "FakeQuantizationToInteger", {"InferType", "DivToMul"}); } diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index aa4ef03c95a4..50c2e0029885 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -770,6 +770,10 @@ inline Expr Copy(Expr data) { return Call(op, {data}, Attrs(), {}); } +inline Expr Max(Expr data, Array axis, bool keepdims, bool exclude) { + return MakeReduce(data, axis, keepdims, exclude, "max"); +} + inline Expr Mean(Expr data, Array axis, bool keepdims, bool exclude) { return MakeReduce(data, axis, keepdims, exclude, "mean"); } diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index f349e0979395..3425a9a72b9b 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -1114,5 +1114,53 @@ def test_fake_quantize_take(): compare_fq_to_int(op, [x_np]) +def test_fake_quantize_softmax(): + shape = [5, 10] + x_ = relay.var("x", shape=shape, dtype="int8") + + is_sorted = lambda a: np.all(a[:-1] <= a[1:]) + + for scale in [1.0, 0.1, 0.01]: + x = relay.qnn.op.dequantize(x_, relay.const(scale), relay.const(0)) + op = relay.op.nn.softmax(x, axis=1) + op = relay.qnn.op.quantize( + op, relay.const(1.0 / 256.0), relay.const(-128), out_dtype="int8" + ) + + x_np = np.random.randint(-128, 127, size=shape, dtype="int8") + x_np = np.sort(x_np) + args = [x_np] + + mod = tvm.IRModule.from_expr(op) + mod = tvm.relay.transform.InferType()(mod) + mod_int = tvm.relay.transform.FakeQuantizationToInteger( + hard_fail=True, optional_qnn_ops=["nn.softmax"] + )(mod) + assert not tvm.ir.structural_equal(mod, mod_int) + + result = ( + relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + .evaluate()(*args) + .numpy() + ) + result_int = ( + relay.create_executor("vm", mod=mod_int, device=tvm.cpu(), target="llvm") + .evaluate()(*args) + .numpy() + ) + + # Check at least the softmax output is in ascending order, + # since it is difficult to use allclose due to not-so-good accuracy. + for qdq, qop in zip(result, result_int): + assert is_sorted(qdq) + assert is_sorted(qop) + + try: + np.testing.assert_allclose(result_int, result, atol=1) + except AssertionError as e: + # To see the difference + print(e) + + if __name__ == "__main__": tvm.testing.main()