Skip to content

Extend quantiser support so as to accelerate more binary models. #668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions larq_compute_engine/core/bitpacking/bitpack_aarch64.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,20 @@

#include "larq_compute_engine/core/types.h"
#include "ruy/profiler/instrumentation.h"
#include "tensorflow/lite/kernels/op_macros.h"

namespace compute_engine {
namespace core {
namespace bitpacking {

template <typename T>
inline void bitpack_aarch64_4x32(const T* input, std::size_t num_blocks,
TBitpacked* output, const T zero_point) {
TFLITE_ASSERT_FALSE;
}

// Bitpack an array of `4 * 32 * num_blocks` floats.
template <>
inline void bitpack_aarch64_4x32(const float* input, std::size_t num_blocks,
TBitpacked* output,
const float zero_point /*ignored*/) {
Expand Down Expand Up @@ -227,6 +235,7 @@ inline void bitpack_aarch64_4x32(const float* input, std::size_t num_blocks,
}

// Bitpack an array of `4 * 32 * num_blocks` int8 bytes.
template <>
inline void bitpack_aarch64_4x32(const std::int8_t* input,
std::size_t num_blocks, TBitpacked* output,
const std::int8_t zero_byte) {
Expand Down
8 changes: 4 additions & 4 deletions larq_compute_engine/mlir/ir/lce_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ def LQ_QuantizeOp : LQ_Op<"Quantize", [NoSideEffect]> {
let summary = "Binary quantize operator";

let description = [{
Converts floating point or integer tensors to binarized bitpacked tensors.
Converts floating point, integer, or boolean tensors to binarized bitpacked tensors.
}];

let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I32, I64, QI8, QI16]>:$x
TensorOf<[BF16, F16, F32, F64, I32, I64, QI8, QI16, I1]>:$x
);

let results = (outs
Expand All @@ -90,15 +90,15 @@ def LQ_DequantizeOp : LQ_Op<"Dequantize", [NoSideEffect]> {
let summary = "Binary dequantize operator";

let description = [{
Converts binarized bitpacked tensors to floating point or integer tensors.
Converts binarized bitpacked tensors to floating point, integer, or boolean tensors.
}];

let arguments = (ins
TensorOf<[I32]>:$x
);

let results = (outs
TensorOf<[BF16, F16, F32, F64, I32, I64, QI8, QI16]>:$y
TensorOf<[BF16, F16, F32, F64, I32, I64, QI8, QI16, I1]>:$y
);

let hasFolder = 1;
Expand Down
44 changes: 44 additions & 0 deletions larq_compute_engine/mlir/tests/optimize.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,50 @@
// RUN: lce-tf-opt %s -tfl-optimize-lce=target=arm -verify-diagnostics | FileCheck %s --check-prefixes CHECK,CHECK-ARM
// RUN: lce-tf-opt %s -tfl-optimize-lce=target=xcore -verify-diagnostics | FileCheck %s --check-prefixes CHECK,CHECK-XCORE

// CHECK-LABEL: @optimize_quantize_greater_equal_zero
func @optimize_quantize_greater_equal_zero(%arg0: tensor<48x48x64xf32>) -> tensor<48x48x2xi32> {
%cst = constant dense<0.0> : tensor<f32>
%0 = "tfl.greater_equal"(%arg0, %cst) : (tensor<48x48x64xf32>, tensor<f32>) -> tensor<48x48x64xi1>
%1 = "lq.Quantize"(%0) : (tensor<48x48x64xi1>) -> tensor<48x48x2xi32>
return %1 : tensor<48x48x2xi32>

// CHECK-NEXT: %0 = "lq.Quantize"(%arg0) : (tensor<48x48x64xf32>) -> tensor<48x48x2xi32>
// CHECK-NEXT: return %0
}

// CHECK-LABEL: @optimize_quantize_greater_equal_non_zero
func @optimize_quantize_greater_equal_non_zero(%arg0: tensor<48x48x64xf32>, %arg1: tensor<48x48x64xf32>) -> tensor<48x48x2xi32> {
%0 = "tfl.greater_equal"(%arg0, %arg1) : (tensor<48x48x64xf32>, tensor<48x48x64xf32>) -> tensor<48x48x64xi1>
%1 = "lq.Quantize"(%0) : (tensor<48x48x64xi1>) -> tensor<48x48x2xi32>
return %1 : tensor<48x48x2xi32>

// CHECK-NEXT: %0 = tfl.sub %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<48x48x64xf32>
// CHECK-NEXT: %1 = "lq.Quantize"(%0) : (tensor<48x48x64xf32>) -> tensor<48x48x2xi32>
// CHECK-NEXT: return %1
}

// CHECK-LABEL: @optimize_quantize_less_equal_zero
func @optimize_quantize_less_equal_zero(%arg0: tensor<48x48x64xf32>) -> tensor<48x48x2xi32> {
%cst = constant dense<0.0> : tensor<64xf32>
%0 = "tfl.less_equal"(%cst, %arg0) : (tensor<64xf32>, tensor<48x48x64xf32>) -> tensor<48x48x64xi1>
%1 = "lq.Quantize"(%0) : (tensor<48x48x64xi1>) -> tensor<48x48x2xi32>
return %1 : tensor<48x48x2xi32>

// CHECK-NEXT: %0 = "lq.Quantize"(%arg0) : (tensor<48x48x64xf32>) -> tensor<48x48x2xi32>
// CHECK-NEXT: return %0
}

// CHECK-LABEL: @optimize_quantize_less_equal_non_zero
func @optimize_quantize_less_equal_non_zero(%arg0: tensor<48x48x64xf32>, %arg1: tensor<48x48x64xf32>) -> tensor<48x48x2xi32> {
%0 = "tfl.less_equal"(%arg0, %arg1) : (tensor<48x48x64xf32>, tensor<48x48x64xf32>) -> tensor<48x48x64xi1>
%1 = "lq.Quantize"(%0) : (tensor<48x48x64xi1>) -> tensor<48x48x2xi32>
return %1 : tensor<48x48x2xi32>

// CHECK-NEXT: %0 = tfl.sub %arg1, %arg0 {fused_activation_function = "NONE"} : tensor<48x48x64xf32>
// CHECK-NEXT: %1 = "lq.Quantize"(%0) : (tensor<48x48x64xf32>) -> tensor<48x48x2xi32>
// CHECK-NEXT: return %1
}

// CHECK-LABEL: @fuse_add_into_bconv2d
func @fuse_add_into_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>, %arg3: none) -> tensor<256x30x30x16xf32> {
%cst = constant dense<1.5> : tensor<16xf32>
Expand Down
67 changes: 65 additions & 2 deletions larq_compute_engine/mlir/tests/prepare-tf.mlir
Original file line number Diff line number Diff line change
@@ -1,8 +1,71 @@
// RUN: lce-tf-opt %s -tfl-prepare-lce=target=arm -verify-diagnostics | FileCheck %s --check-prefixes CHECK,CHECK-ARM
// RUN: lce-tf-opt %s -tfl-prepare-lce=target=xcore -verify-diagnostics | FileCheck %s --check-prefixes CHECK,CHECK-XCORE

// CHECK-LABEL: @fuse_bsign
func @fuse_bsign(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
// CHECK-LABEL: @fuse_bsign_tf_where
func @fuse_bsign_tf_where(%arg0: tensor<8x16xi1>) -> tensor<8x16xf32> {
%cst_l = constant dense<1.0> : tensor<8x16xf32>
%cst_r = constant dense<-1.0> : tensor<8x16xf32>
%0 = "tf.SelectV2"(%arg0, %cst_l, %cst_r) : (tensor<8x16xi1>, tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>

// CHECK-NEXT: %0 = "lq.Quantize"(%arg0) : (tensor<8x16xi1>) -> tensor<8x1xi32>
// CHECK-NEXT: %1 = "lq.Dequantize"(%0) : (tensor<8x1xi32>) -> tensor<8x16xf32>
// CHECK-NEXT: return %1
}

// CHECK-LABEL: @fuse_bsign_tf_where_inverted
func @fuse_bsign_tf_where_inverted(%arg0: tensor<8x16xi1>) -> tensor<8x16xf32> {
%cst_l = constant dense<-1.0> : tensor<8x16xf32>
%cst_r = constant dense<1.0> : tensor<8x16xf32>
%0 = "tf.SelectV2"(%arg0, %cst_l, %cst_r) : (tensor<8x16xi1>, tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>

// CHECK-NEXT: %0 = "tf.LogicalNot"(%arg0) : (tensor<8x16xi1>) -> tensor<8x16xi1>
// CHECK-NEXT: %1 = "lq.Quantize"(%0) : (tensor<8x16xi1>) -> tensor<8x1xi32>
// CHECK-NEXT: %2 = "lq.Dequantize"(%1) : (tensor<8x1xi32>) -> tensor<8x16xf32>
// CHECK-NEXT: return %2
}

// CHECK-LABEL: @fuse_bsign_tf_where_broadcast_cond
func @fuse_bsign_tf_where_broadcast_cond(%arg0: tensor<8x1xi1>) -> tensor<8x16xf32> {
%cst_l = constant dense<1.0> : tensor<8x16xf32>
%cst_r = constant dense<-1.0> : tensor<8x16xf32>
%0 = "tf.SelectV2"(%arg0, %cst_l, %cst_r) : (tensor<8x1xi1>, tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>

// CHECK-NEXT: %cst = constant dense<[8, 16]> : tensor<2xi64>
// CHECK-NEXT: %0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<8x1xi1>, tensor<2xi64>) -> tensor<8x16xi1>
// CHECK-NEXT: %1 = "lq.Quantize"(%0) : (tensor<8x16xi1>) -> tensor<8x1xi32>
// CHECK-NEXT: %2 = "lq.Dequantize"(%1) : (tensor<8x1xi32>) -> tensor<8x16xf32>
// CHECK-NEXT: return %2
}

// CHECK-LABEL: @fuse_bsign_tf_where_broadcast_lhs_rhs
func @fuse_bsign_tf_where_broadcast_lhs_rhs(%arg0: tensor<8x16xi1>) -> tensor<8x16xf32> {
%cst_l = constant dense<1.0> : tensor<f32>
%cst_r = constant dense<-1.0> : tensor<8x1xf32>
%0 = "tf.SelectV2"(%arg0, %cst_l, %cst_r) : (tensor<8x16xi1>, tensor<f32>, tensor<8x1xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>

// CHECK-NEXT: %0 = "lq.Quantize"(%arg0) : (tensor<8x16xi1>) -> tensor<8x1xi32>
// CHECK-NEXT: %1 = "lq.Dequantize"(%0) : (tensor<8x1xi32>) -> tensor<8x16xf32>
// CHECK-NEXT: return %1
}

// CHECK-LABEL: @fuse_bsign_tf_where_select_v1_op
func @fuse_bsign_tf_where_select_v1_op(%arg0: tensor<8x16xi1>) -> tensor<8x16xf32> {
%cst_l = constant dense<1.0> : tensor<8x16xf32>
%cst_r = constant dense<-1.0> : tensor<8x16xf32>
%0 = "tf.Select"(%arg0, %cst_l, %cst_r) : (tensor<8x16xi1>, tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>

// CHECK-NEXT: %0 = "lq.Quantize"(%arg0) : (tensor<8x16xi1>) -> tensor<8x1xi32>
// CHECK-NEXT: %1 = "lq.Dequantize"(%0) : (tensor<8x1xi32>) -> tensor<8x16xf32>
// CHECK-NEXT: return %1
}

// CHECK-LABEL: @fuse_bsign_legacy_tf_sign
func @fuse_bsign_legacy_tf_sign(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
%0 = "tf.Sign"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
%cst = constant dense<0.1> : tensor<f32>
%2 = "tf.AddV2"(%0, %cst) : (tensor<8x16xf32>, tensor<f32>) -> tensor<8x16xf32>
Expand Down
24 changes: 24 additions & 0 deletions larq_compute_engine/mlir/transforms/optimize_patterns_common.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,30 @@ def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;

class ConstantValue<string val> : AttrConstraint<CPred<"IsConstantValue($_self, " # val # ")">>;

def : Pat<(LQ_QuantizeOp
(TFL_GreaterEqualOp:$ge_op
$input,
(ConstantOp ConstantValue<"0.0f">))),
(LQ_QuantizeOp $input),
[(HasOneUse $ge_op)],
(addBenefit 150)>;

def : Pat<(LQ_QuantizeOp
(TFL_GreaterEqualOp:$ge_op
$input,
$threshold)),
(LQ_QuantizeOp
(TFL_SubOp $input, $threshold, TFL_AF_None)),
[(HasOneUse $ge_op)],
(addBenefit 100)>;

def : Pat<(LQ_QuantizeOp
(TFL_LessEqualOp:$ge_op $lhs, $rhs)),
(LQ_QuantizeOp
(TFL_GreaterEqualOp $rhs, $lhs)),
[(HasOneUse $ge_op)],
(addBenefit 100)>;

// TODO: Check shapes before fusing
multiclass FuseAddOrSubWithBConv2D<Op binaryOp> {
def : Pat<(binaryOp
Expand Down
51 changes: 49 additions & 2 deletions larq_compute_engine/mlir/transforms/prepare_patterns_common.td
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,56 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
include "larq_compute_engine/mlir/ir/lce_ops.td"
include "larq_compute_engine/mlir/transforms/op_removal_patterns.td"

class ConstantValue<string val> : AttrConstraint<CPred<"IsConstantValue($_self, " # val # ")">>;

// This relies on implementation details of larq.math.sign. We should make
// this more general in the future
def CreateTFBroadcastToOp : NativeCodeCall<
"$_builder.create<TF::BroadcastToOp>("
"$0.getLoc(),"
"RankedTensorType::get("
"$0.getType().cast<RankedTensorType>().getShape(),"
"getElementTypeOrSelf($1.getType())),"
"$1,"
"$2)">;

def CreateTFShapeOp : NativeCodeCall<
"$_builder.create<TF::ShapeOp>($0.getLoc(), $1, $2)">;

// Base quantiser patterns that match the `tf.where` implementation of `ste_sign`.
multiclass QuantDequantPatterns<Op SelectOp> {
def : Pat<(SelectOp:$select_op
$cond,
(ConstantOp ConstantValue<"1.0f">),
(ConstantOp ConstantValue<"-1.0f">)),
(LQ_DequantizeOp
(LQ_QuantizeOp
(CreateTFBroadcastToOp
$select_op,
$cond,
(CreateTFShapeOp
$select_op,
$select_op,
/*use 32bit*/ConstBoolAttrFalse)))),
[], (addBenefit 100)>;
def : Pat<(SelectOp:$select_op
$cond,
(ConstantOp ConstantValue<"-1.0f">),
(ConstantOp ConstantValue<"1.0f">)),
(LQ_DequantizeOp
(LQ_QuantizeOp
(CreateTFBroadcastToOp
$select_op,
(TF_LogicalNotOp $cond),
(CreateTFShapeOp
$select_op,
$select_op,
/*use 32bit*/ConstBoolAttrFalse)))),
[], (addBenefit 100)>;
}
foreach SelectOp = [TF_SelectOp, TF_SelectV2Op]<Op> in
defm : QuantDequantPatterns<SelectOp>;

// A fallback for the old version of `ste_sign` that uses a specific `tf.sign`
// based implementation of `larq.math.sign`.
def : Pat<(TF_SignOp (TF_AddV2Op (TF_SignOp $arg), $c)),
(LQ_DequantizeOp (LQ_QuantizeOp $arg)), [], (addBenefit 100)>;
def : Pat<(TF_SignOp (TF_AddV2Op $c, (TF_SignOp $arg))),
Expand Down
8 changes: 8 additions & 0 deletions larq_compute_engine/mlir/transforms/prepare_tf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ struct PrepareLCE : public PassWrapper<PrepareLCE, FunctionPass> {
clEnumValN(LCETarget::XCORE, "xcore", "XCORE target"))};
};

bool IsConstantValue(Attribute values, float expected_value) {
if (!values.isa<DenseElementsAttr>()) return false;

for (auto value : values.cast<DenseElementsAttr>().getValues<float>()) {
if (value != expected_value) return false;
}
return true;
}
DenseElementsAttr GetConstantVector(Attribute filter, float val) {
auto filter_type = filter.getType().cast<ShapedType>();
auto filter_shape = filter_type.getShape();
Expand Down
36 changes: 25 additions & 11 deletions larq_compute_engine/tests/end2end_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import os
import sys
import tempfile
Expand All @@ -23,7 +24,7 @@ def convert_keras_model_as_saved_model(model, **kwargs):
return convert_saved_model(saved_model_dir, **kwargs)


def toy_model(**kwargs):
def toy_model(binary_quantizer="ste_sign", **kwargs):
def block(padding, pad_values, activation):
def dummy(x):
shortcut = x
Expand All @@ -32,8 +33,8 @@ def dummy(x):
kernel_size=3,
padding=padding,
pad_values=pad_values,
input_quantizer="ste_sign",
kernel_quantizer="ste_sign",
input_quantizer=binary_quantizer,
kernel_quantizer=binary_quantizer,
use_bias=False,
activation=activation,
)(x)
Expand All @@ -59,7 +60,7 @@ def dummy(x):
return tf.keras.Model(inputs=img_input, outputs=out)


def toy_model_sequential(**kwargs):
def toy_model_sequential(binary_quantizer="ste_sign", **kwargs):
return tf.keras.models.Sequential(
[
tf.keras.layers.Input((224, 224, 3)),
Expand All @@ -70,8 +71,8 @@ def toy_model_sequential(**kwargs):
lq.layers.QuantConv2D(
32,
(3, 3),
input_quantizer="ste_sign",
kernel_quantizer="ste_sign",
input_quantizer=binary_quantizer,
kernel_quantizer=binary_quantizer,
padding="same",
pad_values=1.0,
use_bias=False,
Expand All @@ -85,8 +86,8 @@ def toy_model_sequential(**kwargs):
lq.layers.QuantConv2D(
32,
(3, 3),
input_quantizer="ste_sign",
kernel_quantizer="ste_sign",
input_quantizer=binary_quantizer,
kernel_quantizer=binary_quantizer,
strides=(2, 2),
padding="same",
pad_values=1.0,
Expand All @@ -104,8 +105,8 @@ def toy_model_sequential(**kwargs):
lq.layers.QuantConv2D(
32,
(3, 3),
input_quantizer="ste_sign",
kernel_quantizer="ste_sign",
input_quantizer=binary_quantizer,
kernel_quantizer=binary_quantizer,
padding="same",
pad_values=1.0,
use_bias=False,
Expand Down Expand Up @@ -165,12 +166,25 @@ def dataset():
)


def tf_where_binary_quantizer(x):
return tf.where(x >= 0, tf.ones_like(x), -tf.ones_like(x))


@pytest.mark.parametrize(
"conversion_function", [convert_keras_model, convert_keras_model_as_saved_model]
)
@pytest.mark.parametrize(
"model_cls",
[toy_model, toy_model_sequential, toy_model_int8, lqz.sota.QuickNetSmall],
[
toy_model,
functools.partial(toy_model, binary_quantizer=tf_where_binary_quantizer),
toy_model_sequential,
functools.partial(
toy_model_sequential, binary_quantizer=tf_where_binary_quantizer
),
toy_model_int8,
lqz.sota.QuickNetSmall,
],
)
def test_simple_model(dataset, conversion_function, model_cls):
model = model_cls(weights="imagenet")
Expand Down
1 change: 1 addition & 0 deletions larq_compute_engine/tflite/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ cc_library(
"//larq_compute_engine/core/indirect_bgemm:kernels",
"@flatbuffers",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite:type_to_tflitetype",
"@org_tensorflow//tensorflow/lite/kernels/internal:kernel_utils",
"@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
"@ruy//ruy/profiler:instrumentation",
Expand Down
Loading