Skip to content

Commit

Permalink
fix defaukt loss function in torch
Browse files Browse the repository at this point in the history
  • Loading branch information
reuvenperetz authored Jan 7, 2025
1 parent 009a4f7 commit 36a1f23
Showing 1 changed file with 3 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def get_pytorch_gptq_config(n_epochs: int,
if regularization_factor is None:
regularization_factor = REG_DEFAULT_SLA if use_hessian_sample_attention else REG_DEFAULT

loss = loss or multiple_tensors_mse_loss
hessian_weights_config = None
if use_hessian_sample_attention:
if not use_hessian_based_weights: # pragma: no cover
Expand All @@ -118,6 +117,9 @@ def get_pytorch_gptq_config(n_epochs: int,
hessian_weights_config = GPTQHessianScoresConfig(per_sample=False,
hessians_num_samples=GPTQ_HESSIAN_NUM_SAMPLES,
hessian_batch_size=hessian_batch_size)

# If a loss was not passed (and was not initialized due to use_hessian_sample_attention), use the default loss
loss = loss or multiple_tensors_mse_loss

if isinstance(gradual_activation_quantization, bool):
gradual_quant_config = GradualActivationQuantizationConfig() if gradual_activation_quantization else None
Expand Down

0 comments on commit 36a1f23

Please sign in to comment.