Skip to content

Commit

Permalink
fix default loss function setting in keras
Browse files Browse the repository at this point in the history
  • Loading branch information
reuvenperetz authored Jan 7, 2025
1 parent f1afb23 commit 009a4f7
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions model_compression_toolkit/gptq/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def get_keras_gptq_config(n_epochs: int,
if regularization_factor is None:
regularization_factor = REG_DEFAULT_SLA if use_hessian_sample_attention else REG_DEFAULT

loss = loss or GPTQMultipleTensorsLoss()
hessian_weights_config = None
if use_hessian_sample_attention:
if not use_hessian_based_weights: # pragma: no cover
Expand All @@ -129,7 +128,10 @@ def get_keras_gptq_config(n_epochs: int,
hessian_weights_config = GPTQHessianScoresConfig(per_sample=False,
hessians_num_samples=GPTQ_HESSIAN_NUM_SAMPLES,
hessian_batch_size=hessian_batch_size)


# If a loss was not passed (and was not initialized due to use_hessian_sample_attention), use the default loss
loss = loss or GPTQMultipleTensorsLoss()

if isinstance(gradual_activation_quantization, bool):
gradual_quant_config = GradualActivationQuantizationConfig() if gradual_activation_quantization else None
elif isinstance(gradual_activation_quantization, GradualActivationQuantizationConfig):
Expand Down

0 comments on commit 009a4f7

Please sign in to comment.