Skip to content

Commit

Permalink
align gptq hyper params between keras and pytorch to accuracy runner
Browse files Browse the repository at this point in the history
  • Loading branch information
reuvenp committed Oct 22, 2024
1 parent 7f30c3f commit 4057c7b
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 16 deletions.
9 changes: 8 additions & 1 deletion model_compression_toolkit/gptq/common/gptq_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,11 @@

# GPTQ config constant
QUANT_PARAM_LEARNING_STR = 'quantization_parameter_learning'
MAX_LSB_STR = 'max_lsbs_change_map'
MAX_LSB_STR = 'max_lsbs_change_map'

# GPTQ learning hyperparameters
LR_DEFAULT = 3e-2
LR_REST_DEFAULT = 1e-4
LR_BIAS_DEFAULT = 1e-3
LR_QUANTIZATION_PARAM_DEFAULT = 1e-3
GPTQ_MOMENTUM = 0.9
11 changes: 3 additions & 8 deletions model_compression_toolkit/gptq/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from typing import Callable, Tuple, Union
from packaging import version

from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT, LR_DEFAULT, LR_REST_DEFAULT, \
LR_BIAS_DEFAULT, GPTQ_MOMENTUM
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import TENSORFLOW, ACT_HESSIAN_DEFAULT_BATCH_SIZE
from model_compression_toolkit.verify_packages import FOUND_TF
Expand All @@ -33,13 +33,8 @@
from model_compression_toolkit.gptq.runner import gptq_runner
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
from model_compression_toolkit.metadata import get_versions_dict, create_model_metadata
from model_compression_toolkit.metadata import create_model_metadata

LR_DEFAULT = 0.15
LR_REST_DEFAULT = 1e-4
LR_BIAS_DEFAULT = 1e-4
LR_QUANTIZATION_PARAM_DEFAULT = 1e-3
GPTQ_MOMENTUM = 0.9

if FOUND_TF:
import tensorflow as tf
Expand Down
11 changes: 4 additions & 7 deletions model_compression_toolkit/gptq/pytorch/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE, PYTORCH
from model_compression_toolkit.core import CoreConfig
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
from model_compression_toolkit.core.common.hessian import HessianScoresGranularity, HessianEstimationDistribution
from model_compression_toolkit.core.common.hessian import HessianEstimationDistribution
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
MixedPrecisionQuantizationConfig
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
Expand All @@ -27,18 +27,15 @@
from model_compression_toolkit.core.runner import core_runner
from model_compression_toolkit.gptq.common.gptq_config import (
GradientPTQConfig, GPTQHessianScoresConfig, GradualActivationQuantizationConfig)
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
from model_compression_toolkit.gptq.keras.quantization_facade import GPTQ_MOMENTUM
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT, LR_DEFAULT, LR_REST_DEFAULT, \
LR_BIAS_DEFAULT, GPTQ_MOMENTUM
from model_compression_toolkit.gptq.runner import gptq_runner
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.metadata import create_model_metadata
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
from model_compression_toolkit.verify_packages import FOUND_TORCH

LR_DEFAULT = 1e-4
LR_REST_DEFAULT = 1e-4
LR_BIAS_DEFAULT = 1e-4
LR_QUANTIZATION_PARAM_DEFAULT = 1e-4


if FOUND_TORCH:
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
Expand Down

0 comments on commit 4057c7b

Please sign in to comment.