Skip to content

Commit

Permalink
[microNPU] Add support for MEAN with uint8 ifm (apache#14353)
Browse files Browse the repository at this point in the history
This PR involves supporting the legalization case of MEAN where axis == [1, 2], keep_dims == True and input dtype == 'uint8'.
  • Loading branch information
ilyag-grovety authored Apr 13, 2023
1 parent 606e2b7 commit 815422c
Show file tree
Hide file tree
Showing 14 changed files with 177 additions and 128 deletions.
62 changes: 8 additions & 54 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,7 @@ def callback(
ofm_zero_point=params.ofm.q_params.zero_point,
pool_shape=params.pool_shape,
ofm_channels=params.ofm.shape[channels_map[str(params.ofm.layout)]],
ofm_dtype=params.ofm.dtype,
strides=params.strides,
padding=params.padding,
activation=activation,
Expand Down Expand Up @@ -975,10 +976,8 @@ def __init__(self):

class MeanRewriter(DFPatternCallback):
"""Convert ethosu.mean composite functions to an equivalent legalization:
- Case 1 (axis == [1, 2] and keepsdims == True):
ethosu_depthwise_conv2d + ethosu_binary_elementwise
- Case 2 (ifm qparams == ofm qparams): ethosu_pooling
- Case 3 (else): ethosu_depthwise_conv2d
- Case 1 (ifm qparams == ofm qparams): ethosu_pooling
- Case 2 (else): ethosu_depthwise_conv2d
"""

def __init__(self):
Expand Down Expand Up @@ -1021,56 +1020,7 @@ def callback(
filter_height = 1
reduced_op = relay.reshape(reduced_op, ifm_shape)

if axis == [1, 2] and params.keepdims:
weight_scale = 1
weight_values = np.ones([out_channels, filter_height, filter_width, 1])
scale_bias = vela_api.pack_biases(
biases=np.zeros(ifm_shape[-1]),
ifm_scale=params.ifm.q_params.scale_f32,
ifm_dtype=np.dtype(params.ifm.dtype),
weight_scales=np.array([weight_scale], dtype=np.float),
ofm_scale=params.ofm.q_params.scale_f32,
is_activation_tanh_or_sigmoid=False,
)

reduced_op = ethosu_ops.ethosu_depthwise_conv2d(
ifm=reduced_op,
weight=relay.const(weight_values, params.ifm.dtype),
scale_bias=relay.const(scale_bias, "uint8"),
lut=lut,
ifm_scale=float(params.ifm.q_params.scale_f32),
ifm_zero_point=int(params.ifm.q_params.zero_point),
weight_zero_point=0,
ofm_scale=float(params.ofm.q_params.scale_f32),
ofm_zero_point=int(params.ofm.q_params.zero_point),
kernel_shape=(filter_height, filter_width),
ofm_channels=out_channels,
ofm_dtype="int16",
)

n = int(filter_height * filter_width)
eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0

scalar_tensor = relay.const(np.ones([1, 1, 1, 1], dtype="int16"), dtype="int16")

reduced_op = ethosu_ops.ethosu_binary_elementwise(
ifm=reduced_op,
ifm2=scalar_tensor,
lut=lut,
operator_type="MUL",
ifm_scale=float(params.ofm.q_params.scale_f32),
ifm_zero_point=int(params.ofm.q_params.zero_point),
ifm2_scale=1 / (n - eps),
ifm2_zero_point=0,
ofm_scale=float(params.ofm.q_params.scale_f32),
ofm_zero_point=int(params.ofm.q_params.zero_point),
ifm_channels=out_channels,
ifm2_channels=out_channels,
reversed_operands=False,
ofm_dtype="int8",
rounding_mode="NATURAL",
)
elif (
if (
params.ifm.q_params.scale_f32 == params.ofm.q_params.scale_f32
and params.ifm.q_params.zero_point == params.ofm.q_params.zero_point
):
Expand All @@ -1084,6 +1034,7 @@ def callback(
ofm_zero_point=0,
pool_shape=(filter_height, filter_width),
ofm_channels=out_channels,
ofm_dtype=params.ofm.dtype,
rounding_mode="TRUNCATE",
)
else:
Expand Down Expand Up @@ -1112,6 +1063,7 @@ def callback(
kernel_shape=(filter_height, filter_width),
ofm_channels=out_channels,
rounding_mode="NATURAL",
ofm_dtype=params.ofm.dtype,
)

# Reshape to original ofm shape
Expand Down Expand Up @@ -1168,6 +1120,7 @@ def callback(
ofm_zero_point=0,
pool_shape=(1, 1),
ofm_channels=1,
ofm_dtype="int32",
activation=activation,
clip_min=clip_min,
clip_max=clip_max,
Expand Down Expand Up @@ -1319,6 +1272,7 @@ def callback(
ofm_zero_point=int(params.ofm.q_params.zero_point),
pool_shape=pool_shape,
ofm_channels=in_channels,
ofm_dtype=params.ofm.dtype,
strides=[1, 1],
padding=padding,
upscale="NEAREST",
Expand Down
10 changes: 9 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/op/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def _extract_ethosu_pooling_params(attrs, args):
ofm_zero_point = attrs.ofm_zero_point
pool_shape = attrs.pool_shape
ofm_channels = attrs.ofm_channels
ofm_dtype = attrs.ofm_dtype
strides = attrs.strides
padding = attrs.padding
activation = attrs.activation
Expand All @@ -59,6 +60,7 @@ def _extract_ethosu_pooling_params(attrs, args):
ofm_zero_point,
pool_shape,
ofm_channels,
ofm_dtype,
strides,
padding,
activation,
Expand Down Expand Up @@ -100,6 +102,7 @@ def ethosu_pooling(
ofm_zero_point: int,
pool_shape: Tuple[int, int],
ofm_channels: int,
ofm_dtype: str,
strides: Tuple[int, int] = (1, 1),
padding: Tuple[int, int, int, int] = (0, 0, 0, 0),
activation: str = "NONE",
Expand All @@ -121,7 +124,7 @@ def ethosu_pooling(
lut : tvm.relay.Expr
The look-up table of values to use if activation = "LUT".
pooling_type: str
The type of the pooling. "AVG" - average pool, "MAX" - max pool.
The type of the pooling. "AVG" - average pool, "MAX" - max pool, "SUM" - reduce sum pool.
ifm_scale : float
The quantization scale for the Input Feature Map tensor.
ifm_zero_point : int
Expand All @@ -134,6 +137,10 @@ def ethosu_pooling(
The 2 dimensional pool shape as (pool_shape_height, pool_shape_width).
ofm_channels : int
The number of the Output Feature Map channels
ofm_dtype : str
The Output Feature Map tensor data type.
"AVG" or "MAX" pooling - can be "int8", "uint8", or "int16".
"SUM" pooling - can be "int32".
strides : tuple of int, optional
The 2 dimensional strides as (stride_height, stride_width).
padding : tuple of int, optional
Expand Down Expand Up @@ -179,6 +186,7 @@ def ethosu_pooling(
ofm_zero_point,
pool_shape,
ofm_channels,
ofm_dtype,
strides,
padding,
activation,
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/te/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def pooling_compute(
ofm_zero_point: int,
pool_shape: Tuple[int, int],
ofm_channels: int,
ofm_dtype: str,
strides: Tuple[int, int],
padding: Tuple[int, int, int, int],
activation: str,
Expand Down Expand Up @@ -68,6 +69,10 @@ def pooling_compute(
The 2 dimensional pool shape as (pool_shape_height, pool_shape_width).
ofm_channels : int
The number of the Output Feature Map channels
ofm_dtype : str
The Output Feature Map tensor data type.
"AVG" or "MAX" pooling - can be "int8", "uint8", or "int16".
"SUM" pooling - can be "int32".
strides : Tuple[int, int]
The 2 dimensional strides as (stride_height, stride_width).
padding : Tuple[int, int, int, int]
Expand Down Expand Up @@ -124,7 +129,6 @@ def pooling_compute(
rh = te.reduce_axis((0, pool_shape_h), name="ry")
rw = te.reduce_axis((0, pool_shape_w), name="rx")
rc = te.reduce_axis((0, 1 if pooling_type != "SUM" else ifm_channels), name="rc")
ofm_dtype = ifm.dtype if pooling_type != "SUM" else "int32"

pooling_attrs = {
"op": "ethosu_pooling",
Expand Down
42 changes: 29 additions & 13 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,30 +1336,46 @@ def check_axis(num_dims, axis):
return axis in ([0], [1], [0, 1])
return axis in ([1], [2], [1, 2])

tensor_params = [self.ifm, self.ofm]
if not check_valid_dtypes(tensor_params, supported_dtypes=[np.int8]):
def check_single_axis_across_height(num_dims, axis):
return len(axis) == 1 and (num_dims in (2, 3) and axis == [0] or axis == [1])

same_quantization = (
self.ifm.q_params.scale_f32 == self.ofm.q_params.scale_f32
and self.ifm.q_params.zero_point == self.ofm.q_params.zero_point
)

# IFM must be int8 or uint8
if not check_valid_dtypes([self.ifm], [np.int8, np.uint8]):
return False
if self.ifm.dtype != self.ofm.dtype:
# OFM must be int8, uint8 or int16
if not check_valid_dtypes([self.ofm], [np.int8, np.uint8, np.int16]):
return False
# Input tensor must be at least 2D
if not len(self.ifm.shape) in [2, 3, 4]:
return False
# Axis indices must correspond to height and width axes
if not check_axis(len(self.ifm.shape), self.axis):
return False

# MEAN has further restrictions on the input size, depending on legalization method.
input_size = self.height * self.width

# Product of height and width must be no greater than 65536
if input_size > 65536:
return False
if (
self.ifm.q_params.scale_f32 != self.ofm.q_params.scale_f32
or self.ifm.q_params.zero_point != self.ofm.q_params.zero_point
) and input_size > 4096:
return False
if self.axis == [1, 2] and self.keepdims and self.ifm.dtype == "int8" and input_size > 256:
return False
# Large kernel height reshape only when axis is [1, 2]
if self.axis != [1, 2] and self.height > 64:
# Product of height and width must be no greater than 4096 when:
# IFM and OFM have different scale or zero point; or
# 'keep_dims' is True
if input_size > 4096 and (not same_quantization or self.keepdims):
return False
# For single axis averages across the height dimension:
if check_single_axis_across_height(len(self.ifm.shape), self.axis):
# IFM height must be no greater than 256 if the IFM and OFM scale and zero point match
if self.height > 256 and same_quantization:
return False
# IFM height must be no greater than 64 if the IFM and OFM scale or zero point
# do not match
if self.height > 64 and not same_quantization:
return False
return True


Expand Down
5 changes: 5 additions & 0 deletions src/relay/op/contrib/ethosu/op_attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ struct EthosuPoolingAttrs : public tvm::AttrsNode<EthosuPoolingAttrs> {
int ofm_zero_point;
Array<IndexExpr> pool_shape;
IndexExpr ofm_channels;
String ofm_dtype;
Array<IndexExpr> strides;
Array<IndexExpr> padding;
String activation;
Expand Down Expand Up @@ -376,6 +377,10 @@ struct EthosuPoolingAttrs : public tvm::AttrsNode<EthosuPoolingAttrs> {
TVM_ATTR_FIELD(ofm_channels)
.describe(" The number of the Output Feature Map channels.")
.set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(ofm_dtype).describe(
"The Output Feature Map tensor data type. "
"'AVG' or 'MAX' pooling - can be 'int8', 'uint8', or 'int16'. "
"'SUM' pooling - can be 'int32'.");
TVM_ATTR_FIELD(strides)
.set_default(Array<IndexExpr>({1, 1}))
.describe("The 2 dimensional strides as (stride_height, stride_width).");
Expand Down
19 changes: 16 additions & 3 deletions src/relay/op/contrib/ethosu/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,27 @@ bool EthosuPoolingRel(const Array<Type>& types, int num_inputs, const Attrs& att
DataType::Int(16), DataType::Int(32)};

std::initializer_list<DataType>& allowed_ifm_dtypes = max_avg_pooling_ifm_dtypes;
auto ofm_dtype = ifm->dtype;
if (param->pooling_type == "SUM") {
allowed_ifm_dtypes = sum_pooling_ifm_dtypes;
ofm_dtype = DataType::Int(32);
}

CheckDataType(reporter, ifm->dtype, allowed_ifm_dtypes, operator_name, "ifm",
param->pooling_type);

DataType ofm_dtype = DataTypeFromString(param->ofm_dtype);

std::initializer_list<DataType> max_avg_pooling_ofm_dtypes = {DataType::Int(8), DataType::UInt(8),
DataType::Int(16)};
if (param->pooling_type == "AVG" || param->pooling_type == "MAX") {
CheckDataType(reporter, ofm_dtype, max_avg_pooling_ofm_dtypes, operator_name, "ofm",
param->pooling_type);
CheckDataTypeMatch(reporter, ofm_dtype, ifm->dtype, operator_name, "ifm", "ofm",
param->pooling_type);
} else {
CheckDataType(reporter, ofm_dtype, {DataType::Int(32)}, operator_name, "ofm",
param->pooling_type);
}

CheckUpscaleMethod(reporter, param->upscale, {"NONE", "ZEROS", "NEAREST"}, operator_name);

Array<IndexExpr> ifm_shape = ifm->shape;
Expand All @@ -88,7 +100,7 @@ bool EthosuPoolingRel(const Array<Type>& types, int num_inputs, const Attrs& att

Expr MakeEthosuPooling(Expr ifm, Expr lut, String pooling_type, double ifm_scale,
int ifm_zero_point, double ofm_scale, int ofm_zero_point,
Array<IndexExpr> pool_shape, IndexExpr ofm_channels,
Array<IndexExpr> pool_shape, IndexExpr ofm_channels, String ofm_dtype,
Array<IndexExpr> strides, Array<IndexExpr> padding, String activation,
int clip_min, int clip_max, String rounding_mode, String upscale,
String ifm_layout, String ofm_layout) {
Expand All @@ -100,6 +112,7 @@ Expr MakeEthosuPooling(Expr ifm, Expr lut, String pooling_type, double ifm_scale
attrs->ofm_zero_point = ofm_zero_point;
attrs->pool_shape = std::move(pool_shape);
attrs->ofm_channels = std::move(ofm_channels);
attrs->ofm_dtype = std::move(ofm_dtype);
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->activation = std::move(activation);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def test_ethosu_pooling_matcher(pool_shape, stride, padding, ifm_layout, ofm_lay
ofm_zero_point=0,
pool_shape=pool_shape,
ofm_channels=ofm_channels,
ofm_dtype="int8",
strides=stride,
padding=padding,
activation="NONE",
Expand Down
2 changes: 2 additions & 0 deletions tests/python/contrib/test_ethosu/infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,7 @@ def make_ethosu_pooling(
pooling_type,
pool_shape,
ofm_channels,
ofm_dtype,
strides,
padding,
activation="NONE",
Expand All @@ -657,6 +658,7 @@ def make_ethosu_pooling(
ofm_zero_point=0,
pool_shape=pool_shape,
ofm_channels=ofm_channels,
ofm_dtype=ofm_dtype,
strides=strides,
padding=padding,
activation=activation,
Expand Down
Loading

0 comments on commit 815422c

Please sign in to comment.