Skip to content

Commit

Permalink
readme for Adding a new TrainingMethod (#556)
Browse files Browse the repository at this point in the history
readme for Adding a new TrainingMethod in keras and PyTorch
  • Loading branch information
elad-c authored Feb 2, 2023
1 parent 2e5ec7c commit 75f3e2b
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 0 deletions.
9 changes: 9 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,12 @@ We basically follow the development style used in many GitHub repositories.
Our maintainers will then review your changes.

5. Once your change is finalized, the maintainer will merge your change.


## Contribution Example

Here is a simple tutorial on how to add your quantization method to MCT.
Check out the readme files below:

- [Adding a QAT method in Keras](./model_compression_toolkit/qat/keras/quantizer)
- [Adding a QAT method in PyTorch](./model_compression_toolkit/qat/pytorch/quantizer)
62 changes: 62 additions & 0 deletions model_compression_toolkit/qat/keras/quantizer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Training Methods for QAT

## Introduction

Several training methods may be applied by the user to train the QAT ready model
created by `keras_quantization_aware_training_init` method in [`keras/quantization_facade`](../quantization_facade.py).
Each `TrainingMethod` (an enum defined in the [`qat_config`](../../common/qat_config.py))
and [`QuantizationMethod`](../../../core/common/target_platform/op_quantization_config.py)
selects a quantizer for weights and a quantizer for activations.

Currently, only the STE (straight through estimator) training method is implemented by the MCT.

## Make your own training method

Follow these steps in order to set the quantizers required by your training method:
- Add your `TrainingMethod` enum in [`qat_config`](../../common/qat_config.py).
- Add your quantizers for weights and activation as explained in [quantizer readme](../../../quantizers_infrastructure/keras).
- Add your `TrainingMethod` and quantizers to `METHOD2WEIGHTQUANTIZER` and `METHOD2ACTQUANTIZER` in [`quantization_dispatcher_builder.py`](../quantizer/quantization_dispatcher_builder.py)
according to your desired `QuantizationMethod`.
- Set your `TrainingMethod` in the `QATConfig` and generate the QAT ready model for training.


## Example: Adding a new training method

In this example we'll add a new quantization method, called MTM (my training method).

First, we update the `TrainingMethod` enum in [`qat_config`](../../common/qat_config.py)
```python
class TrainingMethod(Enum):
"""
An enum for selecting a QAT training method
STE - Standard straight-through estimator. Includes PowerOfTwo, symmetric & uniform quantizers
MTM - MyTrainingMethod.
"""
STE = 0
MTM = 1
```

Then we implement a weight quantizer class that implements the desired training scheme: MTMWeightQuantizer

And update the quantizer selection dictionary `METHOD2WEIGHTQUANTIZER` in [`quantization_dispatcher_builder.py`](../quantizer/quantization_dispatcher_builder.py)

```python
from my_quantizers import MTMWeightQuantizer

METHOD2WEIGHTQUANTIZER = {TrainingMethod.STE: {qi.QuantizationMethod.SYMMETRIC: STEWeightQuantizer,
qi.QuantizationMethod.POWER_OF_TWO: STEWeightQuantizer,
qi.QuantizationMethod.UNIFORM: STEUniformWeightQuantizer},
TrainingMethod.MTM: {qi.QuantizationMethod.POWER_OF_TWO: MTMWeightQuantizer}
}
```

Finally, we're ready to generate the model for quantization aware training
by calling `keras_quantization_aware_training_init` method in [`keras/quantization_facade`](../quantization_facade.py)
with the following [`qat_config`](../../common/qat_config.py):

```python
from model_compression_toolkit.qat.common.qat_config import QATConfig, TrainingMethod

qat_config = QATConfig(weight_training_method=TrainingMethod.MTM)
```
62 changes: 62 additions & 0 deletions model_compression_toolkit/qat/pytorch/quantizer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Training Methods for QAT

## Introduction

Several training methods may be applied by the user to train the QAT ready model
created by `pytorch_quantization_aware_training_init` method in [`pytorch/quantization_facade`](../quantization_facade.py).
Each `TrainingMethod` (an enum defined in the [`qat_config`](../../common/qat_config.py))
and [`QuantizationMethod`](../../../core/common/target_platform/op_quantization_config.py)
selects a quantizer for weights and a quantizer for activations.

Currently, only the STE (straight through estimator) training method is implemented by the MCT.

## Make your own training method

Follow these steps in order to set the quantizers required by your training method:
- Add your `TrainingMethod` enum in [`qat_config`](../../common/qat_config.py).
- Add your quantizers for weights and activation as explained in [quantizer readme](../../../quantizers_infrastructure/pytorch).
- Add your `TrainingMethod` and quantizers to `METHOD2WEIGHTQUANTIZER` and `METHOD2ACTQUANTIZER` in [`quantization_dispatcher_builder.py`](../quantizer/quantization_dispatcher_builder.py)
according to your desired `QuantizationMethod`.
- Set your `TrainingMethod` in the `QATConfig` and generate the QAT ready model for training.


## Example: Adding a new training method

In this example we'll add a new quantization method, called MTM (my training method).

First, we update the `TrainingMethod` enum in [`qat_config`](../../common/qat_config.py)
```python
class TrainingMethod(Enum):
"""
An enum for selecting a QAT training method
STE - Standard straight-through estimator. Includes PowerOfTwo, symmetric & uniform quantizers
MTM - MyTrainingMethod.
"""
STE = 0
MTM = 1
```

Then we implement a weight quantizer class that implements the desired training scheme: MTMWeightQuantizer

And update the quantizer selection dictionary `METHOD2WEIGHTQUANTIZER` in [`quantization_dispatcher_builder.py`](../quantizer/quantization_dispatcher_builder.py)

```python
from my_quantizers import MTMWeightQuantizer

METHOD2WEIGHTQUANTIZER = {TrainingMethod.STE: {qi.QuantizationMethod.SYMMETRIC: STEWeightQuantizer,
qi.QuantizationMethod.POWER_OF_TWO: STEWeightQuantizer,
qi.QuantizationMethod.UNIFORM: STEUniformWeightQuantizer},
TrainingMethod.MTM: {qi.QuantizationMethod.POWER_OF_TWO: MTMWeightQuantizer}
}
```

Finally, we're ready to generate the model for quantization aware training
by calling `pytorch_quantization_aware_training_init` method in [`pytorch/quantization_facade`](../quantization_facade.py)
with the following [`qat_config`](../../common/qat_config.py):

```python
from model_compression_toolkit.qat.common.qat_config import QATConfig, TrainingMethod

qat_config = QATConfig(weight_training_method=TrainingMethod.MTM)
```

0 comments on commit 75f3e2b

Please sign in to comment.