Skip to content

Commit

Permalink
Enable SLA + gradual activation quantization by default for PyTorch (#…
Browse files Browse the repository at this point in the history
…1260)

* update default reg factor and bias lr

* update gptq config, set sla + gradual activation as default

* fix tests

* update hessian config

* fix bug in hessian service
  • Loading branch information
irenaby authored Nov 12, 2024
1 parent f34972e commit 9ac9f54
Show file tree
Hide file tree
Showing 15 changed files with 186 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _fetch_hessians_with_compute(self, request: HessianScoresRequest, n_iteratio
target_nodes = [n for n in orig_request.target_nodes if n.name in missing]
request = request.clone(target_nodes=target_nodes)
self._compute_hessians(request, n_iterations, count_by_cache=True)
res, missing = self.cache.fetch_hessian(request)
res, missing = self.cache.fetch_hessian(orig_request)
assert not missing
return res

Expand Down
49 changes: 26 additions & 23 deletions model_compression_toolkit/gptq/common/gptq_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
from enum import Enum
from typing import Callable, Any, Dict, Optional

from model_compression_toolkit.constants import GPTQ_HESSIAN_NUM_SAMPLES, ACT_HESSIAN_DEFAULT_BATCH_SIZE
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE


class RoundingType(Enum):
Expand All @@ -39,20 +38,26 @@ class GPTQHessianScoresConfig:
Configuration to use for computing the Hessian-based scores for GPTQ loss metric.
Args:
per_sample (bool): Whether to use per sample attention score.
hessians_num_samples (int|None): Number of samples to use for computing the Hessian-based scores.
If None, compute Hessian for all images.
norm_scores (bool): Whether to normalize the returned scores of the weighted loss function (to get values between 0 and 1).
log_norm (bool): Whether to use log normalization for the GPTQ Hessian-based scores.
scale_log_norm (bool): Whether to scale the final vector of the Hessian-based scores.
hessian_batch_size (int): The Hessian computation batch size. used only if using GPTQ with Hessian-based objective.
per_sample (bool): Whether to use per sample attention score.
"""
hessians_num_samples: Optional[int] = GPTQ_HESSIAN_NUM_SAMPLES
norm_scores: bool = True
log_norm: bool = True
per_sample: bool
hessians_num_samples: Optional[int]
norm_scores: bool = None
log_norm: bool = None
scale_log_norm: bool = False
hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE
per_sample: bool = False

def __post_init__(self):
if self.norm_scores is None:
self.norm_scores = not self.per_sample
if self.log_norm is None:
self.log_norm = not self.per_sample


@dataclass
Expand Down Expand Up @@ -107,32 +112,30 @@ class GradientPTQConfig:
Args:
n_epochs: Number of representative dataset epochs to train.
optimizer: Optimizer to use.
optimizer_rest: Optimizer to use for bias and quantizer parameters.
loss: The loss to use. See 'multiple_tensors_mse_loss' for the expected interface.
log_function: Function to log information about the GPTQ process.
optimizer: Optimizer to use.
optimizer_rest: Default optimizer to use for bias and quantizer parameters.
train_bias: Whether to update the bias during the training or not.
rounding_type: An enum that defines the rounding type.
use_hessian_based_weights: Whether to use Hessian-based weights for weighted average loss.
optimizer_quantization_parameter: Optimizer to override the rest optimizer for quantizer parameters.
optimizer_bias: Optimizer to override the rest optimizer for bias.
regularization_factor: A floating point number that defines the regularization factor.
hessian_weights_config: A configuration that include all necessary arguments to run a computation of
Hessian scores for the GPTQ loss.
gradual_activation_quantization_config: A configuration for Gradual Activation Quantization.
regularization_factor: A floating point number that defines the regularization factor.
rounding_type: An enum that defines the rounding type.
optimizer_quantization_parameter: Optimizer to override the rest optimizer for quantizer parameters.
optimizer_bias: Optimizer to override the rest optimizer for bias.
log_function: Function to log information about the GPTQ process.
gptq_quantizer_params_override: A dictionary of parameters to override in GPTQ quantizer instantiation.
"""
n_epochs: int
loss: Callable
optimizer: Any
optimizer_rest: Any = None
loss: Callable = None
log_function: Callable = None
train_bias: bool = True
optimizer_rest: Any
train_bias: bool
hessian_weights_config: Optional[GPTQHessianScoresConfig]
gradual_activation_quantization_config: Optional[GradualActivationQuantizationConfig]
regularization_factor: float
rounding_type: RoundingType = RoundingType.SoftQuantizer
use_hessian_based_weights: bool = True
optimizer_quantization_parameter: Any = None
optimizer_bias: Any = None
regularization_factor: float = REG_DEFAULT
hessian_weights_config: GPTQHessianScoresConfig = field(default_factory=GPTQHessianScoresConfig)
gradual_activation_quantization_config: Optional[GradualActivationQuantizationConfig] = None
log_function: Callable = None
gptq_quantizer_params_override: Dict[str, Any] = field(default_factory=dict)
4 changes: 2 additions & 2 deletions model_compression_toolkit/gptq/common/gptq_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
MIM_TEMP = 0.5
MAX_TEMP = 1.0
REG_DEFAULT = 0.01
REG_DEFAULT_SLA = 10
MAX_LSB_CHANGE = 1

# Soft rounding arguments values
Expand All @@ -27,6 +28,5 @@
# GPTQ learning hyperparameters
LR_DEFAULT = 3e-2
LR_REST_DEFAULT = 1e-4
LR_BIAS_DEFAULT = 1e-3
LR_QUANTIZATION_PARAM_DEFAULT = 1e-3
LR_BIAS_DEFAULT = 1e-4
GPTQ_MOMENTUM = 0.9
2 changes: 1 addition & 1 deletion model_compression_toolkit/gptq/common/gptq_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self,
fw_info=self.fw_info)

self.fxp_model, self.gptq_user_info = self.build_gptq_model()
if self.gptq_config.use_hessian_based_weights:
if self.gptq_config.hessian_weights_config:
if not isinstance(hessian_info_service, HessianInfoService):
Logger.critical(f"When using Hessian-based approximations for sensitivity evaluation, "
f"an 'HessianInfoService' object must be provided, but received: {hessian_info_service}.") # pragma: no cover
Expand Down
2 changes: 1 addition & 1 deletion model_compression_toolkit/gptq/keras/gptq_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _get_total_grad_steps():

def _get_compare_points_loss_weights(self):
""" Get compare points weights for the distillation loss. """
if self.gptq_config.use_hessian_based_weights:
if self.gptq_config.hessian_weights_config:
hess_dataloader = data_gen_to_dataloader(self.representative_data_gen_fn,
batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size)
return self.compute_hessian_based_weights(hess_dataloader)
Expand Down
14 changes: 9 additions & 5 deletions model_compression_toolkit/gptq/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT, LR_DEFAULT, LR_REST_DEFAULT, \
LR_BIAS_DEFAULT, GPTQ_MOMENTUM
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import TENSORFLOW, ACT_HESSIAN_DEFAULT_BATCH_SIZE
from model_compression_toolkit.constants import TENSORFLOW, ACT_HESSIAN_DEFAULT_BATCH_SIZE, GPTQ_HESSIAN_NUM_SAMPLES
from model_compression_toolkit.verify_packages import FOUND_TF
from model_compression_toolkit.core.common.user_info import UserInformation
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, GPTQHessianScoresConfig, \
Expand Down Expand Up @@ -117,16 +117,20 @@ def get_keras_gptq_config(n_epochs: int,
raise TypeError(f'gradual_activation_quantization argument should be bool or '
f'GradualActivationQuantizationConfig, received {type(gradual_activation_quantization)}')

return GradientPTQConfig(n_epochs,
optimizer,
hessian_weights_config = None
if use_hessian_based_weights:
hessian_weights_config = GPTQHessianScoresConfig(per_sample=False,
hessians_num_samples=GPTQ_HESSIAN_NUM_SAMPLES,
hessian_batch_size=hessian_batch_size)
return GradientPTQConfig(n_epochs=n_epochs,
optimizer=optimizer,
optimizer_rest=optimizer_rest,
loss=loss,
log_function=log_function,
train_bias=True,
optimizer_bias=bias_optimizer,
use_hessian_based_weights=use_hessian_based_weights,
regularization_factor=regularization_factor,
hessian_weights_config=GPTQHessianScoresConfig(hessian_batch_size=hessian_batch_size),
hessian_weights_config=hessian_weights_config,
gradual_activation_quantization_config=gradual_quant_config)


Expand Down
4 changes: 2 additions & 2 deletions model_compression_toolkit/gptq/pytorch/gptq_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _get_total_grad_steps():
trainable_threshold)
hessian_cfg = self.gptq_config.hessian_weights_config

self.use_sample_layer_attention = hessian_cfg.per_sample
self.use_sample_layer_attention = hessian_cfg and hessian_cfg.per_sample
if self.use_sample_layer_attention:
# normalization is currently not supported, make sure the config reflects it.
if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm:
Expand Down Expand Up @@ -178,7 +178,7 @@ def _prepare_train_dataloader_for_non_sla(self, data_gen_fn: Callable[[], Genera
dataset = IterableDatasetFromGenerator(data_gen_fn)
num_nodes = len(self.compare_points)

if self.gptq_config.use_hessian_based_weights:
if self.gptq_config.hessian_weights_config:
hess_dataloader = DataLoader(dataset, batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size)
loss_weights = torch.from_numpy(self.compute_hessian_based_weights(hess_dataloader))
else:
Expand Down
47 changes: 26 additions & 21 deletions model_compression_toolkit/gptq/pytorch/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import copy
from typing import Callable, Union

from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE, PYTORCH
from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE, PYTORCH, GPTQ_HESSIAN_NUM_SAMPLES
from model_compression_toolkit.core import CoreConfig
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
Expand All @@ -27,7 +27,7 @@
from model_compression_toolkit.gptq.common.gptq_config import (
GradientPTQConfig, GPTQHessianScoresConfig, GradualActivationQuantizationConfig)
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT, LR_DEFAULT, LR_REST_DEFAULT, \
LR_BIAS_DEFAULT, GPTQ_MOMENTUM
LR_BIAS_DEFAULT, GPTQ_MOMENTUM, REG_DEFAULT_SLA
from model_compression_toolkit.gptq.runner import gptq_runner
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.metadata import create_model_metadata
Expand Down Expand Up @@ -55,10 +55,10 @@ def get_pytorch_gptq_config(n_epochs: int,
loss: Callable = None,
log_function: Callable = None,
use_hessian_based_weights: bool = True,
regularization_factor: float = REG_DEFAULT,
regularization_factor: float = None,
hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE,
use_hessian_sample_attention: bool = False,
gradual_activation_quantization: Union[bool, GradualActivationQuantizationConfig] = False,
use_hessian_sample_attention: bool = True,
gradual_activation_quantization: Union[bool, GradualActivationQuantizationConfig] = True,
) -> GradientPTQConfig:
"""
Create a GradientPTQConfig instance for Pytorch models.
Expand Down Expand Up @@ -94,25 +94,26 @@ def get_pytorch_gptq_config(n_epochs: int,
"""
optimizer = optimizer or Adam([torch.Tensor([])], lr=LR_DEFAULT)
optimizer_rest = optimizer_rest or Adam([torch.Tensor([])], lr=LR_REST_DEFAULT)

# TODO this contradicts the docstring for optimizer_rest
bias_optimizer = torch.optim.SGD([torch.Tensor([])], lr=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM)

if regularization_factor is None:
regularization_factor = REG_DEFAULT_SLA if use_hessian_sample_attention else REG_DEFAULT

loss = loss or multiple_tensors_mse_loss
hessian_weights_config = None
if use_hessian_sample_attention:
if not use_hessian_based_weights: # pragma: no cover
raise ValueError('use_hessian_based_weights must be set to True in order to use Sample Layer Attention.')

hessian_weights_config = GPTQHessianScoresConfig(
hessians_num_samples=None,
norm_scores=False,
log_norm=False,
scale_log_norm=False,
hessian_batch_size=hessian_batch_size,
per_sample=True,
)
hessian_weights_config = GPTQHessianScoresConfig(per_sample=True,
hessians_num_samples=None,
hessian_batch_size=hessian_batch_size)
loss = loss or sample_layer_attention_loss
else:
hessian_weights_config = GPTQHessianScoresConfig(hessian_batch_size=hessian_batch_size)
loss = loss or multiple_tensors_mse_loss
elif use_hessian_based_weights:
hessian_weights_config = GPTQHessianScoresConfig(per_sample=False,
hessians_num_samples=GPTQ_HESSIAN_NUM_SAMPLES,
hessian_batch_size=hessian_batch_size)

if isinstance(gradual_activation_quantization, bool):
gradual_quant_config = GradualActivationQuantizationConfig() if gradual_activation_quantization else None
Expand All @@ -122,12 +123,16 @@ def get_pytorch_gptq_config(n_epochs: int,
raise TypeError(f'gradual_activation_quantization argument should be bool or '
f'GradualActivationQuantizationConfig, received {type(gradual_activation_quantization)}')

return GradientPTQConfig(n_epochs, optimizer, optimizer_rest=optimizer_rest, loss=loss,
log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer,
use_hessian_based_weights=use_hessian_based_weights,
return GradientPTQConfig(n_epochs=n_epochs,
loss=loss,
optimizer=optimizer,
optimizer_rest=optimizer_rest,
optimizer_bias=bias_optimizer,
train_bias=True,
regularization_factor=regularization_factor,
hessian_weights_config=hessian_weights_config,
gradual_activation_quantization_config=gradual_quant_config)
gradual_activation_quantization_config=gradual_quant_config,
log_function=log_function)

def pytorch_gradient_post_training_quantization(model: Module,
representative_data_gen: Callable,
Expand Down
Loading

0 comments on commit 9ac9f54

Please sign in to comment.