Skip to content

Commit

Permalink
update QAT and add export API (#308)
Browse files Browse the repository at this point in the history
Add weight quantization in case there are weights not to be fine-tuned
Add export method to remove QuantizeWrappers from the QAT model after fine-tuning
Update docsrc with new API

Co-authored-by: elad cohen <[email protected]>
  • Loading branch information
elad-c and elad cohen authored Jul 31, 2022
1 parent 6467251 commit fb7c037
Show file tree
Hide file tree
Showing 11 changed files with 203 additions and 18 deletions.
3 changes: 2 additions & 1 deletion docsrc/source/api/experimental_api_docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ Functions
- :ref:`pytorch_post_training_quantization_experimental<ug-pytorch_post_training_quantization_experimental>`: A function to use for post training quantization of PyTorch models (experimental).
- :ref:`keras_post_training_quantization_experimental<ug-keras_post_training_quantization_experimental>`: A function to use for post training quantization of Keras models (experimental).
- :ref:`keras_gradient_post_training_quantization_experimental<ug-keras_gradient_post_training_quantization_experimental>`: A function to use for gradient-based post training quantization of Keras models (experimental).
- :ref:`keras_quantization_aware_training_experimental<ug-keras_quantization_aware_training_experimental>`: A function to use for preparing a model for Quantization Aware Training (Experimental)
- :ref:`keras_quantization_aware_training_init<ug-keras_quantization_aware_training_init>`: A function to use for preparing a model for Quantization Aware Training (Experimental)
- :ref:`keras_quantization_aware_training_finalize<ug-keras_quantization_aware_training_finalize>`: A function to finalize a model after Quantization Aware Training to a model without QuantizeWrappers(Experimental)
- :ref:`get_keras_gptq_config<ug-get_keras_gptq_config>`: A function to create a GradientPTQConfig instance to use for Keras models when using GPTQ (experimental).
- :ref:`get_target_platform_capabilities<ug-get_target_platform_capabilities>`: A function to get a target platform model for Tensorflow and Pytorch.
- :ref:`keras_kpi_data_experimental<ug-keras_kpi_data_experimental>`: A function to compute KPI data that can be used to calculate the desired target KPI for mixed-precision quantization of Keras models (experimental).
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
:orphan:

.. _ug-keras_quantization_aware_training_finalize:


================================================
Keras Quantization Aware Training Model Finalize
================================================

.. autofunction:: model_compression_toolkit.keras_quantization_aware_training_finalize
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
:orphan:

.. _ug-keras_quantization_aware_training_experimental:
.. _ug-keras_quantization_aware_training_init:


================================================
Expand Down
3 changes: 2 additions & 1 deletion docsrc/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ Keras:
* :ref:`Hardware-friendly Post Training Quantization<ug-keras_post_training_quantization>` [1]
* :ref:`Gradient base post training using knowledge distillation<ug-GradientPTQConfig>` (Experimental)
* :ref:`Mixed-precision post training quantization<ug-keras_post_training_quantization_mixed_precision>` (Experimental)
* :ref:`Init model for Quantization Aware Training<ug-keras_quantization_aware_training_experimental>` (Experimental)
* :ref:`Init model for Quantization Aware Training<ug-keras_quantization_aware_training_init>` (Experimental)
* :ref:`Finalize model after Quantization Aware Training<ug-keras_quantization_aware_training_finalize>` (Experimental)

Pytorch (Experimental):

Expand Down
3 changes: 2 additions & 1 deletion model_compression_toolkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
from model_compression_toolkit.gptq.keras.quantization_facade import \
keras_gradient_post_training_quantization_experimental
from model_compression_toolkit.gptq.keras.quantization_facade import get_keras_gptq_config
from model_compression_toolkit.qat.keras.quantization_facade import keras_quantization_aware_training_init
from model_compression_toolkit.qat.keras.quantization_facade import keras_quantization_aware_training_init, \
keras_quantization_aware_training_finalize
from model_compression_toolkit.core.pytorch.quantization_facade import pytorch_post_training_quantization, \
pytorch_post_training_quantization_mixed_precision
from model_compression_toolkit.ptq.pytorch.quantization_facade import pytorch_post_training_quantization_experimental
Expand Down
7 changes: 7 additions & 0 deletions model_compression_toolkit/core/common/defaultdict.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,10 @@ def get(self, key: Any) -> Any:
if self.default_factory is not None:
return self.default_factory()
return None

def keys(self):
"""
Get keys of known_dict
Returns: keys of known_dict
"""
return self.known_dict.keys()
15 changes: 13 additions & 2 deletions model_compression_toolkit/core/common/framework_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,31 @@ def __init__(self,
self.kernel_ops_attributes_mapping = kernel_ops_attributes_mapping
self.out_channel_axis_mapping = out_channel_axis_mapping


def get_kernel_op_attributes(self, node_type: Any) -> List[str]:
"""
Get a list of attributes of a layer's weights to quantize.
Args:
node_type: Layer to get its' attributes.
node_type: Layer to get its attributes.
Returns:
A list of attributes the layer has and should be quantized.
"""
attr_list = self.kernel_ops_attributes_mapping.get(node_type)
return attr_list

def is_kernel_op(self, node_type: Any) -> bool:
"""
Check is the node is a kernel operation.
Args:
node_type: Layer to get its attributes.
Returns:
True if node type is a kernel operation, else False.
"""
return node_type in self.kernel_ops_attributes_mapping.keys()

def layers_has_min_max(self, layer: Any) -> bool:
"""
Check if a layer is in a layer to min/max mapping the FrameworkInfo holds.
Expand Down
53 changes: 49 additions & 4 deletions model_compression_toolkit/qat/keras/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================

from typing import Any, Tuple
import copy

import tensorflow as tf
from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
Expand All @@ -22,35 +23,79 @@
from model_compression_toolkit.core.common import Logger
from model_compression_toolkit.core.common.constants import TENSORFLOW
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
from model_compression_toolkit.core.keras.back2framework.model_builder import get_node_name_from_layer, \
is_layer_fake_quant
from model_compression_toolkit.qat.keras.quantizer.config_factory import quantization_config_builder

from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
from model_compression_toolkit.core.keras.constants import DEFAULT_TP_MODEL
from model_compression_toolkit.core.keras.back2framework.model_builder import model_builder as core_model_builder
from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights

from model_compression_toolkit import get_target_platform_capabilities

DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)


def is_qat_applicable(node: common.BaseNode, fw_info: FrameworkInfo) -> bool:
"""
A function for deciding if a layer should be fine-tuned during QAT
Args:
node (BaseNode): Node for quantization decision
fw_info (FrameworkInfo): Keras quantization information
Returns:
A boolean whether the layer is to be wrapped with a QuantizeWrapper
"""

return fw_info.is_kernel_op(node.type) and node.is_weights_quantization_enabled()


def model_builder(graph: common.Graph,
fw_info: FrameworkInfo = DEFAULT_KERAS_INFO) -> Tuple[tf.keras.models.Model, Any]:
fw_info: FrameworkInfo,
fw_impl: KerasImplementation) -> Tuple[tf.keras.models.Model, Any]:
"""
Prepare model for Quantization Aware Training. Build a keras model and then wrap
required layers with a QuantizeWrapper
Args:
graph (Graph): network graph to build
fw_info (FrameworkInfo): Keras quantization information
fw_impl: FrameworkImplementation object with a methods for keras implementation.
Returns:
A boolean whether the layer is to be wrapped with a QuantizeWrapper
"""

#################################################
# Prepare model for Quantization Aware Training
#################################################
quantized_model, user_info = core_model_builder(graph,

# Quantize graph weights that are not to be fine-tuned during QAT
graph_to_quantize = copy.deepcopy(graph)
for n in graph_to_quantize.nodes:
if is_qat_applicable(n, fw_info):
n.final_weights_quantization_cfg.enable_weights_quantization = False
quantized_tg = quantize_graph_weights(graph_to_quantize,
fw_info=fw_info,
fw_impl=fw_impl)

# build keras model
quantized_model, user_info = core_model_builder(quantized_tg,
mode=ModelBuilderMode.QUANTIZED,
fw_info=fw_info)

# Wrap layers to be fine-tuned during QAT with QuantizeWrapper
def _quantize(layer):
nodes = graph.find_node_by_name(get_node_name_from_layer(layer))
if len(nodes) == 1:
node = nodes[0]
return QuantizeWrapper(layer, quantization_config_builder(node, fw_info))
if is_qat_applicable(node, fw_info):
return QuantizeWrapper(layer, quantization_config_builder(node, fw_info))
else:
return layer
elif is_layer_fake_quant(layer):
# A fake quant layer was added in the core_model_builder to quantize the activations
return layer
else:
Logger.error(f"Mismatch between keras model and graph can't find node named: {get_node_name_from_layer(layer)}")
Expand Down
104 changes: 103 additions & 1 deletion model_compression_toolkit/qat/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@


if FOUND_TF:
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from model_compression_toolkit.qat.keras.quantizer.config_factory import WeightQuantizeConfig

from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
Expand Down Expand Up @@ -148,16 +152,114 @@ def keras_quantization_aware_training_init(in_model: Model,

tg = ptq_runner(tg, fw_info, fw_impl, tb_w)

qat_model, user_info = model_builder(tg, fw_info=fw_info)
qat_model, user_info = model_builder(tg, fw_info=fw_info, fw_impl=fw_impl)

user_info.mixed_precision_cfg = bit_widths_config

return qat_model, user_info, QUANTIZATION_CONFIGS_DICT


def keras_quantization_aware_training_finalize(in_model: Model):
"""
Convert a model fine-tuned by the user to a network without QuantizeWrappers. The exported
model contains float (fake-quantized) parameters and fake-quantiztion layers for quantizing
the activations
Args:
in_model (Model): Keras model to remove QuantizeWrappers.
Returns:
A quantized model without QuantizeWrappers.
Examples:
Import MCT:
>>> import model_compression_toolkit as mct
Import a Keras model:
>>> from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2
>>> model = MobileNetV2()
Create a random dataset generator:
>>> import numpy as np
>>> def repr_datagen(): return [np.random.random((1,224,224,3))]
Create a MCT core config, containing the quantization configuration:
>>> config = mct.CoreConfig()
If mixed precision is desired, create a MCT core config with a mixed-precision configuration, to quantize a model with different bitwidths for different layers.
The candidates bitwidth for quantization should be defined in the target platform model:
>>> config = mct.CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfigV2())
For mixed-precision set a target KPI object:
Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
that should be quantized (for example, the kernel of Conv2D in Keras will be affected by this value,
while the bias will not):
>>> kpi = mct.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
Pass the model, the representative dataset generator, the configuration and the target KPI to get a
quantized model:
>>> quantized_model, quantization_info, custom_objects = mct.keras_quantization_aware_training_init(model, repr_datagen, kpi, core_config=config)
Use the quantized model for fine-tuning. For loading the model from file, use the custom_objects dictionary:
>>> quantized_model = tf.keras.models.load_model(model_file, custom_objects=custom_objects)
>>> quantized_model = mct.keras_quantization_aware_training_finalize(quantized_model)
"""

def _export(layer):
if isinstance(layer, tfmot.quantization.keras.QuantizeWrapper):
if not isinstance(layer.quantize_config, tuple(QUANTIZATION_CONFIGS_DICT.values())):
Logger.error(f'Only supported quantization configs are {tuple(QUANTIZATION_CONFIGS_DICT.keys())}')
if isinstance(layer.quantize_config, WeightQuantizeConfig):
new_layer = layer.layer.__class__.from_config(layer.layer.get_config())
with tf.name_scope(new_layer.name):
new_layer.build(layer.input_shape)
weights_list = []
for w in new_layer.weights:
val = None
for qw in layer.weights:
if w.name in qw.name:
if w.name.split('/')[-1].split(':')[0] in layer.quantize_config.weight_attrs:
val = layer.quantize_config.get_weights_and_quantizers(layer.layer)[0][1](qw,
False,
layer.quantize_config.weight_quantizer.quantizer_parameters)
else:
val = qw
val = val.numpy()
if val is None:
Logger.error(f'Could not match weight name: {w.name}')
weights_list.append(val)
new_layer.set_weights(weights_list)
new_layer.trainable = False
return new_layer
else:
Logger.error(f'Undefined quantize_config')
else:
return layer

# clone each layer in the model and apply _export to layers wrapped with a QuantizeWrapper.
exported_model = tf.keras.models.clone_model(in_model, input_tensors=None, clone_function=_export)

return exported_model

else:
# If tensorflow or tensorflow_model_optimization are not installed,
# we raise an exception when trying to use these functions.
def keras_quantization_aware_training_init(*args, **kwargs):
Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
'when using keras_quantization_aware_training_init. '
'Could not find Tensorflow package.')

def keras_quantization_aware_training_finalize(*args, **kwargs):
Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
'when using keras_quantization_aware_training_finalize. '
'Could not find Tensorflow package.')
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@


class QuantizationAwareTrainingTest(BaseKerasFeatureNetworkTest):
def __init__(self, unit_test, weight_bits=2, activation_bits=4):
def __init__(self, unit_test, weight_bits=2, activation_bits=4, finalize=False):
self.weight_bits = weight_bits
self.activation_bits = activation_bits
self.finalize = finalize
super().__init__(unit_test)

def get_tpc(self):
Expand All @@ -42,19 +43,24 @@ def create_networks(self):
def run_test(self):
model_float = self.create_networks()
qc = self.get_quantization_config()
ptq_model, quantization_info, custom_opjects = mct.keras_quantization_aware_training_init(model_float,
ptq_model, quantization_info, custom_objects = mct.keras_quantization_aware_training_init(model_float,
self.representative_data_gen,
fw_info=self.get_fw_info(),
target_platform_capabilities=self.get_tpc())

if self.finalize:
ptq_model = mct.keras_quantization_aware_training_finalize(ptq_model)

self.compare(ptq_model,
model_float,
input_x=self.representative_data_gen(),
quantization_info=quantization_info)

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
self.unit_test.assertTrue(isinstance(quantized_model.layers[2].layer, layers.Conv2D))
self.unit_test.assertTrue(isinstance(quantized_model.layers[3].layer, layers.Activation))
_, qconfig = quantized_model.layers[2].quantize_config.get_weights_and_quantizers(quantized_model.layers[2].layer)[0]
self.unit_test.assertTrue(qconfig.num_bits == self.weight_bits)
self.unit_test.assertTrue(isinstance(quantized_model.layers[3].quantize_config, NoOpQuantizeConfig))
if self.finalize:
self.unit_test.assertTrue(isinstance(quantized_model.layers[2], layers.Conv2D))
else:
self.unit_test.assertTrue(isinstance(quantized_model.layers[2].layer, layers.Conv2D))
_, qconfig = quantized_model.layers[2].quantize_config.get_weights_and_quantizers(quantized_model.layers[2].layer)[0]
self.unit_test.assertTrue(qconfig.num_bits == self.weight_bits)
self.unit_test.assertTrue(isinstance(quantized_model.layers[3], layers.Activation))
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ def test_layer_norm_substitution(self):

def test_qat(self):
QuantizationAwareTrainingTest(self).run_test()
QuantizationAwareTrainingTest(self, finalize=True).run_test()


if __name__ == '__main__':
Expand Down

0 comments on commit fb7c037

Please sign in to comment.