diff --git a/model_compression_toolkit/gptq/keras/quantization_facade.py b/model_compression_toolkit/gptq/keras/quantization_facade.py index 7d5ba1f7e..31bc0548a 100644 --- a/model_compression_toolkit/gptq/keras/quantization_facade.py +++ b/model_compression_toolkit/gptq/keras/quantization_facade.py @@ -115,7 +115,6 @@ def get_keras_gptq_config(n_epochs: int, if regularization_factor is None: regularization_factor = REG_DEFAULT_SLA if use_hessian_sample_attention else REG_DEFAULT - loss = loss or GPTQMultipleTensorsLoss() hessian_weights_config = None if use_hessian_sample_attention: if not use_hessian_based_weights: # pragma: no cover @@ -129,7 +128,10 @@ def get_keras_gptq_config(n_epochs: int, hessian_weights_config = GPTQHessianScoresConfig(per_sample=False, hessians_num_samples=GPTQ_HESSIAN_NUM_SAMPLES, hessian_batch_size=hessian_batch_size) - + + # If a loss was not passed (and was not initialized due to use_hessian_sample_attention), use the default loss + loss = loss or GPTQMultipleTensorsLoss() + if isinstance(gradual_activation_quantization, bool): gradual_quant_config = GradualActivationQuantizationConfig() if gradual_activation_quantization else None elif isinstance(gradual_activation_quantization, GradualActivationQuantizationConfig):