Skip to content

Commit

Permalink
move qat activation quantizers to trainable infra
Browse files Browse the repository at this point in the history
  • Loading branch information
reuvenp committed Oct 13, 2024
1 parent 508e8fa commit 8866392
Show file tree
Hide file tree
Showing 17 changed files with 653 additions and 443 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,14 @@

if FOUND_TF:

class BaseKerasQATTrainableQuantizer(BaseKerasTrainableQuantizer):
class BaseKerasQATWeightTrainableQuantizer(BaseKerasTrainableQuantizer):
"""
A base class for trainable Keras quantizer for QAT.
"""

def __init__(self,
quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
"""
Initializes BaseKerasQATTrainableQuantizer object.
Args:
quantization_config: quantizer config class contains all the information about a quantizer configuration.
"""

super().__init__(quantization_config)
pass

else: # pragma: no cover
class BaseKerasQATTrainableQuantizer(BaseKerasTrainableQuantizer):
class BaseKerasQATWeightTrainableQuantizer(BaseKerasTrainableQuantizer):
def __init__(self,
quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):

Expand Down
96 changes: 2 additions & 94 deletions model_compression_toolkit/qat/keras/quantizer/lsq/symmetric_lsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
from model_compression_toolkit import constants as C

from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_quantizer import BaseKerasQATTrainableQuantizer
from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_weight_quantizer import BaseKerasQATWeightTrainableQuantizer
from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, \
TrainableQuantizerActivationConfig
from mct_quantizers.keras.quantizers import WeightsPOTInferableQuantizer, WeightsSymmetricInferableQuantizer, \
Expand Down Expand Up @@ -68,7 +68,7 @@ def symmetric_lsq_quantizer(x: tf.Tensor,
@mark_quantizer(quantization_target=QuantizationTarget.Weights,
quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
identifier=TrainingMethod.LSQ)
class LSQWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
class LSQWeightQATQuantizer(BaseKerasQATWeightTrainableQuantizer):
"""
Trainable constrained quantizer to quantize layer's weights.
"""
Expand Down Expand Up @@ -159,95 +159,3 @@ def convert2inferable(self) -> Union[WeightsPOTInferableQuantizer, WeightsSymmet
input_rank=len(self.threshold_shape))


@mark_quantizer(quantization_target=QuantizationTarget.Activation,
quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
identifier=TrainingMethod.LSQ)
class LSQActivationQATQuantizer(BaseKerasQATTrainableQuantizer):
"""
Trainable constrained quantizer to quantize layer activations.
"""

def __init__(self, quantization_config: TrainableQuantizerActivationConfig):
"""
Initialize a LSQActivationQATQuantizer object with parameters to use
for the quantization.
Args:
quantization_config: trainable quantizer config class
"""
super().__init__(quantization_config)
self.power_of_two = quantization_config.activation_quantization_method == QuantizationMethod.POWER_OF_TWO
self.threshold_values = float(quantization_config.activation_quantization_params[C.THRESHOLD])
self.threshold_shape = np.asarray(self.threshold_values).shape
self.sign = quantization_config.activation_quantization_params[SIGNED]
self.num_bits = quantization_config.activation_n_bits
n_pos_bits = self.num_bits - int(self.sign)
self.min_int = -int(self.sign) * (2 ** n_pos_bits)
self.max_int = (2 ** n_pos_bits) - 1
if self.power_of_two:
self.threshold_values = np.power(2.0, np.ceil(np.log2(np.maximum(self.threshold_values, C.MIN_THRESHOLD))))


def initialize_quantization(self,
tensor_shape: TensorShape,
name: str,
layer: KerasTrainableQuantizationWrapper):
"""
Add quantizer parameters to the quantizer parameters dictionary
Args:
tensor_shape: tensor shape of the quantized tensor.
name: Tensor name.
layer: Layer to quantize.
"""
ptq_threshold_tensor = layer.add_weight(
name + THRESHOLD_TENSOR,
shape=(),
initializer=tf.keras.initializers.Constant(1.0),
trainable=True)
ptq_threshold_tensor.assign(self.threshold_values)

# save the quantizer added parameters for later calculations
self.add_quantizer_variable(THRESHOLD_TENSOR, ptq_threshold_tensor, VariableGroup.QPARAMS)

def __call__(self,
inputs: tf.Tensor,
training: bool):
"""
Quantize a tensor.
Args:
inputs: Input tensor to quantize.
training: Whether the graph is in training mode.
Returns:
The quantized tensor.
"""

thresholds = self.get_quantizer_variable(THRESHOLD_TENSOR)
n_channels = inputs.shape[-1]
scale_factor = 1.0 / np.sqrt(self.max_int * n_channels)
q_tensor = symmetric_lsq_quantizer(inputs, thresholds, self.num_bits, self.sign, self.min_int, self.max_int, scale_factor)
return q_tensor

def convert2inferable(self) -> Union[ActivationPOTInferableQuantizer, ActivationSymmetricInferableQuantizer]:
"""
Convert quantizer to inferable quantizer.
Returns:
BaseKerasInferableQuantizer object.
"""

if self.power_of_two:
thresholds = 2 ** np.ceil(np.log2(self.get_quantizer_variable(THRESHOLD_TENSOR).numpy()))
return ActivationPOTInferableQuantizer(num_bits=self.num_bits,
# In activation quantization is per-tensor only - thus we pass
# the threshold as a list with a len of 1
threshold=[thresholds],
signed=self.sign)
else:
thresholds = self.get_quantizer_variable(THRESHOLD_TENSOR).numpy()
return ActivationSymmetricInferableQuantizer(num_bits=self.num_bits,
# In activation quantization is per-tensor only - thus we
# pass the threshold as a list with a len of 1
threshold=[thresholds],
signed=self.sign)
95 changes: 3 additions & 92 deletions model_compression_toolkit/qat/keras/quantizer/lsq/uniform_lsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import tensorflow as tf
from tensorflow.python.framework.tensor_shape import TensorShape
from model_compression_toolkit.constants import RANGE_MIN, RANGE_MAX
from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_weight_quantizer import \
BaseKerasQATWeightTrainableQuantizer
from model_compression_toolkit.trainable_infrastructure.common.constants import FQ_MIN, FQ_MAX
from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
from model_compression_toolkit.trainable_infrastructure import TrainingMethod
Expand All @@ -26,7 +28,6 @@

from model_compression_toolkit import constants as C

from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_quantizer import BaseKerasQATTrainableQuantizer
from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, \
TrainableQuantizerActivationConfig
from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
Expand Down Expand Up @@ -66,7 +67,7 @@ def uniform_lsq_quantizer(x: tf.Tensor,
@mark_quantizer(quantization_target=QuantizationTarget.Weights,
quantization_method=[QuantizationMethod.UNIFORM],
identifier=TrainingMethod.LSQ)
class LSQUniformWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
class LSQUniformWeightQATQuantizer(BaseKerasQATWeightTrainableQuantizer):
"""
Trainable constrained quantizer to quantize layer's weights.
"""
Expand Down Expand Up @@ -158,93 +159,3 @@ def convert2inferable(self) -> BaseKerasInferableQuantizer:
channel_axis=self.channel_axis,
input_rank=len(self.min_max_shape))


@mark_quantizer(quantization_target=QuantizationTarget.Activation,
quantization_method=[QuantizationMethod.UNIFORM],
identifier=TrainingMethod.LSQ)
class LSQUniformActivationQATQuantizer(BaseKerasQATTrainableQuantizer):
"""
Trainable constrained quantizer to quantize layer activations.
"""

def __init__(self, quantization_config: TrainableQuantizerActivationConfig):
"""
Initialize a LSQUniformActivationQATQuantizer object with parameters to use
for the quantization.
Args:
quantization_config: trainable quantizer config class
"""
super().__init__(quantization_config)

self.num_bits = quantization_config.activation_n_bits
self.min_range = np.array(quantization_config.activation_quantization_params[C.RANGE_MIN])
self.max_range = np.array(quantization_config.activation_quantization_params[C.RANGE_MAX])
self.min_int = 0
self.max_int = 2**self.num_bits - 1

def initialize_quantization(self,
tensor_shape: TensorShape,
name: str,
layer: KerasTrainableQuantizationWrapper):
"""
Add quantizer parameters to the quantizer parameters dictionary
Args:
tensor_shape: tensor shape of the quantized tensor.
name: Tensor name.
layer: Layer to quantize.
"""
fq_min = layer.add_weight(
name + FQ_MIN,
shape=(),
initializer=tf.keras.initializers.Constant(-1.0),
trainable=True)
fq_min.assign(self.min_range)

fq_max = layer.add_weight(
name + FQ_MAX,
shape=(),
initializer=tf.keras.initializers.Constant(1.0),
trainable=True)
fq_max.assign(self.max_range)

# save the quantizer added parameters for later calculations
self.add_quantizer_variable(FQ_MIN, fq_min, VariableGroup.QPARAMS)
self.add_quantizer_variable(FQ_MAX, fq_max, VariableGroup.QPARAMS)

def __call__(self,
inputs: tf.Tensor,
training: bool):
"""
Quantize a tensor.
Args:
inputs: Input tensor to quantize.
training: Whether the graph is in training mode.
Returns:
The quantized tensor.
"""

min_range = self.get_quantizer_variable(FQ_MIN)
max_range = self.get_quantizer_variable(FQ_MAX)
n_channels = inputs.shape[-1]
scale_factor = 1.0 / np.sqrt(self.max_int * n_channels)
q_tensor = uniform_lsq_quantizer(inputs, min_range, max_range, self.num_bits, self.min_int, self.max_int, scale_factor)
return q_tensor

def convert2inferable(self) -> BaseKerasInferableQuantizer:
"""
Convert quantizer to inferable quantizer.
Returns:
BaseKerasInferableQuantizer object.
"""
min_range, max_range = fix_range_to_include_zero(self.get_quantizer_variable(FQ_MIN).numpy(),
self.get_quantizer_variable(FQ_MAX).numpy(),
self.num_bits)
return ActivationUniformInferableQuantizer(num_bits=self.num_bits,
# In activation quantization is per-tensor only - thus we pass
# the min/max as lists with a len of 1
min_range=[min_range],
max_range=[max_range])
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,18 @@
from typing import Tuple, Dict, List, Callable

from model_compression_toolkit.core import common
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.qat.common.qat_config import QATConfig
from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_quantizer import BaseKerasQATTrainableQuantizer
from mct_quantizers import QuantizationTarget, KerasActivationQuantizationHolder
from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_weight_quantizer import \
BaseKerasQATWeightTrainableQuantizer
from model_compression_toolkit.trainable_infrastructure.common.get_quantizer_config import \
get_trainable_quantizer_weights_config, get_trainable_quantizer_activation_config, \
get_trainable_quantizer_quantization_candidates
from model_compression_toolkit.trainable_infrastructure.common.get_quantizers import \
get_trainable_quantizer_class
from model_compression_toolkit.trainable_infrastructure.keras.activation_quantizers import \
BaseKerasActivationTrainableQuantizer


def get_activation_quantizer_holder(n: common.BaseNode,
Expand Down Expand Up @@ -55,7 +56,7 @@ def get_activation_quantizer_holder(n: common.BaseNode,
def quantization_builder(n: common.BaseNode,
qat_config: QATConfig,
kernel_attr: str = None,
) -> Tuple[Dict[str, BaseKerasQATTrainableQuantizer], List[BaseKerasQATTrainableQuantizer]]:
) -> Tuple[Dict[str, BaseKerasQATWeightTrainableQuantizer], List[BaseKerasActivationTrainableQuantizer]]:
"""
Build quantizers for a node according to its quantization configuration.
Expand All @@ -82,7 +83,7 @@ def quantization_builder(n: common.BaseNode,
quantizer_class = get_trainable_quantizer_class(QuantizationTarget.Weights,
qat_config.weight_training_method,
quant_method,
BaseKerasQATTrainableQuantizer)
BaseKerasQATWeightTrainableQuantizer)

weight_quantizers.update({kernel_attr: quantizer_class(get_trainable_quantizer_weights_config(n,
attr_name=kernel_attr,
Expand All @@ -98,7 +99,7 @@ def quantization_builder(n: common.BaseNode,
quantizer_class = get_trainable_quantizer_class(QuantizationTarget.Activation,
qat_config.activation_training_method,
quant_method,
BaseKerasQATTrainableQuantizer)
BaseKerasActivationTrainableQuantizer)

activation_quantizers = [quantizer_class(get_trainable_quantizer_activation_config(n, aq_cand),
**qat_config.activation_quantizer_params_override)] * len(output_shapes)
Expand Down
Loading

0 comments on commit 8866392

Please sign in to comment.