Skip to content

Commit

Permalink
adding an option to create gptq config with given hessian batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
Ofir Gordon authored and Ofir Gordon committed Jun 5, 2024
1 parent f622d57 commit 88ee7d4
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
11 changes: 7 additions & 4 deletions model_compression_toolkit/gptq/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF, ACT_HESSIAN_DEFAULT_BATCH_SIZE
from model_compression_toolkit.core.common.user_info import UserInformation
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, GPTQHessianScoresConfig
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
from model_compression_toolkit.core import CoreConfig
Expand Down Expand Up @@ -68,7 +68,8 @@ def get_keras_gptq_config(n_epochs: int,
loss: Callable = GPTQMultipleTensorsLoss(),
log_function: Callable = None,
use_hessian_based_weights: bool = True,
regularization_factor: float = REG_DEFAULT) -> GradientPTQConfig:
regularization_factor: float = REG_DEFAULT,
hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE) -> GradientPTQConfig:
"""
Create a GradientPTQConfigV2 instance for Keras models.
Expand All @@ -80,6 +81,7 @@ def get_keras_gptq_config(n_epochs: int,
log_function (Callable): Function to log information about the gptq process.
use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
regularization_factor (float): A floating point number that defines the regularization factor.
hessian_batch_size (int): Batch size for Hessian computation in Hessian-based weights GPTQ.
returns:
a GradientPTQConfigV2 object to use when fine-tuning the quantized model using gptq.
Expand Down Expand Up @@ -112,7 +114,8 @@ def get_keras_gptq_config(n_epochs: int,
train_bias=True,
optimizer_bias=bias_optimizer,
use_hessian_based_weights=use_hessian_based_weights,
regularization_factor=regularization_factor)
regularization_factor=regularization_factor,
hessian_weights_config=GPTQHessianScoresConfig(hessian_batch_size=hessian_batch_size))


def keras_gradient_post_training_quantization(in_model: Model, representative_data_gen: Callable,
Expand Down
12 changes: 8 additions & 4 deletions model_compression_toolkit/gptq/pytorch/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

from typing import Callable
from model_compression_toolkit.core import common
from model_compression_toolkit.constants import FOUND_TORCH
from model_compression_toolkit.constants import FOUND_TORCH, ACT_HESSIAN_DEFAULT_BATCH_SIZE
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import PYTORCH
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, GPTQHessianScoresConfig
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
from model_compression_toolkit.core.runner import core_runner
Expand Down Expand Up @@ -57,7 +57,9 @@ def get_pytorch_gptq_config(n_epochs: int,
loss: Callable = multiple_tensors_mse_loss,
log_function: Callable = None,
use_hessian_based_weights: bool = True,
regularization_factor: float = REG_DEFAULT) -> GradientPTQConfig:
regularization_factor: float = REG_DEFAULT,
hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE
) -> GradientPTQConfig:
"""
Create a GradientPTQConfigV2 instance for Pytorch models.
Expand All @@ -69,6 +71,7 @@ def get_pytorch_gptq_config(n_epochs: int,
log_function (Callable): Function to log information about the gptq process.
use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
regularization_factor (float): A floating point number that defines the regularization factor.
hessian_batch_size (int): Batch size for Hessian computation in Hessian-based weights GPTQ.
returns:
a GradientPTQConfigV2 object to use when fine-tuning the quantized model using gptq.
Expand All @@ -92,7 +95,8 @@ def get_pytorch_gptq_config(n_epochs: int,
return GradientPTQConfig(n_epochs, optimizer, optimizer_rest=optimizer_rest, loss=loss,
log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer,
use_hessian_based_weights=use_hessian_based_weights,
regularization_factor=regularization_factor)
regularization_factor=regularization_factor,
hessian_weights_config=GPTQHessianScoresConfig(hessian_batch_size=hessian_batch_size))


def pytorch_gradient_post_training_quantization(model: Module,
Expand Down

0 comments on commit 88ee7d4

Please sign in to comment.