Skip to content

Commit

Permalink
Add initial Sample-Layer Attention for GPTQ (PyTorch) (#1237)
Browse files Browse the repository at this point in the history
* initial sample layer attention implementation for torch
  • Loading branch information
irenaby authored Oct 7, 2024
1 parent 508e8fa commit b26dd82
Show file tree
Hide file tree
Showing 13 changed files with 452 additions and 143 deletions.
4 changes: 3 additions & 1 deletion model_compression_toolkit/core/common/hessian/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from model_compression_toolkit.core.common.hessian.hessian_scores_request import HessianScoresRequest, HessianMode, HessianScoresGranularity
from model_compression_toolkit.core.common.hessian.hessian_scores_request import (
HessianScoresRequest, HessianMode, HessianScoresGranularity, HessianEstimationDistribution
)
from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService
import model_compression_toolkit.core.common.hessian.hessian_info_utils as hessian_utils
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import hashlib

import numpy as np
from functools import partial
from tqdm import tqdm
from typing import Callable, List, Dict, Any, Tuple
from typing import Callable, List, Dict, Any, Tuple, TYPE_CHECKING

from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
from model_compression_toolkit.core.common.hessian.hessian_scores_request import HessianScoresRequest, \
HessianScoresGranularity, HessianMode
from model_compression_toolkit.logger import Logger
if TYPE_CHECKING: # pragma: no cover
from model_compression_toolkit.core.common import BaseNode


class HessianInfoService:
Expand Down Expand Up @@ -228,6 +231,61 @@ def compute(self,
return next_iter_remain_samples if next_iter_remain_samples is not None and len(next_iter_remain_samples) > 0 \
and len(next_iter_remain_samples[0]) > 0 else None

def compute_trackable_per_sample_hessian(self,
hessian_scores_request: HessianScoresRequest,
inputs_batch: List[np.ndarray]) -> Dict[str, Dict['BaseNode', np.ndarray]]:
"""
Compute hessian score per image hash. We compute the score directly for images rather than via data generator,
as data generator might yield different images each time, depending on how it was defined,
Args:
hessian_scores_request: hessian scores request
inputs_batch: a list containing a batch of inputs.
Returns:
A dict of Hessian scores per image hash per layer {image hash: {layer: score}}
"""
topo_sorted_nodes_names = [x.name for x in self.graph.get_topo_sorted_nodes()]
hessian_scores_request.target_nodes.sort(key=lambda x: topo_sorted_nodes_names.index(x.name))

hessian_score_by_image_hash = {}

if not inputs_batch or not isinstance(inputs_batch, list):
raise TypeError('Expected a non-empty list of inputs') # pragma: no cover
if len(inputs_batch) > 1:
raise NotImplementedError('Per-sample hessian computation is not supported for networks with multiple inputs') # pragma: no cover

# Get the framework-specific calculator Hessian-approximation scores
fw_hessian_calculator = self.fw_impl.get_hessian_scores_calculator(graph=self.graph,
input_images=inputs_batch,
hessian_scores_request=hessian_scores_request,
num_iterations_for_approximation=self.num_iterations_for_approximation)
hessian_scores = fw_hessian_calculator.compute()
for i in range(inputs_batch[0].shape[0]):
img_hash = self.calc_image_hash(inputs_batch[0][i])
hessian_score_by_image_hash[img_hash] = {
node: score[i] for node, score in zip(hessian_scores_request.target_nodes, hessian_scores)
}

return hessian_score_by_image_hash

@staticmethod
def calc_image_hash(image):
"""
Calculates hash for an input image.
Args:
image: input 3d image (without batch).
Returns:
Image hash.
"""
if not len(image.shape) == 3: # pragma: no cover
raise ValueError(f'Expected 3d image (without batch) for image hash calculation, got {len(image.shape)}')
image_bytes = image.astype(np.float32).tobytes()
return hashlib.md5(image_bytes).hexdigest()

def fetch_hessian(self,
hessian_scores_request: HessianScoresRequest,
required_size: int,
Expand All @@ -248,7 +306,7 @@ def fetch_hessian(self,
OC for per-output-channel when the requested node has OC output-channels, etc.)
"""

if len(hessian_scores_request.target_nodes) == 0:
if len(hessian_scores_request.target_nodes) == 0: # pragma: no cover
return []

if required_size == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ class HessianScoresGranularity(Enum):
PER_TENSOR = 2


class HessianEstimationDistribution(str, Enum):
"""
Distribution for Hutchinson estimator random vector
"""
GAUSSIAN = 'gaussian'
RADEMACHER = 'rademacher'


class HessianScoresRequest:
"""
Request configuration for the Hessian-approximation scores.
Expand All @@ -53,7 +61,8 @@ class HessianScoresRequest:
def __init__(self,
mode: HessianMode,
granularity: HessianScoresGranularity,
target_nodes: List):
target_nodes: List,
distribution: HessianEstimationDistribution = HessianEstimationDistribution.GAUSSIAN):
"""
Attributes:
mode (HessianMode): Mode of Hessian-approximation score (w.r.t weights or activations).
Expand All @@ -64,16 +73,18 @@ def __init__(self,
self.mode = mode # w.r.t activations or weights
self.granularity = granularity # per element, per layer, per channel
self.target_nodes = target_nodes
self.distribution = distribution

def __eq__(self, other):
# Checks if the other object is an instance of HessianScoresRequest
# and then checks if all attributes are equal.
return isinstance(other, HessianScoresRequest) and \
self.mode == other.mode and \
self.granularity == other.granularity and \
self.target_nodes == other.target_nodes
self.target_nodes == other.target_nodes and \
self.distribution == other.distribution

def __hash__(self):
# Computes the hash based on the attributes.
# The use of a tuple here ensures that the hash is influenced by all the attributes.
return hash((self.mode, self.granularity, tuple(self.target_nodes)))
return hash((self.mode, self.granularity, tuple(self.target_nodes), self.distribution))
Loading

0 comments on commit b26dd82

Please sign in to comment.