Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into refactor-keras-activa…
Browse files Browse the repository at this point in the history
…tion-quantizers
  • Loading branch information
reuvenp committed Oct 14, 2024
2 parents 8866392 + dc678aa commit fbb4503
Show file tree
Hide file tree
Showing 29 changed files with 4,040 additions and 3,686 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests_common.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- name: Install Python 3
uses: actions/setup-python@v1
with:
python-version: 3.8
python-version: 3.11
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
4 changes: 3 additions & 1 deletion model_compression_toolkit/core/common/hessian/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from model_compression_toolkit.core.common.hessian.hessian_scores_request import HessianScoresRequest, HessianMode, HessianScoresGranularity
from model_compression_toolkit.core.common.hessian.hessian_scores_request import (
HessianScoresRequest, HessianMode, HessianScoresGranularity, HessianEstimationDistribution
)
from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService
import model_compression_toolkit.core.common.hessian.hessian_info_utils as hessian_utils
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import hashlib

import numpy as np
from functools import partial
from tqdm import tqdm
from typing import Callable, List, Dict, Any, Tuple
from typing import Callable, List, Dict, Any, Tuple, TYPE_CHECKING

from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
from model_compression_toolkit.core.common.hessian.hessian_scores_request import HessianScoresRequest, \
HessianScoresGranularity, HessianMode
from model_compression_toolkit.logger import Logger
if TYPE_CHECKING: # pragma: no cover
from model_compression_toolkit.core.common import BaseNode


class HessianInfoService:
Expand Down Expand Up @@ -228,6 +231,61 @@ def compute(self,
return next_iter_remain_samples if next_iter_remain_samples is not None and len(next_iter_remain_samples) > 0 \
and len(next_iter_remain_samples[0]) > 0 else None

def compute_trackable_per_sample_hessian(self,
hessian_scores_request: HessianScoresRequest,
inputs_batch: List[np.ndarray]) -> Dict[str, Dict['BaseNode', np.ndarray]]:
"""
Compute hessian score per image hash. We compute the score directly for images rather than via data generator,
as data generator might yield different images each time, depending on how it was defined,
Args:
hessian_scores_request: hessian scores request
inputs_batch: a list containing a batch of inputs.
Returns:
A dict of Hessian scores per image hash per layer {image hash: {layer: score}}
"""
topo_sorted_nodes_names = [x.name for x in self.graph.get_topo_sorted_nodes()]
hessian_scores_request.target_nodes.sort(key=lambda x: topo_sorted_nodes_names.index(x.name))

hessian_score_by_image_hash = {}

if not inputs_batch or not isinstance(inputs_batch, list):
raise TypeError('Expected a non-empty list of inputs') # pragma: no cover
if len(inputs_batch) > 1:
raise NotImplementedError('Per-sample hessian computation is not supported for networks with multiple inputs') # pragma: no cover

# Get the framework-specific calculator Hessian-approximation scores
fw_hessian_calculator = self.fw_impl.get_hessian_scores_calculator(graph=self.graph,
input_images=inputs_batch,
hessian_scores_request=hessian_scores_request,
num_iterations_for_approximation=self.num_iterations_for_approximation)
hessian_scores = fw_hessian_calculator.compute()
for i in range(inputs_batch[0].shape[0]):
img_hash = self.calc_image_hash(inputs_batch[0][i])
hessian_score_by_image_hash[img_hash] = {
node: score[i] for node, score in zip(hessian_scores_request.target_nodes, hessian_scores)
}

return hessian_score_by_image_hash

@staticmethod
def calc_image_hash(image):
"""
Calculates hash for an input image.
Args:
image: input 3d image (without batch).
Returns:
Image hash.
"""
if not len(image.shape) == 3: # pragma: no cover
raise ValueError(f'Expected 3d image (without batch) for image hash calculation, got {len(image.shape)}')
image_bytes = image.astype(np.float32).tobytes()
return hashlib.md5(image_bytes).hexdigest()

def fetch_hessian(self,
hessian_scores_request: HessianScoresRequest,
required_size: int,
Expand All @@ -248,7 +306,7 @@ def fetch_hessian(self,
OC for per-output-channel when the requested node has OC output-channels, etc.)
"""

if len(hessian_scores_request.target_nodes) == 0:
if len(hessian_scores_request.target_nodes) == 0: # pragma: no cover
return []

if required_size == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ class HessianScoresGranularity(Enum):
PER_TENSOR = 2


class HessianEstimationDistribution(str, Enum):
"""
Distribution for Hutchinson estimator random vector
"""
GAUSSIAN = 'gaussian'
RADEMACHER = 'rademacher'


class HessianScoresRequest:
"""
Request configuration for the Hessian-approximation scores.
Expand All @@ -53,7 +61,8 @@ class HessianScoresRequest:
def __init__(self,
mode: HessianMode,
granularity: HessianScoresGranularity,
target_nodes: List):
target_nodes: List,
distribution: HessianEstimationDistribution = HessianEstimationDistribution.GAUSSIAN):
"""
Attributes:
mode (HessianMode): Mode of Hessian-approximation score (w.r.t weights or activations).
Expand All @@ -64,16 +73,18 @@ def __init__(self,
self.mode = mode # w.r.t activations or weights
self.granularity = granularity # per element, per layer, per channel
self.target_nodes = target_nodes
self.distribution = distribution

def __eq__(self, other):
# Checks if the other object is an instance of HessianScoresRequest
# and then checks if all attributes are equal.
return isinstance(other, HessianScoresRequest) and \
self.mode == other.mode and \
self.granularity == other.granularity and \
self.target_nodes == other.target_nodes
self.target_nodes == other.target_nodes and \
self.distribution == other.distribution

def __hash__(self):
# Computes the hash based on the attributes.
# The use of a tuple here ensures that the hash is influenced by all the attributes.
return hash((self.mode, self.granularity, tuple(self.target_nodes)))
return hash((self.mode, self.granularity, tuple(self.target_nodes), self.distribution))
Loading

0 comments on commit fbb4503

Please sign in to comment.