Skip to content

Commit

Permalink
replace Gaussian with Rademacher distribution in hessian estimation f…
Browse files Browse the repository at this point in the history
…or torch
  • Loading branch information
irenaby committed Oct 22, 2024
1 parent 8c444b9 commit 1e1e223
Show file tree
Hide file tree
Showing 10 changed files with 92 additions and 149 deletions.
2 changes: 1 addition & 1 deletion model_compression_toolkit/core/common/hessian/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
from model_compression_toolkit.core.common.hessian.hessian_scores_request import (
HessianScoresRequest, HessianMode, HessianScoresGranularity, HessianEstimationDistribution
HessianScoresRequest, HessianMode, HessianScoresGranularity
)
from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService
import model_compression_toolkit.core.common.hessian.hessian_info_utils as hessian_utils
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,6 @@ class HessianScoresGranularity(Enum):
PER_TENSOR = 2


class HessianEstimationDistribution(str, Enum):
"""
Distribution for Hutchinson estimator random vector
"""
GAUSSIAN = 'gaussian'
RADEMACHER = 'rademacher'


@dataclasses.dataclass
class HessianScoresRequest:
"""
Expand All @@ -68,15 +60,12 @@ class HessianScoresRequest:
the computation. Can be None if all hessians for the request are expected to be pre-computed previously.
n_samples: The number of samples to fetch hessian estimations for. If None, fetch hessians for a full pass
of the data loader.
distribution: Distribution to use in Hutchinson estimation.
"""
mode: HessianMode
granularity: HessianScoresGranularity
target_nodes: Sequence['BaseNode']
data_loader: Optional[Iterable]
n_samples: Optional[int]
# TODO remove
distribution: HessianEstimationDistribution = HessianEstimationDistribution.GAUSSIAN

def __post_init__(self):
if self.data_loader is None and self.n_samples is None:
Expand All @@ -85,4 +74,3 @@ def __post_init__(self):
def clone(self, **kwargs):
""" Create a clone with optional overrides """
return dataclasses.replace(self, **kwargs)

Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,19 @@

from typing import List

import numpy as np
import torch
from torch import autograd
from tqdm import tqdm
import numpy as np

from model_compression_toolkit.constants import MIN_HESSIAN_ITER, HESSIAN_COMP_TOLERANCE, HESSIAN_NUM_ITERATIONS
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.hessian import (HessianScoresRequest, HessianScoresGranularity,
HessianEstimationDistribution)
from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianScoresGranularity
from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
from model_compression_toolkit.core.pytorch.hessian.hessian_scores_calculator_pytorch import \
HessianScoresCalculatorPytorch
from model_compression_toolkit.core.pytorch.utils import torch_tensor_to_numpy
from model_compression_toolkit.logger import Logger
import torch


class ActivationHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
Expand Down Expand Up @@ -92,29 +91,6 @@ def forward_pass(self):
output = self.concat_tensors(output_tensors)
return output, target_activation_tensors

def _generate_random_vectors_batch(self, shape: tuple, distribution: HessianEstimationDistribution,
device: torch.device) -> torch.Tensor:
"""
Generate a batch of random vectors for Hutchinson estimation
Args:
shape: target shape
distribution: distribution to sample from
device: target device
Returns:
Random tensor
"""
if distribution == HessianEstimationDistribution.GAUSSIAN:
return torch.randn(shape, device=device)

if distribution == HessianEstimationDistribution.RADEMACHER:
v = torch.randint(high=2, size=shape, device=device)
v[v == 0] = -1
return v

raise ValueError(f'Unknown distribution {distribution}') # pragma: no cover

def compute(self) -> List[np.ndarray]:
"""
Compute the scores that are based on the approximation of the Hessian w.r.t the requested target nodes' activations.
Expand Down Expand Up @@ -142,7 +118,7 @@ def _compute_per_tensor(self, output, target_activation_tensors):
prev_mean_results = None
for j in tqdm(range(self.num_iterations_for_approximation), "Hessian random iterations"): # Approximation iterations
# Getting a random vector with normal distribution
v = self._generate_random_vectors_batch(output.shape, self.hessian_request.distribution, output.device)
v = self._generate_random_vectors_batch(output.shape, output.device)
f_v = torch.sum(v * output)
for i, ipt_tensor in enumerate(target_activation_tensors): # Per Interest point activation tensor
# Computing the hessian-approximation scores by getting the gradient of (output * v)
Expand Down Expand Up @@ -183,7 +159,7 @@ def _compute_per_channel(self, output, target_activation_tensors):
for _ in range(len(target_activation_tensors))]

for j in tqdm(range(self.num_iterations_for_approximation), "Hessian random iterations"): # Approximation iterations
v = self._generate_random_vectors_batch(output.shape, self.hessian_request.distribution, output.device)
v = self._generate_random_vectors_batch(output.shape, output.device)
f_v = torch.sum(v * output)
for i, ipt_tensor in enumerate(target_activation_tensors): # Per Interest point activation tensor
hess_v = autograd.grad(outputs=f_v,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,41 +15,31 @@

from typing import Union, List

from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.hessian import HessianScoresRequest
import torch

from model_compression_toolkit.core.common.hessian.hessian_scores_calculator import HessianScoresCalculator
from model_compression_toolkit.logger import Logger
import torch


class HessianScoresCalculatorPytorch(HessianScoresCalculator):
"""
Pytorch-specific implementation of the Hessian approximation scores Calculator.
This class serves as a base for other Pytorch-specific Hessian approximation scores calculators.
"""
def __init__(self,
graph: Graph,
input_images: List[torch.Tensor],
fw_impl,
hessian_scores_request: HessianScoresRequest,
num_iterations_for_approximation: int = HESSIAN_NUM_ITERATIONS):
def _generate_random_vectors_batch(self, shape: tuple, device: torch.device) -> torch.Tensor:
"""
Generate a batch of random vectors for Hutchinson estimation using Rademacher distribution.
Args:
graph: Computational graph for the float model.
input_images: List of input images for the computation.
fw_impl: Framework-specific implementation for Hessian scores computation.
hessian_scores_request: Configuration request for which to compute the Hessian approximation scores.
num_iterations_for_approximation: Number of iterations to use when approximating the Hessian based scores.
shape: target shape.
device: target device.
Returns:
Random tensor.
"""
super(HessianScoresCalculatorPytorch, self).__init__(graph=graph,
input_images=input_images,
fw_impl=fw_impl,
hessian_scores_request=hessian_scores_request,
num_iterations_for_approximation=num_iterations_for_approximation)

v = torch.randint(high=2, size=shape, device=device)
v[v == 0] = -1
return v

def concat_tensors(self, tensors_to_concate: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from tqdm import tqdm
from typing import List

import numpy as np
import torch
from torch import autograd
import numpy as np
from tqdm import tqdm

from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS, MIN_HESSIAN_ITER, HESSIAN_COMP_TOLERANCE
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianScoresGranularity
from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
from model_compression_toolkit.core.pytorch.hessian.hessian_scores_calculator_pytorch import \
HessianScoresCalculatorPytorch
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS, MIN_HESSIAN_ITER, HESSIAN_COMP_TOLERANCE, HESSIAN_EPS


class WeightsHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
Expand Down Expand Up @@ -85,7 +87,7 @@ def compute(self) -> List[np.ndarray]:
prev_mean_results = None
for j in tqdm(range(self.num_iterations_for_approximation)):
# Getting a random vector with normal distribution and the same shape as the model output
v = torch.randn_like(output_tensor, device=device)
v = self._generate_random_vectors_batch(output_tensor.shape, device=device)
f_v = torch.mean(torch.sum(v * output_tensor, dim=-1))
for i, ipt_node in enumerate(self.hessian_request.target_nodes): # Per Interest point weights tensor

Expand Down
2 changes: 0 additions & 2 deletions model_compression_toolkit/gptq/common/gptq_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import Callable, Any, Dict, Optional

from model_compression_toolkit.constants import GPTQ_HESSIAN_NUM_SAMPLES, ACT_HESSIAN_DEFAULT_BATCH_SIZE
from model_compression_toolkit.core.common.hessian import HessianScoresGranularity, HessianEstimationDistribution
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT


Expand Down Expand Up @@ -54,7 +53,6 @@ class GPTQHessianScoresConfig:
scale_log_norm: bool = False
hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE
per_sample: bool = False
estimator_distribution: HessianEstimationDistribution = HessianEstimationDistribution.GAUSSIAN


@dataclass
Expand Down
3 changes: 1 addition & 2 deletions model_compression_toolkit/gptq/common/gptq_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,7 @@ def _build_hessian_request(self, granularity: HessianScoresGranularity, data_loa
granularity=granularity,
target_nodes=self.compare_points,
data_loader=data_loader,
n_samples=n_samples,
distribution=self.gptq_config.hessian_weights_config.estimator_distribution
n_samples=n_samples
)

@abstractmethod
Expand Down
2 changes: 0 additions & 2 deletions model_compression_toolkit/gptq/pytorch/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE, PYTORCH
from model_compression_toolkit.core import CoreConfig
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
from model_compression_toolkit.core.common.hessian import HessianScoresGranularity, HessianEstimationDistribution
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
MixedPrecisionQuantizationConfig
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
Expand Down Expand Up @@ -119,7 +118,6 @@ def get_pytorch_gptq_config(n_epochs: int,
scale_log_norm=False,
hessian_batch_size=hessian_batch_size,
per_sample=True,
estimator_distribution=HessianEstimationDistribution.RADEMACHER
)
loss = loss or sample_layer_attention_loss
else:
Expand Down
23 changes: 8 additions & 15 deletions tests/pytorch_tests/model_tests/feature_models/gptq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import random

import numpy as np
import torch
import torch.nn as nn

import mct_quantizers
import model_compression_toolkit as mct
from model_compression_toolkit import DefaultDict
from model_compression_toolkit.constants import GPTQ_HESSIAN_NUM_SAMPLES
from model_compression_toolkit.core.common.hessian import HessianEstimationDistribution
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR, MAX_LSB_STR
from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest
import model_compression_toolkit as mct
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, GradientPTQConfig, RoundingType, \
GPTQHessianScoresConfig, GradualActivationQuantizationConfig
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy, set_model
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, RoundingType, \
GPTQHessianScoresConfig, GradualActivationQuantizationConfig
from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR, MAX_LSB_STR
from model_compression_toolkit.gptq.pytorch.gptq_loss import multiple_tensors_mse_loss
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_pytorch_tpc
from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model
from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest
from tests.pytorch_tests.utils import extract_model_weights

tp = mct.target_platform
Expand Down Expand Up @@ -60,7 +57,7 @@ def __init__(self, unit_test, weights_bits=8, weights_quant_method=QuantizationM
hessian_weights=True, norm_scores=True, log_norm_weights=True, scaled_log_norm=False, params_learning=True,
num_calibration_iter=GPTQ_HESSIAN_NUM_SAMPLES, gradual_activation_quantization=False,
hessian_num_samples=GPTQ_HESSIAN_NUM_SAMPLES, sample_layer_attention=False,
loss=multiple_tensors_mse_loss, hessian_batch_size=1, estimator_distribution=HessianEstimationDistribution.GAUSSIAN):
loss=multiple_tensors_mse_loss, hessian_batch_size=1):
super().__init__(unit_test, input_shape=(3, 16, 16), num_calibration_iter=num_calibration_iter)
self.seed = 0
self.rounding_type = rounding_type
Expand All @@ -79,7 +76,6 @@ def __init__(self, unit_test, weights_bits=8, weights_quant_method=QuantizationM
self.sample_layer_attention = sample_layer_attention
self.loss = loss
self.hessian_batch_size = hessian_batch_size
self.estimator_distribution = estimator_distribution

def get_quantization_config(self):
return mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.NOCLIPPING,
Expand Down Expand Up @@ -156,10 +152,7 @@ def get_gptq_config(self):
norm_scores=self.norm_scores,
per_sample=self.sample_layer_attention,
hessians_num_samples=self.hessian_num_samples,
hessian_batch_size=self.hessian_batch_size,
estimator_distribution=self.estimator_distribution),


hessian_batch_size=self.hessian_batch_size),
gptq_quantizer_params_override=self.override_params,
gradual_activation_quantization_config=gradual_act_cfg)

Expand Down
Loading

0 comments on commit 1e1e223

Please sign in to comment.