diff --git a/model_compression_toolkit/core/common/graph/base_graph.py b/model_compression_toolkit/core/common/graph/base_graph.py index 432a81f39..3e83aa4af 100644 --- a/model_compression_toolkit/core/common/graph/base_graph.py +++ b/model_compression_toolkit/core/common/graph/base_graph.py @@ -544,10 +544,8 @@ def get_weights_configurable_nodes(self, potential_conf_nodes = [n for n in list(self) if fw_info.is_kernel_op(n.type)] def is_configurable(n): - kernel_attr = fw_info.get_kernel_op_attributes(n.type)[0] - return (n.is_weights_quantization_enabled(kernel_attr) and - not n.is_all_weights_candidates_equal(kernel_attr) and - (not n.reuse or include_reused_nodes)) + kernel_attrs = fw_info.get_kernel_op_attributes(n.type) + return any(n.is_configurable_weight(attr) for attr in kernel_attrs) and (not n.reuse or include_reused_nodes) return [n for n in potential_conf_nodes if is_configurable(n)] @@ -576,7 +574,7 @@ def get_activation_configurable_nodes(self) -> List[BaseNode]: Returns: A list of nodes that their activation can be configured (namely, has one or more activation qc candidate). """ - return [n for n in list(self) if n.is_activation_quantization_enabled() and not n.is_all_activation_candidates_equal()] + return [n for n in list(self) if n.has_configurable_activation()] def get_sorted_activation_configurable_nodes(self) -> List[BaseNode]: """ diff --git a/model_compression_toolkit/core/common/graph/base_node.py b/model_compression_toolkit/core/common/graph/base_node.py index 67c4f2f57..98e053940 100644 --- a/model_compression_toolkit/core/common/graph/base_node.py +++ b/model_compression_toolkit/core/common/graph/base_node.py @@ -150,6 +150,27 @@ def is_weights_quantization_enabled(self, attr_name: str) -> bool: return False + def is_configurable_weight(self, attr_name: str) -> bool: + """ + Checks whether the specific weight attribute has a configurable quantization. + + Args: + attr_name: weight attribute name. + + Returns: + Whether the weight attribute is configurable. + """ + return self.is_weights_quantization_enabled(attr_name) and not self.is_all_weights_candidates_equal(attr_name) + + def has_configurable_activation(self) -> bool: + """ + Checks whether the activation has a configurable quantization. + + Returns: + Whether the activation has a configurable quantization. + """ + return self.is_activation_quantization_enabled() and not self.is_all_activation_candidates_equal() + def __repr__(self): """ @@ -420,11 +441,15 @@ def get_total_output_params(self) -> float: Returns: Output size. """ - output_shapes = self.output_shape if isinstance(self.output_shape, List) else [self.output_shape] + # shape can be tuple or list, and multiple shapes can be packed in list or tuple + if self.output_shape and isinstance(self.output_shape[0], (tuple, list)): + output_shapes = self.output_shape + else: + output_shapes = [self.output_shape] # remove batch size (first element) from output shape output_shapes = [s[1:] for s in output_shapes] - + # for scalar shape (None,) prod returns 1 return sum([np.prod([x for x in output_shape if x is not None]) for output_shape in output_shapes]) def find_min_candidates_indices(self) -> List[int]: diff --git a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py index 7f31563a4..2c6dbd638 100644 --- a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +++ b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py @@ -22,7 +22,6 @@ from model_compression_toolkit.core.common import Graph from model_compression_toolkit.core.common.hessian import HessianInfoService from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization, RUTarget -from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_functions_mapping import ru_functions_mapping from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_manager import MixedPrecisionSearchManager from model_compression_toolkit.core.common.mixed_precision.search_methods.linear_programming import \ @@ -105,16 +104,11 @@ def search_bit_width(graph_to_search_cfg: Graph, disable_activation_for_metric=disable_activation_for_metric, hessian_info_service=hessian_info_service) - # Each pair of (resource utilization method, resource utilization aggregation) should match to a specific - # provided target resource utilization - ru_functions = ru_functions_mapping - # Instantiate a manager object search_manager = MixedPrecisionSearchManager(graph, fw_info, fw_impl, se, - ru_functions, target_resource_utilization, original_graph=graph_to_search_cfg) diff --git a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py index 7fbb0807b..047745ca7 100644 --- a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +++ b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py @@ -13,23 +13,24 @@ # limitations under the License. # ============================================================================== -from typing import Callable, Tuple -from typing import Dict, List +from typing import Callable, Dict, List + import numpy as np from model_compression_toolkit.core.common import BaseNode -from model_compression_toolkit.logger import Logger from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation +from model_compression_toolkit.core.common.framework_info import FrameworkInfo from model_compression_toolkit.core.common.graph.base_graph import Graph from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode, \ VirtualSplitWeightsNode, VirtualSplitActivationNode -from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import RUTarget, ResourceUtilization -from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_functions_mapping import RuFunctions -from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_aggregation_methods import MpRuAggregation -from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_methods import MpRuMetric, calc_graph_cuts -from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_cut import Cut -from model_compression_toolkit.core.common.framework_info import FrameworkInfo +from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \ + RUTarget, ResourceUtilization +from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \ + ResourceUtilizationCalculator, TargetInclusionCriterion, BitwidthMode +from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_methods import \ + MixedPrecisionRUHelper from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation +from model_compression_toolkit.logger import Logger class MixedPrecisionSearchManager: @@ -42,7 +43,6 @@ def __init__(self, fw_info: FrameworkInfo, fw_impl: FrameworkImplementation, sensitivity_evaluator: SensitivityEvaluation, - ru_functions: Dict[RUTarget, RuFunctions], target_resource_utilization: ResourceUtilization, original_graph: Graph = None): """ @@ -53,8 +53,6 @@ def __init__(self, fw_impl: FrameworkImplementation object with specific framework methods implementation. sensitivity_evaluator: A SensitivityEvaluation which provides a function that evaluates the sensitivity of a bit-width configuration for the MP model. - ru_functions: A dictionary with pairs of (MpRuMethod, MpRuAggregationMethod) mapping a RUTarget to - a couple of resource utilization metric function and resource utilization aggregation function. target_resource_utilization: Target Resource Utilization to bound our feasible solution space s.t the configuration does not violate it. original_graph: In case we have a search over a virtual graph (if we have BOPS utilization target), then this argument will contain the original graph (for config reconstruction purposes). @@ -69,29 +67,17 @@ def __init__(self, self.compute_metric_fn = self.get_sensitivity_metric() self._cuts = None - ru_types = [ru_target for ru_target, ru_value in - target_resource_utilization.get_resource_utilization_dict().items() if ru_value < np.inf] - self.compute_ru_functions = {ru_target: ru_fn for ru_target, ru_fn in ru_functions.items() if ru_target in ru_types} + self.ru_metrics = target_resource_utilization.get_restricted_metrics() + self.ru_helper = MixedPrecisionRUHelper(graph, fw_info, fw_impl) self.target_resource_utilization = target_resource_utilization self.min_ru_config = self.graph.get_min_candidates_config(fw_info) self.max_ru_config = self.graph.get_max_candidates_config(fw_info) - self.min_ru = self.compute_min_ru() + self.min_ru = self.ru_helper.compute_utilization(self.ru_metrics, self.min_ru_config) self.non_conf_ru_dict = self._non_configurable_nodes_ru() self.config_reconstruction_helper = ConfigReconstructionHelper(virtual_graph=self.graph, original_graph=self.original_graph) - @property - def cuts(self) -> List[Cut]: - """ - Calculates graph cuts. Written as property, so it will only be calculated once and - only if cuts are needed. - - """ - if self._cuts is None: - self._cuts = calc_graph_cuts(self.original_graph) - return self._cuts - def get_search_space(self) -> Dict[int, List[int]]: """ The search space is a mapping from a node's index to a list of integers (possible bitwidths candidates indeces @@ -122,40 +108,6 @@ def get_sensitivity_metric(self) -> Callable: return self.sensitivity_evaluator.compute_metric - def _calc_ru_fn(self, ru_target, ru_fn, mp_cfg) -> np.ndarray: - """ - Computes a resource utilization for a certain mixed precision configuration. - The method computes a resource utilization vector for specific target resource utilization. - - Returns: resource utilization value. - - """ - # ru_fn is a pair of resource utilization computation method and - # resource utilization aggregation method (in this method we only need the first one) - if ru_target is RUTarget.ACTIVATION: - return ru_fn.metric_fn(mp_cfg, self.graph, self.fw_info, self.fw_impl, self.cuts) - else: - return ru_fn.metric_fn(mp_cfg, self.graph, self.fw_info, self.fw_impl) - - def compute_min_ru(self) -> Dict[RUTarget, np.ndarray]: - """ - Computes a resource utilization vector with the values matching to the minimal mp configuration - (i.e., each node is configured with the quantization candidate that would give the minimal size of the - node's resource utilization). - The method computes the minimal resource utilization vector for each target resource utilization. - - Returns: A dictionary mapping each target resource utilization to its respective minimal - resource utilization values. - - """ - min_ru = {} - for ru_target, ru_fn in self.compute_ru_functions.items(): - # ru_fns is a pair of resource utilization computation method and - # resource utilization aggregation method (in this method we only need the first one) - min_ru[ru_target] = self._calc_ru_fn(ru_target, ru_fn, self.min_ru_config) - - return min_ru - def compute_resource_utilization_matrix(self, target: RUTarget) -> np.ndarray: """ Computes and builds a resource utilization matrix, to be used for the mixed-precision search problem formalization. @@ -184,7 +136,8 @@ def compute_resource_utilization_matrix(self, target: RUTarget) -> np.ndarray: # always be 0 for all entries in the results vector. candidate_rus = np.zeros(shape=self.min_ru[target].shape) else: - candidate_rus = self.compute_candidate_relative_ru(c, candidate_idx, target) + candidate_rus = self.compute_node_ru_for_candidate(c, candidate_idx, target) - self.min_ru[target] + ru_matrix.append(np.asarray(candidate_rus)) # We need to transpose the calculated ru matrix to allow later multiplication with @@ -195,40 +148,6 @@ def compute_resource_utilization_matrix(self, target: RUTarget) -> np.ndarray: np_ru_matrix = np.array(ru_matrix) return np.moveaxis(np_ru_matrix, source=0, destination=len(np_ru_matrix.shape) - 1) - def compute_candidate_relative_ru(self, - conf_node_idx: int, - candidate_idx: int, - target: RUTarget) -> np.ndarray: - """ - Computes a resource utilization vector for a given candidates of a given configurable node, - i.e., the matching resource utilization vector which is obtained by computing the given target's - resource utilization function on a minimal configuration in which the given - layer's candidates is changed to the new given one. - The result is normalized by subtracting the target's minimal resource utilization vector. - - Args: - conf_node_idx: The index of a node in a sorted configurable nodes list. - candidate_idx: The index of a node's quantization configuration candidate. - target: The target for which the resource utilization is calculated (a RUTarget value). - - Returns: Normalized node's resource utilization vector - - """ - return self.compute_node_ru_for_candidate(conf_node_idx, candidate_idx, target) - \ - self.get_min_target_resource_utilization(target) - - def get_min_target_resource_utilization(self, target: RUTarget) -> np.ndarray: - """ - Returns the minimal resource utilization vector (pre-calculated on initialization) of a specific target. - - Args: - target: The target for which the resource utilization is calculated (a RUTarget value). - - Returns: Minimal resource utilization vector. - - """ - return self.min_ru[target] - def compute_node_ru_for_candidate(self, conf_node_idx: int, candidate_idx: int, target: RUTarget) -> np.ndarray: """ Computes a resource utilization vector after replacing the given node's configuration candidate in the minimal @@ -243,7 +162,8 @@ def compute_node_ru_for_candidate(self, conf_node_idx: int, candidate_idx: int, """ cfg = self.replace_config_in_index(self.min_ru_config, conf_node_idx, candidate_idx) - return self._calc_ru_fn(target, self.compute_ru_functions[target], cfg) + # TODO compute for all targets at once. Currently the way up to add_set_of_ru_constraints is per target. + return self.ru_helper.compute_utilization({target}, cfg)[target] @staticmethod def replace_config_in_index(mp_cfg: List[int], idx: int, value: int) -> List[int]: @@ -270,21 +190,10 @@ def _non_configurable_nodes_ru(self) -> Dict[RUTarget, np.ndarray]: Returns: A mapping between a RUTarget and its non-configurable nodes' resource utilization vector. """ - - non_conf_ru_dict = {} - for target, ru_fns in self.compute_ru_functions.items(): - # Call for the ru method of the given target - empty quantization configuration list is passed since we - # compute for non-configurable nodes - if target == RUTarget.BOPS: - ru_vector = None - elif target == RUTarget.ACTIVATION: - ru_vector = ru_fns.metric_fn([], self.graph, self.fw_info, self.fw_impl, self.cuts) - else: - ru_vector = ru_fns.metric_fn([], self.graph, self.fw_info, self.fw_impl) - - non_conf_ru_dict[target] = ru_vector - - return non_conf_ru_dict + ru_metrics = self.ru_metrics - {RUTarget.BOPS} + ru = self.ru_helper.compute_utilization(ru_targets=ru_metrics, mp_cfg=None) + ru[RUTarget.BOPS] = None + return ru def compute_resource_utilization_for_config(self, config: List[int]) -> ResourceUtilization: """ @@ -297,29 +206,11 @@ def compute_resource_utilization_for_config(self, config: List[int]) -> Resource with the given config. """ - - ru_dict = {} - for ru_target, ru_fns in self.compute_ru_functions.items(): - # Passing False to ru methods and aggregations to indicates that the computations - # are not for constraints setting - if ru_target == RUTarget.BOPS: - configurable_nodes_ru_vector = ru_fns.metric_fn(config, self.original_graph, self.fw_info, self.fw_impl, False) - elif ru_target == RUTarget.ACTIVATION: - configurable_nodes_ru_vector = ru_fns.metric_fn(config, self.graph, self.fw_info, self.fw_impl, self.cuts) - else: - configurable_nodes_ru_vector = ru_fns.metric_fn(config, self.original_graph, self.fw_info, self.fw_impl) - non_configurable_nodes_ru_vector = self.non_conf_ru_dict.get(ru_target) - if non_configurable_nodes_ru_vector is None or len(non_configurable_nodes_ru_vector) == 0: - ru_ru = self.compute_ru_functions[ru_target].aggregate_fn(configurable_nodes_ru_vector, False) - else: - ru_ru = self.compute_ru_functions[ru_target].aggregate_fn( - np.concatenate([configurable_nodes_ru_vector, non_configurable_nodes_ru_vector]), False) - - ru_dict[ru_target] = ru_ru[0] - - config_ru = ResourceUtilization() - config_ru.set_resource_utilization_by_target(ru_dict) - return config_ru + act_qcs, w_qcs = self.ru_helper.get_configurable_qcs(config) + ru = self.ru_helper.ru_calculator.compute_resource_utilization( + target_criterion=TargetInclusionCriterion.AnyQuantized, bitwidth_mode=BitwidthMode.QCustom, act_qcs=act_qcs, + w_qcs=w_qcs) + return ru def finalize_distance_metric(self, layer_to_metrics_mapping: Dict[int, Dict[int, float]]): """ diff --git a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py index 934d24f01..3da53184a 100644 --- a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +++ b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py @@ -12,29 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from dataclasses import dataclass from enum import Enum -from typing import Dict, Any +from typing import Dict, Any, Set import numpy as np class RUTarget(Enum): """ - Targets for which we define Resource Utilization metrics for mixed-precision search. - For each target that we care to consider in a mixed-precision search, there should be defined a set of - resource utilization computation function, resource utilization aggregation function, - and resource utilization target (within a ResourceUtilization object). - - Whenever adding a resource utilization metric to ResourceUtilization class we should add a matching target to this enum. - - WEIGHTS - Weights memory ResourceUtilization metric. - - ACTIVATION - Activation memory ResourceUtilization metric. - - TOTAL - Total memory ResourceUtilization metric. - - BOPS - Total Bit-Operations ResourceUtilization Metric. + Resource Utilization targets for mixed-precision search. + WEIGHTS - Weights memory. + ACTIVATION - Activation memory. + TOTAL - Total memory. + BOPS - Total Bit-Operations. """ WEIGHTS = 'weights' @@ -43,34 +35,20 @@ class RUTarget(Enum): BOPS = 'bops' +@dataclass class ResourceUtilization: """ Class to represent measurements of performance. - """ - - def __init__(self, - weights_memory: float = np.inf, - activation_memory: float = np.inf, - total_memory: float = np.inf, - bops: float = np.inf): - """ - - Args: - weights_memory: Memory of a model's weights in bytes. Note that this includes only coefficients that should be quantized (for example, the kernel of Conv2D in Keras will be affected by this value, while the bias will not). - activation_memory: Memory of a model's activation in bytes, according to the given activation resource utilization metric. - total_memory: The sum of model's activation and weights memory in bytes, according to the given total resource utilization metric. - bops: The total bit-operations in the model. - """ - self.weights_memory = weights_memory - self.activation_memory = activation_memory - self.total_memory = total_memory - self.bops = bops - def __repr__(self): - return f"Weights_memory: {self.weights_memory}, " \ - f"Activation_memory: {self.activation_memory}, " \ - f"Total_memory: {self.total_memory}, " \ - f"BOPS: {self.bops}" + weights_memory: Memory of a model's weights in bytes. + activation_memory: Memory of a model's activation in bytes. + total_memory: The sum of model's activation and weights memory in bytes. + bops: The total bit-operations in the model. + """ + weights_memory: float = np.inf + activation_memory: float = np.inf + total_memory: float = np.inf + bops: float = np.inf def weight_restricted(self): return self.weights_memory < np.inf @@ -93,34 +71,30 @@ def get_resource_utilization_dict(self) -> Dict[RUTarget, float]: RUTarget.TOTAL: self.total_memory, RUTarget.BOPS: self.bops} - def set_resource_utilization_by_target(self, ru_mapping: Dict[RUTarget, float]): + def is_satisfied_by(self, ru: 'ResourceUtilization') -> bool: """ - Setting a ResourceUtilization object values for each ResourceUtilization target in the given dictionary. + Checks whether another ResourceUtilization object satisfies the constraints defined by the current object. Args: - ru_mapping: A mapping from a RUTarget to a matching resource utilization value. + ru: A ResourceUtilization object to check against the current object. + Returns: + Whether all constraints are satisfied. """ - self.weights_memory = ru_mapping.get(RUTarget.WEIGHTS, np.inf) - self.activation_memory = ru_mapping.get(RUTarget.ACTIVATION, np.inf) - self.total_memory = ru_mapping.get(RUTarget.TOTAL, np.inf) - self.bops = ru_mapping.get(RUTarget.BOPS, np.inf) + return bool(ru.weights_memory <= self.weights_memory and \ + ru.activation_memory <= self.activation_memory and \ + ru.total_memory <= self.total_memory and \ + ru.bops <= self.bops) - def holds_constraints(self, ru: Any) -> bool: - """ - Checks whether the given ResourceUtilization object holds a set of ResourceUtilization constraints defined by - the current ResourceUtilization object. + def get_restricted_metrics(self) -> Set[RUTarget]: + d = self.get_resource_utilization_dict() + return {k for k, v in d.items() if v < np.inf} - Args: - ru: A ResourceUtilization object to check if it holds the constraints. - - Returns: True if all the given resource utilization values are not greater than the referenced resource utilization values. + def is_any_restricted(self) -> bool: + return bool(self.get_restricted_metrics()) - """ - if not isinstance(ru, ResourceUtilization): - return False - - return ru.weights_memory <= self.weights_memory and \ - ru.activation_memory <= self.activation_memory and \ - ru.total_memory <= self.total_memory and \ - ru.bops <= self.bops + def __repr__(self): + return f"Weights_memory: {self.weights_memory}, " \ + f"Activation_memory: {self.activation_memory}, " \ + f"Total_memory: {self.total_memory}, " \ + f"BOPS: {self.bops}" diff --git a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py new file mode 100644 index 000000000..b99b2f55d --- /dev/null +++ b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py @@ -0,0 +1,667 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from collections import defaultdict +from copy import deepcopy +from enum import Enum, auto +from functools import lru_cache +from typing import Dict, NamedTuple, Optional, Tuple, List, Iterable, Union, Literal, Sequence + +from model_compression_toolkit.constants import FLOAT_BITWIDTH +from model_compression_toolkit.core import FrameworkInfo +from model_compression_toolkit.core.common import Graph, BaseNode +from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation +from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX +from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_cut import compute_graph_max_cut +from model_compression_toolkit.core.common.graph.memory_graph.cut import Cut +from model_compression_toolkit.core.common.graph.memory_graph.memory_graph import MemoryGraph +from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \ + RUTarget, ResourceUtilization +from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig, \ + NodeActivationQuantizationConfig + + +class BitwidthMode(Enum): + """ + Bit-width configuration for resource utilization computation. + + Float: original un-quantized configuration. Assumed to be 32-bit float. + QMaxBit: maximal bit-width configurations. Assigns each node its maximal available precision according to the + target platform capabilities. + QMinBit: minimal bit-width configuration. Assigns each node its minimal available precision according to the + target platform capabilities. + QCustom: explicitly provided bit-width configuration. + QDefaultSP: default single-precision bit-width configuration. Can be used either in a single-precision mode, + or along with TargetInclusionCriterion.QNonConfigurable, which computes the resource utilization only for + single-precision nodes. To compute custom single precision configuration, use QCustom. + """ + Float = auto() + Q8Bit = auto() + QMaxBit = auto() + QMinBit = auto() + QCustom = auto() + QDefaultSP = auto() + + +class TargetInclusionCriterion(Enum): + """ + Target nodes / parameters to include for resource utilization computation. + + QConfigurable: configurable for Mixed Precision targets (multiple quantization candidates). + QNonConfigurable: non-configurable targets (single quantization candidate). + AnyQuantized: any quantized targets (configurable and non-configurable). + Any: all targets (quantized + float). + """ + QConfigurable = auto() + QNonConfigurable = auto() + AnyQuantized = auto() + Any = auto() + + +class Utilization(NamedTuple): + """ + Utility container for a single resource utilization result. + Supports sum, max, min over an iterable of Utilization objects. + + Args: + size: parameters or activation tensor(s) size. + bytes: memory utilization. + """ + size: int + bytes: Optional[float] + + def __add__(self, other: 'Utilization') -> 'Utilization': + return Utilization(self.size + other.size, self.bytes + other.bytes) + + def __radd__(self, other: Union['Utilization', Literal[0]]): + # Needed for sum (with default start_value=0). + if other == 0: + return self + return self + other + + def __gt__(self, other: 'Utilization'): + # Needed for max. Compare by bytes. + return self.bytes > other.bytes + + def __lt__(self, other: 'Utilization'): + # Needed for min. Compare by bytes. + return self.bytes < other.bytes + + +class ResourceUtilizationCalculator: + """ Resource utilization calculator. """ + + _bitwidth_mode_fn = { + BitwidthMode.QMaxBit: max, + BitwidthMode.QMinBit: min, + } + + def __init__(self, graph: Graph, fw_impl: FrameworkImplementation, fw_info: FrameworkInfo): + self.graph = graph + self.fw_impl = fw_impl + self.fw_info = fw_info + + # Currently we go over the full graph even if utilization won't be requested for all nodes. + # We could fill the cache on the fly only for requested nodes, but it's probably negligible. + self._act_tensors_size = {} + self._params_cnt = {} + for n in graph.nodes: + self._act_tensors_size[n] = n.get_total_output_params() + self._params_cnt[n] = {k: v.size for k, v in n.weights.items()} + self._cuts = None + + def compute_resource_utilization(self, + target_criterion: TargetInclusionCriterion, + bitwidth_mode: BitwidthMode, + act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]] = None, + w_qcs: Optional[Dict[BaseNode, NodeWeightsQuantizationConfig]] = None, + ru_targets: Iterable[RUTarget] = None) -> ResourceUtilization: + """ + Compute network's resource utilization. + + Args: + target_criterion: criterion to include targets for computation (applies to weights, activation). + bitwidth_mode: bit-width mode for computation. + act_qcs: custom activations quantization configuration. Should be provided for custom bit mode only. + In custom mode, must provide configuration for all configurable activations. For non-configurable + activations, if not provided, the default configuration will be extracted from the node. + w_qcs: custom weights quantization configuration. Should be provided for custom bit mode only. + In custom mode, must provide configuration for all configurable weights. For non-configurable + weights, if not provided, the default configuration will be extracted from the node. + ru_targets: metrics to include for computation. If None, all metrics are calculated. + + Returns: + Resource utilization object. + """ + ru_targets = set(ru_targets) if ru_targets else set(RUTarget) + + w_total, a_total = None, None + if {RUTarget.WEIGHTS, RUTarget.TOTAL}.intersection(ru_targets): + w_total, *_ = self.compute_weights_utilization(target_criterion, bitwidth_mode, w_qcs) + elif w_qcs is not None: # pragma: no cover + raise ValueError('Weight configuration passed but no relevant metric requested.') + + if act_qcs and not {RUTarget.ACTIVATION, RUTarget.TOTAL}.intersection(ru_targets): # pragma: no cover + raise ValueError('Activation configuration passed but no relevant metric requested.') + if RUTarget.ACTIVATION in ru_targets: + a_total = self.compute_activations_utilization(target_criterion, bitwidth_mode, act_qcs) + + ru = ResourceUtilization() + if RUTarget.WEIGHTS in ru_targets: + ru.weights_memory = w_total + if RUTarget.ACTIVATION in ru_targets: + ru.activation_memory = a_total + if RUTarget.TOTAL in ru_targets: + # TODO use maxcut + act_tensors_total, *_ = self.compute_activation_tensors_utilization(target_criterion, bitwidth_mode, act_qcs) + ru.total_memory = w_total + act_tensors_total + if RUTarget.BOPS in ru_targets: + ru.bops, _ = self.compute_bops(target_criterion=target_criterion, + bitwidth_mode=bitwidth_mode, act_qcs=act_qcs, w_qcs=w_qcs) + + assert ru.get_restricted_metrics() == set(ru_targets), 'Mismatch between the number of requested and computed metrics' + return ru + + def compute_weights_utilization(self, + target_criterion: TargetInclusionCriterion, + bitwidth_mode: BitwidthMode, + w_qcs: Optional[Dict[BaseNode, NodeWeightsQuantizationConfig]] = None) \ + -> Tuple[float, Dict[BaseNode, Utilization], Dict[BaseNode, Dict[str, Utilization]]]: + """ + Compute graph's weights resource utilization. + + Args: + target_criterion: criterion to include targets for computation. + bitwidth_mode: bit-width mode for computation. + w_qcs: custom weights quantization configuration. Should be provided for custom bit mode only. + In custom mode, must provide configuration for all configurable weights. For non-configurable + weights, if not provided, the default configuration will be extracted from the node. + + Returns: + - Total weights utilization of the network. + - Per node total weights utilization. Dict keys are nodes in a topological order. + - Detailed per node per weight attribute utilization. Dict keys are nodes in a topological order. + """ + nodes = self._get_target_weight_nodes(target_criterion, include_reused=False) + if not nodes: + return 0, {}, {} + + util_per_node: Dict[BaseNode, Utilization] = {} + util_per_node_per_weight = {} + + for n in self._topo_sort(nodes): + w_qc = w_qcs.get(n) if w_qcs else None + node_weights_util, per_weight_util = self.compute_node_weights_utilization(n, target_criterion, + bitwidth_mode, w_qc) + util_per_node[n] = node_weights_util + util_per_node_per_weight[n] = per_weight_util + + total_util = sum(util_per_node.values()) + return total_util.bytes, util_per_node, util_per_node_per_weight + + def compute_node_weights_utilization(self, + n: BaseNode, + target_criterion: TargetInclusionCriterion, + bitwidth_mode: BitwidthMode, + qc: NodeWeightsQuantizationConfig)\ + -> Tuple[Utilization, Dict[str, Utilization]]: + """ + Compute resource utilization for weights of a node. + + Args: + n: node. + target_criterion: criterion to include weights for computation. + bitwidth_mode: bit-width mode for the computation. + qc: custom weights quantization configuration. Should be provided for custom bit mode only. + In custom mode, must provide configuration for all configurable weights. For non-configurable + weights, if not provided, the default configuration will be extracted from the node. + + Returns: + - Node's total weights utilization. + - Detailed per weight attribute utilization. + """ + weight_attrs = self._get_target_weight_attrs(n, target_criterion) + if not weight_attrs: # pragma: no cover + return Utilization(0, 0), {} + + attr_util = {} + for attr in weight_attrs: + size = self._params_cnt[n][attr] + nbits = self._get_weight_nbits(n, attr, bitwidth_mode, qc) + bytes_ = size * nbits / 8 + attr_util[attr] = Utilization(size, bytes_) + + total_weights: Utilization = sum(attr_util.values()) # type: ignore + return total_weights, attr_util + + def compute_activations_utilization(self, + target_criterion: TargetInclusionCriterion, + bitwidth_mode: BitwidthMode, + act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]] = None): + """ + Compute total activations utilization in the graph. + + Args: + target_criterion: criterion to include weights for computation. + bitwidth_mode: bit-width mode for the computation. + act_qcs: custom activations quantization configuration. Should be provided for custom bit mode only. + In custom mode, must provide configuration for all configurable activations. For non-configurable + activations, if not provided, the default configuration will be extracted from the node. + + Returns: + Total activation utilization of the network. + """ + return self.compute_cut_activation_utilization(target_criterion, bitwidth_mode, act_qcs)[0] + + def compute_cut_activation_utilization(self, + target_criterion: TargetInclusionCriterion, + bitwidth_mode: BitwidthMode, + act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]]) \ + -> Tuple[float, Dict[Cut, Utilization], Dict[Cut, Dict[BaseNode, Utilization]]]: + """ + Compute graph activation cuts utilization. + + Args: + target_criterion: criterion to include weights for computation. + bitwidth_mode: bit-width mode for the computation. + act_qcs: custom activations quantization configuration. Should be provided for custom bit mode only. + In custom mode, must provide configuration for all configurable activations. For non-configurable + activations, if not provided, the default configuration will be extracted from the node. + + Returns: + - Total activation utilization of the network. + - Total activation utilization per cut. + - Detailed activation utilization per cut per node. + """ + if target_criterion != TargetInclusionCriterion.AnyQuantized: # pragma: no cover + raise NotImplementedError('Computing MaxCut activation utilization is currently only supported for quantized targets.') + + graph_target_nodes = self._get_target_activation_nodes(target_criterion, include_reused=True) + # if there are no target activations in the graph, don't waste time looking for cuts + if not graph_target_nodes: + return 0, {}, {} + + if self._cuts is None: + memory_graph = MemoryGraph(deepcopy(self.graph)) + _, _, cuts = compute_graph_max_cut(memory_graph) + if cuts is None: # pragma: no cover + raise RuntimeError("Failed to calculate activation memory cuts for graph.") # pragma: no cover + cuts = [cut for cut in cuts if cut.mem_elements.elements] + # cache cuts nodes for future use, so do not filter by target + self._cuts = {cut: [self.graph.find_node_by_name(m.node_name)[0] for m in cut.mem_elements.elements] + for cut in cuts} + + util_per_cut: Dict[Cut, Utilization] = {} # type: ignore + util_per_cut_per_node = defaultdict(dict) + for cut in self._cuts: + cut_target_nodes = [n for n in self._cuts[cut] if n in graph_target_nodes] + if not cut_target_nodes: + continue + for n in cut_target_nodes: + qc = act_qcs.get(n) if act_qcs else None + util_per_cut_per_node[cut][n] = self.compute_node_activation_tensor_utilization(n, target_criterion, + bitwidth_mode, qc) + util_per_cut[cut] = sum(util_per_cut_per_node[cut].values()) # type: ignore + + total_util = max(util_per_cut.values()) + return total_util.bytes, util_per_cut, util_per_cut_per_node + + def compute_activation_tensors_utilization(self, + target_criterion: TargetInclusionCriterion, + bitwidth_mode: BitwidthMode, + act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]] = None, + include_reused=False) \ + -> Tuple[float, Dict[BaseNode, Utilization]]: + """ + Compute resource utilization for graph's activations tensors. + + Args: + target_criterion: criterion to include weights for computation. + bitwidth_mode: bit-width mode for the computation. + act_qcs: custom activations quantization configuration. Should be provided for custom bit mode only. + In custom mode, must provide configuration for all configurable activations. For non-configurable + activations, if not provided, the default configuration will be extracted from the node. + include_reused: whether to include reused nodes. + Returns: + - Total activation utilization of the network. + - Detailed utilization per node. Dict keys are nodes in a topological order. + + """ + nodes = self._get_target_activation_nodes(target_criterion, include_reused=include_reused) + if not nodes: + return 0, {} + + util_per_node: Dict[BaseNode, Utilization] = {} + for n in self._topo_sort(nodes): + qc = act_qcs.get(n) if act_qcs else None + util = self.compute_node_activation_tensor_utilization(n, None, bitwidth_mode, qc) + util_per_node[n] = util + + total_util = max(util_per_node.values()) + return total_util.bytes, util_per_node + + def compute_node_activation_tensor_utilization(self, + n: BaseNode, + target_criterion: Optional[TargetInclusionCriterion], + bitwidth_mode: BitwidthMode, + qc: Optional[NodeActivationQuantizationConfig]) -> Utilization: + """ + Compute activation resource utilization for a node. + + Args: + n: node. + target_criterion: criterion to include nodes for computation. If None, will skip the check. + bitwidth_mode: bit-width mode for the computation. + qc: activation quantization config for the node. Should be provided only in custom bit mode. + In custom mode, must be provided if the activation is configurable. For non-configurable activation, if + not passed, the default configuration will be extracted from the node. + Returns: + Node's activation utilization. + """ + if target_criterion: + nodes = self._get_target_activation_nodes(target_criterion=target_criterion, include_reused=True, nodes=[n]) + if not nodes: # pragma: no cover + return Utilization(0, 0) + + size = self._act_tensors_size[n] + nbits = self._get_activation_nbits(n, bitwidth_mode, qc) + bytes_ = size * nbits / 8 + return Utilization(size, bytes_) + + def compute_bops(self, + target_criterion: TargetInclusionCriterion, + bitwidth_mode: BitwidthMode, + act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]] = None, + w_qcs: Optional[Dict[BaseNode, NodeWeightsQuantizationConfig]] = None) \ + -> Tuple[int, Dict[BaseNode, int]]: + """ + Compute bit operations based on nodes with kernel. + Note that 'target_criterion' applies to weights, and BOPS are computed for the selected nodes regardless + of the input activation quantization or lack thereof. + + Args: + target_criterion: criterion to include nodes for computation. + bitwidth_mode: bit-width mode for computation. + act_qcs: custom activations quantization configuration. Should be provided for custom bit mode only. + In custom mode, must provide configuration for all configurable activations. For non-configurable + activations, if not provided, the default configuration will be extracted from the node. + w_qcs: custom weights quantization configuration. Should be provided for custom bit mode only. + In custom mode, must provide configuration for all configurable weights. For non-configurable + weights, if not provided, the default configuration will be extracted from the node. + + Returns: + - Total BOPS count of the network. + - Detailed BOPS count per node. + """ + if target_criterion != TargetInclusionCriterion.AnyQuantized: # pragma: no cover + raise NotImplementedError('BOPS computation is currently only supported for quantized targets.') + + nodes = self._get_target_weight_nodes(target_criterion, include_reused=True) + # filter out nodes with only positional weights # TODO add as arg to get target nodes + nodes = [n for n in nodes if n.has_kernel_weight_to_quantize(self.fw_info)] + + nodes_bops = {} + for n in nodes: + w_qc = w_qcs.get(n) if w_qcs else None + nodes_bops[n] = self.compute_node_bops(n, bitwidth_mode, act_qcs=act_qcs, w_qc=w_qc) + + return sum(nodes_bops.values()), nodes_bops + + def compute_node_bops(self, + n: BaseNode, + bitwidth_mode: BitwidthMode, + act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]] = None, + w_qc: Optional[NodeWeightsQuantizationConfig] = None) -> Union[float, int]: + """ + Compute Bit Operations of a node. + + Args: + n: node. + bitwidth_mode: bit-width mode for the computation. + act_qcs: custom activations quantization configuration. Should be provided for custom bit mode only. + In custom mode, must provide configuration for all configurable activations. For non-configurable + activations, if not provided, the default configuration will be extracted from the node. + w_qc: weights quantization config for the node. Should be provided only in custom bit mode. + Must provide configuration for all configurable weights. For non-configurable weights, will use the + provided configuration if found, or extract the default configuration from the node otherwise. + + Returns: + Node's BOPS count. + """ + node_mac = self.fw_impl.get_node_mac_operations(n, self.fw_info) + if node_mac == 0: # pragma: no cover + return node_mac + + incoming_edges = self.graph.incoming_edges(n, sort_by_attr=EDGE_SINK_INDEX) + # TODO temporary adding this for const_representation test in torch which has Linear with const input + if not incoming_edges: + return 0 + assert len(incoming_edges) == 1, \ + f'Unexpected number of inputs {len(incoming_edges)} for BOPS calculation. Expected 1.' + input_act_node = incoming_edges[0].source_node + act_qc = act_qcs.get(input_act_node) if act_qcs else None + a_nbits = self._get_activation_nbits(input_act_node, bitwidth_mode, act_qc) + + kernel_attrs = self.fw_info.get_kernel_op_attributes(n.type) + if len(kernel_attrs) > 1: # pragma: no cover + raise NotImplementedError('Multiple kernel attributes are not supported for BOPS computation.') + kernel_attr = kernel_attrs[0] + w_nbits = self._get_weight_nbits(n, kernel_attr, bitwidth_mode, w_qc) + + node_bops = a_nbits * w_nbits * node_mac + return node_bops + + @lru_cache + def _get_cut_target_nodes(self, cut: Cut, target_criterion: TargetInclusionCriterion) -> List[BaseNode]: + """ + Retrieve target nodes from a cut filtered by a criterion. + + Args: + cut: a graph cut. + target_criterion: criterion to include nodes for computation. + + Returns: + A list of target nodes from a cut. + """ + cut_nodes = [self.graph.find_node_by_name(e.node_name)[0] for e in cut.mem_elements.elements] + return self._get_target_activation_nodes(target_criterion, include_reused=True, nodes=cut_nodes) + + def _get_target_weight_nodes(self, + target_criterion: TargetInclusionCriterion, + include_reused: bool) -> List[BaseNode]: + """ + Collect nodes to include in weights utilization computation. + + Args: + target_criterion: criterion to include weights for computation. + include_reused: whether to include reused nodes. + + Returns: + Target nodes. + """ + if target_criterion == TargetInclusionCriterion.QConfigurable: + nodes = self.graph.get_weights_configurable_nodes(self.fw_info, include_reused_nodes=include_reused) + elif target_criterion == TargetInclusionCriterion.AnyQuantized: + nodes = [n for n in self.graph if n.has_any_weight_attr_to_quantize()] + elif target_criterion == TargetInclusionCriterion.QNonConfigurable: + # TODO this is wrong. Need to look at specific weights and not the whole node (if w1 is configurable and w2 + # is non-configurable we want to discover the node both as configurable and non-configurable) + quantized = [n for n in self.graph if n.has_any_weight_attr_to_quantize()] + configurable = self.graph.get_weights_configurable_nodes(self.fw_info, include_reused_nodes=include_reused) + nodes = [n for n in quantized if n not in configurable] + elif target_criterion == TargetInclusionCriterion.Any: + nodes = list(self.graph.nodes) + else: # pragma: no cover + raise ValueError(f'Unknown {target_criterion}.') + + if not include_reused: + nodes = [n for n in nodes if not n.reuse] + return nodes + + def _get_target_weight_attrs(self, n: BaseNode, target_criterion: TargetInclusionCriterion) -> List[str]: + """ + Collect weight attributes of a node per criterion. + + Args: + n: node. + target_criterion: selection criterion. + + Returns: + Selected weight attributes names. + """ + weight_attrs = n.get_node_weights_attributes() + if target_criterion == TargetInclusionCriterion.QConfigurable: + weight_attrs = [attr for attr in weight_attrs if n.is_configurable_weight(attr)] + elif target_criterion == TargetInclusionCriterion.AnyQuantized: + weight_attrs = [attr for attr in weight_attrs if n.is_weights_quantization_enabled(attr)] + elif target_criterion == TargetInclusionCriterion.QNonConfigurable: + quantized = [attr for attr in weight_attrs if n.is_weights_quantization_enabled(attr)] + configurable = [attr for attr in weight_attrs if n.is_configurable_weight(attr)] + weight_attrs = [attr for attr in quantized if attr not in configurable] + elif target_criterion != TargetInclusionCriterion.Any: # pragma: no cover + raise ValueError(f'Unknown {target_criterion}') + return weight_attrs + + def _topo_sort(self, nodes: Sequence[BaseNode]) -> List[BaseNode]: + """ + Sort nodes in a topological order (based on graph's nodes). + + Args: + nodes: nodes to sort. + + Returns: + Nodes in topological order. + """ + graph_topo_nodes = self.graph.get_topo_sorted_nodes() + topo_nodes = [n for n in graph_topo_nodes if n in nodes] + if len(topo_nodes) != len(nodes): # pragma: no cover + missing_nodes = [n for n in nodes if n not in topo_nodes] + raise ValueError(f'Could not topo-sort, nodes {missing_nodes} do not match the graph nodes.') + return topo_nodes + + def _get_target_activation_nodes(self, + target_criterion: TargetInclusionCriterion, + include_reused: bool, + nodes: Optional[List[BaseNode]] = None) -> List[BaseNode]: + """ + Collect nodes to include in activation utilization computation. + + Args: + target_criterion: criterion to include activations for computation. + include_reused: whether to include reused nodes. + nodes: nodes to filter target nodes from. By default, uses the graph nodes. + + Returns: + Selected nodes. + """ + nodes = nodes or self.graph.nodes + if target_criterion == TargetInclusionCriterion.QConfigurable: + nodes = [n for n in nodes if n.has_configurable_activation()] + elif target_criterion == TargetInclusionCriterion.AnyQuantized: + nodes = [n for n in nodes if n.is_activation_quantization_enabled()] + elif target_criterion == TargetInclusionCriterion.QNonConfigurable: + nodes = [n for n in nodes if n.is_activation_quantization_enabled() and not n.has_configurable_activation()] + elif target_criterion != TargetInclusionCriterion.Any: # pragma: no cover + raise ValueError(f'Unknown {target_criterion}.') + if not include_reused: + nodes = [n for n in nodes if not n.reuse] + return nodes + + @classmethod + def _get_activation_nbits(cls, + n: BaseNode, + bitwidth_mode: BitwidthMode, + act_qc: Optional[NodeActivationQuantizationConfig]) -> int: + """ + Get activation bit-width for a node according to the requested bit-width mode. + + Args: + n: node. + bitwidth_mode: bit-width mode for computation. + act_qc: activation quantization config for the node. Should be provided only in custom bit mode. + In custom mode, must be provided if the activation is configurable. For non-configurable activation, if + not passed, the default configuration will be extracted from the node. + + Returns: + Activation bit-width. + """ + if act_qc: + if bitwidth_mode != BitwidthMode.QCustom: # pragma: no cover + raise ValueError(f'Activation config is not expected for non-custom bit mode {bitwidth_mode}') + return act_qc.activation_n_bits if act_qc.enable_activation_quantization else FLOAT_BITWIDTH + + if bitwidth_mode == BitwidthMode.Float or not n.is_activation_quantization_enabled(): + return FLOAT_BITWIDTH + + if bitwidth_mode == BitwidthMode.Q8Bit: + return 8 + + if bitwidth_mode in cls._bitwidth_mode_fn: + candidates_nbits = [c.activation_quantization_cfg.activation_n_bits for c in n.candidates_quantization_cfg] + return cls._bitwidth_mode_fn[bitwidth_mode](candidates_nbits) + + if bitwidth_mode in [BitwidthMode.QCustom, BitwidthMode.QDefaultSP]: + qcs = n.get_unique_activation_candidates() + if len(qcs) != 1: # pragma: no cover + raise ValueError(f'Could not retrieve the activation quantization candidate for node {n.name} ' + f'as it has {len(qcs)}!=1 unique candidates .') + return qcs[0].activation_quantization_cfg.activation_n_bits + + raise ValueError(f'Unknown mode {bitwidth_mode}') # pragma: no cover + + @classmethod + def _get_weight_nbits(cls, + n: BaseNode, + w_attr: str, + bitwidth_mode: BitwidthMode, + w_qc: Optional[NodeWeightsQuantizationConfig]) -> int: + """ + Get the bit-width of a specific weight of a node according to the requested bit-width mode. + + Args: + n: node. + w_attr: weight attribute. + bitwidth_mode: bit-width mode for the computation. + w_qc: weights quantization config for the node. Should be provided only in custom bit mode. + Must provide configuration for all configurable weights. For non-configurable weights, will use the + provided configuration if found, or extract the default configuration from the node otherwise. + + Returns: + Weight bit-width. + """ + if w_qc and w_qc.has_attribute_config(w_attr): + if bitwidth_mode != BitwidthMode.QCustom: # pragma: no cover + raise ValueError('Weight config is not expected for non-custom bit mode {bitwidth_mode}') + attr_cfg = w_qc.get_attr_config(w_attr) + return attr_cfg.weights_n_bits if attr_cfg.enable_weights_quantization else FLOAT_BITWIDTH + + if bitwidth_mode == BitwidthMode.Float or not n.is_weights_quantization_enabled(w_attr): + return FLOAT_BITWIDTH + + if bitwidth_mode == BitwidthMode.Q8Bit: + return 8 + + node_qcs = n.get_unique_weights_candidates(w_attr) + w_qcs = [qc.weights_quantization_cfg.get_attr_config(w_attr) for qc in node_qcs] + if bitwidth_mode in cls._bitwidth_mode_fn: + return cls._bitwidth_mode_fn[bitwidth_mode]([qc.weights_n_bits for qc in w_qcs]) + + if bitwidth_mode in [BitwidthMode.QCustom, BitwidthMode.QDefaultSP]: + # if configuration was not passed and the weight has only one candidate, use it + if len(w_qcs) != 1: # pragma: no cover + raise ValueError(f'Could not retrieve the quantization candidate for attr {w_attr} of node {n.name} ' + f'as it {len(w_qcs)}!=1 unique candidates.') + return w_qcs[0].weights_n_bits + + raise ValueError(f'Unknown mode {bitwidth_mode.name}') diff --git a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py index a647a2cc5..03cda3961 100644 --- a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +++ b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py @@ -13,21 +13,17 @@ # limitations under the License. # ============================================================================== import copy -from collections import defaultdict +from typing import Callable, Any -import numpy as np -from typing import Callable, Any, Dict, Tuple - -from model_compression_toolkit.logger import Logger -from model_compression_toolkit.constants import FLOAT_BITWIDTH, BITS_TO_BYTES from model_compression_toolkit.core import FrameworkInfo, ResourceUtilization, CoreConfig, QuantizationErrorMethod from model_compression_toolkit.core.common import Graph from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation -from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX +from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \ + RUTarget +from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \ + ResourceUtilizationCalculator, BitwidthMode, TargetInclusionCriterion from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities -from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import QuantizationConfigOptions -from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_methods import calc_graph_cuts def compute_resource_utilization_data(in_model: Any, @@ -72,174 +68,14 @@ def compute_resource_utilization_data(in_model: Any, fw_impl, tpc, bit_width_config=core_config.bit_width_config, - mixed_precision_enable=mixed_precision_enable) - - # Compute parameters sum - weights_memory_bytes, weights_params = compute_nodes_weights_params(graph=transformed_graph, fw_info=fw_info) - total_weights_params = 0 if len(weights_params) == 0 else sum(weights_params) - - # Compute max activation tensor - activation_output_sizes_bytes, activation_output_sizes = compute_activation_output_maxcut_sizes(graph=transformed_graph) - max_activation_tensor_size = 0 if len(activation_output_sizes) == 0 else max(activation_output_sizes) - - # Compute total memory utilization - parameters sum + max activation tensor - total_size = total_weights_params + max_activation_tensor_size - - # Compute BOPS utilization - total count of bit-operations for all configurable layers with kernel - bops_count = compute_total_bops(graph=transformed_graph, fw_info=fw_info, fw_impl=fw_impl) - bops_count = np.inf if len(bops_count) == 0 else sum(bops_count) - - return ResourceUtilization(weights_memory=total_weights_params, - activation_memory=max_activation_tensor_size, - total_memory=total_size, - bops=bops_count) - - -def compute_nodes_weights_params(graph: Graph, fw_info: FrameworkInfo) -> Tuple[np.ndarray, np.ndarray]: - """ - Calculates the memory usage in bytes and the number of weight parameters for each node within a graph. - Memory calculations are based on the maximum bit-width used for quantization per node. - - Args: - graph: A finalized Graph object, representing the model structure. - fw_info: FrameworkInfo object containing details about the specific framework's - quantization attributes for different layers' weights. - - Returns: - A tuple containing two arrays: - - The first array represents the memory in bytes for each node's weights when quantized at the maximal bit-width. - - The second array represents the total number of weight parameters for each node. - """ - weights_params = [] - weights_memory_bytes = [] - for n in graph.nodes: - # TODO: when enabling multiple attribute quantization by default (currently, - # only kernel quantization is enabled) we should include other attributes memory in the sum of all - # weights memory. - # When implementing this, we should just go over all attributes in the node instead of counting only kernels. - kernel_attr = fw_info.get_kernel_op_attributes(n.type)[0] - if kernel_attr is not None and not n.reuse: - kernel_candidates = n.get_all_weights_attr_candidates(kernel_attr) - - if len(kernel_candidates) > 0 and any([c.enable_weights_quantization for c in kernel_candidates]): - max_weight_bits = max([kc.weights_n_bits for kc in kernel_candidates]) - node_num_weights_params = 0 - for attr in fw_info.get_kernel_op_attributes(n.type): - if attr is not None: - node_num_weights_params += n.get_weights_by_keys(attr).flatten().shape[0] - - weights_params.append(node_num_weights_params) - - # multiply num params by num bits and divide by BITS_TO_BYTES to convert from bits to bytes - weights_memory_bytes.append(node_num_weights_params * max_weight_bits / BITS_TO_BYTES) - - return np.array(weights_memory_bytes), np.array(weights_params) - - -def compute_activation_output_maxcut_sizes(graph: Graph) -> Tuple[np.ndarray, np.ndarray]: - """ - Computes an array of the respective output tensor maxcut size and an array of the output tensor - cut size in bytes for each cut. - - Args: - graph: A finalized Graph object, representing the model structure. - - Returns: - A tuple containing two arrays: - - The first is an array of the size of each activation max-cut size in bytes, calculated - using the maximal bit-width for quantization. - - The second array an array of the size of each activation max-cut activation size in number of parameters. - - """ - cuts = calc_graph_cuts(graph) - - # map nodes to cuts. - node_to_cat_mapping = defaultdict(list) - for i, cut in enumerate(cuts): - mem_element_names = [m.node_name for m in cut.mem_elements.elements] - for m_name in mem_element_names: - if len(graph.find_node_by_name(m_name)) > 0: - node_to_cat_mapping[m_name].append(i) - else: - Logger.critical(f"Missing node: {m_name}") # pragma: no cover + mixed_precision_enable=mixed_precision_enable, + running_gptq=False) - activation_outputs = np.zeros(len(cuts)) - activation_outputs_bytes = np.zeros(len(cuts)) - for n in graph.nodes: - # Go over all nodes that have activation quantization enabled. - if n.has_activation_quantization_enabled_candidate(): - # Fetch maximum bits required for activations quantization. - max_activation_bits = max([qc.activation_quantization_cfg.activation_n_bits for qc in n.candidates_quantization_cfg]) - node_output_size = n.get_total_output_params() - for cut_index in node_to_cat_mapping[n.name]: - activation_outputs[cut_index] += node_output_size - # Calculate activation size in bytes and append to list - activation_outputs_bytes[cut_index] += node_output_size * max_activation_bits / BITS_TO_BYTES - - return activation_outputs_bytes, activation_outputs - - -# TODO maxcut: add test for this function and remove no cover -def compute_activation_output_sizes(graph: Graph) -> Tuple[np.ndarray, np.ndarray]: # pragma: no cover - """ - Computes an array of the respective output tensor size and an array of the output tensor size in bytes for - each node. - - Args: - graph: A finalized Graph object, representing the model structure. - - Returns: - A tuple containing two arrays: - - The first array represents the size of each node's activation output tensor size in bytes, - calculated using the maximal bit-width for quantization. - - The second array represents the size of each node's activation output tensor size. - - """ - activation_outputs = [] - activation_outputs_bytes = [] - for n in graph.nodes: - # Go over all nodes that have configurable activation. - if n.has_activation_quantization_enabled_candidate(): - # Fetch maximum bits required for quantizing activations - max_activation_bits = max([qc.activation_quantization_cfg.activation_n_bits for qc in n.candidates_quantization_cfg]) - node_output_size = n.get_total_output_params() - activation_outputs.append(node_output_size) - # Calculate activation size in bytes and append to list - activation_outputs_bytes.append(node_output_size * max_activation_bits / BITS_TO_BYTES) - - return np.array(activation_outputs_bytes), np.array(activation_outputs) - - -def compute_total_bops(graph: Graph, fw_info: FrameworkInfo, fw_impl: FrameworkImplementation) -> np.ndarray: - """ - Computes a vector with the respective Bit-operations count for each configurable node that includes MAC operations. - The computation assumes that the graph is a representation of a float model, thus, BOPs computation uses 32-bit. - - Args: - graph: Finalized Graph object. - fw_info: FrameworkInfo object about the specific framework - (e.g., attributes of different layers' weights to quantize). - fw_impl: FrameworkImplementation object with a specific framework methods implementation. - - Returns: A vector of nodes' Bit-operations count. - - """ - - bops = [] - - # Go over all configurable nodes that have kernels. - for n in graph.get_topo_sorted_nodes(): - if n.has_kernel_weight_to_quantize(fw_info): - # If node doesn't have weights then its MAC count is 0, and we shouldn't consider it in the BOPS count. - incoming_edges = graph.incoming_edges(n, sort_by_attr=EDGE_SINK_INDEX) - assert len(incoming_edges) == 1, f"Can't compute BOPS metric for node {n.name} with multiple inputs." - - node_mac = fw_impl.get_node_mac_operations(n, fw_info) - - node_bops = (FLOAT_BITWIDTH ** 2) * node_mac - bops.append(node_bops) - - return np.array(bops) + ru_calculator = ResourceUtilizationCalculator(transformed_graph, fw_impl, fw_info) + ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized, BitwidthMode.Q8Bit, + ru_targets=set(RUTarget) - {RUTarget.BOPS}) + ru.bops, _ = ru_calculator.compute_bops(TargetInclusionCriterion.AnyQuantized, BitwidthMode.Float) + return ru def requires_mixed_precision(in_model: Any, @@ -268,7 +104,6 @@ def requires_mixed_precision(in_model: Any, Returns: A boolean indicating if mixed precision is needed. """ - is_mixed_precision = False core_config = _create_core_config_for_ru(core_config) transformed_graph = graph_preparation_runner(in_model, @@ -278,25 +113,13 @@ def requires_mixed_precision(in_model: Any, fw_impl, tpc, bit_width_config=core_config.bit_width_config, - mixed_precision_enable=False) - # Compute max weights memory in bytes - weights_memory_by_layer_bytes, _ = compute_nodes_weights_params(transformed_graph, fw_info) - total_weights_memory_bytes = 0 if len(weights_memory_by_layer_bytes) == 0 else sum(weights_memory_by_layer_bytes) - - # Compute max activation tensor in bytes - activation_memory_estimation_bytes, _ = compute_activation_output_maxcut_sizes(transformed_graph) - max_activation_memory_estimation_bytes = 0 if len(activation_memory_estimation_bytes) == 0 \ - else max(activation_memory_estimation_bytes) - - # Compute BOPS utilization - total count of bit-operations for all configurable layers with kernel - bops_count = compute_total_bops(graph=transformed_graph, fw_info=fw_info, fw_impl=fw_impl) - bops_count = np.inf if len(bops_count) == 0 else sum(bops_count) + mixed_precision_enable=False, + running_gptq=False) - is_mixed_precision |= target_resource_utilization.weights_memory < total_weights_memory_bytes - is_mixed_precision |= target_resource_utilization.activation_memory < max_activation_memory_estimation_bytes - is_mixed_precision |= target_resource_utilization.total_memory < total_weights_memory_bytes + max_activation_memory_estimation_bytes - is_mixed_precision |= target_resource_utilization.bops < bops_count - return is_mixed_precision + ru_calculator = ResourceUtilizationCalculator(transformed_graph, fw_impl, fw_info) + max_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized, BitwidthMode.QMaxBit, + ru_targets=target_resource_utilization.get_restricted_metrics()) + return not target_resource_utilization.is_satisfied_by(max_ru) def _create_core_config_for_ru(core_config: CoreConfig) -> CoreConfig: diff --git a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py deleted file mode 100644 index 123ae4404..000000000 --- a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import copy -from enum import Enum -from functools import partial -from typing import List, Any -import numpy as np - -from pulp import lpSum - - -def sum_ru_values(ru_vector: np.ndarray, set_constraints: bool = True) -> List[Any]: - """ - Aggregates resource utilization vector to a single resource utilization measure by summing all values. - - Args: - ru_vector: A vector with nodes' resource utilization values. - set_constraints: A flag for utilizing the method for resource utilization computation of a - given config not for LP formalization purposes. - - Returns: A list with an lpSum object for lp problem definition with the vector's sum. - - """ - if set_constraints: - return [lpSum(ru_vector)] - return [0] if len(ru_vector) == 0 else [sum(ru_vector)] - - - -def max_ru_values(ru_vector: np.ndarray, set_constraints: bool = True) -> List[float]: - """ - Aggregates resource utilization vector to allow max constraint in the linear programming problem formalization. - In order to do so, we need to define a separate constraint on each value in the resource utilization vector, - to be bounded by the target resource utilization. - - Args: - ru_vector: A vector with nodes' resource utilization values. - set_constraints: A flag for utilizing the method for resource utilization computation of a - given config not for LP formalization purposes. - - Returns: A list with the vector's values, to be used to define max constraint - in the linear programming problem formalization. - - """ - if set_constraints: - return [ru for ru in ru_vector] - return [0] if len(ru_vector) == 0 else [max(ru_vector)] - - - -def total_ru(ru_tensor: np.ndarray, set_constraints: bool = True) -> List[float]: - """ - Aggregates resource utilization vector to allow weights and activation total utilization constraint in the linear programming - problem formalization. In order to do so, we need to define a separate constraint on each activation memory utilization value in - the resource utilization vector, combined with the sum weights memory utilization. - Note that the given ru_tensor should contain weights and activation utilization values in each entry. - - Args: - ru_tensor: A tensor with nodes' resource utilization values for weights and activation. - set_constraints: A flag for utilizing the method for resource utilization computation of a - given config not for LP formalization purposes. - - Returns: A list with lpSum objects, to be used to define total constraint - in the linear programming problem formalization. - - """ - if set_constraints: - weights_ru = lpSum([ru[0] for ru in ru_tensor]) - return [weights_ru + activation_ru for _, activation_ru in ru_tensor] - else: - weights_ru = sum([ru[0] for ru in ru_tensor]) - activation_ru = max([ru[1] for ru in ru_tensor]) - return [weights_ru + activation_ru] - - -class MpRuAggregation(Enum): - """ - Defines resource utilization aggregation functions that can be used to compute final resource utilization metric. - The enum values can be used to call a function on a set of arguments. - - SUM - applies the sum_ru_values function - - MAX - applies the max_ru_values function - - TOTAL - applies the total_ru function - - """ - SUM = partial(sum_ru_values) - MAX = partial(max_ru_values) - TOTAL = partial(total_ru) - - def __call__(self, *args): - return self.value(*args) diff --git a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py deleted file mode 100644 index 86c4a3f86..000000000 --- a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -from typing import NamedTuple - -from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import RUTarget -from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_aggregation_methods import MpRuAggregation -from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_methods import MpRuMetric - - -# When adding a RUTarget that we want to consider in our mp search, -# a matching pair of resource_utilization_tools computation function and a resource_utilization_tools -# aggregation function should be added to this dictionary -class RuFunctions(NamedTuple): - metric_fn: MpRuMetric - aggregate_fn: MpRuAggregation - - -ru_functions_mapping = {RUTarget.WEIGHTS: RuFunctions(MpRuMetric.WEIGHTS_SIZE, MpRuAggregation.SUM), - RUTarget.ACTIVATION: RuFunctions(MpRuMetric.ACTIVATION_MAXCUT_SIZE, MpRuAggregation.MAX), - RUTarget.TOTAL: RuFunctions(MpRuMetric.TOTAL_WEIGHTS_ACTIVATION_SIZE, MpRuAggregation.TOTAL), - RUTarget.BOPS: RuFunctions(MpRuMetric.BOPS_COUNT, MpRuAggregation.SUM)} diff --git a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py index b75bf1232..b3605089f 100644 --- a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py +++ b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py @@ -12,389 +12,191 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from enum import Enum -from functools import partial -from typing import List, Optional -from copy import deepcopy +from typing import List, Set, Dict, Optional, Tuple import numpy as np from model_compression_toolkit.core import FrameworkInfo from model_compression_toolkit.core.common import Graph, BaseNode -from model_compression_toolkit.constants import BITS_TO_BYTES, FLOAT_BITWIDTH from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation -from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX -from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode, \ - VirtualSplitWeightsNode, VirtualSplitActivationNode -from model_compression_toolkit.core.common.graph.memory_graph.memory_graph import MemoryGraph -from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_cut import compute_graph_max_cut, Cut -from model_compression_toolkit.logger import Logger - - -def weights_size_utilization(mp_cfg: List[int], - graph: Graph, - fw_info: FrameworkInfo, - fw_impl: FrameworkImplementation) -> np.ndarray: - """ - Computes a resource utilization vector with the respective weights' memory size for the given weight configurable node, - according to the given mixed-precision configuration. - If an empty configuration is given, then computes resource utilization vector for non-configurable nodes. - - Args: - mp_cfg: A mixed-precision configuration (list of candidates index for each configurable node) - graph: Graph object. - fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize). - fw_impl: FrameworkImplementation object with specific framework methods implementation (not used in this method). - - Returns: A vector of node's weights memory sizes. - Note that the vector is not necessarily of the same length as the given config. - - """ - weights_memory = [] - mp_nodes = graph.get_configurable_sorted_nodes_names(fw_info) - weights_mp_nodes = [n.name for n in graph.get_sorted_weights_configurable_nodes(fw_info)] - - if len(mp_cfg) == 0: - # Computing non-configurable nodes resource utilization - # TODO: when enabling multiple attribute quantization by default (currently, - # only kernel quantization is enabled) we should include other attributes memory in the sum of all - # weights memory (when quantized to their default 8-bit, non-configurable). - # When implementing this, we should just go over all attributes in the node instead of counting only kernels. - for n in graph.nodes: - kernel_attr = fw_info.get_kernel_op_attributes(n.type)[0] - if kernel_attr is None: - continue - non_configurable_node = n.name not in weights_mp_nodes \ - and not n.reuse \ - and n.is_all_weights_candidates_equal(kernel_attr) - - if non_configurable_node: - node_nbits = (n.candidates_quantization_cfg[0].weights_quantization_cfg - .get_attr_config(kernel_attr).weights_n_bits) - node_weights_memory_in_bytes = _compute_node_weights_memory(n, node_nbits, fw_info) - weights_memory.append(node_weights_memory_in_bytes) - else: - # Go over configurable all nodes that should be taken into consideration when computing the weights - # resource utilization. - for n in graph.get_sorted_weights_configurable_nodes(fw_info): - # Only nodes with kernel op can be considered configurable - kernel_attr = fw_info.get_kernel_op_attributes(n.type)[0] - node_idx = mp_nodes.index(n.name) - node_qc = n.candidates_quantization_cfg[mp_cfg[node_idx]] - node_nbits = node_qc.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits - - node_weights_memory_in_bytes = _compute_node_weights_memory(n, node_nbits, fw_info) - - weights_memory.append(node_weights_memory_in_bytes) - - return np.array(weights_memory) - - -def calc_graph_cuts(graph: Graph) -> List[Cut]: - """ - Calculate graph activation cuts. - Args: - graph: A graph object to calculate activation cuts on. - - Returns: - A list of activation cuts. - - """ - memory_graph = MemoryGraph(deepcopy(graph)) - _, _, cuts = compute_graph_max_cut(memory_graph) - - if cuts is None: - Logger.critical("Failed to calculate activation memory cuts for graph.") # pragma: no cover - # filter empty cuts and cuts that contain only nodes with activation quantization disabled. - filtered_cuts = [] - for cut in cuts: - cut_has_no_act_quant_nodes = any( - [graph.find_node_by_name(e.node_name)[0].has_activation_quantization_enabled_candidate() - for e in cut.mem_elements.elements]) - if len(cut.mem_elements.elements) > 0 and cut_has_no_act_quant_nodes: - filtered_cuts.append(cut) - return filtered_cuts - - -def activation_maxcut_size_utilization(mp_cfg: List[int], - graph: Graph, - fw_info: FrameworkInfo, - fw_impl: FrameworkImplementation, - cuts: Optional[List[Cut]] = None) -> np.ndarray: - """ - Computes a resource utilization vector with the respective output memory max-cut size for activation - nodes, according to the given mixed-precision configuration. +from model_compression_toolkit.core.common.graph.memory_graph.cut import Cut +from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode +from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \ + RUTarget +from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \ + ResourceUtilizationCalculator, BitwidthMode, TargetInclusionCriterion +from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig, \ + NodeActivationQuantizationConfig + + +# TODO take into account Virtual nodes. Are candidates defined with respect to virtual or original nodes? +# Can we use the virtual graph only for bops and the original graph for everything else? + +class MixedPrecisionRUHelper: + """ Helper class for resource utilization computations for mixed precision optimization. """ + + def __init__(self, graph: Graph, fw_info: FrameworkInfo, fw_impl: FrameworkImplementation): + self.graph = graph + self.fw_info = fw_info + self.fw_impl = fw_impl + self.ru_calculator = ResourceUtilizationCalculator(graph, fw_impl, fw_info) + + def compute_utilization(self, ru_targets: Set[RUTarget], mp_cfg: Optional[List[int]]) -> Dict[RUTarget, np.ndarray]: + """ + Compute utilization of requested targets for a specific configuration in the format expected by LP problem + formulation, namely an array of ru values corresponding to graph's configurable nodes in the topological order. + For activation target, the array contains values for activation cuts in unspecified order (as long as it is + consistent between configurations). + + Args: + ru_targets: resource utilization targets to compute. + mp_cfg: a list of candidates indices for configurable layers. + + Returns: + Dict of the computed utilization per target. + """ + + ru = {} + + act_qcs, w_qcs = self.get_configurable_qcs(mp_cfg) if mp_cfg else (None, None) + w_util = None + if RUTarget.WEIGHTS in ru_targets: + w_util = self._weights_utilization(w_qcs) + ru[RUTarget.WEIGHTS] = np.array(list(w_util.values())) + + # TODO make mp agnostic to activation method + if RUTarget.ACTIVATION in ru_targets: + act_util = self._activation_maxcut_utilization(act_qcs) + ru[RUTarget.ACTIVATION] = np.array(list(act_util.values())) + + # TODO use maxcut + if RUTarget.TOTAL in ru_targets: + act_tensors_util = self._activation_tensor_utilization(act_qcs) + w_util = w_util or self._weights_utilization(w_qcs) + total = {n: (w_util.get(n, 0), act_tensors_util.get(n, 0)) + # for n in self.graph.nodes if n in act_tensors_util or n in w_util} + for n in self.graph.get_topo_sorted_nodes() if n in act_tensors_util or n in w_util} + ru[RUTarget.TOTAL] = np.array(list(total.values())) + + if RUTarget.BOPS in ru_targets: + ru[RUTarget.BOPS] = self._bops_utilization(mp_cfg) + + return ru + + def get_configurable_qcs(self, mp_cfg) \ + -> Tuple[Dict[BaseNode, NodeActivationQuantizationConfig], Dict[BaseNode, NodeWeightsQuantizationConfig]]: + """ + Retrieve quantization candidates objects for weights and activations from the configuration list. + + Args: + mp_cfg: a list of candidates indices for configurable layers. + + Returns: + Mapping between nodes to weights quantization config, and a mapping between nodes and activation + quantization config. + """ + mp_nodes = self.graph.get_configurable_sorted_nodes(self.fw_info) + node_qcs = {n: n.candidates_quantization_cfg[mp_cfg[i]] for i, n in enumerate(mp_nodes)} + act_qcs = {n: node_qcs[n].activation_quantization_cfg + for n in self.graph.get_activation_configurable_nodes()} + w_qcs = {n: node_qcs[n].weights_quantization_cfg + for n in self.graph.get_weights_configurable_nodes(self.fw_info)} + return act_qcs, w_qcs + + def _weights_utilization(self, w_qcs: Optional[Dict[BaseNode, NodeWeightsQuantizationConfig]]) -> Dict[BaseNode, float]: + """ + Compute weights utilization for configurable weights if configuration is passed, + or for non-configurable nodes otherwise. + + Args: + w_qcs: nodes quantization configuration to compute, or None. + + Returns: + Weight utilization per node. + """ + if w_qcs: + target_criterion = TargetInclusionCriterion.QConfigurable + bitwidth_mode = BitwidthMode.QCustom + else: + target_criterion = TargetInclusionCriterion.QNonConfigurable + bitwidth_mode = BitwidthMode.QDefaultSP + + _, nodes_util, _ = self.ru_calculator.compute_weights_utilization(target_criterion=target_criterion, + bitwidth_mode=bitwidth_mode, + w_qcs=w_qcs) + nodes_util = {n: u.bytes for n, u in nodes_util.items()} + return nodes_util + + def _activation_maxcut_utilization(self, act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]]) \ + -> Optional[Dict[Cut, float]]: + """ + Compute activation utilization using MaxCut for all quantized nodes if configuration is passed. + + Args: + act_qcs: nodes activation configuration or None. + + Returns: + Activation utilization per cut, or empty dict if no configuration was passed. + """ + if act_qcs: + _, cuts_util, _ = self.ru_calculator.compute_cut_activation_utilization(TargetInclusionCriterion.AnyQuantized, + bitwidth_mode=BitwidthMode.QCustom, + act_qcs=act_qcs) + cuts_util = {c: u.bytes for c, u in cuts_util.items()} + return cuts_util - Args: - mp_cfg: A mixed-precision configuration (list of candidates index for each configurable node) - graph: Graph object. - fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize) - (not used in this method). - fw_impl: FrameworkImplementation object with specific framework methods implementation(not used in this method). - cuts: a list of graph cuts (optional. if not provided calculated locally). - TODO maxcut: refactor - need to remove the cuts so all metric functions signatures are the same. - - Returns: A vector of node's cut memory sizes. - Note that the vector is not necessarily of the same length as the given config. - - """ - if len(mp_cfg) == 0: # Computing non-configurable nodes resource utilization for max-cut is included in the calculation of the # configurable nodes. - return np.array([]) - - activation_cut_memory = [] - mp_nodes = graph.get_configurable_sorted_nodes_names(fw_info) - # Go over all nodes that should be taken into consideration when computing the weights memory utilization. - nodes_act_nbits = {} - for n in graph.get_sorted_activation_configurable_nodes(): - node_idx = mp_nodes.index(n.name) - node_qc = n.candidates_quantization_cfg[mp_cfg[node_idx]] - node_nbits = node_qc.activation_quantization_cfg.activation_n_bits - nodes_act_nbits[n.name] = node_nbits - - if cuts is None: - cuts = calc_graph_cuts(graph) - - for i, cut in enumerate(cuts): - mem_elements = [m.node_name for m in cut.mem_elements.elements] - mem = 0 - for op_name in mem_elements: - n = graph.find_node_by_name(op_name)[0] - if n.is_activation_quantization_enabled(): - base_nbits = n.candidates_quantization_cfg[0].activation_quantization_cfg.activation_n_bits - mem += _compute_node_activation_memory(n, nodes_act_nbits.get(op_name, base_nbits)) - - activation_cut_memory.append(mem) - - return np.array(activation_cut_memory) - - -# TODO maxcut: add test for this function and remove no cover -def activation_output_size_utilization(mp_cfg: List[int], - graph: Graph, - fw_info: FrameworkInfo, - fw_impl: FrameworkImplementation) -> np.ndarray: # pragma: no cover - """ - Computes a resource utilization vector with the respective output memory size for each activation configurable node, - according to the given mixed-precision configuration. - If an empty configuration is given, then computes resource utilization vector for non-configurable nodes. - - Args: - mp_cfg: A mixed-precision configuration (list of candidates index for each configurable node) - graph: Graph object. - fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize) - (not used in this method). - fw_impl: FrameworkImplementation object with specific framework methods implementation(not used in this method). - - Returns: A vector of node's activation memory sizes. - Note that the vector is not necessarily of the same length as the given config. - - """ - activation_memory = [] - mp_nodes = graph.get_configurable_sorted_nodes_names(fw_info) - activation_mp_nodes = [n.name for n in graph.get_sorted_activation_configurable_nodes()] - - if len(mp_cfg) == 0: - # Computing non-configurable nodes resource utilization - for n in graph.nodes: - non_configurable_node = n.name not in activation_mp_nodes \ - and n.has_activation_quantization_enabled_candidate() \ - and n.is_all_activation_candidates_equal() - - if non_configurable_node: - node_nbits = n.candidates_quantization_cfg[0].activation_quantization_cfg.activation_n_bits - node_activation_memory_in_bytes = _compute_node_activation_memory(n, node_nbits) - activation_memory.append(node_activation_memory_in_bytes) - else: - # Go over all nodes that should be taken into consideration when computing the weights memory utilization. - for n in graph.get_sorted_activation_configurable_nodes(): - node_idx = mp_nodes.index(n.name) - node_qc = n.candidates_quantization_cfg[mp_cfg[node_idx]] - node_nbits = node_qc.activation_quantization_cfg.activation_n_bits - - node_activation_memory_in_bytes = _compute_node_activation_memory(n, node_nbits) - - activation_memory.append(node_activation_memory_in_bytes) - - return np.array(activation_memory) - - -def total_weights_activation_utilization(mp_cfg: List[int], - graph: Graph, - fw_info: FrameworkInfo, - fw_impl: FrameworkImplementation) -> np.ndarray: - """ - Computes resource utilization tensor with the respective weights size and output memory size for each activation configurable node, - according to the given mixed-precision configuration. - If an empty configuration is given, then computes resource utilization vector for non-configurable nodes. + return {} - Args: - mp_cfg: A mixed-precision configuration (list of candidates index for each configurable node) - graph: Graph object. - fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize) - (not used in this method). - fw_impl: FrameworkImplementation object with specific framework methods implementation(not used in this method). + def _activation_tensor_utilization(self, act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]]): + """ + Compute activation tensors utilization fo configurable nodes if configuration is passed or + for non-configurable nodes otherwise. - Returns: A 2D tensor of nodes' weights memory sizes and activation output memory size. - Note that the vector is not necessarily of the same length as the given config. + Args: + act_qcs: activation quantization configuration or None. - """ - weights_activation_memory = [] - weights_mp_nodes = [n.name for n in graph.get_sorted_weights_configurable_nodes(fw_info)] - activation_mp_nodes = [n.name for n in graph.get_sorted_activation_configurable_nodes()] - - if len(mp_cfg) == 0: - # Computing non-configurable nodes utilization - for n in graph.nodes: - - non_configurable = False - node_weights_memory_in_bytes, node_activation_memory_in_bytes = 0, 0 - - # Non-configurable Weights - # TODO: currently considering only kernel attributes in weights memory utilization. - # When enabling multi-attribute quantization we need to modify this method to count all attributes. - kernel_attr = fw_info.get_kernel_op_attributes(n.type)[0] - if kernel_attr is not None: - is_non_configurable_weights = n.name not in weights_mp_nodes and \ - n.is_all_weights_candidates_equal(kernel_attr) and \ - not n.reuse - - if is_non_configurable_weights: - node_nbits = (n.candidates_quantization_cfg[0].weights_quantization_cfg - .get_attr_config(kernel_attr).weights_n_bits) - node_weights_memory_in_bytes = _compute_node_weights_memory(n, node_nbits, fw_info) - non_configurable = True - - # Non-configurable Activation - is_non_configurable_activation = n.name not in activation_mp_nodes and \ - n.has_activation_quantization_enabled_candidate() and \ - n.is_all_activation_candidates_equal() - - if is_non_configurable_activation: - node_nbits = n.candidates_quantization_cfg[0].activation_quantization_cfg.activation_n_bits - node_activation_memory_in_bytes = _compute_node_activation_memory(n, node_nbits) - non_configurable = True - - if non_configurable: - weights_activation_memory.append( - np.array([node_weights_memory_in_bytes, node_activation_memory_in_bytes])) - else: - # Go over all nodes that should be taken into consideration when computing the weights or - # activation memory utilization (all configurable nodes). - for node_idx, n in enumerate(graph.get_configurable_sorted_nodes(fw_info)): - # TODO: currently considering only kernel attributes in weights memory utilization. When enabling multi-attribute - # quantization we need to modify this method to count all attributes. - - node_qc = n.candidates_quantization_cfg[mp_cfg[node_idx]] - - # Compute node's weights memory (if no weights to quantize then set to 0) - node_weights_memory_in_bytes = 0 - kernel_attr = fw_info.get_kernel_op_attributes(n.type)[0] - if kernel_attr is not None: - if n.is_weights_quantization_enabled(kernel_attr) and not n.is_all_weights_candidates_equal(kernel_attr): - node_weights_nbits = node_qc.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits - node_weights_memory_in_bytes = _compute_node_weights_memory(n, node_weights_nbits, fw_info) - - # Compute node's activation memory (if node's activation are not being quantized then set to 0) - node_activation_nbits = node_qc.activation_quantization_cfg.activation_n_bits - node_activation_memory_in_bytes = 0 - if n.is_activation_quantization_enabled() and not n.is_all_activation_candidates_equal(): - node_activation_memory_in_bytes = _compute_node_activation_memory(n, node_activation_nbits) - - weights_activation_memory.append(np.array([node_weights_memory_in_bytes, node_activation_memory_in_bytes])) - - return np.array(weights_activation_memory) - - -def bops_utilization(mp_cfg: List[int], - graph: Graph, - fw_info: FrameworkInfo, - fw_impl: FrameworkImplementation, - set_constraints: bool = True) -> np.ndarray: - """ - Computes a resource utilization vector with the respective bit-operations (BOPS) count for each configurable node, - according to the given mixed-precision configuration of a virtual graph with composed nodes. + Returns: + Activation utilization per node. + """ + if act_qcs: + target_criterion = TargetInclusionCriterion.QConfigurable + bitwidth_mode = BitwidthMode.QCustom + else: + target_criterion = TargetInclusionCriterion.QNonConfigurable + bitwidth_mode = BitwidthMode.QDefaultSP - Args: - mp_cfg: A mixed-precision configuration (list of candidates index for each configurable node) - graph: Graph object. - fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize). - fw_impl: FrameworkImplementation object with specific framework methods implementation. - set_constraints: A flag for utilizing the method for resource utilization computation of a - given config not for LP formalization purposes. + _, nodes_util = self.ru_calculator.compute_activation_tensors_utilization(target_criterion=target_criterion, + bitwidth_mode=bitwidth_mode, + act_qcs=act_qcs) + return {n: u.bytes for n, u in nodes_util.items()} - Returns: A vector of node's BOPS count. - Note that the vector is not necessarily of the same length as the given config. + def _bops_utilization(self, mp_cfg: List[int]): + """ + Computes a resource utilization vector with the respective bit-operations (BOPS) count for each configurable node, + according to the given mixed-precision configuration of a virtual graph with composed nodes. - """ - - if not set_constraints: - return _bops_utilization(mp_cfg, - graph, - fw_info, - fw_impl) + Args: + mp_cfg: A mixed-precision configuration (list of candidates index for each configurable node) - # BOPs utilization method considers non-configurable nodes, therefore, it doesn't need separate implementation - # for non-configurable nodes for setting a constraint (no need for separate implementation for len(mp_cfg) = 0). + Returns: A vector of node's BOPS count. + Note that the vector is not necessarily of the same length as the given config. - virtual_bops_nodes = [n for n in graph.get_topo_sorted_nodes() if isinstance(n, VirtualActivationWeightsNode)] + """ + # TODO keeping old implementation for now - mp_nodes = graph.get_configurable_sorted_nodes_names(fw_info) - bops = [n.get_bops_count(fw_impl, fw_info, candidate_idx=_get_node_cfg_idx(n, mp_cfg, mp_nodes)) for n in virtual_bops_nodes] + # BOPs utilization method considers non-configurable nodes, therefore, it doesn't need separate implementation + # for non-configurable nodes for setting a constraint (no need for separate implementation for len(mp_cfg) = 0). - return np.array(bops) + virtual_bops_nodes = [n for n in self.graph.get_topo_sorted_nodes() if isinstance(n, VirtualActivationWeightsNode)] + mp_nodes = self.graph.get_configurable_sorted_nodes_names(self.fw_info) -def _bops_utilization(mp_cfg: List[int], - graph: Graph, - fw_info: FrameworkInfo, - fw_impl: FrameworkImplementation) -> np.ndarray: - """ - Computes a resource utilization vector with the respective bit-operations (BOPS) count for each configurable node, - according to the given mixed-precision configuration of an original graph. + bops = [n.get_bops_count(self.fw_impl, self.fw_info, candidate_idx=_get_node_cfg_idx(n, mp_cfg, mp_nodes)) + for n in virtual_bops_nodes] - Args: - mp_cfg: A mixed-precision configuration (list of candidates index for each configurable node) - graph: Graph object. - fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize). - fw_impl: FrameworkImplementation object with specific framework methods implementation. - - Returns: A vector of node's BOPS count. - - """ - - mp_nodes = graph.get_configurable_sorted_nodes_names(fw_info) - - # Go over all nodes that should be taken into consideration when computing the BOPS utilization. - bops = [] - for n in graph.get_topo_sorted_nodes(): - if n.has_kernel_weight_to_quantize(fw_info) and not n.has_positional_weights: - # If node doesn't have weights then its MAC count is 0, and we shouldn't consider it in the BOPS count. - incoming_edges = graph.incoming_edges(n, sort_by_attr=EDGE_SINK_INDEX) - if len(incoming_edges) != 1: - Logger.critical(f"Unable to compute BOPS metric for node {n.name} due to multiple inputs.") # pragma: no cover - input_activation_node = incoming_edges[0].source_node - if len(graph.out_edges(input_activation_node)) > 1: - # In the case where the activation node has multiple outgoing edges - # we don't consider this edge in the BOPS utilization calculation - continue - - input_activation_node_cfg = input_activation_node.candidates_quantization_cfg[_get_node_cfg_idx(input_activation_node, mp_cfg, mp_nodes)] - - node_mac = fw_impl.get_node_mac_operations(n, fw_info) - - node_qc = n.candidates_quantization_cfg[_get_node_cfg_idx(n, mp_cfg, mp_nodes)] - kenrel_node_qc = node_qc.weights_quantization_cfg.get_attr_config(fw_info.get_kernel_op_attributes(n.type)[0]) - node_weights_nbits = kenrel_node_qc.weights_n_bits if \ - kenrel_node_qc.enable_weights_quantization else FLOAT_BITWIDTH - input_activation_nbits = input_activation_node_cfg.activation_quantization_cfg.activation_n_bits if \ - input_activation_node_cfg.activation_quantization_cfg.enable_activation_quantization else FLOAT_BITWIDTH - - node_bops = node_weights_nbits * input_activation_nbits * node_mac - bops.append(node_bops) - - return np.array(bops) + return np.array(bops) def _get_node_cfg_idx(node: BaseNode, mp_cfg: List[int], sorted_configurable_nodes_names: List[str]) -> int: @@ -414,115 +216,7 @@ def _get_node_cfg_idx(node: BaseNode, mp_cfg: List[int], sorted_configurable_nod if node.name in sorted_configurable_nodes_names: node_idx = sorted_configurable_nodes_names.index(node.name) return mp_cfg[node_idx] - else: + else: # pragma: no cover assert len(node.candidates_quantization_cfg) > 0, \ "Any node should have at least one candidate configuration." return 0 - - -def _get_origin_weights_node(n: BaseNode) -> BaseNode: - """ - In case we run a resource utilization computation on a virtual graph, - this method is used to retrieve the original node out of a virtual weights node, - - Args: - n: A possibly virtual node. - - Returns: A node from the original (non-virtual) graph which the given node represents. - - """ - - if isinstance(n, VirtualActivationWeightsNode): - return n.original_weights_node - if isinstance(n, VirtualSplitWeightsNode): - return n.origin_node - - return n - - -def _get_origin_activation_node(n: BaseNode) -> BaseNode: - """ - In case we run a resource utilization computation on a virtual graph, - this method is used to retrieve the original node out of a virtual activation node, - - Args: - n: A possibly virtual node. - - Returns: A node from the original (non-virtual) graph which the given node represents. - - """ - - if isinstance(n, VirtualActivationWeightsNode): - return n.original_activation_node - if isinstance(n, VirtualSplitActivationNode): - return n.origin_node - - return n - - -def _compute_node_weights_memory(n: BaseNode, node_nbits: int, fw_info: FrameworkInfo) -> float: - """ - Computes the weights' memory of the given node. - - Args: - n: A node to compute its weights' memory. - node_nbits: A bit-width in which the node's weights should be quantized. - fw_info: FrameworkInfo object about the specific framework. - - Returns: The total memory of the node's weights when quantized to the given bit-width. - - """ - - origin_node = _get_origin_weights_node(n) - - node_num_weights_params = 0 - for attr in fw_info.get_kernel_op_attributes(origin_node.type): - if attr is not None: - node_num_weights_params += origin_node.get_weights_by_keys(attr).flatten().shape[0] - - return node_num_weights_params * node_nbits / BITS_TO_BYTES - - -def _compute_node_activation_memory(n: BaseNode, node_nbits: int) -> float: - """ - Computes the activation tensor memory of the given node. - - Args: - n: A node to compute its activation tensor memory. - node_nbits: A bit-width in which the node's weights should be quantized. - - Returns: The total memory of the node's activation tensor when quantized to the given bit-width. - - """ - - origin_node = _get_origin_activation_node(n) - node_output_size = origin_node.get_total_output_params() - - return node_output_size * node_nbits / BITS_TO_BYTES - - -class MpRuMetric(Enum): - """ - Defines resource utilization computation functions that can be used to compute bops_utilization for a given target - for a given mp config. The enum values can be used to call a function on a set of arguments. - - WEIGHTS_SIZE - applies the weights_size_utilization function - - ACTIVATION_MAXCUT_SIZE - applies the activation_maxcut_size_utilization function. - - ACTIVATION_OUTPUT_SIZE - applies the activation_output_size_utilization function - - TOTAL_WEIGHTS_ACTIVATION_SIZE - applies the total_weights_activation_utilization function - - BOPS_COUNT - applies the bops_utilization function - - """ - - WEIGHTS_SIZE = partial(weights_size_utilization) - ACTIVATION_MAXCUT_SIZE = partial(activation_maxcut_size_utilization) - ACTIVATION_OUTPUT_SIZE = partial(activation_output_size_utilization) - TOTAL_WEIGHTS_ACTIVATION_SIZE = partial(total_weights_activation_utilization) - BOPS_COUNT = partial(bops_utilization) - - def __call__(self, *args): - return self.value(*args) diff --git a/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py b/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py index 1576c48ad..56ee0e5ca 100644 --- a/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +++ b/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py @@ -16,7 +16,7 @@ import numpy as np from pulp import * from tqdm import tqdm -from typing import Dict, List, Tuple, Callable +from typing import Dict, Tuple from model_compression_toolkit.logger import Logger from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization, RUTarget @@ -218,13 +218,11 @@ def _add_set_of_ru_constraints(search_manager: MixedPrecisionSearchManager, np.sum(indicated_ru_matrix[i], axis=0) + # sum of metric values over all configurations in a row search_manager.min_ru[target][i] for i in range(indicated_ru_matrix.shape[0])]) - # search_manager.compute_ru_functions contains a pair of ru_metric and ru_aggregation for each ru target - # get aggregated ru, considering both configurable and non-configurable nodes - if non_conf_ru_vector is None or len(non_conf_ru_vector) == 0: - aggr_ru = search_manager.compute_ru_functions[target].aggregate_fn(ru_sum_vector) - else: - aggr_ru = search_manager.compute_ru_functions[target].aggregate_fn(np.concatenate([ru_sum_vector, non_conf_ru_vector])) + ru_vec = ru_sum_vector + if non_conf_ru_vector is not None and non_conf_ru_vector.size: + ru_vec = np.concatenate([ru_vec, non_conf_ru_vector]) + aggr_ru = _aggregate_for_lp(ru_vec, target) for v in aggr_ru: if isinstance(v, float): if v > target_resource_utilization_value: @@ -235,6 +233,31 @@ def _add_set_of_ru_constraints(search_manager: MixedPrecisionSearchManager, lp_problem += v <= target_resource_utilization_value +def _aggregate_for_lp(ru_vec, target: RUTarget) -> list: + """ + Aggregate resource utilization values for the LP. + + Args: + ru_vec: a vector of resource utilization values. + target: resource utilization target. + + Returns: + Aggregated resource utilization. + """ + if target == RUTarget.TOTAL: + w = lpSum(v[0] for v in ru_vec) + return [w + v[1] for v in ru_vec] + + if target in [RUTarget.WEIGHTS, RUTarget.BOPS]: + return [lpSum(ru_vec)] + + if target == RUTarget.ACTIVATION: + # for max aggregation, each value constitutes a separate constraint + return list(ru_vec) + + raise ValueError(f'Unexpected target {target}.') + + def _build_layer_to_metrics_mapping(search_manager: MixedPrecisionSearchManager, target_resource_utilization: ResourceUtilization, eps: float = EPS) -> Dict[int, Dict[int, float]]: diff --git a/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py b/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py index 4020d1350..99a50068c 100644 --- a/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +++ b/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py @@ -113,11 +113,9 @@ def __init__(self, # in the new built MP model. self.baseline_model, self.model_mp, self.conf_node2layers = self._build_models() - # Build images batches for inference comparison - self.images_batches = self._get_images_batches(quant_config.num_of_images) - - # Casting images tensors to the framework tensor type. - self.images_batches = [self.fw_impl.to_tensor(img) for img in self.images_batches] + # Build images batches for inference comparison and cat to framework type + images_batches = self._get_images_batches(quant_config.num_of_images) + self.images_batches = [self.fw_impl.to_tensor(img) for img in images_batches] # Initiating baseline_tensors_list since it is not initiated in SensitivityEvaluationManager init. self.baseline_tensors_list = self._init_baseline_tensors_list() diff --git a/model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py b/model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py index a9a1f9d6e..8b3c35597 100644 --- a/model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +++ b/model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py @@ -80,8 +80,8 @@ def greedy_solution_refinement_procedure(mp_solution: List[int], updated_ru.append(node_updated_ru) # filter out new configs that don't hold the resource utilization restrictions - node_filtered_ru = [(node_idx, ru) for node_idx, ru in zip(valid_candidates, updated_ru) if - target_resource_utilization.holds_constraints(ru)] + node_filtered_ru = [(node_idx, ru) for node_idx, ru in zip(valid_candidates, updated_ru) + if target_resource_utilization.is_satisfied_by(ru)] if len(node_filtered_ru) > 0: sorted_by_ru = sorted(node_filtered_ru, key=lambda node_ru: (node_ru[1].total_memory, diff --git a/model_compression_toolkit/core/common/quantization/bit_width_config.py b/model_compression_toolkit/core/common/quantization/bit_width_config.py index e057f0c54..887d828e1 100644 --- a/model_compression_toolkit/core/common/quantization/bit_width_config.py +++ b/model_compression_toolkit/core/common/quantization/bit_width_config.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from dataclasses import dataclass, field from typing import List, Union, Dict from model_compression_toolkit.core.common import Graph @@ -19,6 +20,7 @@ from model_compression_toolkit.logger import Logger +@dataclass class ManualBitWidthSelection: """ Class to encapsulate the manual bit width selection configuration for a specific filter. @@ -27,13 +29,11 @@ class ManualBitWidthSelection: filter (BaseNodeMatcher): The filter used to select nodes for bit width manipulation. bit_width (int): The bit width to be applied to the selected nodes. """ - def __init__(self, - filter: BaseNodeMatcher, - bit_width: int): - self.filter = filter - self.bit_width = bit_width + filter: BaseNodeMatcher + bit_width: int +@dataclass class BitWidthConfig: """ Class to manage manual bit-width configurations. @@ -41,13 +41,7 @@ class BitWidthConfig: Attributes: manual_activation_bit_width_selection_list (List[ManualBitWidthSelection]): A list of ManualBitWidthSelection objects defining manual bit-width configurations. """ - def __init__(self, - manual_activation_bit_width_selection_list: List[ManualBitWidthSelection] = None): - self.manual_activation_bit_width_selection_list = [] if manual_activation_bit_width_selection_list is None else manual_activation_bit_width_selection_list - - def __repr__(self): - # Used for debugging, thus no cover. - return str(self.__dict__) # pragma: no cover + manual_activation_bit_width_selection_list: List[ManualBitWidthSelection] = field(default_factory=list) def set_manual_activation_bit_width(self, filters: Union[List[BaseNodeMatcher], BaseNodeMatcher], diff --git a/model_compression_toolkit/core/runner.py b/model_compression_toolkit/core/runner.py index 1948f28c2..b204b408e 100644 --- a/model_compression_toolkit/core/runner.py +++ b/model_compression_toolkit/core/runner.py @@ -12,44 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from collections import namedtuple import copy - -from typing import Callable, Tuple, Any, List, Dict - import numpy as np +from typing import Callable, Any, List from model_compression_toolkit.core.common import FrameworkInfo +from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation from model_compression_toolkit.core.common.fusion.graph_fuser import GraphFuser - +from model_compression_toolkit.core.common.graph.base_graph import Graph from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_cut import compute_graph_max_cut, \ SchedulerInfo from model_compression_toolkit.core.common.graph.memory_graph.memory_graph import MemoryGraph from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService +from model_compression_toolkit.core.common.mixed_precision.bit_width_setter import set_bit_widths from model_compression_toolkit.core.common.mixed_precision.mixed_precision_candidates_filter import \ filter_candidates_for_mixed_precision +from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_facade import search_bit_width +from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \ + ResourceUtilization +from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \ + ResourceUtilizationCalculator, TargetInclusionCriterion, BitwidthMode from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_data import \ requires_mixed_precision -from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner -from model_compression_toolkit.core.quantization_prep_runner import quantization_preparation_runner -from model_compression_toolkit.logger import Logger -from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation -from model_compression_toolkit.core.common.graph.base_graph import Graph -from model_compression_toolkit.core.common.mixed_precision.bit_width_setter import set_bit_widths -from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization, RUTarget -from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_aggregation_methods import MpRuAggregation -from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_functions_mapping import ru_functions_mapping -from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_methods import MpRuMetric -from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_facade import search_bit_width from model_compression_toolkit.core.common.network_editors.edit_network import edit_network_graph from model_compression_toolkit.core.common.quantization.core_config import CoreConfig -from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities -from model_compression_toolkit.core.common.visualization.final_config_visualizer import \ - WeightsFinalBitwidthConfigVisualizer, \ - ActivationFinalBitwidthConfigVisualizer from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter, \ finalize_bitwidth_in_tb +from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner +from model_compression_toolkit.core.quantization_prep_runner import quantization_preparation_runner +from model_compression_toolkit.logger import Logger +from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import \ + TargetPlatformCapabilities def core_runner(in_model: Any, @@ -88,7 +82,7 @@ def core_runner(in_model: Any, """ # Warn is representative dataset has batch-size == 1 - batch_data = iter(representative_data_gen()).__next__() + batch_data = next(iter(representative_data_gen())) if isinstance(batch_data, list): batch_data = batch_data[0] if batch_data.shape[0] == 1: @@ -96,7 +90,7 @@ def core_runner(in_model: Any, ' consider increasing the batch size') # Checking whether to run mixed precision quantization - if target_resource_utilization is not None: + if target_resource_utilization is not None and target_resource_utilization.is_any_restricted(): if core_config.mixed_precision_config is None: Logger.critical("Provided an initialized target_resource_utilization, that means that mixed precision quantization is " "enabled, but the provided MixedPrecisionQuantizationConfig is None.") @@ -177,7 +171,6 @@ def core_runner(in_model: Any, _set_final_resource_utilization(graph=tg, final_bit_widths_config=bit_widths_config, - ru_functions_dict=ru_functions_mapping, fw_info=fw_info, fw_impl=fw_impl) @@ -215,7 +208,6 @@ def core_runner(in_model: Any, def _set_final_resource_utilization(graph: Graph, final_bit_widths_config: List[int], - ru_functions_dict: Dict[RUTarget, Tuple[MpRuMetric, MpRuAggregation]], fw_info: FrameworkInfo, fw_impl: FrameworkImplementation): """ @@ -225,39 +217,21 @@ def _set_final_resource_utilization(graph: Graph, Args: graph: Graph to compute the resource utilization for. final_bit_widths_config: The final bit-width configuration to quantize the model accordingly. - ru_functions_dict: A mapping between a RUTarget and a pair of resource utilization method and resource utilization aggregation functions. fw_info: A FrameworkInfo object. fw_impl: FrameworkImplementation object with specific framework methods implementation. """ - - final_ru_dict = {} - for ru_target, ru_funcs in ru_functions_dict.items(): - ru_method, ru_aggr = ru_funcs - if ru_target == RUTarget.BOPS: - final_ru_dict[ru_target] = \ - ru_aggr(ru_method(final_bit_widths_config, graph, fw_info, fw_impl, False), False)[0] - else: - non_conf_ru = ru_method([], graph, fw_info, fw_impl) - conf_ru = ru_method(final_bit_widths_config, graph, fw_info, fw_impl) - if len(final_bit_widths_config) > 0 and len(non_conf_ru) > 0: - final_ru_dict[ru_target] = ru_aggr(np.concatenate([conf_ru, non_conf_ru]), False)[0] - elif len(final_bit_widths_config) > 0 and len(non_conf_ru) == 0: - final_ru_dict[ru_target] = ru_aggr(conf_ru, False)[0] - elif len(final_bit_widths_config) == 0 and len(non_conf_ru) > 0: - # final_bit_widths_config == 0 ==> no configurable nodes, - # thus, ru can be computed from non_conf_ru alone - final_ru_dict[ru_target] = ru_aggr(non_conf_ru, False)[0] - else: - # No relevant nodes have been quantized with affect on the given target - since we only consider - # in the model's final size the quantized layers size, this means that the final size for this target - # is zero. - Logger.warning(f"No relevant quantized layers for the ru target {ru_target} were found, the recorded " - f"final ru for this target would be 0.") - final_ru_dict[ru_target] = 0 - - final_ru = ResourceUtilization() - final_ru.set_resource_utilization_by_target(final_ru_dict) - print(final_ru) + w_qcs = {n: n.final_weights_quantization_cfg for n in graph.nodes} + a_qcs = {n: n.final_activation_quantization_cfg for n in graph.nodes} + ru_calculator = ResourceUtilizationCalculator(graph, fw_impl, fw_info) + final_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized, BitwidthMode.QCustom, + act_qcs=a_qcs, w_qcs=w_qcs) + + for ru_target, ru in final_ru.get_resource_utilization_dict().items(): + if ru == 0: + Logger.warning(f"No relevant quantized layers for the resource utilization target {ru_target} were found, " + f"the recorded final ru for this target would be 0.") + + Logger.info(f'Resource utilization (of quantized targets):\n {str(final_ru)}.') graph.user_info.final_resource_utilization = final_ru graph.user_info.mixed_precision_cfg = final_bit_widths_config diff --git a/tests/common_tests/function_tests/test_resource_utilization_object.py b/tests/common_tests/function_tests/test_resource_utilization_object.py index 94ad8a633..f7e3f9374 100644 --- a/tests/common_tests/function_tests/test_resource_utilization_object.py +++ b/tests/common_tests/function_tests/test_resource_utilization_object.py @@ -49,9 +49,5 @@ def test_representation(self): f"BOPS: {4}") def test_ru_hold_constraints(self): - self.assertTrue(default_ru.holds_constraints(custom_ru)) - self.assertFalse(custom_ru.holds_constraints(default_ru)) - self.assertFalse(custom_ru.holds_constraints({RUTarget.WEIGHTS: 1, - RUTarget.ACTIVATION: 1, - RUTarget.TOTAL: 1, - RUTarget.BOPS: 1})) + self.assertTrue(default_ru.is_satisfied_by(custom_ru)) + self.assertFalse(custom_ru.is_satisfied_by(default_ru)) diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision/requires_mixed_precision_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision/requires_mixed_precision_test.py index a82ddd149..47664608d 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision/requires_mixed_precision_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision/requires_mixed_precision_test.py @@ -71,13 +71,9 @@ def get_max_resources_for_model(self, model): attach2keras = AttachTpcToKeras() tpc = attach2keras.attach(tpc, cc.quantization_config.custom_tpc_opset_to_layer) - return compute_resource_utilization_data(in_model=model, - representative_data_gen=self.representative_data_gen(), - core_config=cc, - tpc=tpc, - fw_info=DEFAULT_KERAS_INFO, - fw_impl=KerasImplementation(), - transformed_graph=None, + return compute_resource_utilization_data(in_model=model, representative_data_gen=self.representative_data_gen(), + core_config=cc, tpc=tpc, fw_info=DEFAULT_KERAS_INFO, + fw_impl=KerasImplementation(), transformed_graph=None, mixed_precision_enable=False) def get_quantization_config(self): diff --git a/tests/keras_tests/non_parallel_tests/test_lp_search_bitwidth.py b/tests/keras_tests/non_parallel_tests/test_lp_search_bitwidth.py index 85c41581c..a922b68bd 100644 --- a/tests/keras_tests/non_parallel_tests/test_lp_search_bitwidth.py +++ b/tests/keras_tests/non_parallel_tests/test_lp_search_bitwidth.py @@ -24,8 +24,6 @@ MixedPrecisionQuantizationConfig from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_facade import search_bit_width, \ BitWidthSearchMethod -from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_functions_mapping import \ - RuFunctions from model_compression_toolkit.core.common.mixed_precision.search_methods.linear_programming import \ mp_integer_programming_search from model_compression_toolkit.core.common.model_collector import ModelCollector @@ -65,13 +63,9 @@ def __init__(self, layer_to_ru_mapping): self.compute_metric_fn = lambda x, y=None, z=None: {0: 2, 1: 1, 2: 0}[x[0]] self.min_ru = {RUTarget.WEIGHTS: [[1], [1], [1]], RUTarget.ACTIVATION: [[1], [1], [1]], - RUTarget.TOTAL: [[2], [2], [2]], + RUTarget.TOTAL: [[1, 1], [1, 1], [1, 1]], RUTarget.BOPS: [[1], [1], [1]]} # minimal resource utilization in the tests layer_to_ru_mapping - self.compute_ru_functions = {RUTarget.WEIGHTS: RuFunctions(None, lambda v: [lpSum(v)]), - RUTarget.ACTIVATION: RuFunctions(None, lambda v: [i for i in v]), - RUTarget.TOTAL: RuFunctions(None, lambda v: [lpSum(v[0]) + i for i in v[1]]), - RUTarget.BOPS: RuFunctions(None, lambda v: [lpSum(v)])} self.max_ru_config = [0] self.config_reconstruction_helper = MockReconstructionHelper() self.non_conf_ru_dict = None @@ -83,8 +77,8 @@ def compute_resource_utilization_matrix(self, target): elif target == RUTarget.ACTIVATION: ru_matrix = [np.flip(np.array([ru.activation_memory - 1 for _, ru in self.layer_to_ru_mapping[0].items()]))] elif target == RUTarget.TOTAL: - ru_matrix = [np.flip(np.array([ru.weights_memory - 1 for _, ru in self.layer_to_ru_mapping[0].items()])), - np.flip(np.array([ru.activation_memory - 1 for _, ru in self.layer_to_ru_mapping[0].items()]))] + ru_matrix = [[np.flip(np.array([ru.weights_memory - 1 for _, ru in self.layer_to_ru_mapping[0].items()])), + np.flip(np.array([ru.activation_memory - 1 for _, ru in self.layer_to_ru_mapping[0].items()]))]] elif target == RUTarget.BOPS: ru_matrix = [np.flip(np.array([ru.bops - 1 for _, ru in self.layer_to_ru_mapping[0].items()]))] else: diff --git a/tests/pytorch_tests/model_tests/feature_models/multi_head_attention_test.py b/tests/pytorch_tests/model_tests/feature_models/multi_head_attention_test.py index 503d6751a..0ee734eb2 100644 --- a/tests/pytorch_tests/model_tests/feature_models/multi_head_attention_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/multi_head_attention_test.py @@ -22,6 +22,7 @@ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO +from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model from tests.pytorch_tests.model_tests.base_pytorch_test import BasePytorchTest """ @@ -55,6 +56,22 @@ def create_inputs_shape(self): [self.val_batch_size] + list(self.key_input_shape), [self.val_batch_size] + list(self.value_input_shape)] + def get_tpc(self): + tpc = { + 'no_quantization': generate_test_tp_model({ + 'weights_n_bits': 32, + 'activation_n_bits': 32, + 'enable_weights_quantization': False, + 'enable_activation_quantization': False + }) + } + if self.num_heads < 5: + tpc['all_4bit'] = generate_test_tp_model({'weights_n_bits': 4, + 'activation_n_bits': 4, + 'enable_weights_quantization': True, + 'enable_activation_quantization': True}) + return tpc + class MHANet(nn.Module): # This network based on single MHA layer diff --git a/tests_pytest/core/__init__.py b/tests_pytest/core/__init__.py new file mode 100644 index 000000000..5397dea24 --- /dev/null +++ b/tests_pytest/core/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/tests_pytest/core/common/__init__.py b/tests_pytest/core/common/__init__.py new file mode 100644 index 000000000..5397dea24 --- /dev/null +++ b/tests_pytest/core/common/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/tests_pytest/core/common/mixed_precision/__init__.py b/tests_pytest/core/common/mixed_precision/__init__.py new file mode 100644 index 000000000..5397dea24 --- /dev/null +++ b/tests_pytest/core/common/mixed_precision/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/tests_pytest/core/common/mixed_precision/resource_utilization_tools/__init__.py b/tests_pytest/core/common/mixed_precision/resource_utilization_tools/__init__.py new file mode 100644 index 000000000..5397dea24 --- /dev/null +++ b/tests_pytest/core/common/mixed_precision/resource_utilization_tools/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ==============================================================================