Skip to content

Commit

Permalink
Take score computation out to a new LFH importance score calculator
Browse files Browse the repository at this point in the history
  • Loading branch information
reuvenp committed Dec 3, 2023
1 parent 0e69771 commit 0433744
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 86 deletions.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import List

from abc import abstractmethod, ABC

from model_compression_toolkit.core.common import BaseNode


class BaseImportanceMetric(ABC):

@abstractmethod
def get_entry_node_to_score(self, sections_input_nodes:List[BaseNode]):
raise Exception



Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from model_compression_toolkit.core.common.pruning import ImportanceMetric
from model_compression_toolkit.core.common.pruning.importance_metrics.base_importance_metric import BaseImportanceMetric
from model_compression_toolkit.core.common.pruning.importance_metrics.lfh_importance_metric import LFHImportanceMetric

im_dict = {ImportanceMetric.LFH: LFHImportanceMetric}


def get_importance_metric(im: ImportanceMetric, **kwargs) -> BaseImportanceMetric:
im = im_dict.get(im)
return im(**kwargs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import Callable, List

from model_compression_toolkit.core.common.framework_info import FrameworkInfo
from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianMode, HessianInfoGranularity
from model_compression_toolkit.core.common.pruning.importance_metrics.base_importance_metric import BaseImportanceMetric
import numpy as np

from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig


class LFHImportanceMetric(BaseImportanceMetric):

def __init__(self,
graph:Graph,
representative_data_gen: Callable,
fw_impl: FrameworkImplementation,
pruning_config: PruningConfig,
fw_info: FrameworkInfo):

self.float_graph= graph
self.representative_data_gen = representative_data_gen
self.fw_impl = fw_impl
self.pruning_config = pruning_config
self.fw_info = fw_info

def get_entry_node_to_score(self, sections_input_nodes:List[BaseNode]):
# Initialize services and variables for pruning process.
hessian_info_service = HessianInfoService(graph=self.float_graph,
representative_dataset=self.representative_data_gen,
fw_impl=self.fw_impl)

# Calculate the LFH (Label-Free Hessian) score for each prunable channel.
scores_per_prunable_node = hessian_info_service.fetch_scores_for_multiple_nodes(
mode=HessianMode.WEIGHTS,
granularity=HessianInfoGranularity.PER_OUTPUT_CHANNEL,
nodes=sections_input_nodes,
required_size=self.pruning_config.num_score_approximations)


# Average the scores across approximations and map them to the corresponding nodes.
entry_node_to_score = {node: np.mean(scores, axis=0) for node, scores in
zip(sections_input_nodes, scores_per_prunable_node)}

l2_oc_norm = self.get_l2_out_channel_norm(entry_nodes=sections_input_nodes)
count_oc_nparams = self.count_oc_nparams(entry_nodes=sections_input_nodes)
entry_node_to_score = self.normalize_lfh_scores(entry_node_to_score=entry_node_to_score,
entry_node_to_l2norm=l2_oc_norm,
entry_node_to_nparmas=count_oc_nparams)
return entry_node_to_score


def normalize_lfh_scores(self,
entry_node_to_score,
entry_node_to_l2norm,
entry_node_to_nparmas):
new_scores = {}
for node, trace_vector in entry_node_to_score.items():
new_scores[node] = trace_vector*entry_node_to_l2norm[node]/entry_node_to_nparmas[node]
return new_scores

def count_oc_nparams(self, entry_nodes: List[BaseNode]):
node_channel_params = {}
for entry_node in entry_nodes:
kernel = entry_node.get_weights_by_keys('kernel')
ox_axis = self.fw_info.kernel_channels_mapping.get(entry_node.type)[0]

# Calculate the number of parameters for each output channel
params_per_channel = np.prod(kernel.shape) / kernel.shape[ox_axis]

# Create an array with the number of parameters per channel
num_params_array = np.full(kernel.shape[ox_axis], params_per_channel)

# Store in node_channel_params a dictionary from node to a np.array where
# each element corresponds to the number of parameters of this channel
node_channel_params[entry_node] = num_params_array

return node_channel_params


def get_l2_out_channel_norm(self, entry_nodes: List[BaseNode]):
node_l2_channel_norm = {}
for entry_node in entry_nodes:
kernel = entry_node.get_weights_by_keys('kernel')
ox_axis = self.fw_info.kernel_channels_mapping.get(entry_node.type)[0]

# Compute the l2 norm of each output channel
channels = np.split(kernel, indices_or_sections=kernel.shape[ox_axis], axis=ox_axis)
l2_norms = [np.linalg.norm(c.flatten(), ord=2) ** 2 for c in channels]

# Store in node_l2_channel_norm a dictionary from node to a np.array where
# each element corresponds to the l2 norm of this channel
node_l2_channel_norm[entry_node] = l2_norms

return node_l2_channel_norm
93 changes: 9 additions & 84 deletions model_compression_toolkit/core/common/pruning/pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianMode, HessianInfoGranularity
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
from model_compression_toolkit.core.common.pruning.greedy_mask_calculator import GreedyMaskCalculator
from model_compression_toolkit.core.common.pruning.importance_metrics.importance_metric_factory import \
get_importance_metric
from model_compression_toolkit.core.common.pruning.prune_graph import build_pruned_graph
from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig, ChannelsFilteringStrategy, \
ImportanceMetric
Expand Down Expand Up @@ -104,95 +106,18 @@ def get_score_per_entry_point(self, sections_input_nodes):
Dict[BaseNode, np.ndarray]: A dictionary mapping each entry node to its corresponding importance score.
"""
# Initialize a dictionary to hold scores for each node.
entry_node_to_score = None

# Compute scores based on the importance metric defined in the pruning configuration.
if self.pruning_config.importance_metric == ImportanceMetric.LFH:
# Initialize services and variables for pruning process.
hessian_info_service = HessianInfoService(graph=self.float_graph,
representative_dataset=self.representative_data_gen,
fw_impl=self.fw_impl)

# Calculate the LFH (Label-Free Hessian) score for each prunable channel.
scores_per_prunable_node = hessian_info_service.fetch_scores_for_multiple_nodes(
mode=HessianMode.WEIGHTS,
granularity=HessianInfoGranularity.PER_OUTPUT_CHANNEL,
nodes=sections_input_nodes,
required_size=self.pruning_config.num_score_approximations)


# Average the scores across approximations and map them to the corresponding nodes.
entry_node_to_score = {node: np.mean(scores, axis=0) for node, scores in
zip(sections_input_nodes, scores_per_prunable_node)}

l2_oc_norm = self.get_l2_out_channel_norm(entry_nodes=sections_input_nodes)
count_oc_nparams = self.count_oc_nparams(entry_nodes=sections_input_nodes)
entry_node_to_score = self.normalize_lfh_scores(entry_node_to_score=entry_node_to_score,
entry_node_to_l2norm=l2_oc_norm,
entry_node_to_nparmas=count_oc_nparams)


elif self.pruning_config.importance_metric == ImportanceMetric.RANDOM:
random_scores = [np.random.random(
node.get_weights_by_keys('kernel').shape[self.fw_info.kernel_channels_mapping.get(node.type)[0]])
for node in sections_input_nodes]
entry_node_to_score = {node: scores for node, scores in zip(sections_input_nodes, random_scores)}
l2_oc_norm = self.get_l2_out_channel_norm(entry_nodes=sections_input_nodes)
count_oc_nparams = self.count_oc_nparams(entry_nodes=sections_input_nodes)
entry_node_to_score = self.normalize_lfh_scores(entry_node_to_score=entry_node_to_score,
entry_node_to_l2norm=l2_oc_norm,
entry_node_to_nparmas=count_oc_nparams)
im = get_importance_metric(self.pruning_config.importance_metric,
graph=self.float_graph,
representative_data_gen=self.representative_data_gen,
fw_impl=self.fw_impl,
pruning_config=self.pruning_config,
fw_info=self.fw_info)

else:
# Log an error if an unsupported importance metric is specified.
Logger.error(f"Not supported importance metric: {self.pruning_config.importance_metric}")
entry_node_to_score = im.get_entry_node_to_score(sections_input_nodes)

# Return the dictionary of nodes mapped to their importance scores.
return entry_node_to_score

def normalize_lfh_scores(self,
entry_node_to_score,
entry_node_to_l2norm,
entry_node_to_nparmas):
new_scores = {}
for node, trace_vector in entry_node_to_score.items():
new_scores[node] = trace_vector*entry_node_to_l2norm[node]/entry_node_to_nparmas[node]
return new_scores

def count_oc_nparams(self, entry_nodes: List[BaseNode]):
node_channel_params = {}
for entry_node in entry_nodes:
kernel = entry_node.get_weights_by_keys('kernel')
ox_axis = self.fw_info.kernel_channels_mapping.get(entry_node.type)[0]

# Calculate the number of parameters for each output channel
params_per_channel = np.prod(kernel.shape) / kernel.shape[ox_axis]

# Create an array with the number of parameters per channel
num_params_array = np.full(kernel.shape[ox_axis], params_per_channel)

# Store in node_channel_params a dictionary from node to a np.array where
# each element corresponds to the number of parameters of this channel
node_channel_params[entry_node] = num_params_array

return node_channel_params


def get_l2_out_channel_norm(self, entry_nodes: List[BaseNode]):
node_l2_channel_norm = {}
for entry_node in entry_nodes:
kernel = entry_node.get_weights_by_keys('kernel')
ox_axis = self.fw_info.kernel_channels_mapping.get(entry_node.type)[0]

# Compute the l2 norm of each output channel
channels = np.split(kernel, indices_or_sections=kernel.shape[ox_axis], axis=ox_axis)
l2_norms = [np.linalg.norm(c.flatten(), ord=2) ** 2 for c in channels]

# Store in node_l2_channel_norm a dictionary from node to a np.array where
# each element corresponds to the l2 norm of this channel
node_l2_channel_norm[entry_node] = l2_norms

return node_l2_channel_norm

def get_pruning_info(self) -> PruningInfo:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ class ImportanceMetric(Enum):
Enum for specifying the metric used to determine the importance of channels when pruning.
"""
LFH = 0 # Hessian approximation based on weights, to determine channel importance without explicit labels.
RANDOM = 2 # Random importance metric, possibly used as a baseline comparison.


class ChannelsFilteringStrategy(Enum):
Expand Down
33 changes: 33 additions & 0 deletions tests/keras_tests/pruning_tests/random_importance_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Callable, List

from model_compression_toolkit.core.common.framework_info import FrameworkInfo
from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianMode, HessianInfoGranularity
from model_compression_toolkit.core.common.pruning.importance_metrics.base_importance_metric import BaseImportanceMetric
import numpy as np

from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig


class RandomImportanceMetric(BaseImportanceMetric):

def __init__(self,
graph:Graph,
representative_data_gen: Callable,
fw_impl: FrameworkImplementation,
pruning_config: PruningConfig,
fw_info: FrameworkInfo):

self.float_graph= graph
self.representative_data_gen = representative_data_gen
self.fw_impl = fw_impl
self.pruning_config = pruning_config
self.fw_info = fw_info

def get_entry_node_to_score(self, sections_input_nodes:List[BaseNode]):
random_scores = [np.random.random(
node.get_weights_by_keys('kernel').shape[self.fw_info.kernel_channels_mapping.get(node.type)[0]])
for node in sections_input_nodes]
entry_node_to_score = {node: scores for node, scores in zip(sections_input_nodes, random_scores)}
return entry_node_to_score
13 changes: 12 additions & 1 deletion tests/keras_tests/pruning_tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
from enum import Enum

import unittest

import tensorflow as tf

import model_compression_toolkit as mct
import numpy as np

from model_compression_toolkit.core.common.pruning.importance_metrics.importance_metric_factory import im_dict
from tests.keras_tests.pruning_tests.random_importance_metric import RandomImportanceMetric

keras = tf.keras
layers = keras.layers

NUM_PRUNING_RATIOS = 5

class TestImportanceMetric(Enum):
RANDOM = 'random'

im_dict.update({TestImportanceMetric.RANDOM: RandomImportanceMetric})

class ModelsPruningTest(unittest.TestCase):
def representative_dataset(self, in_shape=(1,224,224,3)):
for _ in range(1):
Expand Down Expand Up @@ -117,7 +127,8 @@ def run_test(self, cr, dense_model, test_retraining=False):
representative_data_gen=self.representative_dataset,
pruning_config=mct.pruning.PruningConfig(
num_score_approximations=1,
importance_metric=mct.pruning.ImportanceMetric.RANDOM))
importance_metric=TestImportanceMetric.RANDOM))

pruned_nparams = sum([l.count_params() for l in pruned_model.layers])
actual_cr = pruned_nparams / dense_nparams
print(f"Target remaining cr: {cr*100}, Actual remaining cr: {actual_cr*100} ")
Expand Down

0 comments on commit 0433744

Please sign in to comment.