Skip to content

Commit

Permalink
small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
irenaby committed Oct 14, 2024
1 parent 0400a32 commit a630aaa
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@


# type hints aliases
SampleHash = str
LayerName = str
Tensor = Any

Expand All @@ -37,6 +36,7 @@ class Query:


class HessianCache:
""" Hessian cache """
def __init__(self):
self._data: Dict[Query, Tensor] = {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Callable, Iterable, Sequence, Optional
from typing import Iterable, Sequence, Optional, TYPE_CHECKING
import dataclasses

from enum import Enum

if TYPE_CHECKING: # pragma: no cover
from model_compression_toolkit.core.common import BaseNode


class HessianMode(Enum):
"""
Expand Down Expand Up @@ -60,7 +63,7 @@ class HessianScoresRequest:
Attributes:
mode: Mode of Hessian-approximation score (w.r.t weights or activations).
granularity: Granularity level for the approximation.
target_nodes: The node names in the float graph for which the Hessian's approximation scores is targeted.
target_nodes: The node objects in the float graph for which the Hessian's approximation scores is targeted.
data_loader: Data loader to compute hessian approximations on. Should reflect the desired batch size for
the computation. Can be None if all hessians for the request are expected to be pre-computed previously.
n_samples: The number of samples to fetch hessian estimations for. If None, fetch hessians for a full pass
Expand Down
29 changes: 19 additions & 10 deletions model_compression_toolkit/gptq/common/gptq_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,21 @@
# limitations under the License.
# ==============================================================================
import copy
import hashlib
from abc import ABC, abstractmethod
from typing import Callable, List, Any, Iterable, Optional, Generator

import numpy as np
from typing import Callable, List, Any, Dict, Iterable, Optional, Generator

from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresRequest, HessianMode, \
HessianScoresGranularity, hessian_info_utils as hessian_utils
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR
from model_compression_toolkit.gptq.common.gptq_framework_implementation import GPTQFrameworkImplemantation
from model_compression_toolkit.gptq.common.gptq_graph import get_compare_points
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresRequest, HessianMode, \
HessianScoresGranularity
from model_compression_toolkit.core.common.hessian import hessian_info_utils as hessian_utils


class GPTQTrainer(ABC):
Expand Down Expand Up @@ -177,7 +175,18 @@ def compute_hessian_based_weights(self, data_loader: Iterable) -> np.ndarray:
return log_weights - np.min(log_weights)

def _build_hessian_request(self, granularity: HessianScoresGranularity, data_loader: Iterable,
n_samples: Optional[int]):
n_samples: Optional[int]) -> HessianScoresRequest:
"""
Build hessian request for hessian service.
Args:
granularity: requested granularity.
data_loader: data loader yielding samples to compute hessians on.
n_samples: request number of samples.
Returns:
Hessian request.
"""
return HessianScoresRequest(
mode=HessianMode.ACTIVATION,
granularity=granularity,
Expand Down
1 change: 1 addition & 0 deletions model_compression_toolkit/gptq/keras/gptq_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def _get_compare_points_loss_weights(self):
hess_dataloader = data_gen_to_dataloader(self.representative_data_gen_fn,
batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size)
return self.compute_hessian_based_weights(hess_dataloader)

num_nodes = len(self.compare_points)
return np.ones((num_nodes,)) / num_nodes

Expand Down
52 changes: 40 additions & 12 deletions model_compression_toolkit/gptq/pytorch/gptq_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
import copy
from typing import Callable, List, Tuple, Union
from typing import Callable, List, Tuple, Union, Generator

import numpy as np
import torch
Expand Down Expand Up @@ -123,30 +123,54 @@ def _get_total_grad_steps():

self.reg_func = get_regularization(self.gptq_config, _get_total_grad_steps)

def _prepare_train_dataloader_sla(self, data_gen_fn):
def _prepare_train_dataloader_sla(self, data_gen_fn: Callable[[], Generator]) -> DataLoader:
"""
Computes Sample-Layer Attention score and builds a train dataloader.
Args:
data_gen_fn: factory for representative dataset generator.
Returns:
PyTorch dataloader yielding three outputs - samples, weights for the distillation loss and
weights for regularization.
"""
fixed_dataset = FixedDatasetFromGenerator(data_gen_fn)
orig_batch_size = fixed_dataset.orig_batch_size
# compute hessians for the whole dataset
hess_data_loader = DataLoader(fixed_dataset,
batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size,
shuffle=False)
request = self._build_hessian_request(granularity=HessianScoresGranularity.PER_OUTPUT_CHANNEL,
data_loader=hess_data_loader,
n_samples=None)
layers_hessians = self.hessian_service.fetch_hessian(request, force_compute=True)
# score is defined as max over channels

# compute sla score defined as max over channels
layers_hessians = {layer: to_torch_tensor(hess.max(1)) for layer, hess in layers_hessians.items()}

# samples X layers
hessians_tensor = torch.stack([layers_hessians[layer.name] for layer in self.compare_points], dim=1)
# build train dataset and dataloader
hessians_tensor = torch.stack([layers_hessians[layer.name] for layer in self.compare_points], dim=1) # samples X layers
assert hessians_tensor.shape[1] == len(self.compare_points)
loss_weights = list(hessians_tensor)
# TODO in the research repo mean is across each batch. I suppose mean over all cannot be worse?
reg_weights = hessians_tensor.mean(dim=0)
sla_train_dataset = FixedSampleInfoDataset(fixed_dataset.samples, loss_weights)
return DataLoader(sla_train_dataset, batch_size=orig_batch_size, shuffle=True,
collate_fn=get_collate_fn_with_extra_outputs(reg_weights))

def _prepare_train_dataloader_for_non_sla(self, data_gen_fn):
reg_weights = hessians_tensor.mean(dim=0)
# use collate to add a single value to each batch
collate_fn = get_collate_fn_with_extra_outputs(reg_weights)

return DataLoader(sla_train_dataset, batch_size=orig_batch_size, shuffle=True, collate_fn=collate_fn)

def _prepare_train_dataloader_for_non_sla(self, data_gen_fn: Callable[[], Generator]) -> DataLoader:
"""
Computes loss weights and builds a train dataloader.
Args:
data_gen_fn: factory for representative dataset generator.
Returns:
PyTorch dataloader yielding three outputs - samples, weights for the distillation loss and
weights for regularization.
"""
dataset = IterableDatasetFromGenerator(data_gen_fn)
num_nodes = len(self.compare_points)

Expand All @@ -156,10 +180,14 @@ def _prepare_train_dataloader_for_non_sla(self, data_gen_fn):
else:
loss_weights = torch.ones(num_nodes) / num_nodes

reg_weights = to_torch_tensor(torch.ones(num_nodes))
train_dataset = IterableSampleWithConstInfoDataset(dataset, to_torch_tensor(loss_weights))

reg_weights = to_torch_tensor(torch.ones(num_nodes))
# use collate to add a single value to each batch
collate_fn = get_collate_fn_with_extra_outputs(reg_weights)

return DataLoader(train_dataset, batch_size=dataset.orig_batch_size,
collate_fn=get_collate_fn_with_extra_outputs(reg_weights))
collate_fn=collate_fn)

def _is_gptq_weights_trainable(self,
node: BaseNode) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ def __call__(self, model: nn.Module, entropy_reg: float, layer_weights: torch.Te
Args:
model: A model to be quantized with SoftRounding.
entropy_reg: Entropy value to scale the quantizer regularization.
layer_weights: a vector of layers weights or a matrix of shape samples X layers.
layer_weights: a vector of layers weights.
Returns: Regularization value.
"""
layers = [m for m in model.modules() if isinstance(m, PytorchQuantizationWrapper)]

if layer_weights.shape[-1] != len(layers):
raise ValueError(f'Expected weights.shape[-1] to be {len(layers)}, '
if layer_weights.shape[0] != len(layers):
raise ValueError(f'Expected weights.shape[0] to be {len(layers)}, '
f'received shape {layer_weights.shape}.') # pragma: no cover
max_w = layer_weights.max()

Expand Down
8 changes: 4 additions & 4 deletions tests/pytorch_tests/model_tests/test_feature_models_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,11 +661,11 @@ def test_gptq_with_sample_layer_attention(self):
hessian_weights=True, hessian_num_samples=None,
estimator_distribution=HessianEstimationDistribution.RADEMACHER,
norm_scores=False, log_norm_weights=False, scaled_log_norm=False)
# GPTQAccuracyTest(self, **kwargs).run_test()
GPTQAccuracyTest(self, **kwargs).run_test()
GPTQAccuracyTest(self, hessian_batch_size=16, rounding_type=RoundingType.SoftQuantizer, **kwargs).run_test()
# GPTQAccuracyTest(self, hessian_batch_size=5, rounding_type=RoundingType.SoftQuantizer,
# gradual_activation_quantization=True, **kwargs).run_test()
# GPTQAccuracyTest(self, rounding_type=RoundingType.STE, **kwargs)
GPTQAccuracyTest(self, hessian_batch_size=5, rounding_type=RoundingType.SoftQuantizer,
gradual_activation_quantization=True, **kwargs).run_test()
GPTQAccuracyTest(self, rounding_type=RoundingType.STE, **kwargs)

def test_qat(self):
"""
Expand Down

0 comments on commit a630aaa

Please sign in to comment.