From e11bf380a24b9777078bccd47f15e9ad94c42f94 Mon Sep 17 00:00:00 2001 From: Lior Dikstein <78903511+lior-dikstein@users.noreply.github.com> Date: Tue, 13 Aug 2024 14:32:26 +0300 Subject: [PATCH] Added tests for Keras to check the manual bit width selection feature. Added an informative message in the case that a node has an existing manual bit width configuration but a new manual configuration request of has been received, and the previous value is being overridden (#1160) Co-authored-by: liord --- .../common/quantization/bit_width_config.py | 7 +- .../feature_networks/manual_bit_selection.py | 175 ++++++++++++++++++ .../test_features_runner.py | 93 ++++++++++ .../feature_models/manual_bit_selection.py | 30 ++- .../model_tests/test_feature_models_runner.py | 10 +- 5 files changed, 290 insertions(+), 25 deletions(-) create mode 100644 tests/keras_tests/feature_networks_tests/feature_networks/manual_bit_selection.py diff --git a/model_compression_toolkit/core/common/quantization/bit_width_config.py b/model_compression_toolkit/core/common/quantization/bit_width_config.py index d2d0bd1c4..e057f0c54 100644 --- a/model_compression_toolkit/core/common/quantization/bit_width_config.py +++ b/model_compression_toolkit/core/common/quantization/bit_width_config.py @@ -87,5 +87,10 @@ def get_nodes_to_manipulate_bit_widths(self, graph: Graph) -> Dict: if len(filtered_nodes) == 0: Logger.critical(f"Node Filtering Error: No nodes found in the graph for filter {manual_bit_width_selection.filter.__dict__} " f"to change their bit width to {manual_bit_width_selection.bit_width}.") - nodes_to_change_bit_width.update({n: manual_bit_width_selection.bit_width for n in filtered_nodes}) + for n in filtered_nodes: + # check if a manual configuration exists for this node + if n in nodes_to_change_bit_width: + Logger.info( + f"Node {n} has an existing manual bit width configuration of {nodes_to_change_bit_width.get(n)}. A new manual configuration request of {manual_bit_width_selection.bit_width} has been received, and the previous value is being overridden.") + nodes_to_change_bit_width.update({n: manual_bit_width_selection.bit_width}) return nodes_to_change_bit_width \ No newline at end of file diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/manual_bit_selection.py b/tests/keras_tests/feature_networks_tests/feature_networks/manual_bit_selection.py new file mode 100644 index 000000000..b9098931f --- /dev/null +++ b/tests/keras_tests/feature_networks_tests/feature_networks/manual_bit_selection.py @@ -0,0 +1,175 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import numpy as np +import tensorflow as tf + +import model_compression_toolkit as mct +from mct_quantizers import KerasActivationQuantizationHolder, KerasQuantizationWrapper +from model_compression_toolkit.constants import TENSORFLOW +from model_compression_toolkit.core.common.network_editors import NodeNameFilter, NodeTypeFilter +from model_compression_toolkit.target_platform_capabilities.constants import IMX500_TP_MODEL +from tests.common_tests.helpers.generate_test_tp_model import generate_test_op_qc, generate_test_attr_configs +from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest +from tests.keras_tests.tpc_keras import get_tpc_with_activation_mp_keras + +keras = tf.keras +layers = keras.layers + +get_op_set = lambda x, x_list: [op_set for op_set in x_list if op_set.name == x][0] + + +class ManualBitWidthSelectionTest(BaseKerasFeatureNetworkTest): + """ + This test check the manual bit width configuration. + Call it with a layer type filter or list of layer type filters, bit width or list of bit widths. + Uses the manual bit width API in the "get_core_configs" method. + """ + + def __init__(self, unit_test, filters, bit_widths): + self.filters = filters + self.bit_widths = bit_widths + self.layer_types = {} + self.layer_names = {} + self.functional_names = {} + + filters = [filters] if not isinstance(filters, list) else filters + bit_widths = [bit_widths] if not isinstance(bit_widths, list) else bit_widths + if len(bit_widths) < len(filters): + bit_widths = [bit_widths[0] for f in filters] + for filter, bit_width in zip(filters, bit_widths): + if isinstance(filter, NodeNameFilter): + self.layer_names.update({filter.node_name: bit_width}) + elif isinstance(filter, NodeTypeFilter): + self.layer_types.update({filter.node_type: bit_width}) + super().__init__(unit_test) + + def create_networks(self): + input_tensor = layers.Input(shape=self.get_input_shapes()[0][1:], name='input') + x1 = layers.Conv2D(filters=32, kernel_size=(1, 1), padding='same', name='conv1')(input_tensor) + x1 = layers.Add(name='add1')([x1, np.ones((3,), dtype=np.float32)]) + + # Second convolutional block + x2 = layers.Conv2D(filters=32, kernel_size=(1, 1), padding='same', name='conv2')(x1) + x2 = layers.BatchNormalization(name='bn1')(x2) + x2 = layers.ReLU(name='relu1')(x2) + + # Addition + x = layers.Add(name='add2')([x1, x2]) + + # Flatten and fully connected layer + x = layers.Flatten()(x) + output_tensor = layers.Dense(units=10, activation='softmax', name='fc')(x) + + return keras.Model(inputs=input_tensor, outputs=output_tensor) + + def get_tpc(self): + eight_bits = generate_test_op_qc(**generate_test_attr_configs()) + default_config = eight_bits.clone_and_edit(attr_weights_configs_mapping={}) + # set only 8 and 4 bit candidates for test, to verify that all layers get exactly 4 bits + mixed_precision_candidates_list = [(8, 8), (8, 4), (8, 2), (4, 8), (4, 4), (4, 2), (2, 8), (2, 4), (2, 2)] + + return get_tpc_with_activation_mp_keras(base_config=eight_bits, + default_config=default_config, + mp_bitwidth_candidates_list=mixed_precision_candidates_list, + name="mixed_precision_4bit_test") + + def get_mp_core_config(self): + qc = mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.MSE, mct.core.QuantizationErrorMethod.MSE, + relu_bound_to_power_of_2=False, weights_bias_correction=True, + input_scaling=False, activation_channel_equalization=False) + mpc = mct.core.MixedPrecisionQuantizationConfig(num_of_images=1) + + core_config = mct.core.CoreConfig(quantization_config=qc, mixed_precision_config=mpc) + return core_config + + def get_core_config(self): + # Configures the core settings including manual bit width adjustments. + core_config = self.get_mp_core_config() + core_config.bit_width_config.set_manual_activation_bit_width(self.filters, self.bit_widths) + return core_config + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + # in the compare we need bit_widths to be a list + bit_widths = [self.bit_widths] if not isinstance(self.bit_widths, list) else self.bit_widths + + for layer in quantized_model.layers: + # check if the layer is an activation quantizer + if isinstance(layer, KerasActivationQuantizationHolder): + # get the layer that's activation is being quantized + layer_q = quantized_model.layers[quantized_model.layers.index(layer) - 1] + if isinstance(layer_q, KerasQuantizationWrapper): + layer_q = layer_q.layer + # check if this layer is in the layer types to change bit width and check that the correct bit width was applied. + layer_q_bit_width = self.layer_names.get(layer_q.name) if self.layer_names.get( + layer_q.name) is not None else self.layer_types.get(type(layer_q)) + if layer_q_bit_width is not None: + self.unit_test.assertTrue(layer.activation_holder_quantizer.num_bits == layer_q_bit_width) + else: + # make sure that the bit width of other layers was not changed. + self.unit_test.assertFalse(layer.activation_holder_quantizer.num_bits in bit_widths, + msg=f"name {layer_q.name}, layer.activation_holder_quantizer.num_bits {layer.activation_holder_quantizer.num_bits}, {self.bit_widths}") + + +class Manual16BitWidthSelectionTest(ManualBitWidthSelectionTest): + """ + This test check the manual bit width configuration for 16 bits. + The network is built such that one multiply can be configured to 16 bit (mul1) and one cannot (mul2). + Call it with a layer type filter or list of layer type filters, bit width or list of bit widths. + Uses the manual bit width API in the "get_core_configs" method. + """ + def get_tpc(self): + tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, 'v4') + # Force Mul base_config to 16bit only + mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) + mul_op_set.qc_options.base_config = \ + [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] + tpc.layer2qco[tf.multiply].base_config = mul_op_set.qc_options.base_config + return tpc + + def create_networks(self): + inputs = layers.Input(shape=self.get_input_shapes()[0][1:], name='input') + x = layers.Multiply(name='mul1')([inputs, inputs]) + x1 = layers.Add(name='add1')([x, x]) + x2 = layers.Subtract(name='sub1')([x1, x]) + x = layers.Multiply(name='mul2')([x, x2]) + x = layers.Conv2D(3, 1, name='conv1')(x) + outputs = tf.divide(x, 2 * np.ones((3,), dtype=np.float32)) + return keras.Model(inputs=inputs, outputs=outputs) + + +class Manual16BitWidthSelectionMixedPrecisionTest(Manual16BitWidthSelectionTest): + """ + This test check the manual bit width configuration for 16 bits with mixed precision. + The network is built such that one multiply can be configured to 16 bit (mul1) and one cannot (mul2). + Call it with a layer type filter or list of layer type filters, bit width or list of bit widths. + Uses the manual bit width API in the "get_core_configs" method. + """ + def get_tpc(self): + tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, 'v4') + mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) + mul_op_set.qc_options.base_config = \ + [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] + tpc.layer2qco[tf.multiply].base_config = mul_op_set.qc_options.base_config + mul_op_set.qc_options.quantization_config_list.extend( + [mul_op_set.qc_options.base_config.clone_and_edit(activation_n_bits=4), + mul_op_set.qc_options.base_config.clone_and_edit(activation_n_bits=2)]) + tpc.layer2qco[tf.multiply].quantization_config_list.extend([ + tpc.layer2qco[tf.multiply].base_config.clone_and_edit(activation_n_bits=4), + tpc.layer2qco[tf.multiply].base_config.clone_and_edit(activation_n_bits=2)]) + + return tpc + + def get_resource_utilization(self): + return mct.core.ResourceUtilization(activation_memory=400) diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index 28b3bfb31..d238b3be2 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -24,6 +24,7 @@ from model_compression_toolkit.core import QuantizationErrorMethod from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting +from model_compression_toolkit.core.common.network_editors import NodeTypeFilter, NodeNameFilter from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod from model_compression_toolkit.gptq import RoundingType from model_compression_toolkit.target_platform_capabilities import constants as C @@ -56,6 +57,8 @@ ThreeConv2DCollapsingTest, FourConv2DCollapsingTest, SixConv2DCollapsingTest, Op2DAddConstCollapsingTest from tests.keras_tests.feature_networks_tests.feature_networks.lut_quantizer import LUTWeightsQuantizerTest, \ LUTActivationQuantizerTest +from tests.keras_tests.feature_networks_tests.feature_networks.manual_bit_selection import ManualBitWidthSelectionTest, \ + Manual16BitWidthSelectionTest, Manual16BitWidthSelectionMixedPrecisionTest from tests.keras_tests.feature_networks_tests.feature_networks.mixed_precision.requires_mixed_precision_test import \ RequiresMixedPrecision, RequiresMixedPrecisionWeights from tests.keras_tests.feature_networks_tests.feature_networks.mixed_precision_bops_test import \ @@ -809,6 +812,96 @@ def test_16bit_activations(self): Activation16BitTest(self).run_test() Activation16BitMixedPrecisionTest(self).run_test() + def test_invalid_bit_width_selection(self): + with self.assertRaises(Exception) as context: + ManualBitWidthSelectionTest(self, NodeTypeFilter(layers.Conv2D), 7).run_test() + # Check that the correct exception message was raised + self.assertEqual(str(context.exception), + "Manually selected activation bit-width 7 is invalid for node Conv2D:conv1.") + + with self.assertRaises(Exception) as context: + ManualBitWidthSelectionTest(self, NodeTypeFilter(layers.Add), 3).run_test() + # Check that the correct exception message was raised + self.assertEqual(str(context.exception), + "Manually selected activation bit-width 3 is invalid for node Add:add2.") + + with self.assertRaises(Exception) as context: + ManualBitWidthSelectionTest(self, NodeNameFilter('relu1'), 3).run_test() + # Check that the correct exception message was raised + self.assertEqual(str(context.exception), + "Manually selected activation bit-width 3 is invalid for node ReLU:relu1.") + + def test_mul_16_bit_manual_selection(self): + """ + This test checks the execptions in the manual bit-width selection feature. + """ + # This "mul" can be configured to 16 bit + Manual16BitWidthSelectionTest(self, NodeNameFilter('mul1'), 16).run_test() + Manual16BitWidthSelectionMixedPrecisionTest(self, NodeNameFilter('mul1'), 16).run_test() + + # This "mul" cannot be configured to 16 bit + with self.assertRaises(Exception) as context: + Manual16BitWidthSelectionTest(self, NodeNameFilter('mul2'), 16).run_test() + # Check that the correct exception message was raised + self.assertEqual(str(context.exception), + "Manually selected activation bit-width 16 is invalid for node Multiply:mul2.") + + # This "mul" cannot be configured to 16 bit + with self.assertRaises(Exception) as context: + Manual16BitWidthSelectionMixedPrecisionTest(self, NodeNameFilter('mul2'), 16).run_test() + # Check that the correct exception message was raised + self.assertEqual(str(context.exception), + "Manually selected activation bit-width 16 is invalid for node Multiply:mul2.") + + def test_exceptions_manual_selection(self): + """ + This test checks the execptions in the manual bit-width selection feature. + """ + # Node name doesn't exist in graph + with self.assertRaises(Exception) as context: + Manual16BitWidthSelectionTest(self, NodeNameFilter('mul_3'), 16).run_test() + # Check that the correct exception message was raised + self.assertEqual(str(context.exception), + "Node Filtering Error: No nodes found in the graph for filter {'node_name': 'mul_3'} to change their bit width to 16.") + + # Invalid inputs to API + with self.assertRaises(Exception) as context: + ManualBitWidthSelectionTest(self, + [NodeNameFilter('relu1'), NodeNameFilter('add1'), NodeNameFilter('add2')], + [2, 4]).run_test() + # Check that the correct exception message was raised + self.assertEqual(str(context.exception), + "Configuration Error: The number of provided bit_width values 2 must match the number of filters 3, or a single bit_width value should be provided for all filters.") + + def test_manual_bit_width_selection(self): + """ + This test checks the manual bit-width selection feature. + """ + ManualBitWidthSelectionTest(self, NodeTypeFilter(layers.Conv2D), 4).run_test() + ManualBitWidthSelectionTest(self, NodeTypeFilter(layers.Conv2D), 2).run_test() + ManualBitWidthSelectionTest(self, NodeTypeFilter(layers.Dense), 4).run_test() + ManualBitWidthSelectionTest(self, NodeTypeFilter(layers.Add), 4).run_test() + ManualBitWidthSelectionTest(self, NodeTypeFilter(layers.Add), 2).run_test() + ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Conv2D), NodeTypeFilter(layers.Dense)], + [2, 4]).run_test() + ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Conv2D), NodeTypeFilter(layers.Dense)], + [4, 4]).run_test() + ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Conv2D), NodeTypeFilter(layers.Add)], + [2, 4]).run_test() + ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Add), NodeTypeFilter(layers.Conv2D)], + [4, 4]).run_test() + ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Add), NodeTypeFilter(layers.Dense)], + 4).run_test() + ManualBitWidthSelectionTest(self, NodeNameFilter('input'), 4).run_test() + ManualBitWidthSelectionTest(self, NodeNameFilter('conv1'), 4).run_test() + ManualBitWidthSelectionTest(self, NodeNameFilter('fc'), 4).run_test() + ManualBitWidthSelectionTest(self, NodeNameFilter('add1'), 4).run_test() + ManualBitWidthSelectionTest(self, NodeNameFilter('add2'), 4).run_test() + ManualBitWidthSelectionTest(self, NodeNameFilter('relu1'), 4).run_test() + ManualBitWidthSelectionTest(self, [NodeNameFilter('add1'), NodeNameFilter('conv1')], [2, 4]).run_test() + ManualBitWidthSelectionTest(self, [NodeNameFilter('add2'), NodeNameFilter('relu1')], 4).run_test() + ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Add), NodeNameFilter('add2')],[4, 2]).run_test() + if __name__ == '__main__': unittest.main() diff --git a/tests/pytorch_tests/model_tests/feature_models/manual_bit_selection.py b/tests/pytorch_tests/model_tests/feature_models/manual_bit_selection.py index d9ffb2d1a..1f93c9e04 100644 --- a/tests/pytorch_tests/model_tests/feature_models/manual_bit_selection.py +++ b/tests/pytorch_tests/model_tests/feature_models/manual_bit_selection.py @@ -53,15 +53,17 @@ def __init__(self, input_shape): b, in_channels, h, w = input_shape[0] self.conv1 = torch.nn.Conv2d(in_channels, in_channels, kernel_size=(1, 1)) self.bn1 = torch.nn.BatchNorm2d(in_channels) + self.bn2 = torch.nn.BatchNorm2d(in_channels) self.conv2 = torch.nn.Conv2d(in_channels, in_channels, kernel_size=(1, 1)) self.relu = torch.nn.ReLU() self.fc = torch.nn.Linear(in_channels * h * w, 5) def forward(self, inp): - out1 = self.conv1(inp) - out1 = out1 + 3 - x = self.bn1(out1) + x = self.conv1(inp) + out1 = self.bn1(x) + x = out1 + 3 x = self.conv2(x) + x = self.bn2(x) x = self.relu(x) x = x + out1 # Flatten the tensor @@ -84,6 +86,11 @@ def get_mp_core_config(self): core_config = mct.core.CoreConfig(quantization_config=qc, mixed_precision_config=mpc) return core_config + def get_core_configs(self): + # Configures the core settings including manual bit width adjustments. + core_config = self.get_mp_core_config() + core_config.bit_width_config.set_manual_activation_bit_width(self.filters, self.bit_widths) + return {"mixed_precision_activation_model": core_config} class ManualBitWidthByLayerTypeTest(BaseManualBitWidthSelectionTest): """ @@ -109,12 +116,6 @@ def __init__(self, unit_test, filters, bit_widths): super().__init__(unit_test) - def get_core_configs(self): - # Configures the core settings including manual bit width adjustments. - core_config = super().get_mp_core_config() - core_config.bit_width_config.set_manual_activation_bit_width(self.filters, self.bit_widths) - return {"mixed_precision_activation_model": core_config} - def compare(self, quantized_models, float_model, input_x=None, quantization_info=None): # in the compare we need bit_widths to be a list bit_widths = [self.bit_widths] if not isinstance(self.bit_widths, list) else self.bit_widths @@ -161,11 +162,6 @@ def __init__(self, unit_test, filters, bit_widths): super().__init__(unit_test) - def get_core_configs(self): - # Configures the core settings including manual bit width adjustments. - core_config = super().get_mp_core_config() - core_config.bit_width_config.set_manual_activation_bit_width(self.filters, self.bit_widths) - return {"mixed_precision_activation_model": core_config} def compare(self, quantized_models, float_model, input_x=None, quantization_info=None): # in the compare we need bit_widths to be a list @@ -187,10 +183,6 @@ def compare(self, quantized_models, float_model, input_x=None, quantization_info class Manual16BitTest(ManualBitWidthByLayerNameTest): - # def get_tpc(self): - # tpc = mct.get_target_platform_capabilities(PYTORCH, IMX500_TP_MODEL, 'v4') - # return {'mixed_precision_activation_model': tpc} - def get_tpc(self): tpc = mct.get_target_platform_capabilities(PYTORCH, IMX500_TP_MODEL, 'v4') mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) @@ -224,7 +216,7 @@ def get_tpc(self): return {'mixed_precision_activation_model': tpc} def get_resource_utilization(self): - return mct.core.ResourceUtilization(activation_memory=10000) + return mct.core.ResourceUtilization(activation_memory=6200) def create_feature_network(self, input_shape): diff --git a/tests/pytorch_tests/model_tests/test_feature_models_runner.py b/tests/pytorch_tests/model_tests/test_feature_models_runner.py index c0d67f0b7..0b0a8e32e 100644 --- a/tests/pytorch_tests/model_tests/test_feature_models_runner.py +++ b/tests/pytorch_tests/model_tests/test_feature_models_runner.py @@ -678,7 +678,7 @@ def test_invalid_bit_width_selection(self): with self.assertRaises(Exception) as context: ManualBitWidthByLayerTypeTest(self, NodeTypeFilter(torch.nn.Conv2d), 7).run_test() # Check that the correct exception message was raised - self.assertEqual(str(context.exception), "Manually selected activation bit-width 7 is invalid for node Conv2d:conv1.") + self.assertEqual(str(context.exception), "Manually selected activation bit-width 7 is invalid for node Conv2d:conv1_bn.") with self.assertRaises(Exception) as context: ManualBitWidthByLayerTypeTest(self, NodeTypeFilter(operator.add), 3).run_test() @@ -752,14 +752,14 @@ def test_manual_bit_width_selection_by_layer_name(self): This test checks the manual bit-width selection feature by layer name filtering. """ ManualBitWidthByLayerNameTest(self, NodeNameFilter('inp'), 4).run_test() - ManualBitWidthByLayerNameTest(self, NodeNameFilter('conv1'), 4).run_test() + ManualBitWidthByLayerNameTest(self, NodeNameFilter('conv1_bn'), 4).run_test() ManualBitWidthByLayerNameTest(self, NodeNameFilter('fc'), 4).run_test() ManualBitWidthByLayerNameTest(self, NodeNameFilter('add'), 4).run_test() ManualBitWidthByLayerNameTest(self, NodeNameFilter('add_1'), 4).run_test() - ManualBitWidthByLayerNameTest(self, NodeNameFilter('bn_conv2'), 4).run_test() + ManualBitWidthByLayerNameTest(self, NodeNameFilter('conv2_bn'), 4).run_test() ManualBitWidthByLayerNameTest(self, NodeNameFilter('relu'), 4).run_test() - ManualBitWidthByLayerNameTest(self, [NodeNameFilter('add'), NodeNameFilter('conv1')], [2, 4]).run_test() - ManualBitWidthByLayerNameTest(self, [NodeNameFilter('add'), NodeNameFilter('conv1')], 4).run_test() + ManualBitWidthByLayerNameTest(self, [NodeNameFilter('add'), NodeNameFilter('conv1_bn')], [2, 4]).run_test() + ManualBitWidthByLayerNameTest(self, [NodeNameFilter('add'), NodeNameFilter('conv1_bn')], 4).run_test()