diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 24dd9afd7bfa..5aaa1417ae4d 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -626,6 +626,7 @@ def callback( ofm_zero_point=params.ofm.q_params.zero_point, pool_shape=params.pool_shape, ofm_channels=params.ofm.shape[channels_map[str(params.ofm.layout)]], + ofm_dtype=params.ofm.dtype, strides=params.strides, padding=params.padding, activation=activation, @@ -975,10 +976,8 @@ def __init__(self): class MeanRewriter(DFPatternCallback): """Convert ethosu.mean composite functions to an equivalent legalization: - - Case 1 (axis == [1, 2] and keepsdims == True): - ethosu_depthwise_conv2d + ethosu_binary_elementwise - - Case 2 (ifm qparams == ofm qparams): ethosu_pooling - - Case 3 (else): ethosu_depthwise_conv2d + - Case 1 (ifm qparams == ofm qparams): ethosu_pooling + - Case 2 (else): ethosu_depthwise_conv2d """ def __init__(self): @@ -1021,56 +1020,7 @@ def callback( filter_height = 1 reduced_op = relay.reshape(reduced_op, ifm_shape) - if axis == [1, 2] and params.keepdims: - weight_scale = 1 - weight_values = np.ones([out_channels, filter_height, filter_width, 1]) - scale_bias = vela_api.pack_biases( - biases=np.zeros(ifm_shape[-1]), - ifm_scale=params.ifm.q_params.scale_f32, - ifm_dtype=np.dtype(params.ifm.dtype), - weight_scales=np.array([weight_scale], dtype=np.float), - ofm_scale=params.ofm.q_params.scale_f32, - is_activation_tanh_or_sigmoid=False, - ) - - reduced_op = ethosu_ops.ethosu_depthwise_conv2d( - ifm=reduced_op, - weight=relay.const(weight_values, params.ifm.dtype), - scale_bias=relay.const(scale_bias, "uint8"), - lut=lut, - ifm_scale=float(params.ifm.q_params.scale_f32), - ifm_zero_point=int(params.ifm.q_params.zero_point), - weight_zero_point=0, - ofm_scale=float(params.ofm.q_params.scale_f32), - ofm_zero_point=int(params.ofm.q_params.zero_point), - kernel_shape=(filter_height, filter_width), - ofm_channels=out_channels, - ofm_dtype="int16", - ) - - n = int(filter_height * filter_width) - eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0 - - scalar_tensor = relay.const(np.ones([1, 1, 1, 1], dtype="int16"), dtype="int16") - - reduced_op = ethosu_ops.ethosu_binary_elementwise( - ifm=reduced_op, - ifm2=scalar_tensor, - lut=lut, - operator_type="MUL", - ifm_scale=float(params.ofm.q_params.scale_f32), - ifm_zero_point=int(params.ofm.q_params.zero_point), - ifm2_scale=1 / (n - eps), - ifm2_zero_point=0, - ofm_scale=float(params.ofm.q_params.scale_f32), - ofm_zero_point=int(params.ofm.q_params.zero_point), - ifm_channels=out_channels, - ifm2_channels=out_channels, - reversed_operands=False, - ofm_dtype="int8", - rounding_mode="NATURAL", - ) - elif ( + if ( params.ifm.q_params.scale_f32 == params.ofm.q_params.scale_f32 and params.ifm.q_params.zero_point == params.ofm.q_params.zero_point ): @@ -1084,6 +1034,7 @@ def callback( ofm_zero_point=0, pool_shape=(filter_height, filter_width), ofm_channels=out_channels, + ofm_dtype=params.ofm.dtype, rounding_mode="TRUNCATE", ) else: @@ -1112,6 +1063,7 @@ def callback( kernel_shape=(filter_height, filter_width), ofm_channels=out_channels, rounding_mode="NATURAL", + ofm_dtype=params.ofm.dtype, ) # Reshape to original ofm shape @@ -1168,6 +1120,7 @@ def callback( ofm_zero_point=0, pool_shape=(1, 1), ofm_channels=1, + ofm_dtype="int32", activation=activation, clip_min=clip_min, clip_max=clip_max, @@ -1319,6 +1272,7 @@ def callback( ofm_zero_point=int(params.ofm.q_params.zero_point), pool_shape=pool_shape, ofm_channels=in_channels, + ofm_dtype=params.ofm.dtype, strides=[1, 1], padding=padding, upscale="NEAREST", diff --git a/python/tvm/relay/backend/contrib/ethosu/op/pooling.py b/python/tvm/relay/backend/contrib/ethosu/op/pooling.py index 2d5aff9bec3c..4d12704acb0f 100644 --- a/python/tvm/relay/backend/contrib/ethosu/op/pooling.py +++ b/python/tvm/relay/backend/contrib/ethosu/op/pooling.py @@ -39,6 +39,7 @@ def _extract_ethosu_pooling_params(attrs, args): ofm_zero_point = attrs.ofm_zero_point pool_shape = attrs.pool_shape ofm_channels = attrs.ofm_channels + ofm_dtype = attrs.ofm_dtype strides = attrs.strides padding = attrs.padding activation = attrs.activation @@ -59,6 +60,7 @@ def _extract_ethosu_pooling_params(attrs, args): ofm_zero_point, pool_shape, ofm_channels, + ofm_dtype, strides, padding, activation, @@ -100,6 +102,7 @@ def ethosu_pooling( ofm_zero_point: int, pool_shape: Tuple[int, int], ofm_channels: int, + ofm_dtype: str, strides: Tuple[int, int] = (1, 1), padding: Tuple[int, int, int, int] = (0, 0, 0, 0), activation: str = "NONE", @@ -121,7 +124,7 @@ def ethosu_pooling( lut : tvm.relay.Expr The look-up table of values to use if activation = "LUT". pooling_type: str - The type of the pooling. "AVG" - average pool, "MAX" - max pool. + The type of the pooling. "AVG" - average pool, "MAX" - max pool, "SUM" - reduce sum pool. ifm_scale : float The quantization scale for the Input Feature Map tensor. ifm_zero_point : int @@ -134,6 +137,10 @@ def ethosu_pooling( The 2 dimensional pool shape as (pool_shape_height, pool_shape_width). ofm_channels : int The number of the Output Feature Map channels + ofm_dtype : str + The Output Feature Map tensor data type. + "AVG" or "MAX" pooling - can be "int8", "uint8", or "int16". + "SUM" pooling - can be "int32". strides : tuple of int, optional The 2 dimensional strides as (stride_height, stride_width). padding : tuple of int, optional @@ -179,6 +186,7 @@ def ethosu_pooling( ofm_zero_point, pool_shape, ofm_channels, + ofm_dtype, strides, padding, activation, diff --git a/python/tvm/relay/backend/contrib/ethosu/te/pooling.py b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py index 6843046fd01e..730810324041 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/pooling.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py @@ -36,6 +36,7 @@ def pooling_compute( ofm_zero_point: int, pool_shape: Tuple[int, int], ofm_channels: int, + ofm_dtype: str, strides: Tuple[int, int], padding: Tuple[int, int, int, int], activation: str, @@ -68,6 +69,10 @@ def pooling_compute( The 2 dimensional pool shape as (pool_shape_height, pool_shape_width). ofm_channels : int The number of the Output Feature Map channels + ofm_dtype : str + The Output Feature Map tensor data type. + "AVG" or "MAX" pooling - can be "int8", "uint8", or "int16". + "SUM" pooling - can be "int32". strides : Tuple[int, int] The 2 dimensional strides as (stride_height, stride_width). padding : Tuple[int, int, int, int] @@ -124,7 +129,6 @@ def pooling_compute( rh = te.reduce_axis((0, pool_shape_h), name="ry") rw = te.reduce_axis((0, pool_shape_w), name="rx") rc = te.reduce_axis((0, 1 if pooling_type != "SUM" else ifm_channels), name="rc") - ofm_dtype = ifm.dtype if pooling_type != "SUM" else "int32" pooling_attrs = { "op": "ethosu_pooling", diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index d74140da5db2..8ec06d3a923e 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -1336,30 +1336,46 @@ def check_axis(num_dims, axis): return axis in ([0], [1], [0, 1]) return axis in ([1], [2], [1, 2]) - tensor_params = [self.ifm, self.ofm] - if not check_valid_dtypes(tensor_params, supported_dtypes=[np.int8]): + def check_single_axis_across_height(num_dims, axis): + return len(axis) == 1 and (num_dims in (2, 3) and axis == [0] or axis == [1]) + + same_quantization = ( + self.ifm.q_params.scale_f32 == self.ofm.q_params.scale_f32 + and self.ifm.q_params.zero_point == self.ofm.q_params.zero_point + ) + + # IFM must be int8 or uint8 + if not check_valid_dtypes([self.ifm], [np.int8, np.uint8]): return False - if self.ifm.dtype != self.ofm.dtype: + # OFM must be int8, uint8 or int16 + if not check_valid_dtypes([self.ofm], [np.int8, np.uint8, np.int16]): return False + # Input tensor must be at least 2D if not len(self.ifm.shape) in [2, 3, 4]: return False + # Axis indices must correspond to height and width axes if not check_axis(len(self.ifm.shape), self.axis): return False - # MEAN has further restrictions on the input size, depending on legalization method. input_size = self.height * self.width + + # Product of height and width must be no greater than 65536 if input_size > 65536: return False - if ( - self.ifm.q_params.scale_f32 != self.ofm.q_params.scale_f32 - or self.ifm.q_params.zero_point != self.ofm.q_params.zero_point - ) and input_size > 4096: - return False - if self.axis == [1, 2] and self.keepdims and self.ifm.dtype == "int8" and input_size > 256: - return False - # Large kernel height reshape only when axis is [1, 2] - if self.axis != [1, 2] and self.height > 64: + # Product of height and width must be no greater than 4096 when: + # IFM and OFM have different scale or zero point; or + # 'keep_dims' is True + if input_size > 4096 and (not same_quantization or self.keepdims): return False + # For single axis averages across the height dimension: + if check_single_axis_across_height(len(self.ifm.shape), self.axis): + # IFM height must be no greater than 256 if the IFM and OFM scale and zero point match + if self.height > 256 and same_quantization: + return False + # IFM height must be no greater than 64 if the IFM and OFM scale or zero point + # do not match + if self.height > 64 and not same_quantization: + return False return True diff --git a/src/relay/op/contrib/ethosu/op_attrs.h b/src/relay/op/contrib/ethosu/op_attrs.h index e4ba2cfb9bad..74e7fe856e89 100644 --- a/src/relay/op/contrib/ethosu/op_attrs.h +++ b/src/relay/op/contrib/ethosu/op_attrs.h @@ -349,6 +349,7 @@ struct EthosuPoolingAttrs : public tvm::AttrsNode { int ofm_zero_point; Array pool_shape; IndexExpr ofm_channels; + String ofm_dtype; Array strides; Array padding; String activation; @@ -376,6 +377,10 @@ struct EthosuPoolingAttrs : public tvm::AttrsNode { TVM_ATTR_FIELD(ofm_channels) .describe(" The number of the Output Feature Map channels.") .set_default(NullValue()); + TVM_ATTR_FIELD(ofm_dtype).describe( + "The Output Feature Map tensor data type. " + "'AVG' or 'MAX' pooling - can be 'int8', 'uint8', or 'int16'. " + "'SUM' pooling - can be 'int32'."); TVM_ATTR_FIELD(strides) .set_default(Array({1, 1})) .describe("The 2 dimensional strides as (stride_height, stride_width)."); diff --git a/src/relay/op/contrib/ethosu/pooling.cc b/src/relay/op/contrib/ethosu/pooling.cc index a9c072a01121..92e704f667ed 100644 --- a/src/relay/op/contrib/ethosu/pooling.cc +++ b/src/relay/op/contrib/ethosu/pooling.cc @@ -61,15 +61,27 @@ bool EthosuPoolingRel(const Array& types, int num_inputs, const Attrs& att DataType::Int(16), DataType::Int(32)}; std::initializer_list& allowed_ifm_dtypes = max_avg_pooling_ifm_dtypes; - auto ofm_dtype = ifm->dtype; if (param->pooling_type == "SUM") { allowed_ifm_dtypes = sum_pooling_ifm_dtypes; - ofm_dtype = DataType::Int(32); } CheckDataType(reporter, ifm->dtype, allowed_ifm_dtypes, operator_name, "ifm", param->pooling_type); + DataType ofm_dtype = DataTypeFromString(param->ofm_dtype); + + std::initializer_list max_avg_pooling_ofm_dtypes = {DataType::Int(8), DataType::UInt(8), + DataType::Int(16)}; + if (param->pooling_type == "AVG" || param->pooling_type == "MAX") { + CheckDataType(reporter, ofm_dtype, max_avg_pooling_ofm_dtypes, operator_name, "ofm", + param->pooling_type); + CheckDataTypeMatch(reporter, ofm_dtype, ifm->dtype, operator_name, "ifm", "ofm", + param->pooling_type); + } else { + CheckDataType(reporter, ofm_dtype, {DataType::Int(32)}, operator_name, "ofm", + param->pooling_type); + } + CheckUpscaleMethod(reporter, param->upscale, {"NONE", "ZEROS", "NEAREST"}, operator_name); Array ifm_shape = ifm->shape; @@ -88,7 +100,7 @@ bool EthosuPoolingRel(const Array& types, int num_inputs, const Attrs& att Expr MakeEthosuPooling(Expr ifm, Expr lut, String pooling_type, double ifm_scale, int ifm_zero_point, double ofm_scale, int ofm_zero_point, - Array pool_shape, IndexExpr ofm_channels, + Array pool_shape, IndexExpr ofm_channels, String ofm_dtype, Array strides, Array padding, String activation, int clip_min, int clip_max, String rounding_mode, String upscale, String ifm_layout, String ofm_layout) { @@ -100,6 +112,7 @@ Expr MakeEthosuPooling(Expr ifm, Expr lut, String pooling_type, double ifm_scale attrs->ofm_zero_point = ofm_zero_point; attrs->pool_shape = std::move(pool_shape); attrs->ofm_channels = std::move(ofm_channels); + attrs->ofm_dtype = std::move(ofm_dtype); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->activation = std::move(activation); diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_pooling_matcher.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_pooling_matcher.py index 38aeee05f936..1faec87ba2aa 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_ethosu_pooling_matcher.py +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_pooling_matcher.py @@ -49,6 +49,7 @@ def test_ethosu_pooling_matcher(pool_shape, stride, padding, ifm_layout, ofm_lay ofm_zero_point=0, pool_shape=pool_shape, ofm_channels=ofm_channels, + ofm_dtype="int8", strides=stride, padding=padding, activation="NONE", diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index b205a6d3350b..c621155827a9 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -639,6 +639,7 @@ def make_ethosu_pooling( pooling_type, pool_shape, ofm_channels, + ofm_dtype, strides, padding, activation="NONE", @@ -657,6 +658,7 @@ def make_ethosu_pooling( ofm_zero_point=0, pool_shape=pool_shape, ofm_channels=ofm_channels, + ofm_dtype=ofm_dtype, strides=strides, padding=padding, activation=activation, diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 1df9e8891495..14441d8e9313 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -391,31 +391,29 @@ def binary_elementwise(lhs, rhs): ) -@pytest.mark.skip(reason="See https://github.com/apache/tvm/issues/12634") @pytest.mark.parametrize( "accel_type", ACCEL_TYPES, ) @pytest.mark.parametrize( - "ifm_shape, axis, keep_dims, use_same_quantization", + "ifm_shape, axis, keep_dims, use_same_quantization, dtype", [ - # mean to depthwise + multiply - [(1, 8, 16, 16), (1, 2), True, False], - [(1, 3, 4), (0, 1), True, False], - [(1, 65, 2, 1), (1, 2), True, False], # special case when h > 64 # mean to average pool - [(1, 8, 16, 16), (2,), False, True], - [(3, 3, 4), (0,), True, True], - [(8, 5), (0,), False, True], + [(1, 8, 16, 16), (2,), False, True, "int8"], + [(1, 8, 16, 16), (2,), False, True, "uint8"], + [(3, 3, 4), (0,), True, True, "int8"], + [(8, 5), (0,), False, True, "int8"], # mean to depthwise - [(1, 8, 16, 16), (2,), True, False], - [(1, 8, 16, 16), (2, 1), False, False], - [(8, 4), (0,), False, False], + [(1, 8, 16, 16), (2,), True, False, "int8"], + [(1, 8, 16, 16), (2,), True, False, "uint8"], + [(1, 8, 16, 16), (2, 1), False, False, "int8"], + [(8, 4), (0,), False, False, "int8"], + [(1, 65, 2, 1), (1, 2), True, False, "int8"], # special case when h > 64 + [(1, 65, 2, 1), (1, 2), True, False, "uint8"], # special case when h > 64 ], ) -def test_mean(accel_type, ifm_shape, axis, keep_dims, use_same_quantization): +def test_mean(accel_type, ifm_shape, axis, keep_dims, use_same_quantization, dtype): np.random.seed(0) - dtype = "int8" def create_mod_from_tflite(): class Model(tf.Module): @@ -462,12 +460,14 @@ def create_mod_from_relay(): input_zero_point=relay.const(0, dtype="int32"), output_scale=relay.const(1.0, dtype="float32"), output_zero_point=relay.const(0, dtype="int32"), + out_dtype=dtype, ) func = relay.Function(relay.analysis.free_vars(requantize), requantize) mod = tvm.IRModule.from_expr(func) - input_data = {"input": np.random.randint(low=-127, high=128, size=ifm_shape, dtype=dtype)} + low, high = (0, 256) if dtype == "uint8" else (-127, 128) + input_data = {"input": np.random.randint(low=low, high=high, size=ifm_shape, dtype=dtype)} output_data = generate_ref_data(mod, input_data) return mod, input_data, output_data @@ -546,6 +546,7 @@ def create_model(): pooling_type="SUM", pool_shape=(1, 1), ofm_channels=1, + ofm_dtype="int32", strides=(1, 1), padding=(0, 0, 0, 0), rounding_mode="NATURAL", diff --git a/tests/python/contrib/test_ethosu/test_identity_optimizer.py b/tests/python/contrib/test_ethosu/test_identity_optimizer.py index f90f0f2e627d..3ae58dfc81ba 100644 --- a/tests/python/contrib/test_ethosu/test_identity_optimizer.py +++ b/tests/python/contrib/test_ethosu/test_identity_optimizer.py @@ -78,12 +78,14 @@ def test_simple_strided_slice_identity_removal(): in the graph and a compute operation follows.""" def get_graph(get_expected=False): - x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8") - x = infra.make_ethosu_pooling(x, "MAX", (1, 1), 4, (1, 1), (0, 0)) + dtype = "int8" + + x = relay.var("x", shape=(1, 2, 2, 4), dtype=dtype) + x = infra.make_ethosu_pooling(x, "MAX", (1, 1), 4, dtype, (1, 1), (0, 0)) x = relay.strided_slice(x, begin=[0, 0, 0, 0], end=[1, 2, 2, 2]) if not get_expected: x = infra.make_ethosu_identity(x) - x = infra.make_ethosu_pooling(x, "MAX", (1, 1), 2, (1, 1), (0, 0)) + x = infra.make_ethosu_pooling(x, "MAX", (1, 1), 2, dtype, (1, 1), (0, 0)) return relay.Function(relay.analysis.free_vars(x), x) actual = _optimize(get_graph()) @@ -95,9 +97,11 @@ def test_no_identity(): """Check the graph is not affected when there is no identity in the graph.""" def get_graph(): - x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8") + dtype = "int8" + + x = relay.var("x", shape=(1, 2, 2, 4), dtype=dtype) x = infra.make_ethosu_conv2d(x, 4, 4, (1, 1), (0, 0), (1, 1), (1, 1)) - x = infra.make_ethosu_pooling(x, "MAX", (1, 1), 4, (1, 1), (0, 0)) + x = infra.make_ethosu_pooling(x, "MAX", (1, 1), 4, dtype, (1, 1), (0, 0)) x = infra.make_ethosu_depthwise_conv2d(x, 4, (1, 1), (0, 0), (1, 1), (1, 1)) x = infra.make_ethosu_unary_elementwise(x, 4, "ABS") return relay.Function(relay.analysis.free_vars(x), x) diff --git a/tests/python/contrib/test_ethosu/test_layout_optimizer.py b/tests/python/contrib/test_ethosu/test_layout_optimizer.py index 05b9dce4c929..69d549acbb3b 100644 --- a/tests/python/contrib/test_ethosu/test_layout_optimizer.py +++ b/tests/python/contrib/test_ethosu/test_layout_optimizer.py @@ -147,6 +147,7 @@ def get_graph(get_expected=False): pooling_type="SUM", pool_shape=(1, 1), ofm_channels=1, + ofm_dtype="int32", strides=(1, 1), padding=(0, 0), ifm_layout=layout, @@ -330,13 +331,16 @@ def test_ignore_concatnate_with_layout_transform(): """ def get_graph(): - in_1 = relay.var("x", shape=(1, 16, 16, 8), dtype="int8") - in_2 = relay.var("y", shape=(1, 16, 16, 8), dtype="int8") + dtype = "int8" + + in_1 = relay.var("x", shape=(1, 16, 16, 8), dtype=dtype) + in_2 = relay.var("y", shape=(1, 16, 16, 8), dtype=dtype) pool_1 = infra.make_ethosu_pooling( in_1, "MAX", (1, 1), ofm_channels=8, + ofm_dtype=dtype, strides=(1, 1), padding=(0, 0), ifm_layout="NHWC", @@ -347,6 +351,7 @@ def get_graph(): "MAX", (1, 1), ofm_channels=8, + ofm_dtype=dtype, strides=(1, 1), padding=(0, 0), ifm_layout="NHWC", @@ -358,6 +363,7 @@ def get_graph(): "MAX", (1, 1), ofm_channels=8, + ofm_dtype=dtype, strides=(1, 1), padding=(0, 0), ifm_layout="NHWC", @@ -385,12 +391,15 @@ def test_multiple_inputs(): def get_graph(): poolings = [] for _ in range(3): - inp = relay.var("x", shape=(1, 3, 3, 4), dtype="int8") + dtype = "int8" + + inp = relay.var("x", shape=(1, 3, 3, 4), dtype=dtype) pool = infra.make_ethosu_pooling( inp, "MAX", (1, 1), ofm_channels=4, + ofm_dtype=dtype, strides=(1, 1), padding=(0, 0), ifm_layout="NHWC", @@ -428,12 +437,15 @@ def test_multiple_outputs(): """ def get_graph(get_expected=False): - in_1 = relay.var("x", shape=(1, 4, 4, 8), dtype="int8") + dtype = "int8" + + in_1 = relay.var("x", shape=(1, 4, 4, 8), dtype=dtype) pool_1 = infra.make_ethosu_pooling( in_1, "MAX", (1, 1), ofm_channels=4, + ofm_dtype=dtype, strides=(1, 1), padding=(0, 0), ifm_layout="NHWC", @@ -447,6 +459,7 @@ def get_graph(get_expected=False): "MAX", (1, 1), ofm_channels=4, + ofm_dtype=dtype, strides=(1, 1), padding=(0, 0), ifm_layout="NHCWB16" if get_expected else "NHWC", @@ -527,7 +540,9 @@ def test_multiple_pooling(): """ def get_graph(get_expected=False): - x = relay.var("x", shape=(1, 8, 8, 4), dtype="int8") + dtype = "int8" + + x = relay.var("x", shape=(1, 8, 8, 4), dtype=dtype) for i in range(3): ifm_layout = "NHCWB16" if get_expected and i != 0 else "NHWC" ofm_layout = "NHCWB16" if get_expected and i != 2 else "NHWC" @@ -536,6 +551,7 @@ def get_graph(get_expected=False): "MAX", (1, 1), ofm_channels=4, + ofm_dtype=dtype, strides=(1, 1), padding=(0, 0), ifm_layout=ifm_layout, @@ -594,8 +610,9 @@ def test_op_without_ethosu_consumer(): def get_graph(get_expected=False): exp_layout = "NHCWB16" if get_expected else "NHWC" + dtype = "int8" - x = relay.var("x", shape=(1, 2, 2, 2), dtype="int8") + x = relay.var("x", shape=(1, 2, 2, 2), dtype=dtype) depthwise = infra.make_ethosu_depthwise_conv2d( x, 2, (1, 1), (0, 0), (1, 1), (0, 0), ofm_layout=exp_layout ) @@ -609,7 +626,7 @@ def get_graph(get_expected=False): (0, 0), ifm_layout=exp_layout, ) - pool = infra.make_ethosu_pooling(conv, "MAX", (1, 1), 2, (1, 1), (0, 0)) + pool = infra.make_ethosu_pooling(conv, "MAX", (1, 1), 2, dtype, (1, 1), (0, 0)) concat = relay.concatenate([conv, pool], axis=0) return relay.Function(relay.analysis.free_vars(concat), concat) @@ -639,21 +656,31 @@ def test_diamond_graph(): def get_graph(get_expected=False): exp_layout = "NHCWB16" if get_expected else "NHWC" - x = relay.var("x", shape=(1, 2, 2, 2), dtype="int8") + dtype = "int8" + + x = relay.var("x", shape=(1, 2, 2, 2), dtype=dtype) pool_1 = infra.make_ethosu_pooling( - x, "MAX", (1, 1), 2, (1, 1), (0, 0), ofm_layout=exp_layout + x, "MAX", (1, 1), 2, dtype, (1, 1), (0, 0), ofm_layout=exp_layout ) pool_2 = infra.make_ethosu_pooling( - pool_1, "MAX", (1, 1), 2, (1, 1), (0, 0), ifm_layout=exp_layout + pool_1, "MAX", (1, 1), 2, dtype, (1, 1), (0, 0), ifm_layout=exp_layout ) pool_3 = infra.make_ethosu_pooling( - pool_2, "MAX", (1, 1), 2, (1, 1), (0, 0), ofm_layout=exp_layout + pool_2, "MAX", (1, 1), 2, dtype, (1, 1), (0, 0), ofm_layout=exp_layout ) pool_4 = infra.make_ethosu_pooling( - pool_3, "MAX", (1, 1), 2, (1, 1), (0, 0), ifm_layout=exp_layout, ofm_layout=exp_layout + pool_3, + "MAX", + (1, 1), + 2, + dtype, + (1, 1), + (0, 0), + ifm_layout=exp_layout, + ofm_layout=exp_layout, ) pool_5 = infra.make_ethosu_pooling( - pool_4, "MAX", (1, 1), 2, (1, 1), (0, 0), ifm_layout=exp_layout + pool_4, "MAX", (1, 1), 2, dtype, (1, 1), (0, 0), ifm_layout=exp_layout ) concat = relay.concatenate([pool_2, pool_5], axis=0) return relay.Function(relay.analysis.free_vars(concat), concat) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 594f4a0e2aef..6330930fa5f8 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -1535,15 +1535,10 @@ def representative_dataset(): assert tuple(func_body.args[1].checked_type.shape) == (256,) +@pytest.mark.parametrize("dtype", ["int8", "uint8"]) @pytest.mark.parametrize( "ifm_shape, axis, keep_dims, use_same_quantization", [ - # mean to depthwise + multiply - [(1, 8, 16, 16), (1, 2), True, False], - [(1, 8, 16, 16), (2, 1), True, False], - [(1, 3, 4), (0, 1), True, False], - [(8, 5), (1, 0), True, False], - [(1, 65, 2, 1), (1, 2), True, False], # special case when h > 64 # mean to average pool [(1, 8, 16, 16), (1,), True, True], [(1, 8, 16, 16), (2,), False, True], @@ -1557,11 +1552,10 @@ def representative_dataset(): [(1, 8, 16, 16), (2,), True, False], [(1, 8, 16, 16), (1, 2), False, False], [(8, 4), (0,), False, False], + [(1, 65, 2, 1), (1, 2), True, False], # special case when h > 64 ], ) -def test_mean(ifm_shape, axis, keep_dims, use_same_quantization): - dtype = "int8" - +def test_mean(ifm_shape, axis, keep_dims, use_same_quantization, dtype): def create_tflite_graph(): class Model(tf.Module): @tf.function @@ -1606,6 +1600,7 @@ def create_relay_graph_with_same_quantization(): input_zero_point=relay.const(0, dtype="int32"), output_scale=relay.const(1.0, dtype="float32"), output_zero_point=relay.const(0, dtype="int32"), + out_dtype=dtype, ) func = relay.Function(relay.analysis.free_vars(requantize), requantize) @@ -1616,7 +1611,6 @@ def verify(ext_func): out_var = ext_func.body next_op = out_var - mul_op = None pooling_op = None depthwise_op = None if ( @@ -1625,9 +1619,6 @@ def verify(ext_func): and next_op.op.name == "reshape" ): next_op = next_op.args[0] - if util.is_named_ethosu_op(next_op, "binary_elementwise"): - mul_op = next_op - next_op = next_op.args[0] if util.is_named_ethosu_op(next_op, "pooling"): pooling_op = next_op next_op = next_op.args[0] @@ -1654,24 +1645,33 @@ def calculate_expected_output_shape(): # check IFM assert tuple(in_var.checked_type.shape) == ifm_shape - assert in_var.checked_type.dtype == dtype + + if use_same_quantization: + assert in_var.checked_type.dtype == dtype + else: + # in_var's dtype is equal to int8 due to TFLite's requantize + assert in_var.checked_type.dtype == "int8" # check OFM assert tuple(out_var.checked_type.shape) == out_shape - assert out_var.checked_type.dtype == dtype + if use_same_quantization: + assert out_var.checked_type.dtype == dtype + else: + # out_var's dtype is equal to int8 due to TFLite's requantize + assert out_var.checked_type.dtype == "int8" # check expected legalization case - if axis in [(1, 2), (2, 1), (0, 1), (1, 0)] and keep_dims and dtype == "int8": - assert depthwise_op and mul_op - assert mul_op.attrs.operator_type == "MUL" - elif pooling_op: + if pooling_op: attrs = pooling_op.attrs assert ( attrs.ifm_scale == attrs.ofm_scale and attrs.ifm_zero_point == attrs.ofm_zero_point ) else: assert depthwise_op - assert not mul_op + attrs = depthwise_op.attrs + assert ( + attrs.ifm_scale != attrs.ofm_scale or attrs.ifm_zero_point != attrs.ofm_zero_point + ) rewriter = legalize.MeanRewriter() pattern_table = [ diff --git a/tests/python/contrib/test_ethosu/test_replace_pooling.py b/tests/python/contrib/test_ethosu/test_replace_pooling.py index 1ef59e0b9b03..e4438eb62abd 100644 --- a/tests/python/contrib/test_ethosu/test_replace_pooling.py +++ b/tests/python/contrib/test_ethosu/test_replace_pooling.py @@ -169,12 +169,15 @@ def test_avg_max_pooling_single( # hardcoded padding values are used for each case. padding = (1, 1, 1, 0) if upscale == "NONE" else (0, 0, 0, 0) - ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + dtype = "int8" + + ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype) pooling = make_ethosu_pooling( ifm, pooling_type, pool_shape, ofm_channels, + dtype, strides, padding, activation, @@ -232,6 +235,7 @@ def test_sum_pooling_single( pooling_type="SUM", pool_shape=(1, 1), ofm_channels=1, + ofm_dtype="int32", strides=(1, 1), padding=(0, 0, 0, 0), activation=activation, @@ -276,13 +280,15 @@ def test_correct_stride_with_multiple_pooling(): pool_shape = (1, 1) strides = (1, 1) padding = (0, 0, 0, 0) + dtype = "int8" - ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype) op = make_ethosu_pooling( ifm, pooling_type, pool_shape, ofm_channels, + dtype, strides, padding, ifm_layout="NHWC", @@ -293,6 +299,7 @@ def test_correct_stride_with_multiple_pooling(): pooling_type, pool_shape, ofm_channels, + dtype, strides, padding, ifm_layout="NHCWB16", diff --git a/tests/python/contrib/test_ethosu/test_type_inference.py b/tests/python/contrib/test_ethosu/test_type_inference.py index 380d7532f8c1..48a4dbde81c3 100644 --- a/tests/python/contrib/test_ethosu/test_type_inference.py +++ b/tests/python/contrib/test_ethosu/test_type_inference.py @@ -201,6 +201,7 @@ def test_ethosu_pooling_type_inference( pooling_type, pool_shape, ofm_channels, + dtype, strides, padding, ifm_layout=ifm_layout, @@ -215,6 +216,7 @@ def test_ethosu_pooling_type_inference( def test_ethosu_pooling_invalid_pooling_type(): invalid_pooling_type = "A" dtype = "int8" + ifm = relay.var("ifm", shape=[1, 56, 72, 55], dtype=dtype) pool_shape = (3, 2) ofm_channels = 55 @@ -225,6 +227,7 @@ def test_ethosu_pooling_invalid_pooling_type(): invalid_pooling_type, pool_shape, ofm_channels, + dtype, strides, padding, ) @@ -246,6 +249,7 @@ def test_ethosu_pooling_invalid_dtype(): pooling_type, pool_shape, ofm_channels, + "int8", strides, padding, ) @@ -256,12 +260,15 @@ def test_ethosu_pooling_invalid_dtype(): def test_ethosu_pooling_invalid_upscale_method(): invalid_upscale_method = "FOO" - ifm = relay.var("ifm", shape=[1, 56, 72, 55], dtype="int8") + dtype = "int8" + + ifm = relay.var("ifm", shape=[1, 56, 72, 55], dtype=dtype) pooling = make_ethosu_pooling( ifm, "MAX", (3, 2), 55, + dtype, (1, 2), (0, 1, 2, 3), upscale=invalid_upscale_method,