Several training methods may be applied by the user to train the QAT ready model
created by keras_quantization_aware_training_init
method in keras/quantization_facade
.
Each TrainingMethod
(an enum defined in the qat_config
)
and QuantizationMethod
selects a quantizer for weights and a quantizer for activations.
Currently, only the STE (straight through estimator) training method is implemented by the MCT.
Follow these steps in order to set the quantizers required by your training method:
- Add your
TrainingMethod
enum inqat_config
. - Add your quantizers for weights and activation as explained in quantizer readme.
- Import your quantizer package in the quantizer
__init.py__
file. - Set your
TrainingMethod
in theQATConfig
and generate the QAT ready model for training.
In this example we'll add a new quantization method, called MTM (my training method).
First, we update the TrainingMethod
enum in qat_config
class TrainingMethod(Enum):
"""
An enum for selecting a QAT training method
STE - Standard straight-through estimator. Includes PowerOfTwo, symmetric & uniform quantizers
MTM - MyTrainingMethod.
"""
STE = 0
MTM = 1
Then we implement a weight quantizer class that implements the desired training scheme: MTMWeightQuantizer,
under a new package in qat/keras/quantizer/mtm_quantizer/mtm.py
, and import it in the quantizer __init__.py
file.
import model_compression_toolkit.qat.keras.quantizer.mtm_quantizer.mtm
Finally, we're ready to generate the model for quantization aware training
by calling keras_quantization_aware_training_init
method in keras/quantization_facade
with the following qat_config
:
from model_compression_toolkit.qat.common.qat_config import QATConfig, TrainingMethod
qat_config = QATConfig(weight_training_method=TrainingMethod.MTM)