diff --git a/model_compression_toolkit/core/common/graph/base_node.py b/model_compression_toolkit/core/common/graph/base_node.py index d5b16132b..98e053940 100644 --- a/model_compression_toolkit/core/common/graph/base_node.py +++ b/model_compression_toolkit/core/common/graph/base_node.py @@ -151,11 +151,24 @@ def is_weights_quantization_enabled(self, attr_name: str) -> bool: return False def is_configurable_weight(self, attr_name: str) -> bool: - """ Checks whether the specific weight has a configurable quantization. """ + """ + Checks whether the specific weight attribute has a configurable quantization. + + Args: + attr_name: weight attribute name. + + Returns: + Whether the weight attribute is configurable. + """ return self.is_weights_quantization_enabled(attr_name) and not self.is_all_weights_candidates_equal(attr_name) - def has_configurable_activation(self): - """ Checks whether the activation has a configurable quantization. """ + def has_configurable_activation(self) -> bool: + """ + Checks whether the activation has a configurable quantization. + + Returns: + Whether the activation has a configurable quantization. + """ return self.is_activation_quantization_enabled() and not self.is_all_activation_candidates_equal() def __repr__(self): diff --git a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py index 3f17c51b8..047745ca7 100644 --- a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +++ b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py @@ -28,7 +28,7 @@ from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \ ResourceUtilizationCalculator, TargetInclusionCriterion, BitwidthMode from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_methods import \ - MixPrecisionRUHelper + MixedPrecisionRUHelper from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation from model_compression_toolkit.logger import Logger @@ -68,7 +68,7 @@ def __init__(self, self._cuts = None self.ru_metrics = target_resource_utilization.get_restricted_metrics() - self.ru_helper = MixPrecisionRUHelper(graph, fw_info, fw_impl) + self.ru_helper = MixedPrecisionRUHelper(graph, fw_info, fw_impl) self.target_resource_utilization = target_resource_utilization self.min_ru_config = self.graph.get_min_candidates_config(fw_info) self.max_ru_config = self.graph.get_max_candidates_config(fw_info) @@ -207,10 +207,9 @@ def compute_resource_utilization_for_config(self, config: List[int]) -> Resource """ act_qcs, w_qcs = self.ru_helper.get_configurable_qcs(config) - ru = self.ru_helper.ru_calculator.compute_resource_utilization(target_criterion=TargetInclusionCriterion.AnyQuantized, - bitwidth_mode=BitwidthMode.MpCustom, - act_qcs=act_qcs, - w_qcs=w_qcs) + ru = self.ru_helper.ru_calculator.compute_resource_utilization( + target_criterion=TargetInclusionCriterion.AnyQuantized, bitwidth_mode=BitwidthMode.QCustom, act_qcs=act_qcs, + w_qcs=w_qcs) return ru def finalize_distance_metric(self, layer_to_metrics_mapping: Dict[int, Dict[int, float]]): diff --git a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py index d59ef5a6d..779b08a0a 100644 --- a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +++ b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py @@ -45,7 +45,6 @@ class ResourceUtilization: total_memory: The sum of model's activation and weights memory in bytes. bops: The total bit-operations in the model. """ - # TODO the user facade actually computes size, not memory. Do we want to change fields names? weights_memory: float = np.inf activation_memory: float = np.inf total_memory: float = np.inf @@ -93,9 +92,3 @@ def get_restricted_metrics(self) -> Set[RUTarget]: def is_any_restricted(self) -> bool: return bool(self.get_restricted_metrics()) - - def __repr__(self): - return f"Weights_memory: {self.weights_memory}, " \ - f"Activation_memory: {self.activation_memory}, " \ - f"Total_memory: {self.total_memory}, " \ - f"BOPS: {self.bops}" diff --git a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py index dc1bcf59c..d96ce470f 100644 --- a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +++ b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py @@ -1,4 +1,4 @@ -# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ from copy import deepcopy from enum import Enum, auto from functools import lru_cache -from typing import Dict, NamedTuple, Optional, Tuple, List, Iterable, Union, Literal +from typing import Dict, NamedTuple, Optional, Tuple, List, Iterable, Union, Literal, Sequence from model_compression_toolkit.constants import FLOAT_BITWIDTH from model_compression_toolkit.core import FrameworkInfo @@ -36,19 +36,22 @@ class BitwidthMode(Enum): """ Bit-width configuration for resource utilization computation. - Size: tensors sizes. - Float: float. - MpMax: maximal bit-width mixed-precision configuration. - MpMin: minimal bit-width mixed-precision configuration. - MpCustom: explicitly provided bit-width configuration. - SpDefault: single-precision configuration (for non-configurable quantization). + Float: original un-quantized configuration. Assumed to be 32-bit float. + QMaxBit: maximal bit-width configurations. Assigns each node its maximal available precision according to the + target platform capabilities. + QMinBit: minimal bit-width configuration. Assigns each node its minimal available precision according to the + target platform capabilities. + QCustom: explicitly provided bit-width configuration. + QDefaultSP: default single-precision bit-width configuration. Can be used either in a single-precision mode, + or along with TargetInclusionCriterion.QNonConfigurable, which computes the resource utilization only for + single-precision nodes. To compute custom single precision configuration, use QCustom. """ - Size = auto() Float = auto() - MpMax = auto() - MpMin = auto() - MpCustom = auto() - SpDefault = auto() + Q8Bit = auto() + QMaxBit = auto() + QMinBit = auto() + QCustom = auto() + QDefaultSP = auto() class TargetInclusionCriterion(Enum): @@ -78,21 +81,8 @@ class Utilization(NamedTuple): size: int bytes: Optional[float] - def by_bit_mode(self, bitwidth_mode: BitwidthMode) -> Union[int, float]: - """ Retrieve value corresponding to the bit-width mode. """ - if bitwidth_mode == BitwidthMode.Size: - return self.size - return self.bytes - - @staticmethod - def zero_utilization(bitwidth_mode: BitwidthMode) -> 'Utilization': - """ Construct zero utilization object. """ - return Utilization(0, bytes=None if bitwidth_mode == BitwidthMode.Size else 0) - def __add__(self, other: 'Utilization') -> 'Utilization': - self._validate_pair(self, other) - bytes_ = None if self.bytes is None else (self.bytes + other.bytes) - return Utilization(self.size + other.size, bytes_) + return Utilization(self.size + other.size, self.bytes + other.bytes) def __radd__(self, other: Union['Utilization', Literal[0]]): # Needed for sum (with default start_value=0). @@ -101,23 +91,12 @@ def __radd__(self, other: Union['Utilization', Literal[0]]): return self + other def __gt__(self, other: 'Utilization'): - # Needed for max. Compare by bytes, if not defined then by size. - self._validate_pair(self, other) - if self.bytes is not None: - return self.bytes > other.bytes - return self.size > other.size + # Needed for max. Compare by bytes. + return self.bytes > other.bytes def __lt__(self, other: 'Utilization'): - self._validate_pair(self, other) - # Needed for min. Compare by bytes, if not defined then by size. - if self.bytes is not None: - return self.bytes < other.bytes - return self.size < other.size - - @staticmethod - def _validate_pair(u1, u2): - if [u1.bytes, u2.bytes].count(None) == 1: - raise ValueError('bytes field must be set either by both or by none of the objects.') + # Needed for min. Compare by bytes. + return self.bytes < other.bytes class AggregationMethod(Enum): @@ -139,15 +118,14 @@ def __call__(self, *args, **kwarg): } -_bitwidth_mode_fn = { - BitwidthMode.MpMax: max, - BitwidthMode.MpMin: min -} - - class ResourceUtilizationCalculator: """ Resource utilization calculator. """ + _bitwidth_mode_fn = { + BitwidthMode.QMaxBit: max, + BitwidthMode.QMinBit: min, + } + def __init__(self, graph: Graph, fw_impl: FrameworkImplementation, fw_info: FrameworkInfo): self.graph = graph self.fw_impl = fw_impl @@ -167,49 +145,51 @@ def compute_resource_utilization(self, bitwidth_mode: BitwidthMode, act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]] = None, w_qcs: Optional[Dict[BaseNode, NodeWeightsQuantizationConfig]] = None, - metrics: Iterable[RUTarget] = None) -> ResourceUtilization: + ru_targets: Iterable[RUTarget] = None) -> ResourceUtilization: """ - Compute total resource utilization. + Compute network's resource utilization. Args: target_criterion: criterion to include targets for computation (applies to weights, activation). bitwidth_mode: bit-width mode for computation. - act_qcs: activation quantization candidates for custom bit-width mode. Must provide configuration for all - configurable activations. - w_qcs: weights quantization candidates for custom bit-width mode. Must provide configuration for all - configurable weights. - metrics: metrics to include for computation. If None, all metrics are calculated. + act_qcs: custom activations quantization configuration. Should be provided for custom bit mode only. + In custom mode, must provide configuration for all configurable activations. For non-configurable + activations, if not provided, the default configuration will be extracted from the node. + w_qcs: custom weights quantization configuration. Should be provided for custom bit mode only. + In custom mode, must provide configuration for all configurable weights. For non-configurable + weights, if not provided, the default configuration will be extracted from the node. + ru_targets: metrics to include for computation. If None, all metrics are calculated. Returns: Resource utilization object. """ - metrics = set(metrics) if metrics else set(RUTarget) + ru_targets = set(ru_targets) if ru_targets else set(RUTarget) w_total, a_total = None, None - if {RUTarget.WEIGHTS, RUTarget.TOTAL}.intersection(metrics): + if {RUTarget.WEIGHTS, RUTarget.TOTAL}.intersection(ru_targets): w_total, *_ = self.compute_weights_utilization(target_criterion, bitwidth_mode, w_qcs) elif w_qcs is not None: # pragma: no cover raise ValueError('Weight configuration passed but no relevant metric requested.') - if act_qcs and not {RUTarget.ACTIVATION, RUTarget.TOTAL}.intersection(metrics): # pragma: no cover + if act_qcs and not {RUTarget.ACTIVATION, RUTarget.TOTAL}.intersection(ru_targets): # pragma: no cover raise ValueError('Activation configuration passed but no relevant metric requested.') - if RUTarget.ACTIVATION in metrics: - a_total, *_ = self.compute_activations_utilization(target_criterion, bitwidth_mode, act_qcs) + if RUTarget.ACTIVATION in ru_targets: + a_total = self.compute_activations_utilization(target_criterion, bitwidth_mode, act_qcs) ru = ResourceUtilization() - if RUTarget.WEIGHTS in metrics: + if RUTarget.WEIGHTS in ru_targets: ru.weights_memory = w_total - if RUTarget.ACTIVATION in metrics: + if RUTarget.ACTIVATION in ru_targets: ru.activation_memory = a_total - if RUTarget.TOTAL in metrics: + if RUTarget.TOTAL in ru_targets: # TODO use maxcut act_tensors_total, *_ = self.compute_activation_tensors_utilization(target_criterion, bitwidth_mode, act_qcs) ru.total_memory = w_total + act_tensors_total - if RUTarget.BOPS in metrics: + if RUTarget.BOPS in ru_targets: ru.bops, _ = self.compute_bops(target_criterion=target_criterion, bitwidth_mode=bitwidth_mode, act_qcs=act_qcs, w_qcs=w_qcs) - assert ru.get_restricted_metrics() == set(metrics), 'Mismatch between the number of requested and computed metrics' + assert ru.get_restricted_metrics() == set(ru_targets), 'Mismatch between the number of requested and computed metrics' return ru def compute_weights_utilization(self, @@ -223,15 +203,18 @@ def compute_weights_utilization(self, Args: target_criterion: criterion to include targets for computation. bitwidth_mode: bit-width mode for computation. - w_qcs: weights quantization config per node for the custom bit mode. Must provide configuration for all - configurable weights. + w_qcs: custom weights quantization configuration. Should be provided for custom bit mode only. + In custom mode, must provide configuration for all configurable weights. For non-configurable + weights, if not provided, the default configuration will be extracted from the node. Returns: - - Total weights utilization. - - Per node total utilization. Dict keys are nodes in a topological order. - - Detailed per node per weight utilization. Dict keys are nodes in a topological order. + - Total weights utilization of the network. + - Per node total weights utilization. Dict keys are nodes in a topological order. + - Detailed per node per weight attribute utilization. Dict keys are nodes in a topological order. """ nodes = self._get_target_weight_nodes(target_criterion, include_reused=False) + if not nodes: + return 0, {}, {} util_per_node: Dict[BaseNode, Utilization] = {} util_per_node_per_weight = {} @@ -244,8 +227,8 @@ def compute_weights_utilization(self, util_per_node_per_weight[n] = per_weight_util aggregate_fn = ru_target_aggregation_fn[RUTarget.WEIGHTS] - total_util = aggregate_fn(u.by_bit_mode(bitwidth_mode) for u in util_per_node.values()) - return total_util, util_per_node, util_per_node_per_weight + total_util = aggregate_fn(util_per_node.values()) + return total_util.bytes, util_per_node, util_per_node_per_weight def compute_node_weights_utilization(self, n: BaseNode, @@ -260,34 +243,46 @@ def compute_node_weights_utilization(self, n: node. target_criterion: criterion to include weights for computation. bitwidth_mode: bit-width mode for the computation. - qc: weight quantization config for the custom bit mode computation. Must provide configuration for all - configurable weights. + qc: custom weights quantization configuration. Should be provided for custom bit mode only. + In custom mode, must provide configuration for all configurable weights. For non-configurable + weights, if not provided, the default configuration will be extracted from the node. Returns: - - Total utilization. - - Detailed per weight utilization. + - Node's total weights utilization. + - Detailed per weight attribute utilization. """ weight_attrs = self._get_target_weight_attrs(n, target_criterion) if not weight_attrs: # pragma: no cover - return Utilization.zero_utilization(bitwidth_mode, ), {} + return Utilization(0, 0), {} attr_util = {} for attr in weight_attrs: size = self._params_cnt[n][attr] - bytes_ = None - if bitwidth_mode != BitwidthMode.Size: - nbits = self._get_weight_nbits(n, attr, bitwidth_mode, qc) - bytes_ = size * nbits / 8 + nbits = self._get_weight_nbits(n, attr, bitwidth_mode, qc) + bytes_ = size * nbits / 8 attr_util[attr] = Utilization(size, bytes_) - total_weights = sum(attr_util.values()) + total_weights: Utilization = sum(attr_util.values()) # type: ignore return total_weights, attr_util def compute_activations_utilization(self, target_criterion: TargetInclusionCriterion, bitwidth_mode: BitwidthMode, act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]] = None): - return self.compute_cut_activation_utilization(target_criterion, bitwidth_mode, act_qcs) + """ + Compute total activations utilization in the graph. + + Args: + target_criterion: criterion to include weights for computation. + bitwidth_mode: bit-width mode for the computation. + act_qcs: custom activations quantization configuration. Should be provided for custom bit mode only. + In custom mode, must provide configuration for all configurable activations. For non-configurable + activations, if not provided, the default configuration will be extracted from the node. + + Returns: + Total activation utilization of the network. + """ + return self.compute_cut_activation_utilization(target_criterion, bitwidth_mode, act_qcs)[0] def compute_cut_activation_utilization(self, target_criterion: TargetInclusionCriterion, @@ -295,18 +290,19 @@ def compute_cut_activation_utilization(self, act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]]) \ -> Tuple[float, Dict[Cut, Utilization], Dict[Cut, Dict[BaseNode, Utilization]]]: """ - Calculate graph activation cuts utilization. + Compute graph activation cuts utilization. Args: target_criterion: criterion to include weights for computation. bitwidth_mode: bit-width mode for the computation. - act_qcs: custom configuration for the custom bit mode. Must provide configuration for all configurable - activations. + act_qcs: custom activations quantization configuration. Should be provided for custom bit mode only. + In custom mode, must provide configuration for all configurable activations. For non-configurable + activations, if not provided, the default configuration will be extracted from the node. Returns: - - Total utilization. - - Total utilization per cut. - - Detailed utilization per cut per node. + - Total activation utilization of the network. + - Total activation utilization per cut. + - Detailed activation utilization per cut per node. """ if target_criterion != TargetInclusionCriterion.AnyQuantized: # pragma: no cover raise NotImplementedError('Computing MaxCut activation utilization is currently only supported for quantized targets.') @@ -339,8 +335,8 @@ def compute_cut_activation_utilization(self, util_per_cut[cut] = sum(util_per_cut_per_node[cut].values()) # type: ignore aggregate_fn = ru_target_aggregation_fn[RUTarget.ACTIVATION] - total_util = aggregate_fn(u.by_bit_mode(bitwidth_mode) for u in util_per_cut.values()) - return total_util, util_per_cut, util_per_cut_per_node + total_util = aggregate_fn(util_per_cut.values()) + return total_util.bytes, util_per_cut, util_per_cut_per_node def compute_activation_tensors_utilization(self, target_criterion: TargetInclusionCriterion, @@ -354,15 +350,19 @@ def compute_activation_tensors_utilization(self, Args: target_criterion: criterion to include weights for computation. bitwidth_mode: bit-width mode for the computation. - act_qcs: custom configuration for the custom bit mode. Must provide configuration for all configurable - activations. + act_qcs: custom activations quantization configuration. Should be provided for custom bit mode only. + In custom mode, must provide configuration for all configurable activations. For non-configurable + activations, if not provided, the default configuration will be extracted from the node. include_reused: whether to include reused nodes. Returns: - - Total activation utilization. + - Total activation utilization of the network. - Detailed utilization per node. Dict keys are nodes in a topological order. """ nodes = self._get_target_activation_nodes(target_criterion, include_reused=include_reused) + if not nodes: + return 0, {} + util_per_node: Dict[BaseNode, Utilization] = {} for n in self._topo_sort(nodes): qc = act_qcs.get(n) if act_qcs else None @@ -370,8 +370,8 @@ def compute_activation_tensors_utilization(self, util_per_node[n] = util aggregate_fn = ru_target_aggregation_fn[RUTarget.ACTIVATION] - total_util = aggregate_fn(u.by_bit_mode(bitwidth_mode) for u in util_per_node.values()) - return total_util, util_per_node + total_util = aggregate_fn(util_per_node.values()) + return total_util.bytes, util_per_node def compute_node_activation_tensor_utilization(self, n: BaseNode, @@ -385,21 +385,20 @@ def compute_node_activation_tensor_utilization(self, n: node. target_criterion: criterion to include nodes for computation. If None, will skip the check. bitwidth_mode: bit-width mode for the computation. - qc: activation quantization config for the custom bit mode. Must be provided for a configurable activation. - + qc: activation quantization config for the node. Should be provided only in custom bit mode. + In custom mode, must be provided if the activation is configurable. For non-configurable activation, if + not passed, the default configuration will be extracted from the node. Returns: Node's activation utilization. """ if target_criterion: nodes = self._get_target_activation_nodes(target_criterion=target_criterion, include_reused=True, nodes=[n]) if not nodes: # pragma: no cover - return Utilization.zero_utilization(bitwidth_mode) + return Utilization(0, 0) size = self._act_tensors_size[n] - bytes_ = None - if bitwidth_mode != BitwidthMode.Size: - nbits = self._get_activation_nbits(n, bitwidth_mode, qc) - bytes_ = size * nbits / 8 + nbits = self._get_activation_nbits(n, bitwidth_mode, qc) + bytes_ = size * nbits / 8 return Utilization(size, bytes_) def compute_bops(self, @@ -410,25 +409,27 @@ def compute_bops(self, -> Tuple[int, Dict[BaseNode, int]]: """ Compute bit operations based on nodes with kernel. + Note that 'target_criterion' applies to weights, and BOPS are computed for the selected nodes regardless + of the input activation quantization or lack thereof. Args: target_criterion: criterion to include nodes for computation. bitwidth_mode: bit-width mode for computation. - act_qcs: activation quantization candidates for custom bit-width mode. Must provide configuration for all - configurable activations. - w_qcs: weights quantization candidates for custom bit-width mode. Must provide configuration for all - configurable weights. + act_qcs: custom activations quantization configuration. Should be provided for custom bit mode only. + In custom mode, must provide configuration for all configurable activations. For non-configurable + activations, if not provided, the default configuration will be extracted from the node. + w_qcs: custom weights quantization configuration. Should be provided for custom bit mode only. + In custom mode, must provide configuration for all configurable weights. For non-configurable + weights, if not provided, the default configuration will be extracted from the node. Returns: - - Total BOPS count. + - Total BOPS count of the network. - Detailed BOPS count per node. """ - # currently we compute bops for all nodes with quantized weights, regardless of whether the input - # activation is quantized. if target_criterion != TargetInclusionCriterion.AnyQuantized: # pragma: no cover raise NotImplementedError('BOPS computation is currently only supported for quantized targets.') - nodes = [n for n in self.graph.nodes if n.has_kernel_weight_to_quantize(self.fw_info)] + nodes = self._get_target_weight_nodes(target_criterion, include_reused=True) nodes_bops = {} for n in nodes: w_qc = w_qcs.get(n) if w_qcs else None @@ -441,23 +442,25 @@ def compute_node_bops(self, n: BaseNode, bitwidth_mode: BitwidthMode, act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]] = None, - w_qc: Optional[NodeWeightsQuantizationConfig] = None) -> int: + w_qc: Optional[NodeWeightsQuantizationConfig] = None) -> Union[float, int]: """ Compute Bit Operations of a node. Args: n: node. bitwidth_mode: bit-width mode for the computation. - act_qcs: nodes activation quantization configuration for the custom bit mode. Must provide configuration for all - configurable activations. - w_qc: weights quantization config for the node for the custom bit mode. Must provide configuration for all - configurable weights. + act_qcs: custom activations quantization configuration. Should be provided for custom bit mode only. + In custom mode, must provide configuration for all configurable activations. For non-configurable + activations, if not provided, the default configuration will be extracted from the node. + w_qc: weights quantization config for the node. Should be provided only in custom bit mode. + Must provide configuration for all configurable weights. For non-configurable weights, will use the + provided configuration if found, or extract the default configuration from the node otherwise. Returns: - BOPS count. + Node's BOPS count. """ node_mac = self.fw_impl.get_node_mac_operations(n, self.fw_info) - if node_mac == 0 or bitwidth_mode == BitwidthMode.Size: # pragma: no cover + if node_mac == 0: # pragma: no cover return node_mac incoming_edges = self.graph.incoming_edges(n, sort_by_attr=EDGE_SINK_INDEX) @@ -549,8 +552,16 @@ def _get_target_weight_attrs(self, n: BaseNode, target_criterion: TargetInclusio raise ValueError(f'Unknown {target_criterion}') return weight_attrs - def _topo_sort(self, nodes): - """ Sort nodes in a topological order (based on graph's nodes). """ + def _topo_sort(self, nodes: Sequence[BaseNode]) -> List[BaseNode]: + """ + Sort nodes in a topological order (based on graph's nodes). + + Args: + nodes: nodes to sort. + + Returns: + Nodes in topological order. + """ graph_topo_nodes = self.graph.get_topo_sorted_nodes() topo_nodes = [n for n in graph_topo_nodes if n in nodes] if len(topo_nodes) != len(nodes): @@ -586,8 +597,9 @@ def _get_target_activation_nodes(self, nodes = [n for n in nodes if not n.reuse] return nodes - @staticmethod - def _get_activation_nbits(n: BaseNode, + @classmethod + def _get_activation_nbits(cls, + n: BaseNode, bitwidth_mode: BitwidthMode, act_qc: Optional[NodeActivationQuantizationConfig]) -> int: """ @@ -596,16 +608,15 @@ def _get_activation_nbits(n: BaseNode, Args: n: node. bitwidth_mode: bit-width mode for computation. - act_qc: quantization candidate for the custom bit mode. Must be provided for a configurable activation. + act_qc: activation quantization config for the node. Should be provided only in custom bit mode. + In custom mode, must be provided if the activation is configurable. For non-configurable activation, if + not passed, the default configuration will be extracted from the node. Returns: Activation bit-width. """ - if bitwidth_mode == BitwidthMode.Size: - raise ValueError(f'nbits is not defined for {bitwidth_mode}.') - if act_qc: - if bitwidth_mode != BitwidthMode.MpCustom or not n.is_activation_quantization_enabled(): + if bitwidth_mode != BitwidthMode.QCustom or not n.is_activation_quantization_enabled(): raise ValueError( f'Activation config is not expected for non-custom bit mode or for un-quantized activation.' f'Mode: {bitwidth_mode}, quantized activation: {n.is_activation_quantization_enabled()}' @@ -616,11 +627,14 @@ def _get_activation_nbits(n: BaseNode, if bitwidth_mode == BitwidthMode.Float or not n.is_activation_quantization_enabled(): return FLOAT_BITWIDTH - if bitwidth_mode in _bitwidth_mode_fn: + if bitwidth_mode == BitwidthMode.Q8Bit: + return 8 + + if bitwidth_mode in cls._bitwidth_mode_fn: candidates_nbits = [c.activation_quantization_cfg.activation_n_bits for c in n.candidates_quantization_cfg] - return _bitwidth_mode_fn[bitwidth_mode](candidates_nbits) + return cls._bitwidth_mode_fn[bitwidth_mode](candidates_nbits) - if bitwidth_mode in [BitwidthMode.MpCustom, BitwidthMode.SpDefault]: + if bitwidth_mode in [BitwidthMode.QCustom, BitwidthMode.QDefaultSP]: qcs = n.get_unique_activation_candidates() if len(qcs) != 1: raise ValueError(f'Could not retrieve the activation quantization candidate for node {n.name} ' @@ -629,8 +643,12 @@ def _get_activation_nbits(n: BaseNode, raise ValueError(f'Unknown mode {bitwidth_mode}') - @staticmethod - def _get_weight_nbits(n, w_attr: str, bitwidth_mode: BitwidthMode, w_qc: Optional[NodeWeightsQuantizationConfig]) -> int: + @classmethod + def _get_weight_nbits(cls, + n: BaseNode, + w_attr: str, + bitwidth_mode: BitwidthMode, + w_qc: Optional[NodeWeightsQuantizationConfig]) -> int: """ Get the bit-width of a specific weight of a node according to the requested bit-width mode. @@ -638,17 +656,15 @@ def _get_weight_nbits(n, w_attr: str, bitwidth_mode: BitwidthMode, w_qc: Optiona n: node. w_attr: weight attribute. bitwidth_mode: bit-width mode for the computation. - w_qc: weights quantization config for the node for the custom bit mode. Must provide configuration for all - configurable weights. + w_qc: weights quantization config for the node. Should be provided only in custom bit mode. + Must provide configuration for all configurable weights. For non-configurable weights, will use the + provided configuration if found, or extract the default configuration from the node otherwise. Returns: Weight bit-width. """ - if bitwidth_mode == BitwidthMode.Size: - raise ValueError(f'nbits is not defined for {bitwidth_mode}.') - if w_qc and w_qc.has_attribute_config(w_attr): - if bitwidth_mode != BitwidthMode.MpCustom or not n.is_weights_quantization_enabled(w_attr): + if bitwidth_mode != BitwidthMode.QCustom or not n.is_weights_quantization_enabled(w_attr): raise ValueError('Weight config is not expected for non-custom bit mode or for un-quantized weight.' f'Bit mode: {bitwidth_mode}, quantized attr {w_attr}: ' f'{n.is_weights_quantization_enabled(w_attr)}') @@ -659,12 +675,15 @@ def _get_weight_nbits(n, w_attr: str, bitwidth_mode: BitwidthMode, w_qc: Optiona if bitwidth_mode == BitwidthMode.Float or not n.is_weights_quantization_enabled(w_attr): return FLOAT_BITWIDTH + if bitwidth_mode == BitwidthMode.Q8Bit: + return 8 + node_qcs = n.get_unique_weights_candidates(w_attr) w_qcs = [qc.weights_quantization_cfg.get_attr_config(w_attr) for qc in node_qcs] - if bitwidth_mode in _bitwidth_mode_fn: - return _bitwidth_mode_fn[bitwidth_mode]([qc.weights_n_bits for qc in w_qcs]) + if bitwidth_mode in cls._bitwidth_mode_fn: + return cls._bitwidth_mode_fn[bitwidth_mode]([qc.weights_n_bits for qc in w_qcs]) - if bitwidth_mode in [BitwidthMode.MpCustom, BitwidthMode.SpDefault]: + if bitwidth_mode in [BitwidthMode.QCustom, BitwidthMode.QDefaultSP]: # if configuration was not passed and the weight has only one candidate, use it if len(w_qcs) != 1: raise ValueError(f'Could not retrieve the quantization candidate for attr {w_attr} of node {n.name} ' diff --git a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py index 52d3b8683..03cda3961 100644 --- a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +++ b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py @@ -33,7 +33,7 @@ def compute_resource_utilization_data(in_model: Any, fw_info: FrameworkInfo, fw_impl: FrameworkImplementation, transformed_graph: Graph = None, - mixed_precision_enabled: bool = True) -> ResourceUtilization: + mixed_precision_enable: bool = True) -> ResourceUtilization: """ Compute Resource Utilization information that can be relevant for defining target ResourceUtilization for mixed precision search. Calculates maximal activation tensor size, the sum of the model's weight parameters and the total memory combining both weights @@ -49,7 +49,7 @@ def compute_resource_utilization_data(in_model: Any, fw_impl: FrameworkImplementation object with a specific framework methods implementation. transformed_graph: An internal graph representation of the input model. Defaults to None. If no graph is provided, a graph will be constructed using the specified model. - mixed_precision_enabled: Indicates if mixed precision is enabled, defaults to True. + mixed_precision_enable: Indicates if mixed precision is enabled, defaults to True. If disabled, computes resource utilization using base quantization configurations across all layers. @@ -68,13 +68,12 @@ def compute_resource_utilization_data(in_model: Any, fw_impl, tpc, bit_width_config=core_config.bit_width_config, - mixed_precision_enable=mixed_precision_enabled, + mixed_precision_enable=mixed_precision_enable, running_gptq=False) ru_calculator = ResourceUtilizationCalculator(transformed_graph, fw_impl, fw_info) - ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized, - BitwidthMode.Size, - metrics=set(RUTarget) - {RUTarget.BOPS}) + ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized, BitwidthMode.Q8Bit, + ru_targets=set(RUTarget) - {RUTarget.BOPS}) ru.bops, _ = ru_calculator.compute_bops(TargetInclusionCriterion.AnyQuantized, BitwidthMode.Float) return ru @@ -118,9 +117,8 @@ def requires_mixed_precision(in_model: Any, running_gptq=False) ru_calculator = ResourceUtilizationCalculator(transformed_graph, fw_impl, fw_info) - max_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized, - BitwidthMode.MpMax, - metrics=target_resource_utilization.get_restricted_metrics()) + max_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized, BitwidthMode.QMaxBit, + ru_targets=target_resource_utilization.get_restricted_metrics()) return not target_resource_utilization.is_satisfied_by(max_ru) diff --git a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py index a350aa0b9..b3605089f 100644 --- a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py +++ b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_methods.py @@ -20,8 +20,7 @@ from model_compression_toolkit.core.common import Graph, BaseNode from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation from model_compression_toolkit.core.common.graph.memory_graph.cut import Cut -from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode, \ - VirtualSplitWeightsNode, VirtualSplitActivationNode +from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \ RUTarget from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \ @@ -33,7 +32,9 @@ # TODO take into account Virtual nodes. Are candidates defined with respect to virtual or original nodes? # Can we use the virtual graph only for bops and the original graph for everything else? -class MixPrecisionRUHelper: +class MixedPrecisionRUHelper: + """ Helper class for resource utilization computations for mixed precision optimization. """ + def __init__(self, graph: Graph, fw_info: FrameworkInfo, fw_impl: FrameworkImplementation): self.graph = graph self.fw_info = fw_info @@ -42,7 +43,10 @@ def __init__(self, graph: Graph, fw_info: FrameworkInfo, fw_impl: FrameworkImple def compute_utilization(self, ru_targets: Set[RUTarget], mp_cfg: Optional[List[int]]) -> Dict[RUTarget, np.ndarray]: """ - Compute utilization of requested targets for a specific configuration + Compute utilization of requested targets for a specific configuration in the format expected by LP problem + formulation, namely an array of ru values corresponding to graph's configurable nodes in the topological order. + For activation target, the array contains values for activation cuts in unspecified order (as long as it is + consistent between configurations). Args: ru_targets: resource utilization targets to compute. @@ -112,10 +116,10 @@ def _weights_utilization(self, w_qcs: Optional[Dict[BaseNode, NodeWeightsQuantiz """ if w_qcs: target_criterion = TargetInclusionCriterion.QConfigurable - bitwidth_mode = BitwidthMode.MpCustom + bitwidth_mode = BitwidthMode.QCustom else: target_criterion = TargetInclusionCriterion.QNonConfigurable - bitwidth_mode = BitwidthMode.SpDefault + bitwidth_mode = BitwidthMode.QDefaultSP _, nodes_util, _ = self.ru_calculator.compute_weights_utilization(target_criterion=target_criterion, bitwidth_mode=bitwidth_mode, @@ -136,7 +140,7 @@ def _activation_maxcut_utilization(self, act_qcs: Optional[Dict[BaseNode, NodeAc """ if act_qcs: _, cuts_util, _ = self.ru_calculator.compute_cut_activation_utilization(TargetInclusionCriterion.AnyQuantized, - bitwidth_mode=BitwidthMode.MpCustom, + bitwidth_mode=BitwidthMode.QCustom, act_qcs=act_qcs) cuts_util = {c: u.bytes for c, u in cuts_util.items()} return cuts_util @@ -158,10 +162,10 @@ def _activation_tensor_utilization(self, act_qcs: Optional[Dict[BaseNode, NodeAc """ if act_qcs: target_criterion = TargetInclusionCriterion.QConfigurable - bitwidth_mode = BitwidthMode.MpCustom + bitwidth_mode = BitwidthMode.QCustom else: target_criterion = TargetInclusionCriterion.QNonConfigurable - bitwidth_mode = BitwidthMode.SpDefault + bitwidth_mode = BitwidthMode.QDefaultSP _, nodes_util = self.ru_calculator.compute_activation_tensors_utilization(target_criterion=target_criterion, bitwidth_mode=bitwidth_mode, diff --git a/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py b/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py index a85d65378..52f82e237 100644 --- a/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +++ b/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py @@ -14,10 +14,9 @@ # ============================================================================== import numpy as np -import pulp from pulp import * from tqdm import tqdm -from typing import Dict, List, Tuple, Callable +from typing import Dict, Tuple from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \ ru_target_aggregation_fn, AggregationMethod @@ -236,13 +235,23 @@ def _add_set_of_ru_constraints(search_manager: MixedPrecisionSearchManager, lp_problem += v <= target_resource_utilization_value -def _aggregate_for_lp(ru_vec, target) -> list: +def _aggregate_for_lp(ru_vec, target: RUTarget) -> list: + """ + Aggregate resource utilization values for the LP. + + Args: + ru_vec: a vector of resource utilization values. + target: resource utilization target. + + Returns: + Aggregated resource utilization. + """ if target == RUTarget.TOTAL: - w = pulp.lpSum(v[0] for v in ru_vec) + w = lpSum(v[0] for v in ru_vec) return [w + v[1] for v in ru_vec] if ru_target_aggregation_fn[target] == AggregationMethod.SUM: - return [pulp.lpSum(ru_vec)] + return [lpSum(ru_vec)] if ru_target_aggregation_fn[target] == AggregationMethod.MAX: return list(ru_vec) diff --git a/model_compression_toolkit/core/runner.py b/model_compression_toolkit/core/runner.py index 8c585fb1f..b204b408e 100644 --- a/model_compression_toolkit/core/runner.py +++ b/model_compression_toolkit/core/runner.py @@ -224,14 +224,14 @@ def _set_final_resource_utilization(graph: Graph, w_qcs = {n: n.final_weights_quantization_cfg for n in graph.nodes} a_qcs = {n: n.final_activation_quantization_cfg for n in graph.nodes} ru_calculator = ResourceUtilizationCalculator(graph, fw_impl, fw_info) - final_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized, BitwidthMode.MpCustom, + final_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized, BitwidthMode.QCustom, act_qcs=a_qcs, w_qcs=w_qcs) for ru_target, ru in final_ru.get_resource_utilization_dict().items(): if ru == 0: - Logger.warning(f"No relevant quantized layers for the ru target {ru_target} were found, the recorded " - f"final ru for this target would be 0.") + Logger.warning(f"No relevant quantized layers for the resource utilization target {ru_target} were found, " + f"the recorded final ru for this target would be 0.") - print(final_ru) + Logger.info(f'Resource utilization (of quantized targets):\n {str(final_ru)}.') graph.user_info.final_resource_utilization = final_ru graph.user_info.mixed_precision_cfg = final_bit_widths_config diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision/requires_mixed_precision_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision/requires_mixed_precision_test.py index 7a8dd0b08..47664608d 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision/requires_mixed_precision_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision/requires_mixed_precision_test.py @@ -71,14 +71,10 @@ def get_max_resources_for_model(self, model): attach2keras = AttachTpcToKeras() tpc = attach2keras.attach(tpc, cc.quantization_config.custom_tpc_opset_to_layer) - return compute_resource_utilization_data(in_model=model, - representative_data_gen=self.representative_data_gen(), - core_config=cc, - tpc=tpc, - fw_info=DEFAULT_KERAS_INFO, - fw_impl=KerasImplementation(), - transformed_graph=None, - mixed_precision_enabled=False) + return compute_resource_utilization_data(in_model=model, representative_data_gen=self.representative_data_gen(), + core_config=cc, tpc=tpc, fw_info=DEFAULT_KERAS_INFO, + fw_impl=KerasImplementation(), transformed_graph=None, + mixed_precision_enable=False) def get_quantization_config(self): return mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.MSE, diff --git a/tests_pytest/core/__init__.py b/tests_pytest/core/__init__.py index e11a7cc60..5397dea24 100644 --- a/tests_pytest/core/__init__.py +++ b/tests_pytest/core/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests_pytest/core/common/__init__.py b/tests_pytest/core/common/__init__.py index e11a7cc60..5397dea24 100644 --- a/tests_pytest/core/common/__init__.py +++ b/tests_pytest/core/common/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests_pytest/core/common/mixed_precision/__init__.py b/tests_pytest/core/common/mixed_precision/__init__.py index e11a7cc60..5397dea24 100644 --- a/tests_pytest/core/common/mixed_precision/__init__.py +++ b/tests_pytest/core/common/mixed_precision/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests_pytest/core/common/mixed_precision/resource_utilization_tools/__init__.py b/tests_pytest/core/common/mixed_precision/resource_utilization_tools/__init__.py index e11a7cc60..5397dea24 100644 --- a/tests_pytest/core/common/mixed_precision/resource_utilization_tools/__init__.py +++ b/tests_pytest/core/common/mixed_precision/resource_utilization_tools/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.