Skip to content

Commit

Permalink
Added tests for Keras to check the manual bit width selection feature…
Browse files Browse the repository at this point in the history
…. Added an informative message in the case that a node has an existing manual bit width configuration but a new manual configuration request of has been received, and the previous value is being overridden (#1160)

Co-authored-by: liord <[email protected]>
  • Loading branch information
lior-dikstein and liord authored Aug 13, 2024
1 parent 65da7bd commit e11bf38
Show file tree
Hide file tree
Showing 5 changed files with 290 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,10 @@ def get_nodes_to_manipulate_bit_widths(self, graph: Graph) -> Dict:
if len(filtered_nodes) == 0:
Logger.critical(f"Node Filtering Error: No nodes found in the graph for filter {manual_bit_width_selection.filter.__dict__} "
f"to change their bit width to {manual_bit_width_selection.bit_width}.")
nodes_to_change_bit_width.update({n: manual_bit_width_selection.bit_width for n in filtered_nodes})
for n in filtered_nodes:
# check if a manual configuration exists for this node
if n in nodes_to_change_bit_width:
Logger.info(
f"Node {n} has an existing manual bit width configuration of {nodes_to_change_bit_width.get(n)}. A new manual configuration request of {manual_bit_width_selection.bit_width} has been received, and the previous value is being overridden.")
nodes_to_change_bit_width.update({n: manual_bit_width_selection.bit_width})
return nodes_to_change_bit_width
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import tensorflow as tf

import model_compression_toolkit as mct
from mct_quantizers import KerasActivationQuantizationHolder, KerasQuantizationWrapper
from model_compression_toolkit.constants import TENSORFLOW
from model_compression_toolkit.core.common.network_editors import NodeNameFilter, NodeTypeFilter
from model_compression_toolkit.target_platform_capabilities.constants import IMX500_TP_MODEL
from tests.common_tests.helpers.generate_test_tp_model import generate_test_op_qc, generate_test_attr_configs
from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest
from tests.keras_tests.tpc_keras import get_tpc_with_activation_mp_keras

keras = tf.keras
layers = keras.layers

get_op_set = lambda x, x_list: [op_set for op_set in x_list if op_set.name == x][0]


class ManualBitWidthSelectionTest(BaseKerasFeatureNetworkTest):
"""
This test check the manual bit width configuration.
Call it with a layer type filter or list of layer type filters, bit width or list of bit widths.
Uses the manual bit width API in the "get_core_configs" method.
"""

def __init__(self, unit_test, filters, bit_widths):
self.filters = filters
self.bit_widths = bit_widths
self.layer_types = {}
self.layer_names = {}
self.functional_names = {}

filters = [filters] if not isinstance(filters, list) else filters
bit_widths = [bit_widths] if not isinstance(bit_widths, list) else bit_widths
if len(bit_widths) < len(filters):
bit_widths = [bit_widths[0] for f in filters]
for filter, bit_width in zip(filters, bit_widths):
if isinstance(filter, NodeNameFilter):
self.layer_names.update({filter.node_name: bit_width})
elif isinstance(filter, NodeTypeFilter):
self.layer_types.update({filter.node_type: bit_width})
super().__init__(unit_test)

def create_networks(self):
input_tensor = layers.Input(shape=self.get_input_shapes()[0][1:], name='input')
x1 = layers.Conv2D(filters=32, kernel_size=(1, 1), padding='same', name='conv1')(input_tensor)
x1 = layers.Add(name='add1')([x1, np.ones((3,), dtype=np.float32)])

# Second convolutional block
x2 = layers.Conv2D(filters=32, kernel_size=(1, 1), padding='same', name='conv2')(x1)
x2 = layers.BatchNormalization(name='bn1')(x2)
x2 = layers.ReLU(name='relu1')(x2)

# Addition
x = layers.Add(name='add2')([x1, x2])

# Flatten and fully connected layer
x = layers.Flatten()(x)
output_tensor = layers.Dense(units=10, activation='softmax', name='fc')(x)

return keras.Model(inputs=input_tensor, outputs=output_tensor)

def get_tpc(self):
eight_bits = generate_test_op_qc(**generate_test_attr_configs())
default_config = eight_bits.clone_and_edit(attr_weights_configs_mapping={})
# set only 8 and 4 bit candidates for test, to verify that all layers get exactly 4 bits
mixed_precision_candidates_list = [(8, 8), (8, 4), (8, 2), (4, 8), (4, 4), (4, 2), (2, 8), (2, 4), (2, 2)]

return get_tpc_with_activation_mp_keras(base_config=eight_bits,
default_config=default_config,
mp_bitwidth_candidates_list=mixed_precision_candidates_list,
name="mixed_precision_4bit_test")

def get_mp_core_config(self):
qc = mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.MSE, mct.core.QuantizationErrorMethod.MSE,
relu_bound_to_power_of_2=False, weights_bias_correction=True,
input_scaling=False, activation_channel_equalization=False)
mpc = mct.core.MixedPrecisionQuantizationConfig(num_of_images=1)

core_config = mct.core.CoreConfig(quantization_config=qc, mixed_precision_config=mpc)
return core_config

def get_core_config(self):
# Configures the core settings including manual bit width adjustments.
core_config = self.get_mp_core_config()
core_config.bit_width_config.set_manual_activation_bit_width(self.filters, self.bit_widths)
return core_config

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
# in the compare we need bit_widths to be a list
bit_widths = [self.bit_widths] if not isinstance(self.bit_widths, list) else self.bit_widths

for layer in quantized_model.layers:
# check if the layer is an activation quantizer
if isinstance(layer, KerasActivationQuantizationHolder):
# get the layer that's activation is being quantized
layer_q = quantized_model.layers[quantized_model.layers.index(layer) - 1]
if isinstance(layer_q, KerasQuantizationWrapper):
layer_q = layer_q.layer
# check if this layer is in the layer types to change bit width and check that the correct bit width was applied.
layer_q_bit_width = self.layer_names.get(layer_q.name) if self.layer_names.get(
layer_q.name) is not None else self.layer_types.get(type(layer_q))
if layer_q_bit_width is not None:
self.unit_test.assertTrue(layer.activation_holder_quantizer.num_bits == layer_q_bit_width)
else:
# make sure that the bit width of other layers was not changed.
self.unit_test.assertFalse(layer.activation_holder_quantizer.num_bits in bit_widths,
msg=f"name {layer_q.name}, layer.activation_holder_quantizer.num_bits {layer.activation_holder_quantizer.num_bits}, {self.bit_widths}")


class Manual16BitWidthSelectionTest(ManualBitWidthSelectionTest):
"""
This test check the manual bit width configuration for 16 bits.
The network is built such that one multiply can be configured to 16 bit (mul1) and one cannot (mul2).
Call it with a layer type filter or list of layer type filters, bit width or list of bit widths.
Uses the manual bit width API in the "get_core_configs" method.
"""
def get_tpc(self):
tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, 'v4')
# Force Mul base_config to 16bit only
mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set)
mul_op_set.qc_options.base_config = \
[l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0]
tpc.layer2qco[tf.multiply].base_config = mul_op_set.qc_options.base_config
return tpc

def create_networks(self):
inputs = layers.Input(shape=self.get_input_shapes()[0][1:], name='input')
x = layers.Multiply(name='mul1')([inputs, inputs])
x1 = layers.Add(name='add1')([x, x])
x2 = layers.Subtract(name='sub1')([x1, x])
x = layers.Multiply(name='mul2')([x, x2])
x = layers.Conv2D(3, 1, name='conv1')(x)
outputs = tf.divide(x, 2 * np.ones((3,), dtype=np.float32))
return keras.Model(inputs=inputs, outputs=outputs)


class Manual16BitWidthSelectionMixedPrecisionTest(Manual16BitWidthSelectionTest):
"""
This test check the manual bit width configuration for 16 bits with mixed precision.
The network is built such that one multiply can be configured to 16 bit (mul1) and one cannot (mul2).
Call it with a layer type filter or list of layer type filters, bit width or list of bit widths.
Uses the manual bit width API in the "get_core_configs" method.
"""
def get_tpc(self):
tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, 'v4')
mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set)
mul_op_set.qc_options.base_config = \
[l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0]
tpc.layer2qco[tf.multiply].base_config = mul_op_set.qc_options.base_config
mul_op_set.qc_options.quantization_config_list.extend(
[mul_op_set.qc_options.base_config.clone_and_edit(activation_n_bits=4),
mul_op_set.qc_options.base_config.clone_and_edit(activation_n_bits=2)])
tpc.layer2qco[tf.multiply].quantization_config_list.extend([
tpc.layer2qco[tf.multiply].base_config.clone_and_edit(activation_n_bits=4),
tpc.layer2qco[tf.multiply].base_config.clone_and_edit(activation_n_bits=2)])

return tpc

def get_resource_utilization(self):
return mct.core.ResourceUtilization(activation_memory=400)
93 changes: 93 additions & 0 deletions tests/keras_tests/feature_networks_tests/test_features_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from model_compression_toolkit.core import QuantizationErrorMethod
from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting
from model_compression_toolkit.core.common.network_editors import NodeTypeFilter, NodeNameFilter
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from model_compression_toolkit.gptq import RoundingType
from model_compression_toolkit.target_platform_capabilities import constants as C
Expand Down Expand Up @@ -56,6 +57,8 @@
ThreeConv2DCollapsingTest, FourConv2DCollapsingTest, SixConv2DCollapsingTest, Op2DAddConstCollapsingTest
from tests.keras_tests.feature_networks_tests.feature_networks.lut_quantizer import LUTWeightsQuantizerTest, \
LUTActivationQuantizerTest
from tests.keras_tests.feature_networks_tests.feature_networks.manual_bit_selection import ManualBitWidthSelectionTest, \
Manual16BitWidthSelectionTest, Manual16BitWidthSelectionMixedPrecisionTest
from tests.keras_tests.feature_networks_tests.feature_networks.mixed_precision.requires_mixed_precision_test import \
RequiresMixedPrecision, RequiresMixedPrecisionWeights
from tests.keras_tests.feature_networks_tests.feature_networks.mixed_precision_bops_test import \
Expand Down Expand Up @@ -809,6 +812,96 @@ def test_16bit_activations(self):
Activation16BitTest(self).run_test()
Activation16BitMixedPrecisionTest(self).run_test()

def test_invalid_bit_width_selection(self):
with self.assertRaises(Exception) as context:
ManualBitWidthSelectionTest(self, NodeTypeFilter(layers.Conv2D), 7).run_test()
# Check that the correct exception message was raised
self.assertEqual(str(context.exception),
"Manually selected activation bit-width 7 is invalid for node Conv2D:conv1.")

with self.assertRaises(Exception) as context:
ManualBitWidthSelectionTest(self, NodeTypeFilter(layers.Add), 3).run_test()
# Check that the correct exception message was raised
self.assertEqual(str(context.exception),
"Manually selected activation bit-width 3 is invalid for node Add:add2.")

with self.assertRaises(Exception) as context:
ManualBitWidthSelectionTest(self, NodeNameFilter('relu1'), 3).run_test()
# Check that the correct exception message was raised
self.assertEqual(str(context.exception),
"Manually selected activation bit-width 3 is invalid for node ReLU:relu1.")

def test_mul_16_bit_manual_selection(self):
"""
This test checks the execptions in the manual bit-width selection feature.
"""
# This "mul" can be configured to 16 bit
Manual16BitWidthSelectionTest(self, NodeNameFilter('mul1'), 16).run_test()
Manual16BitWidthSelectionMixedPrecisionTest(self, NodeNameFilter('mul1'), 16).run_test()

# This "mul" cannot be configured to 16 bit
with self.assertRaises(Exception) as context:
Manual16BitWidthSelectionTest(self, NodeNameFilter('mul2'), 16).run_test()
# Check that the correct exception message was raised
self.assertEqual(str(context.exception),
"Manually selected activation bit-width 16 is invalid for node Multiply:mul2.")

# This "mul" cannot be configured to 16 bit
with self.assertRaises(Exception) as context:
Manual16BitWidthSelectionMixedPrecisionTest(self, NodeNameFilter('mul2'), 16).run_test()
# Check that the correct exception message was raised
self.assertEqual(str(context.exception),
"Manually selected activation bit-width 16 is invalid for node Multiply:mul2.")

def test_exceptions_manual_selection(self):
"""
This test checks the execptions in the manual bit-width selection feature.
"""
# Node name doesn't exist in graph
with self.assertRaises(Exception) as context:
Manual16BitWidthSelectionTest(self, NodeNameFilter('mul_3'), 16).run_test()
# Check that the correct exception message was raised
self.assertEqual(str(context.exception),
"Node Filtering Error: No nodes found in the graph for filter {'node_name': 'mul_3'} to change their bit width to 16.")

# Invalid inputs to API
with self.assertRaises(Exception) as context:
ManualBitWidthSelectionTest(self,
[NodeNameFilter('relu1'), NodeNameFilter('add1'), NodeNameFilter('add2')],
[2, 4]).run_test()
# Check that the correct exception message was raised
self.assertEqual(str(context.exception),
"Configuration Error: The number of provided bit_width values 2 must match the number of filters 3, or a single bit_width value should be provided for all filters.")

def test_manual_bit_width_selection(self):
"""
This test checks the manual bit-width selection feature.
"""
ManualBitWidthSelectionTest(self, NodeTypeFilter(layers.Conv2D), 4).run_test()
ManualBitWidthSelectionTest(self, NodeTypeFilter(layers.Conv2D), 2).run_test()
ManualBitWidthSelectionTest(self, NodeTypeFilter(layers.Dense), 4).run_test()
ManualBitWidthSelectionTest(self, NodeTypeFilter(layers.Add), 4).run_test()
ManualBitWidthSelectionTest(self, NodeTypeFilter(layers.Add), 2).run_test()
ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Conv2D), NodeTypeFilter(layers.Dense)],
[2, 4]).run_test()
ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Conv2D), NodeTypeFilter(layers.Dense)],
[4, 4]).run_test()
ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Conv2D), NodeTypeFilter(layers.Add)],
[2, 4]).run_test()
ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Add), NodeTypeFilter(layers.Conv2D)],
[4, 4]).run_test()
ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Add), NodeTypeFilter(layers.Dense)],
4).run_test()
ManualBitWidthSelectionTest(self, NodeNameFilter('input'), 4).run_test()
ManualBitWidthSelectionTest(self, NodeNameFilter('conv1'), 4).run_test()
ManualBitWidthSelectionTest(self, NodeNameFilter('fc'), 4).run_test()
ManualBitWidthSelectionTest(self, NodeNameFilter('add1'), 4).run_test()
ManualBitWidthSelectionTest(self, NodeNameFilter('add2'), 4).run_test()
ManualBitWidthSelectionTest(self, NodeNameFilter('relu1'), 4).run_test()
ManualBitWidthSelectionTest(self, [NodeNameFilter('add1'), NodeNameFilter('conv1')], [2, 4]).run_test()
ManualBitWidthSelectionTest(self, [NodeNameFilter('add2'), NodeNameFilter('relu1')], 4).run_test()
ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Add), NodeNameFilter('add2')],[4, 2]).run_test()


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit e11bf38

Please sign in to comment.