diff --git a/model_compression_toolkit/core/common/pruning/importance_metrics/__init__.py b/model_compression_toolkit/core/common/pruning/importance_metrics/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/model_compression_toolkit/core/common/pruning/importance_metrics/base_importance_metric.py b/model_compression_toolkit/core/common/pruning/importance_metrics/base_importance_metric.py new file mode 100644 index 000000000..9d0b32a61 --- /dev/null +++ b/model_compression_toolkit/core/common/pruning/importance_metrics/base_importance_metric.py @@ -0,0 +1,15 @@ +from typing import List + +from abc import abstractmethod, ABC + +from model_compression_toolkit.core.common import BaseNode + + +class BaseImportanceMetric(ABC): + + @abstractmethod + def get_entry_node_to_score(self, sections_input_nodes:List[BaseNode]): + raise Exception + + + diff --git a/model_compression_toolkit/core/common/pruning/importance_metrics/importance_metric_factory.py b/model_compression_toolkit/core/common/pruning/importance_metrics/importance_metric_factory.py new file mode 100644 index 000000000..cd72744d2 --- /dev/null +++ b/model_compression_toolkit/core/common/pruning/importance_metrics/importance_metric_factory.py @@ -0,0 +1,10 @@ +from model_compression_toolkit.core.common.pruning import ImportanceMetric +from model_compression_toolkit.core.common.pruning.importance_metrics.base_importance_metric import BaseImportanceMetric +from model_compression_toolkit.core.common.pruning.importance_metrics.lfh_importance_metric import LFHImportanceMetric + +im_dict = {ImportanceMetric.LFH: LFHImportanceMetric} + + +def get_importance_metric(im: ImportanceMetric, **kwargs) -> BaseImportanceMetric: + im = im_dict.get(im) + return im(**kwargs) diff --git a/model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py b/model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py new file mode 100644 index 000000000..9963f3b5d --- /dev/null +++ b/model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py @@ -0,0 +1,96 @@ +from typing import Callable, List + +from model_compression_toolkit.core.common.framework_info import FrameworkInfo +from model_compression_toolkit.core.common import Graph, BaseNode +from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation +from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianMode, HessianInfoGranularity +from model_compression_toolkit.core.common.pruning.importance_metrics.base_importance_metric import BaseImportanceMetric +import numpy as np + +from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig + + +class LFHImportanceMetric(BaseImportanceMetric): + + def __init__(self, + graph:Graph, + representative_data_gen: Callable, + fw_impl: FrameworkImplementation, + pruning_config: PruningConfig, + fw_info: FrameworkInfo): + + self.float_graph= graph + self.representative_data_gen = representative_data_gen + self.fw_impl = fw_impl + self.pruning_config = pruning_config + self.fw_info = fw_info + + def get_entry_node_to_score(self, sections_input_nodes:List[BaseNode]): + # Initialize services and variables for pruning process. + hessian_info_service = HessianInfoService(graph=self.float_graph, + representative_dataset=self.representative_data_gen, + fw_impl=self.fw_impl) + + # Calculate the LFH (Label-Free Hessian) score for each prunable channel. + scores_per_prunable_node = hessian_info_service.fetch_scores_for_multiple_nodes( + mode=HessianMode.WEIGHTS, + granularity=HessianInfoGranularity.PER_OUTPUT_CHANNEL, + nodes=sections_input_nodes, + required_size=self.pruning_config.num_score_approximations) + + + # Average the scores across approximations and map them to the corresponding nodes. + entry_node_to_score = {node: np.mean(scores, axis=0) for node, scores in + zip(sections_input_nodes, scores_per_prunable_node)} + + l2_oc_norm = self.get_l2_out_channel_norm(entry_nodes=sections_input_nodes) + count_oc_nparams = self.count_oc_nparams(entry_nodes=sections_input_nodes) + entry_node_to_score = self.normalize_lfh_scores(entry_node_to_score=entry_node_to_score, + entry_node_to_l2norm=l2_oc_norm, + entry_node_to_nparmas=count_oc_nparams) + return entry_node_to_score + + + def normalize_lfh_scores(self, + entry_node_to_score, + entry_node_to_l2norm, + entry_node_to_nparmas): + new_scores = {} + for node, trace_vector in entry_node_to_score.items(): + new_scores[node] = trace_vector*entry_node_to_l2norm[node]/entry_node_to_nparmas[node] + return new_scores + + def count_oc_nparams(self, entry_nodes: List[BaseNode]): + node_channel_params = {} + for entry_node in entry_nodes: + kernel = entry_node.get_weights_by_keys('kernel') + ox_axis = self.fw_info.kernel_channels_mapping.get(entry_node.type)[0] + + # Calculate the number of parameters for each output channel + params_per_channel = np.prod(kernel.shape) / kernel.shape[ox_axis] + + # Create an array with the number of parameters per channel + num_params_array = np.full(kernel.shape[ox_axis], params_per_channel) + + # Store in node_channel_params a dictionary from node to a np.array where + # each element corresponds to the number of parameters of this channel + node_channel_params[entry_node] = num_params_array + + return node_channel_params + + + def get_l2_out_channel_norm(self, entry_nodes: List[BaseNode]): + node_l2_channel_norm = {} + for entry_node in entry_nodes: + kernel = entry_node.get_weights_by_keys('kernel') + ox_axis = self.fw_info.kernel_channels_mapping.get(entry_node.type)[0] + + # Compute the l2 norm of each output channel + channels = np.split(kernel, indices_or_sections=kernel.shape[ox_axis], axis=ox_axis) + l2_norms = [np.linalg.norm(c.flatten(), ord=2) ** 2 for c in channels] + + # Store in node_l2_channel_norm a dictionary from node to a np.array where + # each element corresponds to the l2 norm of this channel + node_l2_channel_norm[entry_node] = l2_norms + + return node_l2_channel_norm \ No newline at end of file diff --git a/model_compression_toolkit/core/common/pruning/pruner.py b/model_compression_toolkit/core/common/pruning/pruner.py index 55e99a433..f0c763f1c 100644 --- a/model_compression_toolkit/core/common/pruning/pruner.py +++ b/model_compression_toolkit/core/common/pruning/pruner.py @@ -7,6 +7,8 @@ from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianMode, HessianInfoGranularity from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI from model_compression_toolkit.core.common.pruning.greedy_mask_calculator import GreedyMaskCalculator +from model_compression_toolkit.core.common.pruning.importance_metrics.importance_metric_factory import \ + get_importance_metric from model_compression_toolkit.core.common.pruning.prune_graph import build_pruned_graph from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig, ChannelsFilteringStrategy, \ ImportanceMetric @@ -104,95 +106,18 @@ def get_score_per_entry_point(self, sections_input_nodes): Dict[BaseNode, np.ndarray]: A dictionary mapping each entry node to its corresponding importance score. """ # Initialize a dictionary to hold scores for each node. - entry_node_to_score = None - - # Compute scores based on the importance metric defined in the pruning configuration. - if self.pruning_config.importance_metric == ImportanceMetric.LFH: - # Initialize services and variables for pruning process. - hessian_info_service = HessianInfoService(graph=self.float_graph, - representative_dataset=self.representative_data_gen, - fw_impl=self.fw_impl) - - # Calculate the LFH (Label-Free Hessian) score for each prunable channel. - scores_per_prunable_node = hessian_info_service.fetch_scores_for_multiple_nodes( - mode=HessianMode.WEIGHTS, - granularity=HessianInfoGranularity.PER_OUTPUT_CHANNEL, - nodes=sections_input_nodes, - required_size=self.pruning_config.num_score_approximations) - - - # Average the scores across approximations and map them to the corresponding nodes. - entry_node_to_score = {node: np.mean(scores, axis=0) for node, scores in - zip(sections_input_nodes, scores_per_prunable_node)} - - l2_oc_norm = self.get_l2_out_channel_norm(entry_nodes=sections_input_nodes) - count_oc_nparams = self.count_oc_nparams(entry_nodes=sections_input_nodes) - entry_node_to_score = self.normalize_lfh_scores(entry_node_to_score=entry_node_to_score, - entry_node_to_l2norm=l2_oc_norm, - entry_node_to_nparmas=count_oc_nparams) - - - elif self.pruning_config.importance_metric == ImportanceMetric.RANDOM: - random_scores = [np.random.random( - node.get_weights_by_keys('kernel').shape[self.fw_info.kernel_channels_mapping.get(node.type)[0]]) - for node in sections_input_nodes] - entry_node_to_score = {node: scores for node, scores in zip(sections_input_nodes, random_scores)} - l2_oc_norm = self.get_l2_out_channel_norm(entry_nodes=sections_input_nodes) - count_oc_nparams = self.count_oc_nparams(entry_nodes=sections_input_nodes) - entry_node_to_score = self.normalize_lfh_scores(entry_node_to_score=entry_node_to_score, - entry_node_to_l2norm=l2_oc_norm, - entry_node_to_nparmas=count_oc_nparams) + im = get_importance_metric(self.pruning_config.importance_metric, + graph=self.float_graph, + representative_data_gen=self.representative_data_gen, + fw_impl=self.fw_impl, + pruning_config=self.pruning_config, + fw_info=self.fw_info) - else: - # Log an error if an unsupported importance metric is specified. - Logger.error(f"Not supported importance metric: {self.pruning_config.importance_metric}") + entry_node_to_score = im.get_entry_node_to_score(sections_input_nodes) # Return the dictionary of nodes mapped to their importance scores. return entry_node_to_score - def normalize_lfh_scores(self, - entry_node_to_score, - entry_node_to_l2norm, - entry_node_to_nparmas): - new_scores = {} - for node, trace_vector in entry_node_to_score.items(): - new_scores[node] = trace_vector*entry_node_to_l2norm[node]/entry_node_to_nparmas[node] - return new_scores - - def count_oc_nparams(self, entry_nodes: List[BaseNode]): - node_channel_params = {} - for entry_node in entry_nodes: - kernel = entry_node.get_weights_by_keys('kernel') - ox_axis = self.fw_info.kernel_channels_mapping.get(entry_node.type)[0] - - # Calculate the number of parameters for each output channel - params_per_channel = np.prod(kernel.shape) / kernel.shape[ox_axis] - - # Create an array with the number of parameters per channel - num_params_array = np.full(kernel.shape[ox_axis], params_per_channel) - - # Store in node_channel_params a dictionary from node to a np.array where - # each element corresponds to the number of parameters of this channel - node_channel_params[entry_node] = num_params_array - - return node_channel_params - - - def get_l2_out_channel_norm(self, entry_nodes: List[BaseNode]): - node_l2_channel_norm = {} - for entry_node in entry_nodes: - kernel = entry_node.get_weights_by_keys('kernel') - ox_axis = self.fw_info.kernel_channels_mapping.get(entry_node.type)[0] - - # Compute the l2 norm of each output channel - channels = np.split(kernel, indices_or_sections=kernel.shape[ox_axis], axis=ox_axis) - l2_norms = [np.linalg.norm(c.flatten(), ord=2) ** 2 for c in channels] - - # Store in node_l2_channel_norm a dictionary from node to a np.array where - # each element corresponds to the l2 norm of this channel - node_l2_channel_norm[entry_node] = l2_norms - - return node_l2_channel_norm def get_pruning_info(self) -> PruningInfo: """ diff --git a/model_compression_toolkit/core/common/pruning/pruning_config.py b/model_compression_toolkit/core/common/pruning/pruning_config.py index cd36bada9..1f80d2e48 100644 --- a/model_compression_toolkit/core/common/pruning/pruning_config.py +++ b/model_compression_toolkit/core/common/pruning/pruning_config.py @@ -8,7 +8,6 @@ class ImportanceMetric(Enum): Enum for specifying the metric used to determine the importance of channels when pruning. """ LFH = 0 # Hessian approximation based on weights, to determine channel importance without explicit labels. - RANDOM = 2 # Random importance metric, possibly used as a baseline comparison. class ChannelsFilteringStrategy(Enum): diff --git a/tests/keras_tests/pruning_tests/random_importance_metric.py b/tests/keras_tests/pruning_tests/random_importance_metric.py new file mode 100644 index 000000000..c388601c2 --- /dev/null +++ b/tests/keras_tests/pruning_tests/random_importance_metric.py @@ -0,0 +1,33 @@ +from typing import Callable, List + +from model_compression_toolkit.core.common.framework_info import FrameworkInfo +from model_compression_toolkit.core.common import Graph, BaseNode +from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation +from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianMode, HessianInfoGranularity +from model_compression_toolkit.core.common.pruning.importance_metrics.base_importance_metric import BaseImportanceMetric +import numpy as np + +from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig + + +class RandomImportanceMetric(BaseImportanceMetric): + + def __init__(self, + graph:Graph, + representative_data_gen: Callable, + fw_impl: FrameworkImplementation, + pruning_config: PruningConfig, + fw_info: FrameworkInfo): + + self.float_graph= graph + self.representative_data_gen = representative_data_gen + self.fw_impl = fw_impl + self.pruning_config = pruning_config + self.fw_info = fw_info + + def get_entry_node_to_score(self, sections_input_nodes:List[BaseNode]): + random_scores = [np.random.random( + node.get_weights_by_keys('kernel').shape[self.fw_info.kernel_channels_mapping.get(node.type)[0]]) + for node in sections_input_nodes] + entry_node_to_score = {node: scores for node, scores in zip(sections_input_nodes, random_scores)} + return entry_node_to_score \ No newline at end of file diff --git a/tests/keras_tests/pruning_tests/test_models.py b/tests/keras_tests/pruning_tests/test_models.py index b983254fe..2be8f95b3 100644 --- a/tests/keras_tests/pruning_tests/test_models.py +++ b/tests/keras_tests/pruning_tests/test_models.py @@ -1,3 +1,5 @@ +from enum import Enum + import unittest import tensorflow as tf @@ -5,11 +7,19 @@ import model_compression_toolkit as mct import numpy as np +from model_compression_toolkit.core.common.pruning.importance_metrics.importance_metric_factory import im_dict +from tests.keras_tests.pruning_tests.random_importance_metric import RandomImportanceMetric + keras = tf.keras layers = keras.layers NUM_PRUNING_RATIOS = 5 +class TestImportanceMetric(Enum): + RANDOM = 'random' + +im_dict.update({TestImportanceMetric.RANDOM: RandomImportanceMetric}) + class ModelsPruningTest(unittest.TestCase): def representative_dataset(self, in_shape=(1,224,224,3)): for _ in range(1): @@ -117,7 +127,8 @@ def run_test(self, cr, dense_model, test_retraining=False): representative_data_gen=self.representative_dataset, pruning_config=mct.pruning.PruningConfig( num_score_approximations=1, - importance_metric=mct.pruning.ImportanceMetric.RANDOM)) + importance_metric=TestImportanceMetric.RANDOM)) + pruned_nparams = sum([l.count_params() for l in pruned_model.layers]) actual_cr = pruned_nparams / dense_nparams print(f"Target remaining cr: {cr*100}, Actual remaining cr: {actual_cr*100} ")