From 228c35ff0dc097ef079200e52bf0af6282983fa1 Mon Sep 17 00:00:00 2001
From: Irena Byzalov <102301507+irenaby@users.noreply.github.com>
Date: Wed, 15 Jan 2025 08:46:23 +0200
Subject: [PATCH] Report final ru only for mixed precision constrained targets
 (#1326)

* report final ru only for mixed precision constrained targets
---
 .../mixed_precision_search_manager.py         |  2 +-
 .../resource_utilization.py                   | 30 +++++++++++++-----
 .../resource_utilization_calculator.py        | 22 ++++++++-----
 .../resource_utilization_data.py              |  2 +-
 .../search_methods/linear_programming.py      |  2 +-
 model_compression_toolkit/core/runner.py      | 31 +++++++++++--------
 .../test_resource_utilization_object.py       | 21 +++++++------
 .../feature_networks/mixed_precision_tests.py | 21 -------------
 .../weights_mixed_precision_tests.py          | 11 -------
 9 files changed, 71 insertions(+), 71 deletions(-)

diff --git a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py
index 670dc11cc..c1bad8313 100644
--- a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py
+++ b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py
@@ -69,7 +69,7 @@ def __init__(self,
 
         # To define RU Total constraints we need to compute weights and activations even if they have no constraints
         # TODO currently this logic is duplicated in linear_programming.py
-        targets = target_resource_utilization.get_restricted_metrics()
+        targets = target_resource_utilization.get_restricted_targets()
         if RUTarget.TOTAL in targets:
             targets = targets.union({RUTarget.ACTIVATION, RUTarget.WEIGHTS}) - {RUTarget.TOTAL}
         self.ru_targets_to_compute = targets
diff --git a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py
index 3da53184a..d2746da1b 100644
--- a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py
+++ b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py
@@ -86,15 +86,31 @@ def is_satisfied_by(self, ru: 'ResourceUtilization') -> bool:
                     ru.total_memory <= self.total_memory and \
                     ru.bops <= self.bops)
 
-    def get_restricted_metrics(self) -> Set[RUTarget]:
+    def get_restricted_targets(self) -> Set[RUTarget]:
         d = self.get_resource_utilization_dict()
         return {k for k, v in d.items() if v < np.inf}
 
     def is_any_restricted(self) -> bool:
-        return bool(self.get_restricted_metrics())
+        return bool(self.get_restricted_targets())
 
-    def __repr__(self):
-        return f"Weights_memory: {self.weights_memory}, " \
-               f"Activation_memory: {self.activation_memory}, " \
-               f"Total_memory: {self.total_memory}, " \
-               f"BOPS: {self.bops}"
+    def get_summary_str(self, restricted: bool):
+        """
+        Generate summary string.
+
+        Args:
+            restricted: whether to include non-restricted targets.
+
+        Returns:
+            Summary string.
+        """
+        targets = self.get_restricted_targets() if restricted else list(RUTarget)
+        summary = []
+        if RUTarget.WEIGHTS in targets:
+            summary.append(f"Weights memory: {self.weights_memory}")
+        if RUTarget.ACTIVATION in targets:
+            summary.append(f"Activation memory: {self.activation_memory}")
+        if RUTarget.TOTAL in targets:
+            summary.append(f"Total memory: {self.total_memory}")
+        if RUTarget.BOPS in targets:
+            summary.append(f"BOPS: {self.bops}")
+        return ', '.join(summary)
diff --git a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py
index aff19117f..2b118ee1b 100644
--- a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py
+++ b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py
@@ -15,8 +15,7 @@
 from collections import defaultdict
 from copy import deepcopy
 from enum import Enum, auto
-from functools import lru_cache
-from typing import Dict, NamedTuple, Optional, Tuple, List, Iterable, Union, Literal, Sequence
+from typing import Dict, NamedTuple, Optional, Tuple, List, Iterable, Union, Literal, Sequence, Set
 
 from model_compression_toolkit.constants import FLOAT_BITWIDTH
 from model_compression_toolkit.core import FrameworkInfo
@@ -160,16 +159,17 @@ def compute_resource_utilization(self,
         """
         ru_targets = set(ru_targets) if ru_targets else set(RUTarget)
 
+        if w_qcs is not None and not self.is_custom_weights_config_applicable(ru_targets):
+            raise ValueError('Weight configuration passed but no relevant metric requested.')
+        if act_qcs is not None and not self.is_custom_activation_config_applicable(ru_targets):
+            raise ValueError('Activation configuration passed but no relevant metric requested.')
+
         w_total, a_total = None, None
         if {RUTarget.WEIGHTS, RUTarget.TOTAL}.intersection(ru_targets):
             w_total, *_ = self.compute_weights_utilization(target_criterion, bitwidth_mode, w_qcs)
-        elif w_qcs is not None:    # pragma: no cover
-            raise ValueError('Weight configuration passed but no relevant metric requested.')
 
         if {RUTarget.ACTIVATION, RUTarget.TOTAL}.intersection(ru_targets):
             a_total = self.compute_activations_utilization(target_criterion, bitwidth_mode, act_qcs)
-        elif act_qcs is not None:    # pragma: no cover
-            raise ValueError('Activation configuration passed but no relevant metric requested.')
 
         ru = ResourceUtilization()
         if RUTarget.WEIGHTS in ru_targets:
@@ -182,7 +182,7 @@ def compute_resource_utilization(self,
             ru.bops, _ = self.compute_bops(target_criterion=target_criterion,
                                            bitwidth_mode=bitwidth_mode, act_qcs=act_qcs, w_qcs=w_qcs)
 
-        assert ru.get_restricted_metrics() == set(ru_targets), 'Mismatch between the number of requested and computed metrics'
+        assert ru.get_restricted_targets() == set(ru_targets), 'Mismatch between the number of requested and computed metrics'
         return ru
 
     def compute_weights_utilization(self,
@@ -464,6 +464,14 @@ def compute_node_bops(self,
         node_bops = a_nbits * w_nbits * node_mac
         return node_bops
 
+    def is_custom_weights_config_applicable(self, ru_targets: Set[RUTarget]) -> bool:
+        """ Whether custom configuration for weights is compatible with the requested targets."""
+        return bool({RUTarget.WEIGHTS, RUTarget.TOTAL, RUTarget.BOPS}.intersection(ru_targets))
+
+    def is_custom_activation_config_applicable(self, ru_targets: Set[RUTarget]) -> bool:
+        """ Whether custom configuration for activations is compatible with the requested targets."""
+        return bool({RUTarget.ACTIVATION, RUTarget.TOTAL, RUTarget.BOPS}.intersection(ru_targets))
+
     def _get_cut_target_nodes(self, cut: Cut, target_criterion: TargetInclusionCriterion) -> List[BaseNode]:
         """
         Retrieve target nodes from a cut filtered by a criterion.
diff --git a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py
index 0564b5ddf..c61dbf6a1 100644
--- a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py
+++ b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py
@@ -118,7 +118,7 @@ def requires_mixed_precision(in_model: Any,
 
     ru_calculator = ResourceUtilizationCalculator(transformed_graph, fw_impl, fw_info)
     max_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized, BitwidthMode.QMaxBit,
-                                                        ru_targets=target_resource_utilization.get_restricted_metrics())
+                                                        ru_targets=target_resource_utilization.get_restricted_targets())
     return not target_resource_utilization.is_satisfied_by(max_ru)
 
 
diff --git a/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py b/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py
index bf89f1ff8..34e6fcbaa 100644
--- a/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py
+++ b/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py
@@ -196,7 +196,7 @@ def _add_ru_constraints(search_manager: MixedPrecisionSearchManager,
     """
     ru_indicated_vectors = {}
     # targets to add constraints for
-    constraints_targets = target_resource_utilization.get_restricted_metrics()
+    constraints_targets = target_resource_utilization.get_restricted_targets()
     # to add constraints for Total target we need to compute weight and activation
     targets_to_compute = constraints_targets
     if RUTarget.TOTAL in constraints_targets:
diff --git a/model_compression_toolkit/core/runner.py b/model_compression_toolkit/core/runner.py
index 0e678070c..d9812ff25 100644
--- a/model_compression_toolkit/core/runner.py
+++ b/model_compression_toolkit/core/runner.py
@@ -14,7 +14,7 @@
 # ==============================================================================
 
 import copy
-from typing import Callable, Any, List
+from typing import Callable, Any, List, Optional
 
 from model_compression_toolkit.core.common import FrameworkInfo
 from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
@@ -170,6 +170,7 @@ def core_runner(in_model: Any,
 
     _set_final_resource_utilization(graph=tg,
                                     final_bit_widths_config=bit_widths_config,
+                                    target_resource_utilization=target_resource_utilization,
                                     fw_info=fw_info,
                                     fw_impl=fw_impl)
 
@@ -207,6 +208,7 @@ def core_runner(in_model: Any,
 
 def _set_final_resource_utilization(graph: Graph,
                                     final_bit_widths_config: List[int],
+                                    target_resource_utilization: Optional[ResourceUtilization],
                                     fw_info: FrameworkInfo,
                                     fw_impl: FrameworkImplementation):
     """
@@ -216,21 +218,24 @@ def _set_final_resource_utilization(graph: Graph,
     Args:
         graph: Graph to compute the resource utilization for.
         final_bit_widths_config: The final bit-width configuration to quantize the model accordingly.
+        target_resource_utilization: Requested target resource utilization if relevant.
         fw_info: A FrameworkInfo object.
         fw_impl: FrameworkImplementation object with specific framework methods implementation.
 
     """
-    w_qcs = {n: n.final_weights_quantization_cfg for n in graph.nodes}
-    a_qcs = {n: n.final_activation_quantization_cfg for n in graph.nodes}
-    ru_calculator = ResourceUtilizationCalculator(graph, fw_impl, fw_info)
-    final_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized, BitwidthMode.QCustom,
-                                                          act_qcs=a_qcs, w_qcs=w_qcs)
-
-    for ru_target, ru in final_ru.get_resource_utilization_dict().items():
-        if ru == 0:
-            Logger.warning(f"No relevant quantized layers for the resource utilization target {ru_target} were found, "
-                           f"the recorded final ru for this target would be 0.")
-
-    Logger.info(f'Resource utilization (of quantized targets):\n {str(final_ru)}.')
+    ru_targets = target_resource_utilization.get_restricted_targets() if target_resource_utilization else None
+    final_ru = None
+    if ru_targets:
+        ru_calculator = ResourceUtilizationCalculator(graph, fw_impl, fw_info)
+        w_qcs, a_qcs = None, None
+        if ru_calculator.is_custom_weights_config_applicable(ru_targets):
+            w_qcs = {n: n.final_weights_quantization_cfg for n in graph.nodes}
+        if ru_calculator.is_custom_activation_config_applicable(ru_targets):
+            a_qcs = {n: n.final_activation_quantization_cfg for n in graph.nodes}
+        final_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized,
+                                                              BitwidthMode.QCustom,
+                                                              act_qcs=a_qcs, w_qcs=w_qcs, ru_targets=ru_targets)
+        summary = final_ru.get_summary_str(restricted=True)
+        Logger.info(f'Resource utilization for quantized mixed-precision targets:\n {summary}.')
     graph.user_info.final_resource_utilization = final_ru
     graph.user_info.mixed_precision_cfg = final_bit_widths_config
diff --git a/tests/common_tests/function_tests/test_resource_utilization_object.py b/tests/common_tests/function_tests/test_resource_utilization_object.py
index f7e3f9374..d9b783240 100644
--- a/tests/common_tests/function_tests/test_resource_utilization_object.py
+++ b/tests/common_tests/function_tests/test_resource_utilization_object.py
@@ -22,6 +22,7 @@
 
 default_ru = ResourceUtilization()
 custom_ru = ResourceUtilization(1, 2, 3, 4)
+mixed_ru = ResourceUtilization(activation_memory=5, bops=10)
 
 
 class TestResourceUtilizationObject(unittest.TestCase):
@@ -38,15 +39,17 @@ def test_default(self):
         self.assertTrue(custom_ru.bops, 4)
 
     def test_representation(self):
-        self.assertEqual(repr(default_ru), f"Weights_memory: {np.inf}, "
-                                           f"Activation_memory: {np.inf}, "
-                                           f"Total_memory: {np.inf}, "
-                                           f"BOPS: {np.inf}")
-
-        self.assertEqual(repr(custom_ru), f"Weights_memory: {1}, "
-                                          f"Activation_memory: {2}, "
-                                          f"Total_memory: {3}, "
-                                          f"BOPS: {4}")
+        self.assertEqual(default_ru.get_summary_str(restricted=False), f"Weights memory: {np.inf}, "
+                                                                       f"Activation memory: {np.inf}, "
+                                                                       f"Total memory: {np.inf}, "
+                                                                       f"BOPS: {np.inf}")
+        self.assertEqual(default_ru.get_summary_str(restricted=True), "")
+
+        self.assertEqual(mixed_ru.get_summary_str(restricted=False), f"Weights memory: {np.inf}, "
+                                                                     "Activation memory: 5, "
+                                                                     f"Total memory: {np.inf}, "
+                                                                     "BOPS: 10")
+        self.assertEqual(mixed_ru.get_summary_str(restricted=True), "Activation memory: 5, BOPS: 10")
 
     def test_ru_hold_constraints(self):
         self.assertTrue(default_ru.is_satisfied_by(custom_ru))
diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py b/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py
index f89aff1fa..7f0bed284 100644
--- a/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py
+++ b/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py
@@ -174,13 +174,6 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
         # test with its current setup (therefore, we don't check the input layer's bitwidth)
         self.unit_test.assertTrue((activation_bits == [4, 8]))
 
-        # Verify final resource utilization
-        self.unit_test.assertTrue(
-            quantization_info.final_resource_utilization.total_memory ==
-            quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory,
-            "Running weights and activation mixed-precision, "
-            "final total memory should be equal to sum of weights and activation memory.")
-
 
 class MixedPrecisionActivationSearch2BitsAvgTest(MixedPrecisionActivationBaseTest):
     def __init__(self, unit_test):
@@ -206,13 +199,6 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
                                  activation_layers_idx=self.activation_layers_idx,
                                  unique_tensor_values=4)
 
-        # Verify final resource utilization
-        self.unit_test.assertTrue(
-            quantization_info.final_resource_utilization.total_memory ==
-            quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory,
-            "Running weights and activation mixed-precision, "
-            "final total memory should be equal to sum of weights and activation memory.")
-
 
 class MixedPrecisionActivationDepthwiseTest(MixedPrecisionActivationBaseTest):
     def __init__(self, unit_test):
@@ -484,13 +470,6 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info:
                                  activation_layers_idx=self.activation_layers_idx,
                                  unique_tensor_values=16)
 
-        # Verify final ResourceUtilization
-        self.unit_test.assertTrue(
-            quantization_info.final_resource_utilization.total_memory ==
-            quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory,
-            "Running weights and activation mixed-precision, "
-            "final total memory should be equal to sum of weights and activation memory.")
-
 
 class MixedPrecisionMultipleResourcesTightUtilizationSearchTest(MixedPrecisionActivationBaseTest):
     def __init__(self, unit_test):
diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py b/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py
index b6e5344e2..328e3674d 100644
--- a/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py
+++ b/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py
@@ -341,12 +341,6 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
         self.unit_test.assertTrue(quantization_info.final_resource_utilization.activation_memory <=
                                   self.target_total_ru.activation_memory)
 
-        self.unit_test.assertTrue(
-            quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory ==
-            quantization_info.final_resource_utilization.total_memory,
-            "Running weights mixed-precision with unconstrained Resource Utilization, "
-            "final weights and activation memory sum should be equal to total memory.")
-
 
 class MixedPrecisionSearchTotalMemoryNonConfNodesTest(MixedPrecisionBaseTest):
     def __init__(self, unit_test):
@@ -362,11 +356,6 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
         # we're only interested in the ResourceUtilization
         self.unit_test.assertTrue(
             quantization_info.final_resource_utilization.total_memory <= self.target_total_ru.total_memory)
-        self.unit_test.assertTrue(
-            quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory ==
-            quantization_info.final_resource_utilization.total_memory,
-            "Running weights mixed-precision with unconstrained ResourceUtilization, "
-            "final weights and activation memory sum should be equal to total memory.")
 
 
 class MixedPrecisionDepthwiseTest(MixedPrecisionBaseTest):