Skip to content

Commit

Permalink
Move STE/LSQ activation quantizers from QAT to trainable infrastructu…
Browse files Browse the repository at this point in the history
…re for PyTorch (#1178)

* move activation quantizers from qat to trainable infrastructure for pytorch

* update pytorch qat and trainable_infrastructure README

* code review fixes

* update import
  • Loading branch information
irenaby authored Sep 3, 2024
1 parent 36ba0f9 commit 3a3cc2c
Show file tree
Hide file tree
Showing 41 changed files with 806 additions and 574 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,14 @@
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.trainable_infrastructure.common.get_quantizer_config import \
get_trainable_quantizer_weights_config
from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_quantizer import BasePytorchQATTrainableQuantizer
from model_compression_toolkit.trainable_infrastructure.common.get_quantizers import \
get_trainable_quantizer_class


def quantization_builder(n: common.BaseNode,
gptq_config: GradientPTQConfig,
kernel_attr: str = None
) -> Tuple[Dict[str, BasePytorchQATTrainableQuantizer], List[BasePyTorchInferableQuantizer]]:
) -> Tuple[Dict[str, BasePytorchGPTQTrainableQuantizer], List[BasePyTorchInferableQuantizer]]:
"""
Build quantizers for a node according to its quantization configuration and
a global NoOpQuantizeConfig object.
Expand Down
4 changes: 2 additions & 2 deletions model_compression_toolkit/qat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from model_compression_toolkit.qat.common.qat_config import QATConfig, TrainingMethod
from model_compression_toolkit.qat.common.qat_config import QATConfig

from model_compression_toolkit.qat.keras.quantization_facade import keras_quantization_aware_training_init_experimental, keras_quantization_aware_training_finalize_experimental
from model_compression_toolkit.qat.pytorch.quantization_facade import pytorch_quantization_aware_training_init_experimental, pytorch_quantization_aware_training_finalize_experimental
from model_compression_toolkit.qat.pytorch.quantization_facade import pytorch_quantization_aware_training_init_experimental, pytorch_quantization_aware_training_finalize_experimental
20 changes: 1 addition & 19 deletions model_compression_toolkit/qat/common/qat_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
# ==============================================================================

from typing import Dict
from enum import Enum
from model_compression_toolkit.core import common
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.trainable_infrastructure import TrainingMethod


def is_qat_applicable(node: common.BaseNode,
Expand All @@ -38,23 +37,6 @@ def is_qat_applicable(node: common.BaseNode,
or node.is_activation_quantization_enabled()



class TrainingMethod(Enum):
"""
An enum for selecting a QAT training method
STE - Standard straight-through estimator. Includes PowerOfTwo, symmetric & uniform quantizers
DQA - DNN Quantization with Attention. Includes a smooth quantization introduces by DQA method
LSQ - Learned Step size Quantization. Includes PowerOfTwo, symmetric & uniform quantizers: https://arxiv.org/pdf/1902.08153.pdf
"""
STE = "STE",
DQA = "DQA",
LSQ = "LSQ"


class QATConfig:
"""
QAT configuration class.
Expand Down
4 changes: 2 additions & 2 deletions model_compression_toolkit/qat/keras/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Assume that the quantizer has a new training method called `MyTrainig` which is

```python
import tensorflow as tf
from model_compression_toolkit.qat import TrainingMethod
from model_compression_toolkit.trainable_infrastructure import TrainingMethod
from mct_quantizers import mark_quantizer, QuantizationMethod, QuantizationTarget
from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_quantizer import BaseKerasQATTrainableQuantizer

Expand Down Expand Up @@ -67,7 +67,7 @@ Assume that the quantizer has a new training method called `MyTrainig` which is
import tensorflow as tf

NEW_PARAM = "new_param_name"
from model_compression_toolkit.qat import TrainingMethod
from model_compression_toolkit.trainable_infrastructure import TrainingMethod
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from mct_quantizers import mark_quantizer, QuantizationTarget
from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_quantizer import BaseKerasQATTrainableQuantizer
Expand Down
2 changes: 1 addition & 1 deletion model_compression_toolkit/qat/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
MixedPrecisionQuantizationConfig
from mct_quantizers import KerasActivationQuantizationHolder
from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
from model_compression_toolkit.core.runner import core_runner
from model_compression_toolkit.ptq.runner import ptq_runner
Expand All @@ -34,6 +33,7 @@
from tensorflow.keras.layers import Layer
from tensorflow.keras.models import Model

from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tensorflow.python.framework.tensor_shape import TensorShape
from model_compression_toolkit.constants import SIGNED

from model_compression_toolkit.qat import TrainingMethod
from model_compression_toolkit.trainable_infrastructure import TrainingMethod

from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from model_compression_toolkit.constants import RANGE_MIN, RANGE_MAX
from model_compression_toolkit.trainable_infrastructure.common.constants import FQ_MIN, FQ_MAX
from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
from model_compression_toolkit.qat import TrainingMethod
from model_compression_toolkit.trainable_infrastructure import TrainingMethod

from mct_quantizers import mark_quantizer, QuantizationMethod, QuantizationTarget
from mct_quantizers.keras.quantizers import \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from model_compression_toolkit.constants import SIGNED
from model_compression_toolkit.trainable_infrastructure.common.constants import FQ_MIN, FQ_MAX

from model_compression_toolkit.qat import TrainingMethod
from model_compression_toolkit.trainable_infrastructure import TrainingMethod

from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from model_compression_toolkit.constants import RANGE_MIN, RANGE_MAX
from model_compression_toolkit.trainable_infrastructure.common.constants import FQ_MIN, FQ_MAX
from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
from model_compression_toolkit.qat import TrainingMethod
from model_compression_toolkit.trainable_infrastructure import TrainingMethod

from mct_quantizers import mark_quantizer, QuantizationMethod, QuantizationTarget
from mct_quantizers.keras.quantizers import \
Expand Down
58 changes: 29 additions & 29 deletions model_compression_toolkit/qat/pytorch/README.md
Original file line number Diff line number Diff line change
@@ -1,37 +1,39 @@
# QAT Quantizers

## Introduction
[`BasePytorchQATTrainableQuantizer`](./quantizer/base_pytorch_qat_quantizer.py) is an interface that utilizes the Quantization Infrastructure's [`BasePytorchTrainableQuantizer`](../../trainable_infrastructure/pytorch/base_pytorch_quantizer.py) class to enable easy development of quantizers dedicated to Quantization-Aware Training (QAT).
All available training types for QAT are defined in the Enum [`TrainingMethod`](./quantizer/README.md).
All available training types for QAT are defined in the Enum [`TrainingMethod`](../../trainable_infrastructure/common/training_method.py).
A trainable quantizer can be Weights Quantizer or Activation Quantizer.
Any Activation Quantizer defined in [Activation Quantizers](../../trainable_infrastructure/pytorch/activation_quantizers) can be used for QAT.


## Make your own Pytorch trainable quantizers
A trainable quantizer can be Weights Quantizer or Activation Quantizer.
In order to make your new quantizer you need to create your quantizer class, `MyTrainingQuantizer` and do as follows:
- `MyTrainingQuantizer` should inherit from [`BasePytorchTrainableQuantizer`](../../trainable_infrastructure/pytorch/base_pytorch_quantizer.py).
- `MyTrainingQuantizer` should have [`init`](../../trainable_infrastructure/common/base_trainable_quantizer.py) function that gets `quantization_config` which is [`NodeWeightsQuantizationConfig`](../../core/common/quantization/node_quantization_config.py#L228) if you choose to implement weights quantizer or [`NodeActivationQuantizationConfig`](../../core/common/quantization/node_quantization_config.py#L63) if you choose activation quantizer.
- Implement [`initialize_quantization`](../../trainable_infrastructure/common/base_trainable_quantizer.py) where you can define your parameters for the quantizer.
- Implement [`__call__`](../../trainable_infrastructure/common/base_trainable_quantizer.py) method to quantize the given inputs while training. This is your custom quantization itself.
- Implement [`convert2inferable`](../../trainable_infrastructure/common/base_trainable_quantizer.py) method. This method exports your quantizer for inference (deployment). For doing that you need to choose one of the available Inferable Quantizers from the [MCT Quantizers](https://github.com/sony/mct_quantizers) package, according to the target when implementing `convert2inferable`, and set your learned quantization parameters there.
- `MyTrainingQuantizer` should inherit from `BasePytorchQATWeightTrainableQuantizer` for weights quantizer or `BasePytorchActivationTrainableQuantizer` for activation quantizer
- `MyTrainingQuantizer` should have `__init__` method that accepts `quantization_config` of type `TrainableQuantizerWeightsConfig` for weights quantizer or `TrainableQuantizerActivationConfig` for activation quantizer.
- Implement `initialize_quantization` where you can define your parameters for the quantizer.
- Implement `__call__` method to quantize the given inputs while training. This is your custom quantization itself.
- Implement `convert2inferable` method. This method exports your quantizer for inference (deployment). For doing that you need to choose one of the available Inferable Quantizers from the [MCT Quantizers](https://github.com/sony/mct_quantizers) package, according to the target when implementing `convert2inferable`, and set your learned quantization parameters there.
- Decorate `MyTrainingQuantizer` class with the `@mark_quantizer` decorator (provided by the [MCT Quantizers](https://github.com/sony/mct_quantizers) package) and choose the appropriate properties to set for your quantizer. The "identifier" property for the decorator should be of the type `TrainingMethod` enum. See explanation about `@mark_quantizer` and how to use it under the [Pytorch Quantization Infrastructure](../../trainable_infrastructure/pytorch/README.md).

## Example: Symmetric Weights Quantizer
To create custom `MyWeightsTrainingQuantizer` which is a symmetric weights training quantizer you need to set
`qi.QuantizationTarget.Weights` as target and `qi.QuantizationMethod.SYMMETRIC` as method.
Assume that the quantizer has a new training method called `MyTrainig` which is defined in the `TrainingMethod` Enum.
`QuantizationTarget.Weights` as target and `QuantizationMethod.SYMMETRIC` as method.
Assume that the quantizer has a new training method called `MyTraining` which is defined in the `TrainingMethod` Enum.

```python
NEW_PARAM = "new_param_name"
from model_compression_toolkit TrainingMethod
from mct_quantizers import mark_quantizer, QuantizationMethod, QuantizationTarget
from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_quantizer import BasePytorchQATTrainableQuantizer
from mct_quantizers.pytorch.quantizers import WeightsSymmetricInferableQuantizer
from model_compression_toolkit.trainable_infrastructure import TrainingMethod, TrainableQuantizerWeightsConfig
from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_weight_quantizer import BasePytorchQATWeightTrainableQuantizer


@mark_quantizer(quantization_target=QuantizationTarget.Weights,
quantization_method=[QuantizationMethod.SYMMETRIC],
identifier=TrainingMethod.MyQuantizerType)
class MyWeightsTrainingQuantizer(BasePytorchQATTrainableQuantizer):
def __init__(self, quantization_config: NodeWeightsQuantizationConfig):
super(MyWeightsTrainingQuantizer, self).__init__(quantization_config)
identifier=TrainingMethod.MyTraining)
class MyWeightsTrainingQuantizer(BasePytorchQATWeightTrainableQuantizer):
def __init__(self, quantization_config: TrainableQuantizerWeightsConfig):
super().__init__(quantization_config)
# Define your new params here:
self.new_param = ...

Expand All @@ -50,28 +52,26 @@ class MyWeightsTrainingQuantizer(BasePytorchQATTrainableQuantizer):
return quantized_inputs

def convert2inferable(self):
return qi.WeightsUniformInferableQuantizer(...)
return WeightsSymmetricInferableQuantizer(...)
```

## Example: Symmetric Activations Quantizer
To create custom `MyActivationsTrainingQuantizer` which is a symmetric activations training quantizer you need to set `qi.QuantizationTarget.Activation` as target and `qi.QuantizationMethod.SYMMETRIC` as method.
Assume that the quantizer has a new training method called `MyTrainig` which is defined in the `TrainingMethod` Enum.
To create custom `MyActivationsTrainingQuantizer` which is a symmetric activations training quantizer you need to set `QuantizationTarget.Activation` as target and `QuantizationMethod.SYMMETRIC` as method.
Assume that the quantizer has a new training method called `MyTraining` which is defined in the `TrainingMethod` Enum.

```python
NEW_PARAM = "new_param_name"
from model_compression_toolkit.qat import TrainingMethod
from mct_quantizers import mark_quantizer, QuantizationTarget, QuantizationMethod
from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_quantizer import BasePytorchQATTrainableQuantizer

from mct_quantizers.pytorch.quantizers import ActivationSymmetricInferableQuantizer
from model_compression_toolkit.trainable_infrastructure import TrainingMethod, TrainableQuantizerActivationConfig
from model_compression_toolkit.trainable_infrastructure import BasePytorchActivationTrainableQuantizer

@mark_quantizer(quantization_target=QuantizationTarget.Activation,
quantization_method=[QuantizationMethod.SYMMETRIC],
identifier=TrainingMethod.MyQuantizerType)
class MyActivationsTrainingQuantizer(BasePytorchQATTrainableQuantizer):
def __init__(self, quantization_config: NodeActivationQuantizationConfig):
super(MyActivationsTrainingQuantizer, self).__init__(quantization_config,
QuantizationTarget.Activation,
[QuantizationMethod.SYMMETRIC])
identifier=TrainingMethod.MyTraining)
class MyActivationsTrainingQuantizer(BasePytorchActivationTrainableQuantizer):
def __init__(self, quantization_config: TrainableQuantizerActivationConfig):
super().__init__(quantization_config)
# Define your new params here:
self.new_param = ...

Expand All @@ -90,5 +90,5 @@ class MyActivationsTrainingQuantizer(BasePytorchQATTrainableQuantizer):
return quantized_inputs

def convert2inferable(self):
return ActivationUniformInferableQuantizer(...)
return ActivationSymmetricInferableQuantizer(...)
```
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,18 @@
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from mct_quantizers import QuantizationTarget, PytorchQuantizationWrapper, mark_quantizer
from model_compression_toolkit.constants import RANGE_MAX, RANGE_MIN
from model_compression_toolkit.qat import TrainingMethod
from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_quantizer import BasePytorchQATTrainableQuantizer
from model_compression_toolkit.trainable_infrastructure import TrainingMethod
from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_weight_quantizer import BasePytorchQATWeightTrainableQuantizer
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
from model_compression_toolkit.qat.pytorch.quantizer.quantizer_utils import uniform_quantizer
from model_compression_toolkit.trainable_infrastructure.pytorch.quantizer_utils import uniform_quantizer
from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig
from model_compression_toolkit.trainable_infrastructure.common.constants import FQ_MIN, FQ_MAX


@mark_quantizer(quantization_target=QuantizationTarget.Weights,
quantization_method=[QuantizationMethod.UNIFORM],
identifier=TrainingMethod.DQA)
class DQAUniformWeightQuantizer(BasePytorchQATTrainableQuantizer):
class DQAUniformWeightQuantizer(BasePytorchQATWeightTrainableQuantizer):
"""
Trainable constrained quantizer to quantize a layer inputs.
"""
Expand Down
11 changes: 5 additions & 6 deletions model_compression_toolkit/qat/pytorch/quantizer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@

Several training methods may be applied by the user to train the QAT ready model
created by `pytorch_quantization_aware_training_init` method in [`pytorch/quantization_facade`](../quantization_facade.py).
Each `TrainingMethod` (an enum defined in the [`qat_config`](../../common/qat_config.py))
Each [`TrainingMethod`](../../../trainable_infrastructure/common/training_method.py)
and [`QuantizationMethod`](../../../target_platform_capabilities/target_platform/op_quantization_config.py)
selects a quantizer for weights and a quantizer for activations.

Currently, only the STE (straight through estimator) training method is implemented by the MCT.

## Make your own training method

Follow these steps in order to set the quantizers required by your training method:
- Add your `TrainingMethod` enum in [`qat_config`](../../common/qat_config.py).
- Add your training method to the `TrainingMethod` enum.
- Add your quantizers for weights and activation as explained in [quantizer readme](../../../trainable_infrastructure/pytorch).
- Import your quantizer package in the quantizer [`__init.py__`](./__init__.py) file.
- Set your `TrainingMethod` in the `QATConfig` and generate the QAT ready model for training.
Expand All @@ -23,7 +21,7 @@ Follow these steps in order to set the quantizers required by your training meth

In this example we'll add a new quantization method, called MTM (my training method).

First, we update the `TrainingMethod` enum in [`qat_config`](../../common/qat_config.py)
First, we update the `TrainingMethod`(../../../trainable_infrastructure/common/training_method.py)
```python
class TrainingMethod(Enum):
"""
Expand All @@ -48,7 +46,8 @@ by calling `pytorch_quantization_aware_training_init` method in [`pytorch/quanti
with the following [`qat_config`](../../common/qat_config.py):

```python
from model_compression_toolkit.qat.common.qat_config import QATConfig, TrainingMethod
from model_compression_toolkit.qat.common.qat_config import QATConfig
from model_compression_toolkit.trainable_infrastructure import TrainingMethod

qat_config = QATConfig(weight_training_method=TrainingMethod.MTM)
```
Loading

0 comments on commit 3a3cc2c

Please sign in to comment.