Skip to content

Commit

Permalink
replace Gaussian with Rademacher distribution in hessian estimation (#…
Browse files Browse the repository at this point in the history
…1250)

* replace Gaussian with Rademacher distribution in hessian estimation for torch
* do not remove duplicate hessians in hessian cache as they may be valid
  • Loading branch information
irenaby authored Oct 22, 2024
1 parent cfb313c commit b253ebd
Show file tree
Hide file tree
Showing 12 changed files with 109 additions and 152 deletions.
2 changes: 1 addition & 1 deletion model_compression_toolkit/core/common/hessian/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
from model_compression_toolkit.core.common.hessian.hessian_scores_request import (
HessianScoresRequest, HessianMode, HessianScoresGranularity, HessianEstimationDistribution
HessianScoresRequest, HessianMode, HessianScoresGranularity
)
from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService
import model_compression_toolkit.core.common.hessian.hessian_info_utils as hessian_utils
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def update(self, layers_hessians: Dict[str, np.ndarray], request: HessianScoresR
"""
Updates the cache with new hessians estimations.
Note: we assume that the new hessians were computed on different samples than previously stored hessians.
If same samples were used more than once, duplicates will be stored. This can only be a problem if hessians
for the same query were computed via multiple requests and dataloader in each request yields same samples.
We cannot just filter out duplicates since in some cases we can get valid identical hessians on different
samples.
Args:
layers_hessians: a dictionary from layer names to their hessian score tensors.
request: request per which hessians were computed.
Expand All @@ -60,7 +66,7 @@ def update(self, layers_hessians: Dict[str, np.ndarray], request: HessianScoresR
for node_name, hess in layers_hessians.items():
query = Query(request.mode, request.granularity, node_name)
saved_hess = self._data.get(query)
new_hess = hess if saved_hess is None else np.unique(np.concatenate([saved_hess, hess], axis=0), axis=0)
new_hess = hess if saved_hess is None else np.concatenate([saved_hess, hess], axis=0)
self._data[query] = new_hess
n_nodes_samples.append(new_hess.shape[0])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,6 @@ class HessianScoresGranularity(Enum):
PER_TENSOR = 2


class HessianEstimationDistribution(str, Enum):
"""
Distribution for Hutchinson estimator random vector
"""
GAUSSIAN = 'gaussian'
RADEMACHER = 'rademacher'


@dataclasses.dataclass
class HessianScoresRequest:
"""
Expand All @@ -68,15 +60,12 @@ class HessianScoresRequest:
the computation. Can be None if all hessians for the request are expected to be pre-computed previously.
n_samples: The number of samples to fetch hessian estimations for. If None, fetch hessians for a full pass
of the data loader.
distribution: Distribution to use in Hutchinson estimation.
"""
mode: HessianMode
granularity: HessianScoresGranularity
target_nodes: Sequence['BaseNode']
data_loader: Optional[Iterable]
n_samples: Optional[int]
# TODO remove
distribution: HessianEstimationDistribution = HessianEstimationDistribution.GAUSSIAN

def __post_init__(self):
if self.data_loader is None and self.n_samples is None:
Expand All @@ -85,4 +74,3 @@ def __post_init__(self):
def clone(self, **kwargs):
""" Create a clone with optional overrides """
return dataclasses.replace(self, **kwargs)

Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,19 @@

from typing import List

import numpy as np
import torch
from torch import autograd
from tqdm import tqdm
import numpy as np

from model_compression_toolkit.constants import MIN_HESSIAN_ITER, HESSIAN_COMP_TOLERANCE, HESSIAN_NUM_ITERATIONS
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.hessian import (HessianScoresRequest, HessianScoresGranularity,
HessianEstimationDistribution)
from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianScoresGranularity
from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
from model_compression_toolkit.core.pytorch.hessian.hessian_scores_calculator_pytorch import \
HessianScoresCalculatorPytorch
from model_compression_toolkit.core.pytorch.utils import torch_tensor_to_numpy
from model_compression_toolkit.logger import Logger
import torch


class ActivationHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
Expand Down Expand Up @@ -92,29 +91,6 @@ def forward_pass(self):
output = self.concat_tensors(output_tensors)
return output, target_activation_tensors

def _generate_random_vectors_batch(self, shape: tuple, distribution: HessianEstimationDistribution,
device: torch.device) -> torch.Tensor:
"""
Generate a batch of random vectors for Hutchinson estimation
Args:
shape: target shape
distribution: distribution to sample from
device: target device
Returns:
Random tensor
"""
if distribution == HessianEstimationDistribution.GAUSSIAN:
return torch.randn(shape, device=device)

if distribution == HessianEstimationDistribution.RADEMACHER:
v = torch.randint(high=2, size=shape, device=device)
v[v == 0] = -1
return v

raise ValueError(f'Unknown distribution {distribution}') # pragma: no cover

def compute(self) -> List[np.ndarray]:
"""
Compute the scores that are based on the approximation of the Hessian w.r.t the requested target nodes' activations.
Expand All @@ -141,8 +117,8 @@ def _compute_per_tensor(self, output, target_activation_tensors):
for _ in range(len(target_activation_tensors))]
prev_mean_results = None
for j in tqdm(range(self.num_iterations_for_approximation), "Hessian random iterations"): # Approximation iterations
# Getting a random vector with normal distribution
v = self._generate_random_vectors_batch(output.shape, self.hessian_request.distribution, output.device)
# Getting a random vector
v = self._generate_random_vectors_batch(output.shape, output.device)
f_v = torch.sum(v * output)
for i, ipt_tensor in enumerate(target_activation_tensors): # Per Interest point activation tensor
# Computing the hessian-approximation scores by getting the gradient of (output * v)
Expand Down Expand Up @@ -183,7 +159,7 @@ def _compute_per_channel(self, output, target_activation_tensors):
for _ in range(len(target_activation_tensors))]

for j in tqdm(range(self.num_iterations_for_approximation), "Hessian random iterations"): # Approximation iterations
v = self._generate_random_vectors_batch(output.shape, self.hessian_request.distribution, output.device)
v = self._generate_random_vectors_batch(output.shape, output.device)
f_v = torch.sum(v * output)
for i, ipt_tensor in enumerate(target_activation_tensors): # Per Interest point activation tensor
hess_v = autograd.grad(outputs=f_v,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,41 +15,31 @@

from typing import Union, List

from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.hessian import HessianScoresRequest
import torch

from model_compression_toolkit.core.common.hessian.hessian_scores_calculator import HessianScoresCalculator
from model_compression_toolkit.logger import Logger
import torch


class HessianScoresCalculatorPytorch(HessianScoresCalculator):
"""
Pytorch-specific implementation of the Hessian approximation scores Calculator.
This class serves as a base for other Pytorch-specific Hessian approximation scores calculators.
"""
def __init__(self,
graph: Graph,
input_images: List[torch.Tensor],
fw_impl,
hessian_scores_request: HessianScoresRequest,
num_iterations_for_approximation: int = HESSIAN_NUM_ITERATIONS):
def _generate_random_vectors_batch(self, shape: tuple, device: torch.device) -> torch.Tensor:
"""
Generate a batch of random vectors for Hutchinson estimation using Rademacher distribution.
Args:
graph: Computational graph for the float model.
input_images: List of input images for the computation.
fw_impl: Framework-specific implementation for Hessian scores computation.
hessian_scores_request: Configuration request for which to compute the Hessian approximation scores.
num_iterations_for_approximation: Number of iterations to use when approximating the Hessian based scores.
shape: target shape.
device: target device.
Returns:
Random tensor.
"""
super(HessianScoresCalculatorPytorch, self).__init__(graph=graph,
input_images=input_images,
fw_impl=fw_impl,
hessian_scores_request=hessian_scores_request,
num_iterations_for_approximation=num_iterations_for_approximation)

v = torch.randint(high=2, size=shape, device=device)
v[v == 0] = -1
return v

def concat_tensors(self, tensors_to_concate: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from tqdm import tqdm
from typing import List

import numpy as np
import torch
from torch import autograd
import numpy as np
from tqdm import tqdm

from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS, MIN_HESSIAN_ITER, HESSIAN_COMP_TOLERANCE
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianScoresGranularity
from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
from model_compression_toolkit.core.pytorch.hessian.hessian_scores_calculator_pytorch import \
HessianScoresCalculatorPytorch
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS, MIN_HESSIAN_ITER, HESSIAN_COMP_TOLERANCE, HESSIAN_EPS


class WeightsHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
Expand Down Expand Up @@ -84,8 +86,8 @@ def compute(self) -> List[np.ndarray]:

prev_mean_results = None
for j in tqdm(range(self.num_iterations_for_approximation)):
# Getting a random vector with normal distribution and the same shape as the model output
v = torch.randn_like(output_tensor, device=device)
# Getting a random vector with the same shape as the model output
v = self._generate_random_vectors_batch(output_tensor.shape, device=device)
f_v = torch.mean(torch.sum(v * output_tensor, dim=-1))
for i, ipt_node in enumerate(self.hessian_request.target_nodes): # Per Interest point weights tensor

Expand Down
2 changes: 0 additions & 2 deletions model_compression_toolkit/gptq/common/gptq_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import Callable, Any, Dict, Optional

from model_compression_toolkit.constants import GPTQ_HESSIAN_NUM_SAMPLES, ACT_HESSIAN_DEFAULT_BATCH_SIZE
from model_compression_toolkit.core.common.hessian import HessianScoresGranularity, HessianEstimationDistribution
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT


Expand Down Expand Up @@ -54,7 +53,6 @@ class GPTQHessianScoresConfig:
scale_log_norm: bool = False
hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE
per_sample: bool = False
estimator_distribution: HessianEstimationDistribution = HessianEstimationDistribution.GAUSSIAN


@dataclass
Expand Down
3 changes: 1 addition & 2 deletions model_compression_toolkit/gptq/common/gptq_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,7 @@ def _build_hessian_request(self, granularity: HessianScoresGranularity, data_loa
granularity=granularity,
target_nodes=self.compare_points,
data_loader=data_loader,
n_samples=n_samples,
distribution=self.gptq_config.hessian_weights_config.estimator_distribution
n_samples=n_samples
)

@abstractmethod
Expand Down
2 changes: 0 additions & 2 deletions model_compression_toolkit/gptq/pytorch/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE, PYTORCH
from model_compression_toolkit.core import CoreConfig
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
from model_compression_toolkit.core.common.hessian import HessianScoresGranularity, HessianEstimationDistribution
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
MixedPrecisionQuantizationConfig
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
Expand Down Expand Up @@ -119,7 +118,6 @@ def get_pytorch_gptq_config(n_epochs: int,
scale_log_norm=False,
hessian_batch_size=hessian_batch_size,
per_sample=True,
estimator_distribution=HessianEstimationDistribution.RADEMACHER
)
loss = loss or sample_layer_attention_loss
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,19 @@ class multiple_inputs_model(torch.nn.Module):
def __init__(self):
super(multiple_inputs_model, self).__init__()
self.conv1 = Conv2d(3, 3, kernel_size=3, stride=1, padding=1)
self.bn1 = BatchNorm2d(3)
self.relu1 = ReLU()
self.conv2 = Conv2d(3, 3, kernel_size=3, stride=1, padding=1)
self.bn2 = BatchNorm2d(3)
self.relu2 = ReLU()

def forward(self, inp1, inp2):
x1 = self.conv1(inp1)
x1 = self.bn1(x1)
x1 = self.relu1(x1)
x2 = self.conv2(inp2)
x2 = self.bn2(x2)
x2 = self.relu2(x2)
return x1 + x2


Expand Down
23 changes: 8 additions & 15 deletions tests/pytorch_tests/model_tests/feature_models/gptq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import random

import numpy as np
import torch
import torch.nn as nn

import mct_quantizers
import model_compression_toolkit as mct
from model_compression_toolkit import DefaultDict
from model_compression_toolkit.constants import GPTQ_HESSIAN_NUM_SAMPLES
from model_compression_toolkit.core.common.hessian import HessianEstimationDistribution
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR, MAX_LSB_STR
from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest
import model_compression_toolkit as mct
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, GradientPTQConfig, RoundingType, \
GPTQHessianScoresConfig, GradualActivationQuantizationConfig
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy, set_model
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, RoundingType, \
GPTQHessianScoresConfig, GradualActivationQuantizationConfig
from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR, MAX_LSB_STR
from model_compression_toolkit.gptq.pytorch.gptq_loss import multiple_tensors_mse_loss
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_pytorch_tpc
from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model
from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest
from tests.pytorch_tests.utils import extract_model_weights

tp = mct.target_platform
Expand Down Expand Up @@ -60,7 +57,7 @@ def __init__(self, unit_test, weights_bits=8, weights_quant_method=QuantizationM
hessian_weights=True, norm_scores=True, log_norm_weights=True, scaled_log_norm=False, params_learning=True,
num_calibration_iter=GPTQ_HESSIAN_NUM_SAMPLES, gradual_activation_quantization=False,
hessian_num_samples=GPTQ_HESSIAN_NUM_SAMPLES, sample_layer_attention=False,
loss=multiple_tensors_mse_loss, hessian_batch_size=1, estimator_distribution=HessianEstimationDistribution.GAUSSIAN):
loss=multiple_tensors_mse_loss, hessian_batch_size=1):
super().__init__(unit_test, input_shape=(3, 16, 16), num_calibration_iter=num_calibration_iter)
self.seed = 0
self.rounding_type = rounding_type
Expand All @@ -79,7 +76,6 @@ def __init__(self, unit_test, weights_bits=8, weights_quant_method=QuantizationM
self.sample_layer_attention = sample_layer_attention
self.loss = loss
self.hessian_batch_size = hessian_batch_size
self.estimator_distribution = estimator_distribution

def get_quantization_config(self):
return mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.NOCLIPPING,
Expand Down Expand Up @@ -156,10 +152,7 @@ def get_gptq_config(self):
norm_scores=self.norm_scores,
per_sample=self.sample_layer_attention,
hessians_num_samples=self.hessian_num_samples,
hessian_batch_size=self.hessian_batch_size,
estimator_distribution=self.estimator_distribution),


hessian_batch_size=self.hessian_batch_size),
gptq_quantizer_params_override=self.override_params,
gradual_activation_quantization_config=gradual_act_cfg)

Expand Down
Loading

0 comments on commit b253ebd

Please sign in to comment.