Skip to content

Commit 962a202

Browse files
committed
Extend quantiser support so as to accelerate more binary models.
Add the ability to convert `tf.where`-style binary quantisers, and add support for boolean input to `LceQuantize` and `LceDequantize`.
1 parent 72e5150 commit 962a202

File tree

11 files changed

+264
-23
lines changed

11 files changed

+264
-23
lines changed

larq_compute_engine/core/bitpacking/bitpack_aarch64.h

+9
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,20 @@
99

1010
#include "larq_compute_engine/core/types.h"
1111
#include "ruy/profiler/instrumentation.h"
12+
#include "tensorflow/lite/kernels/op_macros.h"
1213

1314
namespace compute_engine {
1415
namespace core {
1516
namespace bitpacking {
1617

18+
template <typename T>
19+
inline void bitpack_aarch64_4x32(const T* input, std::size_t num_blocks,
20+
TBitpacked* output, const T zero_point) {
21+
TFLITE_ASSERT_FALSE;
22+
}
23+
1724
// Bitpack an array of `4 * 32 * num_blocks` floats.
25+
template <>
1826
inline void bitpack_aarch64_4x32(const float* input, std::size_t num_blocks,
1927
TBitpacked* output,
2028
const float zero_point /*ignored*/) {
@@ -227,6 +235,7 @@ inline void bitpack_aarch64_4x32(const float* input, std::size_t num_blocks,
227235
}
228236

229237
// Bitpack an array of `4 * 32 * num_blocks` int8 bytes.
238+
template <>
230239
inline void bitpack_aarch64_4x32(const std::int8_t* input,
231240
std::size_t num_blocks, TBitpacked* output,
232241
const std::int8_t zero_byte) {

larq_compute_engine/mlir/ir/lce_ops.td

+4-4
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,11 @@ def LQ_QuantizeOp : LQ_Op<"Quantize", [NoSideEffect]> {
7070
let summary = "Binary quantize operator";
7171

7272
let description = [{
73-
Converts floating point or integer tensors to binarized bitpacked tensors.
73+
Converts floating point, integer, or boolean tensors to binarized bitpacked tensors.
7474
}];
7575

7676
let arguments = (ins
77-
TensorOf<[BF16, F16, F32, F64, I32, I64, QI8, QI16]>:$x
77+
TensorOf<[BF16, F16, F32, F64, I32, I64, QI8, QI16, I1]>:$x
7878
);
7979

8080
let results = (outs
@@ -90,15 +90,15 @@ def LQ_DequantizeOp : LQ_Op<"Dequantize", [NoSideEffect]> {
9090
let summary = "Binary dequantize operator";
9191

9292
let description = [{
93-
Converts binarized bitpacked tensors to floating point or integer tensors.
93+
Converts binarized bitpacked tensors to floating point, integer, or boolean tensors.
9494
}];
9595

9696
let arguments = (ins
9797
TensorOf<[I32]>:$x
9898
);
9999

100100
let results = (outs
101-
TensorOf<[BF16, F16, F32, F64, I32, I64, QI8, QI16]>:$y
101+
TensorOf<[BF16, F16, F32, F64, I32, I64, QI8, QI16, I1]>:$y
102102
);
103103

104104
let hasFolder = 1;

larq_compute_engine/mlir/tests/optimize.mlir

+44
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,50 @@
11
// RUN: lce-tf-opt %s -tfl-optimize-lce=target=arm -verify-diagnostics | FileCheck %s --check-prefixes CHECK,CHECK-ARM
22
// RUN: lce-tf-opt %s -tfl-optimize-lce=target=xcore -verify-diagnostics | FileCheck %s --check-prefixes CHECK,CHECK-XCORE
33

4+
// CHECK-LABEL: @optimize_quantize_greater_equal_zero
5+
func @optimize_quantize_greater_equal_zero(%arg0: tensor<48x48x64xf32>) -> tensor<48x48x2xi32> {
6+
%cst = constant dense<0.0> : tensor<f32>
7+
%0 = "tfl.greater_equal"(%arg0, %cst) : (tensor<48x48x64xf32>, tensor<f32>) -> tensor<48x48x64xi1>
8+
%1 = "lq.Quantize"(%0) : (tensor<48x48x64xi1>) -> tensor<48x48x2xi32>
9+
return %1 : tensor<48x48x2xi32>
10+
11+
// CHECK-NEXT: %0 = "lq.Quantize"(%arg0) : (tensor<48x48x64xf32>) -> tensor<48x48x2xi32>
12+
// CHECK-NEXT: return %0
13+
}
14+
15+
// CHECK-LABEL: @optimize_quantize_greater_equal_non_zero
16+
func @optimize_quantize_greater_equal_non_zero(%arg0: tensor<48x48x64xf32>, %arg1: tensor<48x48x64xf32>) -> tensor<48x48x2xi32> {
17+
%0 = "tfl.greater_equal"(%arg0, %arg1) : (tensor<48x48x64xf32>, tensor<48x48x64xf32>) -> tensor<48x48x64xi1>
18+
%1 = "lq.Quantize"(%0) : (tensor<48x48x64xi1>) -> tensor<48x48x2xi32>
19+
return %1 : tensor<48x48x2xi32>
20+
21+
// CHECK-NEXT: %0 = tfl.sub %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<48x48x64xf32>
22+
// CHECK-NEXT: %1 = "lq.Quantize"(%0) : (tensor<48x48x64xf32>) -> tensor<48x48x2xi32>
23+
// CHECK-NEXT: return %1
24+
}
25+
26+
// CHECK-LABEL: @optimize_quantize_less_equal_zero
27+
func @optimize_quantize_less_equal_zero(%arg0: tensor<48x48x64xf32>) -> tensor<48x48x2xi32> {
28+
%cst = constant dense<0.0> : tensor<64xf32>
29+
%0 = "tfl.less_equal"(%cst, %arg0) : (tensor<64xf32>, tensor<48x48x64xf32>) -> tensor<48x48x64xi1>
30+
%1 = "lq.Quantize"(%0) : (tensor<48x48x64xi1>) -> tensor<48x48x2xi32>
31+
return %1 : tensor<48x48x2xi32>
32+
33+
// CHECK-NEXT: %0 = "lq.Quantize"(%arg0) : (tensor<48x48x64xf32>) -> tensor<48x48x2xi32>
34+
// CHECK-NEXT: return %0
35+
}
36+
37+
// CHECK-LABEL: @optimize_quantize_less_equal_non_zero
38+
func @optimize_quantize_less_equal_non_zero(%arg0: tensor<48x48x64xf32>, %arg1: tensor<48x48x64xf32>) -> tensor<48x48x2xi32> {
39+
%0 = "tfl.less_equal"(%arg0, %arg1) : (tensor<48x48x64xf32>, tensor<48x48x64xf32>) -> tensor<48x48x64xi1>
40+
%1 = "lq.Quantize"(%0) : (tensor<48x48x64xi1>) -> tensor<48x48x2xi32>
41+
return %1 : tensor<48x48x2xi32>
42+
43+
// CHECK-NEXT: %0 = tfl.sub %arg1, %arg0 {fused_activation_function = "NONE"} : tensor<48x48x64xf32>
44+
// CHECK-NEXT: %1 = "lq.Quantize"(%0) : (tensor<48x48x64xf32>) -> tensor<48x48x2xi32>
45+
// CHECK-NEXT: return %1
46+
}
47+
448
// CHECK-LABEL: @fuse_add_into_bconv2d
549
func @fuse_add_into_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>, %arg3: none) -> tensor<256x30x30x16xf32> {
650
%cst = constant dense<1.5> : tensor<16xf32>

larq_compute_engine/mlir/tests/prepare-tf.mlir

+65-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,71 @@
11
// RUN: lce-tf-opt %s -tfl-prepare-lce=target=arm -verify-diagnostics | FileCheck %s --check-prefixes CHECK,CHECK-ARM
22
// RUN: lce-tf-opt %s -tfl-prepare-lce=target=xcore -verify-diagnostics | FileCheck %s --check-prefixes CHECK,CHECK-XCORE
33

4-
// CHECK-LABEL: @fuse_bsign
5-
func @fuse_bsign(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
4+
// CHECK-LABEL: @fuse_bsign_tf_where
5+
func @fuse_bsign_tf_where(%arg0: tensor<8x16xi1>) -> tensor<8x16xf32> {
6+
%cst_l = constant dense<1.0> : tensor<8x16xf32>
7+
%cst_r = constant dense<-1.0> : tensor<8x16xf32>
8+
%0 = "tf.SelectV2"(%arg0, %cst_l, %cst_r) : (tensor<8x16xi1>, tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
9+
return %0 : tensor<8x16xf32>
10+
11+
// CHECK-NEXT: %0 = "lq.Quantize"(%arg0) : (tensor<8x16xi1>) -> tensor<8x1xi32>
12+
// CHECK-NEXT: %1 = "lq.Dequantize"(%0) : (tensor<8x1xi32>) -> tensor<8x16xf32>
13+
// CHECK-NEXT: return %1
14+
}
15+
16+
// CHECK-LABEL: @fuse_bsign_tf_where_inverted
17+
func @fuse_bsign_tf_where_inverted(%arg0: tensor<8x16xi1>) -> tensor<8x16xf32> {
18+
%cst_l = constant dense<-1.0> : tensor<8x16xf32>
19+
%cst_r = constant dense<1.0> : tensor<8x16xf32>
20+
%0 = "tf.SelectV2"(%arg0, %cst_l, %cst_r) : (tensor<8x16xi1>, tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
21+
return %0 : tensor<8x16xf32>
22+
23+
// CHECK-NEXT: %0 = "tf.LogicalNot"(%arg0) : (tensor<8x16xi1>) -> tensor<8x16xi1>
24+
// CHECK-NEXT: %1 = "lq.Quantize"(%0) : (tensor<8x16xi1>) -> tensor<8x1xi32>
25+
// CHECK-NEXT: %2 = "lq.Dequantize"(%1) : (tensor<8x1xi32>) -> tensor<8x16xf32>
26+
// CHECK-NEXT: return %2
27+
}
28+
29+
// CHECK-LABEL: @fuse_bsign_tf_where_broadcast_cond
30+
func @fuse_bsign_tf_where_broadcast_cond(%arg0: tensor<8x1xi1>) -> tensor<8x16xf32> {
31+
%cst_l = constant dense<1.0> : tensor<8x16xf32>
32+
%cst_r = constant dense<-1.0> : tensor<8x16xf32>
33+
%0 = "tf.SelectV2"(%arg0, %cst_l, %cst_r) : (tensor<8x1xi1>, tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
34+
return %0 : tensor<8x16xf32>
35+
36+
// CHECK-NEXT: %cst = constant dense<[8, 1]> : tensor<2xi64>
37+
// CHECK-NEXT: %0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<8x1xi1>, tensor<2xi64>) -> tensor<8x16xi1>
38+
// CHECK-NEXT: %1 = "lq.Quantize"(%0) : (tensor<8x16xi1>) -> tensor<8x1xi32>
39+
// CHECK-NEXT: %2 = "lq.Dequantize"(%1) : (tensor<8x1xi32>) -> tensor<8x16xf32>
40+
// CHECK-NEXT: return %2
41+
}
42+
43+
// CHECK-LABEL: @fuse_bsign_tf_where_broadcast_lhs_rhs
44+
func @fuse_bsign_tf_where_broadcast_lhs_rhs(%arg0: tensor<8x16xi1>) -> tensor<8x16xf32> {
45+
%cst_l = constant dense<1.0> : tensor<f32>
46+
%cst_r = constant dense<-1.0> : tensor<8x1xf32>
47+
%0 = "tf.SelectV2"(%arg0, %cst_l, %cst_r) : (tensor<8x16xi1>, tensor<f32>, tensor<8x1xf32>) -> tensor<8x16xf32>
48+
return %0 : tensor<8x16xf32>
49+
50+
// CHECK-NEXT: %0 = "lq.Quantize"(%arg0) : (tensor<8x16xi1>) -> tensor<8x1xi32>
51+
// CHECK-NEXT: %1 = "lq.Dequantize"(%0) : (tensor<8x1xi32>) -> tensor<8x16xf32>
52+
// CHECK-NEXT: return %1
53+
}
54+
55+
// CHECK-LABEL: @fuse_bsign_tf_where_select_v1_op
56+
func @fuse_bsign_tf_where_select_v1_op(%arg0: tensor<8x16xi1>) -> tensor<8x16xf32> {
57+
%cst_l = constant dense<1.0> : tensor<8x16xf32>
58+
%cst_r = constant dense<-1.0> : tensor<8x16xf32>
59+
%0 = "tf.Select"(%arg0, %cst_l, %cst_r) : (tensor<8x16xi1>, tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
60+
return %0 : tensor<8x16xf32>
61+
62+
// CHECK-NEXT: %0 = "lq.Quantize"(%arg0) : (tensor<8x16xi1>) -> tensor<8x1xi32>
63+
// CHECK-NEXT: %1 = "lq.Dequantize"(%0) : (tensor<8x1xi32>) -> tensor<8x16xf32>
64+
// CHECK-NEXT: return %1
65+
}
66+
67+
// CHECK-LABEL: @fuse_bsign_legacy_tf_sign
68+
func @fuse_bsign_legacy_tf_sign(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
669
%0 = "tf.Sign"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
770
%cst = constant dense<0.1> : tensor<f32>
871
%2 = "tf.AddV2"(%0, %cst) : (tensor<8x16xf32>, tensor<f32>) -> tensor<8x16xf32>

larq_compute_engine/mlir/transforms/optimize_patterns_common.td

+24
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,30 @@ def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
1111

1212
class ConstantValue<string val> : AttrConstraint<CPred<"IsConstantValue($_self, " # val # ")">>;
1313

14+
def : Pat<(LQ_QuantizeOp
15+
(TFL_GreaterEqualOp:$ge_op
16+
$input,
17+
(ConstantOp:$threshold ConstantValue<"0.0f">))),
18+
(LQ_QuantizeOp $input),
19+
[(HasOneUse $ge_op), (HasOneUse $threshold)],
20+
(addBenefit 150)>;
21+
22+
def : Pat<(LQ_QuantizeOp
23+
(TFL_GreaterEqualOp:$ge_op
24+
$input,
25+
$threshold)),
26+
(LQ_QuantizeOp
27+
(TFL_SubOp $input, $threshold, TFL_AF_None)),
28+
[(HasOneUse $ge_op)],
29+
(addBenefit 100)>;
30+
31+
def : Pat<(LQ_QuantizeOp
32+
(TFL_LessEqualOp:$ge_op $lhs, $rhs)),
33+
(LQ_QuantizeOp
34+
(TFL_GreaterEqualOp $rhs, $lhs)),
35+
[(HasOneUse $ge_op)],
36+
(addBenefit 100)>;
37+
1438
// TODO: Check shapes before fusing
1539
multiclass FuseAddOrSubWithBConv2D<Op binaryOp> {
1640
def : Pat<(binaryOp

larq_compute_engine/mlir/transforms/prepare_patterns_common.td

+49-2
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,56 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
44
include "larq_compute_engine/mlir/ir/lce_ops.td"
55
include "larq_compute_engine/mlir/transforms/op_removal_patterns.td"
66

7+
class ConstantValue<string val> : AttrConstraint<CPred<"IsConstantValue($_self, " # val # ")">>;
78

8-
// This relies on implementation details of larq.math.sign. We should make
9-
// this more general in the future
9+
def CreateTFBroadcastToOp : NativeCodeCall<
10+
"$_builder.create<TF::BroadcastToOp>("
11+
"$0.getLoc(),"
12+
"RankedTensorType::get("
13+
"$0.getType().cast<RankedTensorType>().getShape(),"
14+
"getElementTypeOrSelf($1.getType())),"
15+
"$1,"
16+
"$2)">;
17+
18+
def CreateTFShapeOp : NativeCodeCall<
19+
"$_builder.create<TF::ShapeOp>($0.getLoc(), $1, $2)">;
20+
21+
// Base quantiser patterns that match the `tf.where` implementation of `ste_sign`.
22+
multiclass QuantDequantPatterns<Op SelectOp> {
23+
def : Pat<(SelectOp:$select_op
24+
$cond,
25+
(ConstantOp ConstantValue<"1.0f">),
26+
(ConstantOp ConstantValue<"-1.0f">)),
27+
(LQ_DequantizeOp
28+
(LQ_QuantizeOp
29+
(CreateTFBroadcastToOp
30+
$select_op,
31+
$cond,
32+
(CreateTFShapeOp
33+
$select_op,
34+
$cond,
35+
/*use 32bit*/ConstBoolAttrFalse)))),
36+
[], (addBenefit 100)>;
37+
def : Pat<(SelectOp:$select_op
38+
$cond,
39+
(ConstantOp ConstantValue<"-1.0f">),
40+
(ConstantOp ConstantValue<"1.0f">)),
41+
(LQ_DequantizeOp
42+
(LQ_QuantizeOp
43+
(CreateTFBroadcastToOp
44+
$select_op,
45+
(TF_LogicalNotOp $cond),
46+
(CreateTFShapeOp
47+
$select_op,
48+
$cond,
49+
/*use 32bit*/ConstBoolAttrFalse)))),
50+
[], (addBenefit 100)>;
51+
}
52+
foreach SelectOp = [TF_SelectOp, TF_SelectV2Op] in
53+
defm : QuantDequantPatterns<!cast<Op>(SelectOp)>;
54+
55+
// A fallback for the old version of `ste_sign` that uses a specific `tf.sign`
56+
// based implementation of `larq.math.sign`.
1057
def : Pat<(TF_SignOp (TF_AddV2Op (TF_SignOp $arg), $c)),
1158
(LQ_DequantizeOp (LQ_QuantizeOp $arg)), [], (addBenefit 100)>;
1259
def : Pat<(TF_SignOp (TF_AddV2Op $c, (TF_SignOp $arg))),

larq_compute_engine/mlir/transforms/prepare_tf.cc

+8
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@ struct PrepareLCE : public PassWrapper<PrepareLCE, FunctionPass> {
3636
clEnumValN(LCETarget::XCORE, "xcore", "XCORE target"))};
3737
};
3838

39+
bool IsConstantValue(Attribute values, float expected_value) {
40+
if (!values.isa<DenseElementsAttr>()) return false;
41+
42+
for (auto value : values.cast<DenseElementsAttr>().getValues<float>()) {
43+
if (value != expected_value) return false;
44+
}
45+
return true;
46+
}
3947
DenseElementsAttr GetConstantVector(Attribute filter, float val) {
4048
auto filter_type = filter.getType().cast<ShapedType>();
4149
auto filter_shape = filter_type.getShape();

larq_compute_engine/tests/end2end_test.py

+25-11
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import os
23
import sys
34
import tempfile
@@ -23,7 +24,7 @@ def convert_keras_model_as_saved_model(model, **kwargs):
2324
return convert_saved_model(saved_model_dir, **kwargs)
2425

2526

26-
def toy_model(**kwargs):
27+
def toy_model(binary_quantizer="ste_sign", **kwargs):
2728
def block(padding, pad_values, activation):
2829
def dummy(x):
2930
shortcut = x
@@ -32,8 +33,8 @@ def dummy(x):
3233
kernel_size=3,
3334
padding=padding,
3435
pad_values=pad_values,
35-
input_quantizer="ste_sign",
36-
kernel_quantizer="ste_sign",
36+
input_quantizer=binary_quantizer,
37+
kernel_quantizer=binary_quantizer,
3738
use_bias=False,
3839
activation=activation,
3940
)(x)
@@ -59,7 +60,7 @@ def dummy(x):
5960
return tf.keras.Model(inputs=img_input, outputs=out)
6061

6162

62-
def toy_model_sequential(**kwargs):
63+
def toy_model_sequential(binary_quantizer="ste_sign", **kwargs):
6364
return tf.keras.models.Sequential(
6465
[
6566
tf.keras.layers.Input((224, 224, 3)),
@@ -70,8 +71,8 @@ def toy_model_sequential(**kwargs):
7071
lq.layers.QuantConv2D(
7172
32,
7273
(3, 3),
73-
input_quantizer="ste_sign",
74-
kernel_quantizer="ste_sign",
74+
input_quantizer=binary_quantizer,
75+
kernel_quantizer=binary_quantizer,
7576
padding="same",
7677
pad_values=1.0,
7778
use_bias=False,
@@ -85,8 +86,8 @@ def toy_model_sequential(**kwargs):
8586
lq.layers.QuantConv2D(
8687
32,
8788
(3, 3),
88-
input_quantizer="ste_sign",
89-
kernel_quantizer="ste_sign",
89+
input_quantizer=binary_quantizer,
90+
kernel_quantizer=binary_quantizer,
9091
strides=(2, 2),
9192
padding="same",
9293
pad_values=1.0,
@@ -104,8 +105,8 @@ def toy_model_sequential(**kwargs):
104105
lq.layers.QuantConv2D(
105106
32,
106107
(3, 3),
107-
input_quantizer="ste_sign",
108-
kernel_quantizer="ste_sign",
108+
input_quantizer=binary_quantizer,
109+
kernel_quantizer=binary_quantizer,
109110
padding="same",
110111
pad_values=1.0,
111112
use_bias=False,
@@ -165,12 +166,25 @@ def dataset():
165166
)
166167

167168

169+
def tf_where_binary_quantizer(x):
170+
return tf.where(x >= 0, tf.ones_like(x), -tf.ones_like(x))
171+
172+
168173
@pytest.mark.parametrize(
169174
"conversion_function", [convert_keras_model, convert_keras_model_as_saved_model]
170175
)
171176
@pytest.mark.parametrize(
172177
"model_cls",
173-
[toy_model, toy_model_sequential, toy_model_int8, lqz.sota.QuickNetSmall],
178+
[
179+
toy_model,
180+
functools.partial(toy_model, binary_quantizer=tf_where_binary_quantizer),
181+
toy_model_sequential,
182+
functools.partial(
183+
toy_model_sequential, binary_quantizer=tf_where_binary_quantizer
184+
),
185+
toy_model_int8,
186+
lqz.sota.QuickNetSmall,
187+
],
174188
)
175189
def test_simple_model(dataset, conversion_function, model_cls):
176190
model = model_cls(weights="imagenet")

larq_compute_engine/tflite/kernels/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ cc_library(
4141
"//larq_compute_engine/core/indirect_bgemm:kernels",
4242
"@flatbuffers",
4343
"@org_tensorflow//tensorflow/lite:framework",
44+
"@org_tensorflow//tensorflow/lite:type_to_tflitetype",
4445
"@org_tensorflow//tensorflow/lite/kernels/internal:kernel_utils",
4546
"@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
4647
"@ruy//ruy/profiler:instrumentation",

0 commit comments

Comments
 (0)