Several training methods may be applied by the user to train the QAT ready model
created by pytorch_quantization_aware_training_init
method in pytorch/quantization_facade
.
Each TrainingMethod
and QuantizationMethod
selects a quantizer for weights and a quantizer for activations.
Follow these steps in order to set the quantizers required by your training method:
- Add your training method to the
TrainingMethod
enum. - Add your quantizers for weights and activation as explained in quantizer readme.
- Import your quantizer package in the quantizer
__init.py__
file. - Set your
TrainingMethod
in theQATConfig
and generate the QAT ready model for training.
In this example we'll add a new quantization method, called MTM (my training method).
First, we update the TrainingMethod
(../../../trainable_infrastructure/common/training_method.py)
class TrainingMethod(Enum):
"""
An enum for selecting a QAT training method
STE - Standard straight-through estimator. Includes PowerOfTwo, symmetric & uniform quantizers
MTM - MyTrainingMethod.
"""
STE = 0
MTM = 1
Then we implement a weight quantizer class that implements the desired training scheme: MTMWeightQuantizer,
under a new package in qat/pytorch/quantizer/mtm_quantizer/mtm.py
, and import it in the quantizer __init__.py
file.
import model_compression_toolkit.qat.pytorch.quantizer.mtm_quantizer.mtm
Finally, we're ready to generate the model for quantization aware training
by calling pytorch_quantization_aware_training_init
method in pytorch/quantization_facade
with the following qat_config
:
from model_compression_toolkit.qat.common.qat_config import QATConfig
from model_compression_toolkit.trainable_infrastructure import TrainingMethod
qat_config = QATConfig(weight_training_method=TrainingMethod.MTM)