diff --git a/docsrc/source/api/experimental_api_docs/index.rst b/docsrc/source/api/experimental_api_docs/index.rst index 3a90e9bfc..c3b9e8ced 100644 --- a/docsrc/source/api/experimental_api_docs/index.rst +++ b/docsrc/source/api/experimental_api_docs/index.rst @@ -26,7 +26,8 @@ Functions - :ref:`pytorch_post_training_quantization_experimental`: A function to use for post training quantization of PyTorch models (experimental). - :ref:`keras_post_training_quantization_experimental`: A function to use for post training quantization of Keras models (experimental). - :ref:`keras_gradient_post_training_quantization_experimental`: A function to use for gradient-based post training quantization of Keras models (experimental). -- :ref:`keras_quantization_aware_training_experimental`: A function to use for preparing a model for Quantization Aware Training (Experimental) +- :ref:`keras_quantization_aware_training_init`: A function to use for preparing a model for Quantization Aware Training (Experimental) +- :ref:`keras_quantization_aware_training_finalize`: A function to finalize a model after Quantization Aware Training to a model without QuantizeWrappers(Experimental) - :ref:`get_keras_gptq_config`: A function to create a GradientPTQConfig instance to use for Keras models when using GPTQ (experimental). - :ref:`get_target_platform_capabilities`: A function to get a target platform model for Tensorflow and Pytorch. - :ref:`keras_kpi_data_experimental`: A function to compute KPI data that can be used to calculate the desired target KPI for mixed-precision quantization of Keras models (experimental). diff --git a/docsrc/source/api/experimental_api_docs/methods/keras_quantization_aware_training_export.rst b/docsrc/source/api/experimental_api_docs/methods/keras_quantization_aware_training_export.rst new file mode 100644 index 000000000..64b5f5272 --- /dev/null +++ b/docsrc/source/api/experimental_api_docs/methods/keras_quantization_aware_training_export.rst @@ -0,0 +1,10 @@ +:orphan: + +.. _ug-keras_quantization_aware_training_finalize: + + +================================================ +Keras Quantization Aware Training Model Finalize +================================================ + +.. autofunction:: model_compression_toolkit.keras_quantization_aware_training_finalize diff --git a/docsrc/source/api/experimental_api_docs/methods/keras_quantization_aware_training_experimental.rst b/docsrc/source/api/experimental_api_docs/methods/keras_quantization_aware_training_init.rst similarity index 81% rename from docsrc/source/api/experimental_api_docs/methods/keras_quantization_aware_training_experimental.rst rename to docsrc/source/api/experimental_api_docs/methods/keras_quantization_aware_training_init.rst index ffce8709b..50072b514 100644 --- a/docsrc/source/api/experimental_api_docs/methods/keras_quantization_aware_training_experimental.rst +++ b/docsrc/source/api/experimental_api_docs/methods/keras_quantization_aware_training_init.rst @@ -1,6 +1,6 @@ :orphan: -.. _ug-keras_quantization_aware_training_experimental: +.. _ug-keras_quantization_aware_training_init: ================================================ diff --git a/docsrc/source/index.rst b/docsrc/source/index.rst index 4656c56a1..d1a8d0ffa 100644 --- a/docsrc/source/index.rst +++ b/docsrc/source/index.rst @@ -60,7 +60,8 @@ Keras: * :ref:`Hardware-friendly Post Training Quantization` [1] * :ref:`Gradient base post training using knowledge distillation` (Experimental) * :ref:`Mixed-precision post training quantization` (Experimental) -* :ref:`Init model for Quantization Aware Training` (Experimental) +* :ref:`Init model for Quantization Aware Training` (Experimental) +* :ref:`Finalize model after Quantization Aware Training` (Experimental) Pytorch (Experimental): diff --git a/model_compression_toolkit/__init__.py b/model_compression_toolkit/__init__.py index 468e5bea9..b2b0b911d 100644 --- a/model_compression_toolkit/__init__.py +++ b/model_compression_toolkit/__init__.py @@ -37,7 +37,8 @@ from model_compression_toolkit.gptq.keras.quantization_facade import \ keras_gradient_post_training_quantization_experimental from model_compression_toolkit.gptq.keras.quantization_facade import get_keras_gptq_config -from model_compression_toolkit.qat.keras.quantization_facade import keras_quantization_aware_training_init +from model_compression_toolkit.qat.keras.quantization_facade import keras_quantization_aware_training_init, \ + keras_quantization_aware_training_finalize from model_compression_toolkit.core.pytorch.quantization_facade import pytorch_post_training_quantization, \ pytorch_post_training_quantization_mixed_precision from model_compression_toolkit.ptq.pytorch.quantization_facade import pytorch_post_training_quantization_experimental diff --git a/model_compression_toolkit/core/common/defaultdict.py b/model_compression_toolkit/core/common/defaultdict.py index 5ea8e9ff4..6bd670cb8 100644 --- a/model_compression_toolkit/core/common/defaultdict.py +++ b/model_compression_toolkit/core/common/defaultdict.py @@ -56,3 +56,10 @@ def get(self, key: Any) -> Any: if self.default_factory is not None: return self.default_factory() return None + + def keys(self): + """ + Get keys of known_dict + Returns: keys of known_dict + """ + return self.known_dict.keys() diff --git a/model_compression_toolkit/core/common/framework_info.py b/model_compression_toolkit/core/common/framework_info.py index 4fb84ac37..edd3355e6 100644 --- a/model_compression_toolkit/core/common/framework_info.py +++ b/model_compression_toolkit/core/common/framework_info.py @@ -99,13 +99,12 @@ def __init__(self, self.kernel_ops_attributes_mapping = kernel_ops_attributes_mapping self.out_channel_axis_mapping = out_channel_axis_mapping - def get_kernel_op_attributes(self, node_type: Any) -> List[str]: """ Get a list of attributes of a layer's weights to quantize. Args: - node_type: Layer to get its' attributes. + node_type: Layer to get its attributes. Returns: A list of attributes the layer has and should be quantized. @@ -113,6 +112,18 @@ def get_kernel_op_attributes(self, node_type: Any) -> List[str]: attr_list = self.kernel_ops_attributes_mapping.get(node_type) return attr_list + def is_kernel_op(self, node_type: Any) -> bool: + """ + Check is the node is a kernel operation. + + Args: + node_type: Layer to get its attributes. + + Returns: + True if node type is a kernel operation, else False. + """ + return node_type in self.kernel_ops_attributes_mapping.keys() + def layers_has_min_max(self, layer: Any) -> bool: """ Check if a layer is in a layer to min/max mapping the FrameworkInfo holds. diff --git a/model_compression_toolkit/qat/keras/model_builder.py b/model_compression_toolkit/qat/keras/model_builder.py index e78fc9484..488159544 100644 --- a/model_compression_toolkit/qat/keras/model_builder.py +++ b/model_compression_toolkit/qat/keras/model_builder.py @@ -14,6 +14,7 @@ # ============================================================================== from typing import Any, Tuple +import copy import tensorflow as tf from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper @@ -22,35 +23,79 @@ from model_compression_toolkit.core.common import Logger from model_compression_toolkit.core.common.constants import TENSORFLOW from model_compression_toolkit.core.common.framework_info import FrameworkInfo +from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode from model_compression_toolkit.core.keras.back2framework.model_builder import get_node_name_from_layer, \ is_layer_fake_quant from model_compression_toolkit.qat.keras.quantizer.config_factory import quantization_config_builder -from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO from model_compression_toolkit.core.keras.constants import DEFAULT_TP_MODEL from model_compression_toolkit.core.keras.back2framework.model_builder import model_builder as core_model_builder +from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights from model_compression_toolkit import get_target_platform_capabilities DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL) +def is_qat_applicable(node: common.BaseNode, fw_info: FrameworkInfo) -> bool: + """ + A function for deciding if a layer should be fine-tuned during QAT + Args: + node (BaseNode): Node for quantization decision + fw_info (FrameworkInfo): Keras quantization information + + Returns: + A boolean whether the layer is to be wrapped with a QuantizeWrapper + """ + + return fw_info.is_kernel_op(node.type) and node.is_weights_quantization_enabled() + + def model_builder(graph: common.Graph, - fw_info: FrameworkInfo = DEFAULT_KERAS_INFO) -> Tuple[tf.keras.models.Model, Any]: + fw_info: FrameworkInfo, + fw_impl: KerasImplementation) -> Tuple[tf.keras.models.Model, Any]: + """ + Prepare model for Quantization Aware Training. Build a keras model and then wrap + required layers with a QuantizeWrapper + Args: + graph (Graph): network graph to build + fw_info (FrameworkInfo): Keras quantization information + fw_impl: FrameworkImplementation object with a methods for keras implementation. + + Returns: + A boolean whether the layer is to be wrapped with a QuantizeWrapper + """ + ################################################# # Prepare model for Quantization Aware Training ################################################# - quantized_model, user_info = core_model_builder(graph, + + # Quantize graph weights that are not to be fine-tuned during QAT + graph_to_quantize = copy.deepcopy(graph) + for n in graph_to_quantize.nodes: + if is_qat_applicable(n, fw_info): + n.final_weights_quantization_cfg.enable_weights_quantization = False + quantized_tg = quantize_graph_weights(graph_to_quantize, + fw_info=fw_info, + fw_impl=fw_impl) + + # build keras model + quantized_model, user_info = core_model_builder(quantized_tg, mode=ModelBuilderMode.QUANTIZED, fw_info=fw_info) + # Wrap layers to be fine-tuned during QAT with QuantizeWrapper def _quantize(layer): nodes = graph.find_node_by_name(get_node_name_from_layer(layer)) if len(nodes) == 1: node = nodes[0] - return QuantizeWrapper(layer, quantization_config_builder(node, fw_info)) + if is_qat_applicable(node, fw_info): + return QuantizeWrapper(layer, quantization_config_builder(node, fw_info)) + else: + return layer elif is_layer_fake_quant(layer): + # A fake quant layer was added in the core_model_builder to quantize the activations return layer else: Logger.error(f"Mismatch between keras model and graph can't find node named: {get_node_name_from_layer(layer)}") diff --git a/model_compression_toolkit/qat/keras/quantization_facade.py b/model_compression_toolkit/qat/keras/quantization_facade.py index 8ef8b4d64..10fe58cdd 100644 --- a/model_compression_toolkit/qat/keras/quantization_facade.py +++ b/model_compression_toolkit/qat/keras/quantization_facade.py @@ -30,6 +30,10 @@ if FOUND_TF: + import tensorflow as tf + import tensorflow_model_optimization as tfmot + from model_compression_toolkit.qat.keras.quantizer.config_factory import WeightQuantizeConfig + from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation @@ -148,12 +152,105 @@ def keras_quantization_aware_training_init(in_model: Model, tg = ptq_runner(tg, fw_info, fw_impl, tb_w) - qat_model, user_info = model_builder(tg, fw_info=fw_info) + qat_model, user_info = model_builder(tg, fw_info=fw_info, fw_impl=fw_impl) user_info.mixed_precision_cfg = bit_widths_config return qat_model, user_info, QUANTIZATION_CONFIGS_DICT + + def keras_quantization_aware_training_finalize(in_model: Model): + """ + Convert a model fine-tuned by the user to a network without QuantizeWrappers. The exported + model contains float (fake-quantized) parameters and fake-quantiztion layers for quantizing + the activations + + Args: + in_model (Model): Keras model to remove QuantizeWrappers. + + Returns: + A quantized model without QuantizeWrappers. + + Examples: + + Import MCT: + + >>> import model_compression_toolkit as mct + + Import a Keras model: + + >>> from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2 + >>> model = MobileNetV2() + + Create a random dataset generator: + + >>> import numpy as np + >>> def repr_datagen(): return [np.random.random((1,224,224,3))] + + Create a MCT core config, containing the quantization configuration: + + >>> config = mct.CoreConfig() + + If mixed precision is desired, create a MCT core config with a mixed-precision configuration, to quantize a model with different bitwidths for different layers. + The candidates bitwidth for quantization should be defined in the target platform model: + + >>> config = mct.CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfigV2()) + + For mixed-precision set a target KPI object: + Create a KPI object to limit our returned model's size. Note that this value affects only coefficients + that should be quantized (for example, the kernel of Conv2D in Keras will be affected by this value, + while the bias will not): + + >>> kpi = mct.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits. + + Pass the model, the representative dataset generator, the configuration and the target KPI to get a + quantized model: + + >>> quantized_model, quantization_info, custom_objects = mct.keras_quantization_aware_training_init(model, repr_datagen, kpi, core_config=config) + + Use the quantized model for fine-tuning. For loading the model from file, use the custom_objects dictionary: + + >>> quantized_model = tf.keras.models.load_model(model_file, custom_objects=custom_objects) + >>> quantized_model = mct.keras_quantization_aware_training_finalize(quantized_model) + + """ + + def _export(layer): + if isinstance(layer, tfmot.quantization.keras.QuantizeWrapper): + if not isinstance(layer.quantize_config, tuple(QUANTIZATION_CONFIGS_DICT.values())): + Logger.error(f'Only supported quantization configs are {tuple(QUANTIZATION_CONFIGS_DICT.keys())}') + if isinstance(layer.quantize_config, WeightQuantizeConfig): + new_layer = layer.layer.__class__.from_config(layer.layer.get_config()) + with tf.name_scope(new_layer.name): + new_layer.build(layer.input_shape) + weights_list = [] + for w in new_layer.weights: + val = None + for qw in layer.weights: + if w.name in qw.name: + if w.name.split('/')[-1].split(':')[0] in layer.quantize_config.weight_attrs: + val = layer.quantize_config.get_weights_and_quantizers(layer.layer)[0][1](qw, + False, + layer.quantize_config.weight_quantizer.quantizer_parameters) + else: + val = qw + val = val.numpy() + if val is None: + Logger.error(f'Could not match weight name: {w.name}') + weights_list.append(val) + new_layer.set_weights(weights_list) + new_layer.trainable = False + return new_layer + else: + Logger.error(f'Undefined quantize_config') + else: + return layer + + # clone each layer in the model and apply _export to layers wrapped with a QuantizeWrapper. + exported_model = tf.keras.models.clone_model(in_model, input_tensors=None, clone_function=_export) + + return exported_model + else: # If tensorflow or tensorflow_model_optimization are not installed, # we raise an exception when trying to use these functions. @@ -161,3 +258,8 @@ def keras_quantization_aware_training_init(*args, **kwargs): Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory ' 'when using keras_quantization_aware_training_init. ' 'Could not find Tensorflow package.') + + def keras_quantization_aware_training_finalize(*args, **kwargs): + Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory ' + 'when using keras_quantization_aware_training_finalize. ' + 'Could not find Tensorflow package.') diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/qat_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/qat_test.py index 10de73ed5..9fc5c2850 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/qat_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/qat_test.py @@ -26,9 +26,10 @@ class QuantizationAwareTrainingTest(BaseKerasFeatureNetworkTest): - def __init__(self, unit_test, weight_bits=2, activation_bits=4): + def __init__(self, unit_test, weight_bits=2, activation_bits=4, finalize=False): self.weight_bits = weight_bits self.activation_bits = activation_bits + self.finalize = finalize super().__init__(unit_test) def get_tpc(self): @@ -42,19 +43,24 @@ def create_networks(self): def run_test(self): model_float = self.create_networks() qc = self.get_quantization_config() - ptq_model, quantization_info, custom_opjects = mct.keras_quantization_aware_training_init(model_float, + ptq_model, quantization_info, custom_objects = mct.keras_quantization_aware_training_init(model_float, self.representative_data_gen, fw_info=self.get_fw_info(), target_platform_capabilities=self.get_tpc()) + if self.finalize: + ptq_model = mct.keras_quantization_aware_training_finalize(ptq_model) + self.compare(ptq_model, model_float, input_x=self.representative_data_gen(), quantization_info=quantization_info) def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): - self.unit_test.assertTrue(isinstance(quantized_model.layers[2].layer, layers.Conv2D)) - self.unit_test.assertTrue(isinstance(quantized_model.layers[3].layer, layers.Activation)) - _, qconfig = quantized_model.layers[2].quantize_config.get_weights_and_quantizers(quantized_model.layers[2].layer)[0] - self.unit_test.assertTrue(qconfig.num_bits == self.weight_bits) - self.unit_test.assertTrue(isinstance(quantized_model.layers[3].quantize_config, NoOpQuantizeConfig)) + if self.finalize: + self.unit_test.assertTrue(isinstance(quantized_model.layers[2], layers.Conv2D)) + else: + self.unit_test.assertTrue(isinstance(quantized_model.layers[2].layer, layers.Conv2D)) + _, qconfig = quantized_model.layers[2].quantize_config.get_weights_and_quantizers(quantized_model.layers[2].layer)[0] + self.unit_test.assertTrue(qconfig.num_bits == self.weight_bits) + self.unit_test.assertTrue(isinstance(quantized_model.layers[3], layers.Activation)) diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index 43b33269c..bd36eae6a 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -484,6 +484,7 @@ def test_layer_norm_substitution(self): def test_qat(self): QuantizationAwareTrainingTest(self).run_test() + QuantizationAwareTrainingTest(self, finalize=True).run_test() if __name__ == '__main__':