diff --git a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py index 059304ba5..3e5ed188c 100644 --- a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +++ b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py @@ -99,25 +99,6 @@ def __lt__(self, other: 'Utilization'): return self.bytes < other.bytes -class AggregationMethod(Enum): - SUM = sum - MAX = lambda seq: max(seq) if (seq := list(seq)) else 0 # walrus op for empty generator - - def __call__(self, *args, **kwarg): - return self.value(*args, **kwarg) - - -# default aggregation methods -# TODO This is used by mp to use the same aggregation. Except that for total it must do its own thing (add indicators -# to weights before summation). So maybe just get rid of it altogether? If it ever becomes configurable we can add it. -ru_target_aggregation_fn = { - RUTarget.WEIGHTS: AggregationMethod.SUM, - RUTarget.ACTIVATION: AggregationMethod.MAX, - RUTarget.TOTAL: AggregationMethod.SUM, - RUTarget.BOPS: AggregationMethod.SUM -} - - class ResourceUtilizationCalculator: """ Resource utilization calculator. """ @@ -226,8 +207,7 @@ def compute_weights_utilization(self, util_per_node[n] = node_weights_util util_per_node_per_weight[n] = per_weight_util - aggregate_fn = ru_target_aggregation_fn[RUTarget.WEIGHTS] - total_util = aggregate_fn(util_per_node.values()) + total_util = sum(util_per_node.values()) return total_util.bytes, util_per_node, util_per_node_per_weight def compute_node_weights_utilization(self, @@ -334,8 +314,7 @@ def compute_cut_activation_utilization(self, bitwidth_mode, qc) util_per_cut[cut] = sum(util_per_cut_per_node[cut].values()) # type: ignore - aggregate_fn = ru_target_aggregation_fn[RUTarget.ACTIVATION] - total_util = aggregate_fn(util_per_cut.values()) + total_util = max(util_per_cut.values()) return total_util.bytes, util_per_cut, util_per_cut_per_node def compute_activation_tensors_utilization(self, @@ -369,8 +348,7 @@ def compute_activation_tensors_utilization(self, util = self.compute_node_activation_tensor_utilization(n, None, bitwidth_mode, qc) util_per_node[n] = util - aggregate_fn = ru_target_aggregation_fn[RUTarget.ACTIVATION] - total_util = aggregate_fn(util_per_node.values()) + total_util = max(util_per_node.values()) return total_util.bytes, util_per_node def compute_node_activation_tensor_utilization(self, @@ -438,8 +416,7 @@ def compute_bops(self, w_qc = w_qcs.get(n) if w_qcs else None nodes_bops[n] = self.compute_node_bops(n, bitwidth_mode, act_qcs=act_qcs, w_qc=w_qc) - aggregate_fn = ru_target_aggregation_fn[RUTarget.BOPS] - return aggregate_fn(nodes_bops.values()), nodes_bops + return sum(nodes_bops.values()), nodes_bops def compute_node_bops(self, n: BaseNode, @@ -621,8 +598,7 @@ def _get_activation_nbits(cls, if act_qc: if bitwidth_mode != BitwidthMode.QCustom: raise ValueError(f'Activation config is not expected for non-custom bit mode {bitwidth_mode}') - assert act_qc.enable_activation_quantization or act_qc.activation_n_bits == FLOAT_BITWIDTH - return act_qc.activation_n_bits + return act_qc.activation_n_bits if act_qc.enable_activation_quantization else FLOAT_BITWIDTH if bitwidth_mode == BitwidthMode.Float or not n.is_activation_quantization_enabled(): return FLOAT_BITWIDTH @@ -667,8 +643,7 @@ def _get_weight_nbits(cls, if bitwidth_mode != BitwidthMode.QCustom: raise ValueError('Weight config is not expected for non-custom bit mode {bitwidth_mode}') attr_cfg = w_qc.get_attr_config(w_attr) - assert attr_cfg.enable_weights_quantization or attr_cfg.weights_n_bits == FLOAT_BITWIDTH - return attr_cfg.weights_n_bits + return attr_cfg.weights_n_bits if attr_cfg.enable_weights_quantization else FLOAT_BITWIDTH if bitwidth_mode == BitwidthMode.Float or not n.is_weights_quantization_enabled(w_attr): return FLOAT_BITWIDTH diff --git a/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py b/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py index 52f82e237..56ee0e5ca 100644 --- a/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +++ b/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py @@ -18,8 +18,6 @@ from tqdm import tqdm from typing import Dict, Tuple -from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \ - ru_target_aggregation_fn, AggregationMethod from model_compression_toolkit.logger import Logger from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization, RUTarget from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_manager import MixedPrecisionSearchManager @@ -250,14 +248,14 @@ def _aggregate_for_lp(ru_vec, target: RUTarget) -> list: w = lpSum(v[0] for v in ru_vec) return [w + v[1] for v in ru_vec] - if ru_target_aggregation_fn[target] == AggregationMethod.SUM: + if target in [RUTarget.WEIGHTS, RUTarget.BOPS]: return [lpSum(ru_vec)] - if ru_target_aggregation_fn[target] == AggregationMethod.MAX: + if target == RUTarget.ACTIVATION: + # for max aggregation, each value constitutes a separate constraint return list(ru_vec) - raise NotImplementedError(f'Cannot define lp constraints with unsupported aggregation function ' - f'{ru_target_aggregation_fn[target]}') # pragma: no cover + raise ValueError(f'Unexpected target {target}.') def _build_layer_to_metrics_mapping(search_manager: MixedPrecisionSearchManager,