diff --git a/model_compression_toolkit/gptq/pytorch/quantization_facade.py b/model_compression_toolkit/gptq/pytorch/quantization_facade.py index 38b09aaa8..86b557a94 100644 --- a/model_compression_toolkit/gptq/pytorch/quantization_facade.py +++ b/model_compression_toolkit/gptq/pytorch/quantization_facade.py @@ -104,7 +104,6 @@ def get_pytorch_gptq_config(n_epochs: int, if regularization_factor is None: regularization_factor = REG_DEFAULT_SLA if use_hessian_sample_attention else REG_DEFAULT - loss = loss or multiple_tensors_mse_loss hessian_weights_config = None if use_hessian_sample_attention: if not use_hessian_based_weights: # pragma: no cover @@ -118,6 +117,9 @@ def get_pytorch_gptq_config(n_epochs: int, hessian_weights_config = GPTQHessianScoresConfig(per_sample=False, hessians_num_samples=GPTQ_HESSIAN_NUM_SAMPLES, hessian_batch_size=hessian_batch_size) + + # If a loss was not passed (and was not initialized due to use_hessian_sample_attention), use the default loss + loss = loss or multiple_tensors_mse_loss if isinstance(gradual_activation_quantization, bool): gradual_quant_config = GradualActivationQuantizationConfig() if gradual_activation_quantization else None