diff --git a/model_compression_toolkit/core/common/hessian/hessian_info_service.py b/model_compression_toolkit/core/common/hessian/hessian_info_service.py index 41bd896df..6d93dc126 100644 --- a/model_compression_toolkit/core/common/hessian/hessian_info_service.py +++ b/model_compression_toolkit/core/common/hessian/hessian_info_service.py @@ -204,7 +204,7 @@ def _fetch_hessians_with_compute(self, request: HessianScoresRequest, n_iteratio target_nodes = [n for n in orig_request.target_nodes if n.name in missing] request = request.clone(target_nodes=target_nodes) self._compute_hessians(request, n_iterations, count_by_cache=True) - res, missing = self.cache.fetch_hessian(request) + res, missing = self.cache.fetch_hessian(orig_request) assert not missing return res diff --git a/model_compression_toolkit/gptq/common/gptq_config.py b/model_compression_toolkit/gptq/common/gptq_config.py index 3a96fd86b..d7ef753a2 100644 --- a/model_compression_toolkit/gptq/common/gptq_config.py +++ b/model_compression_toolkit/gptq/common/gptq_config.py @@ -16,8 +16,7 @@ from enum import Enum from typing import Callable, Any, Dict, Optional -from model_compression_toolkit.constants import GPTQ_HESSIAN_NUM_SAMPLES, ACT_HESSIAN_DEFAULT_BATCH_SIZE -from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT +from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE class RoundingType(Enum): @@ -39,20 +38,26 @@ class GPTQHessianScoresConfig: Configuration to use for computing the Hessian-based scores for GPTQ loss metric. Args: + per_sample (bool): Whether to use per sample attention score. hessians_num_samples (int|None): Number of samples to use for computing the Hessian-based scores. If None, compute Hessian for all images. norm_scores (bool): Whether to normalize the returned scores of the weighted loss function (to get values between 0 and 1). log_norm (bool): Whether to use log normalization for the GPTQ Hessian-based scores. scale_log_norm (bool): Whether to scale the final vector of the Hessian-based scores. hessian_batch_size (int): The Hessian computation batch size. used only if using GPTQ with Hessian-based objective. - per_sample (bool): Whether to use per sample attention score. """ - hessians_num_samples: Optional[int] = GPTQ_HESSIAN_NUM_SAMPLES - norm_scores: bool = True - log_norm: bool = True + per_sample: bool + hessians_num_samples: Optional[int] + norm_scores: bool = None + log_norm: bool = None scale_log_norm: bool = False hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE - per_sample: bool = False + + def __post_init__(self): + if self.norm_scores is None: + self.norm_scores = not self.per_sample + if self.log_norm is None: + self.log_norm = not self.per_sample @dataclass @@ -107,32 +112,30 @@ class GradientPTQConfig: Args: n_epochs: Number of representative dataset epochs to train. - optimizer: Optimizer to use. - optimizer_rest: Optimizer to use for bias and quantizer parameters. loss: The loss to use. See 'multiple_tensors_mse_loss' for the expected interface. - log_function: Function to log information about the GPTQ process. + optimizer: Optimizer to use. + optimizer_rest: Default optimizer to use for bias and quantizer parameters. train_bias: Whether to update the bias during the training or not. - rounding_type: An enum that defines the rounding type. - use_hessian_based_weights: Whether to use Hessian-based weights for weighted average loss. - optimizer_quantization_parameter: Optimizer to override the rest optimizer for quantizer parameters. - optimizer_bias: Optimizer to override the rest optimizer for bias. - regularization_factor: A floating point number that defines the regularization factor. hessian_weights_config: A configuration that include all necessary arguments to run a computation of Hessian scores for the GPTQ loss. gradual_activation_quantization_config: A configuration for Gradual Activation Quantization. + regularization_factor: A floating point number that defines the regularization factor. + rounding_type: An enum that defines the rounding type. + optimizer_quantization_parameter: Optimizer to override the rest optimizer for quantizer parameters. + optimizer_bias: Optimizer to override the rest optimizer for bias. + log_function: Function to log information about the GPTQ process. gptq_quantizer_params_override: A dictionary of parameters to override in GPTQ quantizer instantiation. """ n_epochs: int + loss: Callable optimizer: Any - optimizer_rest: Any = None - loss: Callable = None - log_function: Callable = None - train_bias: bool = True + optimizer_rest: Any + train_bias: bool + hessian_weights_config: Optional[GPTQHessianScoresConfig] + gradual_activation_quantization_config: Optional[GradualActivationQuantizationConfig] + regularization_factor: float rounding_type: RoundingType = RoundingType.SoftQuantizer - use_hessian_based_weights: bool = True optimizer_quantization_parameter: Any = None optimizer_bias: Any = None - regularization_factor: float = REG_DEFAULT - hessian_weights_config: GPTQHessianScoresConfig = field(default_factory=GPTQHessianScoresConfig) - gradual_activation_quantization_config: Optional[GradualActivationQuantizationConfig] = None + log_function: Callable = None gptq_quantizer_params_override: Dict[str, Any] = field(default_factory=dict) diff --git a/model_compression_toolkit/gptq/common/gptq_constants.py b/model_compression_toolkit/gptq/common/gptq_constants.py index 3451d9732..50b3304c7 100644 --- a/model_compression_toolkit/gptq/common/gptq_constants.py +++ b/model_compression_toolkit/gptq/common/gptq_constants.py @@ -14,6 +14,7 @@ MIM_TEMP = 0.5 MAX_TEMP = 1.0 REG_DEFAULT = 0.01 +REG_DEFAULT_SLA = 10 MAX_LSB_CHANGE = 1 # Soft rounding arguments values @@ -27,6 +28,5 @@ # GPTQ learning hyperparameters LR_DEFAULT = 3e-2 LR_REST_DEFAULT = 1e-4 -LR_BIAS_DEFAULT = 1e-3 -LR_QUANTIZATION_PARAM_DEFAULT = 1e-3 +LR_BIAS_DEFAULT = 1e-4 GPTQ_MOMENTUM = 0.9 diff --git a/model_compression_toolkit/gptq/common/gptq_training.py b/model_compression_toolkit/gptq/common/gptq_training.py index 32fc85163..04a39e2be 100644 --- a/model_compression_toolkit/gptq/common/gptq_training.py +++ b/model_compression_toolkit/gptq/common/gptq_training.py @@ -75,7 +75,7 @@ def __init__(self, fw_info=self.fw_info) self.fxp_model, self.gptq_user_info = self.build_gptq_model() - if self.gptq_config.use_hessian_based_weights: + if self.gptq_config.hessian_weights_config: if not isinstance(hessian_info_service, HessianInfoService): Logger.critical(f"When using Hessian-based approximations for sensitivity evaluation, " f"an 'HessianInfoService' object must be provided, but received: {hessian_info_service}.") # pragma: no cover diff --git a/model_compression_toolkit/gptq/keras/gptq_training.py b/model_compression_toolkit/gptq/keras/gptq_training.py index 4ba44e8fe..166778b77 100644 --- a/model_compression_toolkit/gptq/keras/gptq_training.py +++ b/model_compression_toolkit/gptq/keras/gptq_training.py @@ -139,7 +139,7 @@ def _get_total_grad_steps(): def _get_compare_points_loss_weights(self): """ Get compare points weights for the distillation loss. """ - if self.gptq_config.use_hessian_based_weights: + if self.gptq_config.hessian_weights_config: hess_dataloader = data_gen_to_dataloader(self.representative_data_gen_fn, batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size) return self.compute_hessian_based_weights(hess_dataloader) diff --git a/model_compression_toolkit/gptq/keras/quantization_facade.py b/model_compression_toolkit/gptq/keras/quantization_facade.py index ead11c941..3233e1e28 100644 --- a/model_compression_toolkit/gptq/keras/quantization_facade.py +++ b/model_compression_toolkit/gptq/keras/quantization_facade.py @@ -21,7 +21,7 @@ from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT, LR_DEFAULT, LR_REST_DEFAULT, \ LR_BIAS_DEFAULT, GPTQ_MOMENTUM from model_compression_toolkit.logger import Logger -from model_compression_toolkit.constants import TENSORFLOW, ACT_HESSIAN_DEFAULT_BATCH_SIZE +from model_compression_toolkit.constants import TENSORFLOW, ACT_HESSIAN_DEFAULT_BATCH_SIZE, GPTQ_HESSIAN_NUM_SAMPLES from model_compression_toolkit.verify_packages import FOUND_TF from model_compression_toolkit.core.common.user_info import UserInformation from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, GPTQHessianScoresConfig, \ @@ -117,16 +117,20 @@ def get_keras_gptq_config(n_epochs: int, raise TypeError(f'gradual_activation_quantization argument should be bool or ' f'GradualActivationQuantizationConfig, received {type(gradual_activation_quantization)}') - return GradientPTQConfig(n_epochs, - optimizer, + hessian_weights_config = None + if use_hessian_based_weights: + hessian_weights_config = GPTQHessianScoresConfig(per_sample=False, + hessians_num_samples=GPTQ_HESSIAN_NUM_SAMPLES, + hessian_batch_size=hessian_batch_size) + return GradientPTQConfig(n_epochs=n_epochs, + optimizer=optimizer, optimizer_rest=optimizer_rest, loss=loss, log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer, - use_hessian_based_weights=use_hessian_based_weights, regularization_factor=regularization_factor, - hessian_weights_config=GPTQHessianScoresConfig(hessian_batch_size=hessian_batch_size), + hessian_weights_config=hessian_weights_config, gradual_activation_quantization_config=gradual_quant_config) diff --git a/model_compression_toolkit/gptq/pytorch/gptq_training.py b/model_compression_toolkit/gptq/pytorch/gptq_training.py index ff510e6a6..ee91a36d3 100644 --- a/model_compression_toolkit/gptq/pytorch/gptq_training.py +++ b/model_compression_toolkit/gptq/pytorch/gptq_training.py @@ -116,7 +116,7 @@ def _get_total_grad_steps(): trainable_threshold) hessian_cfg = self.gptq_config.hessian_weights_config - self.use_sample_layer_attention = hessian_cfg.per_sample + self.use_sample_layer_attention = hessian_cfg and hessian_cfg.per_sample if self.use_sample_layer_attention: # normalization is currently not supported, make sure the config reflects it. if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm: @@ -178,7 +178,7 @@ def _prepare_train_dataloader_for_non_sla(self, data_gen_fn: Callable[[], Genera dataset = IterableDatasetFromGenerator(data_gen_fn) num_nodes = len(self.compare_points) - if self.gptq_config.use_hessian_based_weights: + if self.gptq_config.hessian_weights_config: hess_dataloader = DataLoader(dataset, batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size) loss_weights = torch.from_numpy(self.compute_hessian_based_weights(hess_dataloader)) else: diff --git a/model_compression_toolkit/gptq/pytorch/quantization_facade.py b/model_compression_toolkit/gptq/pytorch/quantization_facade.py index 220fbb3ec..dbe964ee0 100644 --- a/model_compression_toolkit/gptq/pytorch/quantization_facade.py +++ b/model_compression_toolkit/gptq/pytorch/quantization_facade.py @@ -15,7 +15,7 @@ import copy from typing import Callable, Union -from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE, PYTORCH +from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE, PYTORCH, GPTQ_HESSIAN_NUM_SAMPLES from model_compression_toolkit.core import CoreConfig from model_compression_toolkit.core.analyzer import analyzer_model_quantization from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \ @@ -27,7 +27,7 @@ from model_compression_toolkit.gptq.common.gptq_config import ( GradientPTQConfig, GPTQHessianScoresConfig, GradualActivationQuantizationConfig) from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT, LR_DEFAULT, LR_REST_DEFAULT, \ - LR_BIAS_DEFAULT, GPTQ_MOMENTUM + LR_BIAS_DEFAULT, GPTQ_MOMENTUM, REG_DEFAULT_SLA from model_compression_toolkit.gptq.runner import gptq_runner from model_compression_toolkit.logger import Logger from model_compression_toolkit.metadata import create_model_metadata @@ -55,10 +55,10 @@ def get_pytorch_gptq_config(n_epochs: int, loss: Callable = None, log_function: Callable = None, use_hessian_based_weights: bool = True, - regularization_factor: float = REG_DEFAULT, + regularization_factor: float = None, hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE, - use_hessian_sample_attention: bool = False, - gradual_activation_quantization: Union[bool, GradualActivationQuantizationConfig] = False, + use_hessian_sample_attention: bool = True, + gradual_activation_quantization: Union[bool, GradualActivationQuantizationConfig] = True, ) -> GradientPTQConfig: """ Create a GradientPTQConfig instance for Pytorch models. @@ -94,25 +94,26 @@ def get_pytorch_gptq_config(n_epochs: int, """ optimizer = optimizer or Adam([torch.Tensor([])], lr=LR_DEFAULT) optimizer_rest = optimizer_rest or Adam([torch.Tensor([])], lr=LR_REST_DEFAULT) - + # TODO this contradicts the docstring for optimizer_rest bias_optimizer = torch.optim.SGD([torch.Tensor([])], lr=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM) + if regularization_factor is None: + regularization_factor = REG_DEFAULT_SLA if use_hessian_sample_attention else REG_DEFAULT + + loss = loss or multiple_tensors_mse_loss + hessian_weights_config = None if use_hessian_sample_attention: if not use_hessian_based_weights: # pragma: no cover raise ValueError('use_hessian_based_weights must be set to True in order to use Sample Layer Attention.') - hessian_weights_config = GPTQHessianScoresConfig( - hessians_num_samples=None, - norm_scores=False, - log_norm=False, - scale_log_norm=False, - hessian_batch_size=hessian_batch_size, - per_sample=True, - ) + hessian_weights_config = GPTQHessianScoresConfig(per_sample=True, + hessians_num_samples=None, + hessian_batch_size=hessian_batch_size) loss = loss or sample_layer_attention_loss - else: - hessian_weights_config = GPTQHessianScoresConfig(hessian_batch_size=hessian_batch_size) - loss = loss or multiple_tensors_mse_loss + elif use_hessian_based_weights: + hessian_weights_config = GPTQHessianScoresConfig(per_sample=False, + hessians_num_samples=GPTQ_HESSIAN_NUM_SAMPLES, + hessian_batch_size=hessian_batch_size) if isinstance(gradual_activation_quantization, bool): gradual_quant_config = GradualActivationQuantizationConfig() if gradual_activation_quantization else None @@ -122,12 +123,16 @@ def get_pytorch_gptq_config(n_epochs: int, raise TypeError(f'gradual_activation_quantization argument should be bool or ' f'GradualActivationQuantizationConfig, received {type(gradual_activation_quantization)}') - return GradientPTQConfig(n_epochs, optimizer, optimizer_rest=optimizer_rest, loss=loss, - log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer, - use_hessian_based_weights=use_hessian_based_weights, + return GradientPTQConfig(n_epochs=n_epochs, + loss=loss, + optimizer=optimizer, + optimizer_rest=optimizer_rest, + optimizer_bias=bias_optimizer, + train_bias=True, regularization_factor=regularization_factor, hessian_weights_config=hessian_weights_config, - gradual_activation_quantization_config=gradual_quant_config) + gradual_activation_quantization_config=gradual_quant_config, + log_function=log_function) def pytorch_gradient_post_training_quantization(model: Module, representative_data_gen: Callable, diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/gptq/gptq_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/gptq/gptq_test.py index 0b7c60acc..d16eb8a16 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/gptq/gptq_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/gptq/gptq_test.py @@ -59,8 +59,10 @@ def build_model(in_input_shape: List[int]) -> keras.Model: class GradientPTQBaseTest(BaseKerasFeatureNetworkTest): def __init__(self, unit_test, quant_method=QuantizationMethod.SYMMETRIC, rounding_type=RoundingType.STE, per_channel=True, input_shape=(1, 16, 16, 3), - hessian_weights=True, log_norm_weights=True, scaled_log_norm=False, - quantization_parameter_learning=True, num_calibration_iter=GPTQ_HESSIAN_NUM_SAMPLES): + hessian_weights=True, log_norm_weights=True, scaled_log_norm=False, norm_scores=False, + quantization_parameter_learning=True, num_calibration_iter=GPTQ_HESSIAN_NUM_SAMPLES, + hessian_num_samples=GPTQ_HESSIAN_NUM_SAMPLES, per_sample=False, + reg_factor=1, grad_act_quant_cfg=None): super().__init__(unit_test, input_shape=input_shape, num_calibration_iter=num_calibration_iter) @@ -68,15 +70,24 @@ def __init__(self, unit_test, quant_method=QuantizationMethod.SYMMETRIC, roundin self.quant_method = quant_method self.rounding_type = rounding_type self.per_channel = per_channel - self.hessian_weights = hessian_weights - self.log_norm_weights = log_norm_weights - self.scaled_log_norm = scaled_log_norm + + self.hessian_weights_config = None + if hessian_weights: + self.hessian_weights_config = GPTQHessianScoresConfig(per_sample=per_sample, + norm_scores=norm_scores, + log_norm=log_norm_weights, + scale_log_norm=scaled_log_norm, + hessians_num_samples=hessian_num_samples) + + self.grad_act_quant_cfg = grad_act_quant_cfg + self.reg_factor = reg_factor + if rounding_type == RoundingType.SoftQuantizer: self.override_params = {QUANT_PARAM_LEARNING_STR: quantization_parameter_learning} elif rounding_type == RoundingType.STE: self.override_params = {MAX_LSB_STR: DefaultDict(default_value=1)} else: - self.override_params = None + raise ValueError('unknown rounding type', rounding_type) def get_tpc(self): return get_tpc("gptq_test", 16, 16, self.quant_method) @@ -87,12 +98,13 @@ def get_quantization_config(self): relu_bound_to_power_of_2=True, weights_bias_correction=False) def get_gptq_config(self): - return GradientPTQConfig(5, optimizer=tf.keras.optimizers.Adam( - learning_rate=0.0001), optimizer_rest=tf.keras.optimizers.Adam( - learning_rate=0.0001), loss=multiple_tensors_mse_loss, train_bias=True, rounding_type=self.rounding_type, - use_hessian_based_weights=self.hessian_weights, - hessian_weights_config=GPTQHessianScoresConfig(log_norm=self.log_norm_weights, - scale_log_norm=self.scaled_log_norm), + return GradientPTQConfig(5, + optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), + optimizer_rest=tf.keras.optimizers.Adam(learning_rate=0.0001), + loss=multiple_tensors_mse_loss, train_bias=True, rounding_type=self.rounding_type, + hessian_weights_config=self.hessian_weights_config, + gradual_activation_quantization_config=self.grad_act_quant_cfg, + regularization_factor=self.reg_factor, gptq_quantizer_params_override=self.override_params) def create_networks(self): @@ -144,9 +156,13 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= class GradientPTQNoTempLearningTest(GradientPTQBaseTest): def get_gptq_config(self): - return GradientPTQConfig(1, optimizer=tf.keras.optimizers.Adam( - learning_rate=0.0001), optimizer_rest=tf.keras.optimizers.Adam( - learning_rate=0.0001), loss=multiple_tensors_mse_loss, train_bias=True, rounding_type=self.rounding_type, + return GradientPTQConfig(1, + optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), + optimizer_rest=tf.keras.optimizers.Adam(learning_rate=0.0001), + loss=multiple_tensors_mse_loss, train_bias=True, rounding_type=self.rounding_type, + hessian_weights_config=self.hessian_weights_config, + gradual_activation_quantization_config=self.grad_act_quant_cfg, + regularization_factor=self.reg_factor, gptq_quantizer_params_override=self.override_params) def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): @@ -159,9 +175,13 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= class GradientPTQWeightsUpdateTest(GradientPTQBaseTest): def get_gptq_config(self): - return GradientPTQConfig(20, optimizer=tf.keras.optimizers.Adam( - learning_rate=1e-2), optimizer_rest=tf.keras.optimizers.Adam( - learning_rate=1e-1), loss=multiple_tensors_mse_loss, train_bias=True, rounding_type=self.rounding_type, + return GradientPTQConfig(20, + optimizer=tf.keras.optimizers.Adam(learning_rate=1e-2), + optimizer_rest=tf.keras.optimizers.Adam(learning_rate=1e-1), + loss=multiple_tensors_mse_loss, train_bias=True, rounding_type=self.rounding_type, + hessian_weights_config=self.hessian_weights_config, + gradual_activation_quantization_config=self.grad_act_quant_cfg, + regularization_factor=self.reg_factor, gptq_quantizer_params_override=self.override_params) def compare(self, quantized_model, quantized_gptq_model, input_x=None, quantization_info=None): @@ -182,9 +202,13 @@ def compare(self, quantized_model, quantized_gptq_model, input_x=None, quantizat class GradientPTQLearnRateZeroTest(GradientPTQBaseTest): def get_gptq_config(self): - return GradientPTQConfig(1, optimizer=tf.keras.optimizers.SGD( - learning_rate=0.0), optimizer_rest=tf.keras.optimizers.SGD( - learning_rate=0.0), loss=multiple_tensors_mse_loss, train_bias=True, rounding_type=self.rounding_type, + return GradientPTQConfig(1, + optimizer=tf.keras.optimizers.SGD(learning_rate=0.0), + optimizer_rest=tf.keras.optimizers.SGD(learning_rate=0.0), + loss=multiple_tensors_mse_loss, train_bias=True, rounding_type=self.rounding_type, + hessian_weights_config=self.hessian_weights_config, + gradual_activation_quantization_config=self.grad_act_quant_cfg, + regularization_factor=self.reg_factor, gptq_quantizer_params_override=self.override_params) def compare(self, quantized_model, quantized_gptq_model, input_x=None, quantization_info=None): @@ -200,16 +224,17 @@ def compare(self, quantized_model, quantized_gptq_model, input_x=None, quantizat class GradientPTQWeightedLossTest(GradientPTQBaseTest): + def __init__(self, unit_test, **kwargs): + super().__init__(unit_test, norm_scores=False, hessian_num_samples=16, **kwargs) def get_gptq_config(self): - return GradientPTQConfig(5, optimizer=tf.keras.optimizers.Adam( - learning_rate=0.0001), optimizer_rest=tf.keras.optimizers.Adam( - learning_rate=0.0001), loss=multiple_tensors_mse_loss, train_bias=True, rounding_type=self.rounding_type, - use_hessian_based_weights=True, - hessian_weights_config=GPTQHessianScoresConfig(hessians_num_samples=16, - norm_scores=False, - log_norm=self.log_norm_weights, - scale_log_norm=self.scaled_log_norm), + return GradientPTQConfig(5, + optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), + optimizer_rest=tf.keras.optimizers.Adam(learning_rate=0.0001), + loss=multiple_tensors_mse_loss, train_bias=True, rounding_type=self.rounding_type, + hessian_weights_config=self.hessian_weights_config, + gradual_activation_quantization_config=self.grad_act_quant_cfg, + regularization_factor=self.reg_factor, gptq_quantizer_params_override=self.override_params) def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index 59336b057..b11be7c04 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -710,7 +710,7 @@ def test_gptq(self): per_channel=True, hessian_weights=True, log_norm_weights=True, scaled_log_norm=True).run_test() GradientPTQWeightedLossTest(self, rounding_type=RoundingType.SoftQuantizer, - per_channel=True, hessian_weights=True, log_norm_weights=True, + per_channel=True, log_norm_weights=True, scaled_log_norm=True).run_test() GradientPTQNoTempLearningTest(self, rounding_type=RoundingType.SoftQuantizer).run_test() diff --git a/tests/keras_tests/function_tests/test_get_gptq_config.py b/tests/keras_tests/function_tests/test_get_gptq_config.py index fbdf99921..f7d2a0e0a 100644 --- a/tests/keras_tests/function_tests/test_get_gptq_config.py +++ b/tests/keras_tests/function_tests/test_get_gptq_config.py @@ -66,7 +66,10 @@ def setUp(self): weights_bias_correction=False) # disable bias correction when working with GPTQ self.cc = CoreConfig(quantization_config=self.qc) - test_hessian_weights_config = GPTQHessianScoresConfig(hessians_num_samples=2, + default_reg_factor = 1 + default_hessian_cfg = GPTQHessianScoresConfig(per_sample=False, hessians_num_samples=2) + test_hessian_weights_config = GPTQHessianScoresConfig(per_sample=False, + hessians_num_samples=2, norm_scores=False, log_norm=True, scale_log_norm=True) @@ -75,35 +78,51 @@ def setUp(self): optimizer_rest=tf.keras.optimizers.RMSprop(), train_bias=True, loss=multiple_tensors_mse_loss, - rounding_type=RoundingType.SoftQuantizer), + rounding_type=RoundingType.SoftQuantizer, + hessian_weights_config=default_hessian_cfg, + regularization_factor=default_reg_factor, + gradual_activation_quantization_config=None), GradientPTQConfig(1, optimizer=tf.keras.optimizers.Adam(), optimizer_rest=tf.keras.optimizers.Adam(), train_bias=True, loss=multiple_tensors_mse_loss, - rounding_type=RoundingType.SoftQuantizer), + rounding_type=RoundingType.SoftQuantizer, + hessian_weights_config=default_hessian_cfg, + regularization_factor=default_reg_factor, + gradual_activation_quantization_config=None), GradientPTQConfig(1, optimizer=tf.keras.optimizers.Adam(), optimizer_rest=tf.keras.optimizers.Adam(), train_bias=True, loss=multiple_tensors_mse_loss, rounding_type=RoundingType.SoftQuantizer, - regularization_factor=15), + hessian_weights_config=default_hessian_cfg, + regularization_factor=15, + gradual_activation_quantization_config=None), GradientPTQConfig(1, optimizer=tf.keras.optimizers.Adam(), optimizer_rest=tf.keras.optimizers.Adam(), train_bias=True, loss=multiple_tensors_mse_loss, rounding_type=RoundingType.SoftQuantizer, + hessian_weights_config=default_hessian_cfg, + regularization_factor=default_reg_factor, + gradual_activation_quantization_config=None, gptq_quantizer_params_override={QUANT_PARAM_LEARNING_STR: True}), GradientPTQConfig(1, optimizer=tf.keras.optimizers.Adam(), optimizer_rest=tf.keras.optimizers.Adam(), train_bias=True, loss=multiple_tensors_mse_loss, rounding_type=RoundingType.SoftQuantizer, - hessian_weights_config=test_hessian_weights_config), + hessian_weights_config=test_hessian_weights_config, + regularization_factor=default_reg_factor, + gradual_activation_quantization_config=None), GradientPTQConfig(1, optimizer=tf.keras.optimizers.Adam(), optimizer_rest=tf.keras.optimizers.Adam(), train_bias=True, loss=multiple_tensors_mse_loss, rounding_type=RoundingType.STE, + hessian_weights_config=default_hessian_cfg, + regularization_factor=default_reg_factor, + gradual_activation_quantization_config=None, gptq_quantizer_params_override={ MAX_LSB_STR: DefaultDict(default_value=1)}), get_keras_gptq_config(n_epochs=1, @@ -112,7 +131,6 @@ def setUp(self): optimizer=tf.keras.optimizers.Adam(), regularization_factor=0.001)] - pot_tp = generate_test_tp_model({'weights_quantization_method': QuantizationMethod.POWER_OF_TWO}) self.pot_weights_tpc = generate_keras_tpc(name="gptq_pot_config_test", tp_model=pot_tp) @@ -153,10 +171,12 @@ def test_get_keras_gptq_config_symmetric(self): tf.config.run_functions_eagerly(False) def test_gradual_activation_quantization_custom_config(self): - custom_config = mct.gptq.GradualActivationQuantizationConfig(q_fraction_scheduler_policy=mct.gptq.QFractionLinearAnnealingConfig(initial_q_fraction=0.2, - target_q_fraction=0.8, - start_step=1, - end_step=2)) + custom_config = mct.gptq.GradualActivationQuantizationConfig( + q_fraction_scheduler_policy=mct.gptq.QFractionLinearAnnealingConfig(initial_q_fraction=0.2, + target_q_fraction=0.8, + start_step=1, + end_step=2) + ) config = get_keras_gptq_config(n_epochs=5, gradual_activation_quantization=custom_config) self.assertEqual(config.gradual_activation_quantization_config, custom_config) diff --git a/tests/pytorch_tests/function_tests/test_activation_quantization_holder_gptq.py b/tests/pytorch_tests/function_tests/test_activation_quantization_holder_gptq.py index 1746de6ec..0c1e8ce8f 100644 --- a/tests/pytorch_tests/function_tests/test_activation_quantization_holder_gptq.py +++ b/tests/pytorch_tests/function_tests/test_activation_quantization_holder_gptq.py @@ -112,6 +112,7 @@ def test_adding_holder_with_gradual_act_quantization(self): QFractionLinearAnnealingConfig(initial_q_fraction=0.1, target_q_fraction=0.9, start_step=100, end_step=500) ) gptq_cfg = mct.gptq.get_pytorch_gptq_config(1, use_hessian_based_weights=False, + use_hessian_sample_attention=False, gradual_activation_quantization=gradual_act_quant_cfg) gptq_model = self._get_gptq_model(INPUT_SHAPE, BasicModel(), gptq_cfg) activation_holders = self._get_holders_with_validation(gptq_model, exp_n_holders=3) @@ -153,7 +154,9 @@ def _get_gptq_model(self, input_shape, in_model, gptq_cfg=None): qc=qc) graph = set_bit_widths(mixed_precision_enable=False, graph=graph) - gptq_cfg = gptq_cfg or mct.gptq.get_pytorch_gptq_config(1, use_hessian_based_weights=False) + gptq_cfg = gptq_cfg or mct.gptq.get_pytorch_gptq_config(1, use_hessian_based_weights=False, + use_hessian_sample_attention=False, + gradual_activation_quantization=False) trainer = PytorchGPTQTrainer(graph, graph, gptq_cfg, diff --git a/tests/pytorch_tests/model_tests/feature_models/gptq_test.py b/tests/pytorch_tests/model_tests/feature_models/gptq_test.py index 602d7c839..3b8b2230c 100644 --- a/tests/pytorch_tests/model_tests/feature_models/gptq_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/gptq_test.py @@ -57,25 +57,30 @@ def __init__(self, unit_test, weights_bits=8, weights_quant_method=QuantizationM hessian_weights=True, norm_scores=True, log_norm_weights=True, scaled_log_norm=False, params_learning=True, num_calibration_iter=GPTQ_HESSIAN_NUM_SAMPLES, gradual_activation_quantization=False, hessian_num_samples=GPTQ_HESSIAN_NUM_SAMPLES, sample_layer_attention=False, - loss=multiple_tensors_mse_loss, hessian_batch_size=1): + loss=multiple_tensors_mse_loss, hessian_batch_size=1, reg_factor=1): super().__init__(unit_test, input_shape=(3, 16, 16), num_calibration_iter=num_calibration_iter) self.seed = 0 self.rounding_type = rounding_type self.weights_bits = weights_bits self.weights_quant_method = weights_quant_method self.per_channel = per_channel - self.hessian_weights = hessian_weights - self.norm_scores = norm_scores - self.log_norm_weights = log_norm_weights - self.scaled_log_norm = scaled_log_norm - self.override_params = {QUANT_PARAM_LEARNING_STR: params_learning} if \ - rounding_type == RoundingType.SoftQuantizer else {MAX_LSB_STR: DefaultDict(default_value=1)} \ - if rounding_type == RoundingType.STE else None + if rounding_type == RoundingType.SoftQuantizer: + self.override_params = {QUANT_PARAM_LEARNING_STR: params_learning} + elif rounding_type == RoundingType.STE: + self.override_params = {MAX_LSB_STR: DefaultDict(default_value=1)} + else: + raise ValueError('unknown rounding_type', rounding_type) self.gradual_activation_quantization = gradual_activation_quantization - self.hessian_num_samples = hessian_num_samples - self.sample_layer_attention = sample_layer_attention self.loss = loss - self.hessian_batch_size = hessian_batch_size + self.reg_factor = reg_factor + self.hessian_cfg = None + if hessian_weights: + self.hessian_cfg = GPTQHessianScoresConfig(per_sample=sample_layer_attention, + norm_scores=norm_scores, + log_norm=log_norm_weights, + scale_log_norm=scaled_log_norm, + hessians_num_samples=hessian_num_samples, + hessian_batch_size=hessian_batch_size) def get_quantization_config(self): return mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.NOCLIPPING, @@ -145,16 +150,11 @@ def get_gptq_config(self): return GradientPTQConfig(5, optimizer=torch.optim.Adam([torch.Tensor([])], lr=1e-4), optimizer_rest=torch.optim.Adam([torch.Tensor([])], lr=1e-4), loss=self.loss, train_bias=True, rounding_type=self.rounding_type, - use_hessian_based_weights=self.hessian_weights, optimizer_bias=torch.optim.Adam([torch.Tensor([])], lr=0.4), - hessian_weights_config=GPTQHessianScoresConfig(log_norm=self.log_norm_weights, - scale_log_norm=self.scaled_log_norm, - norm_scores=self.norm_scores, - per_sample=self.sample_layer_attention, - hessians_num_samples=self.hessian_num_samples, - hessian_batch_size=self.hessian_batch_size), + hessian_weights_config=self.hessian_cfg, gptq_quantizer_params_override=self.override_params, - gradual_activation_quantization_config=gradual_act_cfg) + gradual_activation_quantization_config=gradual_act_cfg, + regularization_factor=self.reg_factor) def gptq_compare(self, ptq_model, gptq_model, input_x=None): ptq_weights = torch_tensor_to_numpy(list(ptq_model.parameters())) @@ -171,7 +171,9 @@ def get_gptq_config(self): optimizer_rest=torch.optim.Adam([torch.Tensor([])], lr=0.5), loss=multiple_tensors_mse_loss, train_bias=True, rounding_type=self.rounding_type, gradual_activation_quantization_config=gradual_act_cfg, - gptq_quantizer_params_override=self.override_params) + gptq_quantizer_params_override=self.override_params, + regularization_factor=self.reg_factor, + hessian_weights_config=self.hessian_cfg) def compare(self, ptq_model, gptq_model, input_x=None, max_change=None): ptq_weights = torch_tensor_to_numpy(list(ptq_model.parameters())) @@ -194,7 +196,9 @@ def get_gptq_config(self): optimizer_rest=torch.optim.Adam([torch.Tensor([])], lr=0), loss=multiple_tensors_mse_loss, train_bias=False, rounding_type=self.rounding_type, gradual_activation_quantization_config=gradual_act_cfg, - gptq_quantizer_params_override=self.override_params) + gptq_quantizer_params_override=self.override_params, + regularization_factor=self.reg_factor, + hessian_weights_config=self.hessian_cfg) def gptq_compare(self, ptq_model, gptq_model, input_x=None): ptq_out = torch_tensor_to_numpy(ptq_model(input_x)) diff --git a/tests_pytest/keras/gptq/test_gradual_act_quantization.py b/tests_pytest/keras/gptq/test_gradual_act_quantization.py index 9eceff8bc..cab0a78c7 100644 --- a/tests_pytest/keras/gptq/test_gradual_act_quantization.py +++ b/tests_pytest/keras/gptq/test_gradual_act_quantization.py @@ -94,7 +94,8 @@ def test_factory_linear_common_case(self, x): def _run_factory_test(self, qdrop_cfg, get_grad_steps_fn): # Mocks are used to just pass anything - gptq_cfg = GradientPTQConfig(n_epochs=5, optimizer=Mock(), loss=Mock(), + gptq_cfg = GradientPTQConfig(n_epochs=5, optimizer=Mock(), optimizer_rest=Mock(), loss=Mock(), + hessian_weights_config=None, train_bias=False, regularization_factor=1, gradual_activation_quantization_config=qdrop_cfg) factory = get_gradual_activation_quantizer_wrapper_factory(gptq_cfg, get_grad_steps_fn, KerasLinearAnnealingScheduler) quantizer = Quantizer() diff --git a/tests_pytest/pytorch/gptq/test_gradual_act_quantization.py b/tests_pytest/pytorch/gptq/test_gradual_act_quantization.py index 08c440d6c..02173a7ec 100644 --- a/tests_pytest/pytorch/gptq/test_gradual_act_quantization.py +++ b/tests_pytest/pytorch/gptq/test_gradual_act_quantization.py @@ -92,7 +92,8 @@ def test_factory_linear_common_case(self, x): def _run_factory_test(self, qdrop_cfg, get_grad_steps_fn): # Mocks are used to just pass anything - gptq_cfg = GradientPTQConfig(n_epochs=5, optimizer=Mock(), loss=Mock(), + gptq_cfg = GradientPTQConfig(n_epochs=5, optimizer=Mock(), loss=Mock(), optimizer_rest=Mock(), + hessian_weights_config=None, train_bias=False, regularization_factor=1, gradual_activation_quantization_config=qdrop_cfg) factory = get_gradual_activation_quantizer_wrapper_factory(gptq_cfg, get_grad_steps_fn, PytorchLinearAnnealingScheduler) quantizer = Quantizer()