Skip to content

Commit

Permalink
Merge branch 'sony:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
SumaiyaTarannumNoor authored Jan 15, 2025
2 parents 46b5a87 + 228c35f commit 441bb66
Show file tree
Hide file tree
Showing 9 changed files with 71 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(self,

# To define RU Total constraints we need to compute weights and activations even if they have no constraints
# TODO currently this logic is duplicated in linear_programming.py
targets = target_resource_utilization.get_restricted_metrics()
targets = target_resource_utilization.get_restricted_targets()
if RUTarget.TOTAL in targets:
targets = targets.union({RUTarget.ACTIVATION, RUTarget.WEIGHTS}) - {RUTarget.TOTAL}
self.ru_targets_to_compute = targets
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,31 @@ def is_satisfied_by(self, ru: 'ResourceUtilization') -> bool:
ru.total_memory <= self.total_memory and \
ru.bops <= self.bops)

def get_restricted_metrics(self) -> Set[RUTarget]:
def get_restricted_targets(self) -> Set[RUTarget]:
d = self.get_resource_utilization_dict()
return {k for k, v in d.items() if v < np.inf}

def is_any_restricted(self) -> bool:
return bool(self.get_restricted_metrics())
return bool(self.get_restricted_targets())

def __repr__(self):
return f"Weights_memory: {self.weights_memory}, " \
f"Activation_memory: {self.activation_memory}, " \
f"Total_memory: {self.total_memory}, " \
f"BOPS: {self.bops}"
def get_summary_str(self, restricted: bool):
"""
Generate summary string.
Args:
restricted: whether to include non-restricted targets.
Returns:
Summary string.
"""
targets = self.get_restricted_targets() if restricted else list(RUTarget)
summary = []
if RUTarget.WEIGHTS in targets:
summary.append(f"Weights memory: {self.weights_memory}")
if RUTarget.ACTIVATION in targets:
summary.append(f"Activation memory: {self.activation_memory}")
if RUTarget.TOTAL in targets:
summary.append(f"Total memory: {self.total_memory}")
if RUTarget.BOPS in targets:
summary.append(f"BOPS: {self.bops}")
return ', '.join(summary)
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from collections import defaultdict
from copy import deepcopy
from enum import Enum, auto
from functools import lru_cache
from typing import Dict, NamedTuple, Optional, Tuple, List, Iterable, Union, Literal, Sequence
from typing import Dict, NamedTuple, Optional, Tuple, List, Iterable, Union, Literal, Sequence, Set

from model_compression_toolkit.constants import FLOAT_BITWIDTH
from model_compression_toolkit.core import FrameworkInfo
Expand Down Expand Up @@ -160,16 +159,17 @@ def compute_resource_utilization(self,
"""
ru_targets = set(ru_targets) if ru_targets else set(RUTarget)

if w_qcs is not None and not self.is_custom_weights_config_applicable(ru_targets):
raise ValueError('Weight configuration passed but no relevant metric requested.')
if act_qcs is not None and not self.is_custom_activation_config_applicable(ru_targets):
raise ValueError('Activation configuration passed but no relevant metric requested.')

w_total, a_total = None, None
if {RUTarget.WEIGHTS, RUTarget.TOTAL}.intersection(ru_targets):
w_total, *_ = self.compute_weights_utilization(target_criterion, bitwidth_mode, w_qcs)
elif w_qcs is not None: # pragma: no cover
raise ValueError('Weight configuration passed but no relevant metric requested.')

if {RUTarget.ACTIVATION, RUTarget.TOTAL}.intersection(ru_targets):
a_total = self.compute_activations_utilization(target_criterion, bitwidth_mode, act_qcs)
elif act_qcs is not None: # pragma: no cover
raise ValueError('Activation configuration passed but no relevant metric requested.')

ru = ResourceUtilization()
if RUTarget.WEIGHTS in ru_targets:
Expand All @@ -182,7 +182,7 @@ def compute_resource_utilization(self,
ru.bops, _ = self.compute_bops(target_criterion=target_criterion,
bitwidth_mode=bitwidth_mode, act_qcs=act_qcs, w_qcs=w_qcs)

assert ru.get_restricted_metrics() == set(ru_targets), 'Mismatch between the number of requested and computed metrics'
assert ru.get_restricted_targets() == set(ru_targets), 'Mismatch between the number of requested and computed metrics'
return ru

def compute_weights_utilization(self,
Expand Down Expand Up @@ -464,6 +464,14 @@ def compute_node_bops(self,
node_bops = a_nbits * w_nbits * node_mac
return node_bops

def is_custom_weights_config_applicable(self, ru_targets: Set[RUTarget]) -> bool:
""" Whether custom configuration for weights is compatible with the requested targets."""
return bool({RUTarget.WEIGHTS, RUTarget.TOTAL, RUTarget.BOPS}.intersection(ru_targets))

def is_custom_activation_config_applicable(self, ru_targets: Set[RUTarget]) -> bool:
""" Whether custom configuration for activations is compatible with the requested targets."""
return bool({RUTarget.ACTIVATION, RUTarget.TOTAL, RUTarget.BOPS}.intersection(ru_targets))

def _get_cut_target_nodes(self, cut: Cut, target_criterion: TargetInclusionCriterion) -> List[BaseNode]:
"""
Retrieve target nodes from a cut filtered by a criterion.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def requires_mixed_precision(in_model: Any,

ru_calculator = ResourceUtilizationCalculator(transformed_graph, fw_impl, fw_info)
max_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized, BitwidthMode.QMaxBit,
ru_targets=target_resource_utilization.get_restricted_metrics())
ru_targets=target_resource_utilization.get_restricted_targets())
return not target_resource_utilization.is_satisfied_by(max_ru)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def _add_ru_constraints(search_manager: MixedPrecisionSearchManager,
"""
ru_indicated_vectors = {}
# targets to add constraints for
constraints_targets = target_resource_utilization.get_restricted_metrics()
constraints_targets = target_resource_utilization.get_restricted_targets()
# to add constraints for Total target we need to compute weight and activation
targets_to_compute = constraints_targets
if RUTarget.TOTAL in constraints_targets:
Expand Down
31 changes: 18 additions & 13 deletions model_compression_toolkit/core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================

import copy
from typing import Callable, Any, List
from typing import Callable, Any, List, Optional

from model_compression_toolkit.core.common import FrameworkInfo
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
Expand Down Expand Up @@ -170,6 +170,7 @@ def core_runner(in_model: Any,

_set_final_resource_utilization(graph=tg,
final_bit_widths_config=bit_widths_config,
target_resource_utilization=target_resource_utilization,
fw_info=fw_info,
fw_impl=fw_impl)

Expand Down Expand Up @@ -207,6 +208,7 @@ def core_runner(in_model: Any,

def _set_final_resource_utilization(graph: Graph,
final_bit_widths_config: List[int],
target_resource_utilization: Optional[ResourceUtilization],
fw_info: FrameworkInfo,
fw_impl: FrameworkImplementation):
"""
Expand All @@ -216,21 +218,24 @@ def _set_final_resource_utilization(graph: Graph,
Args:
graph: Graph to compute the resource utilization for.
final_bit_widths_config: The final bit-width configuration to quantize the model accordingly.
target_resource_utilization: Requested target resource utilization if relevant.
fw_info: A FrameworkInfo object.
fw_impl: FrameworkImplementation object with specific framework methods implementation.
"""
w_qcs = {n: n.final_weights_quantization_cfg for n in graph.nodes}
a_qcs = {n: n.final_activation_quantization_cfg for n in graph.nodes}
ru_calculator = ResourceUtilizationCalculator(graph, fw_impl, fw_info)
final_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized, BitwidthMode.QCustom,
act_qcs=a_qcs, w_qcs=w_qcs)

for ru_target, ru in final_ru.get_resource_utilization_dict().items():
if ru == 0:
Logger.warning(f"No relevant quantized layers for the resource utilization target {ru_target} were found, "
f"the recorded final ru for this target would be 0.")

Logger.info(f'Resource utilization (of quantized targets):\n {str(final_ru)}.')
ru_targets = target_resource_utilization.get_restricted_targets() if target_resource_utilization else None
final_ru = None
if ru_targets:
ru_calculator = ResourceUtilizationCalculator(graph, fw_impl, fw_info)
w_qcs, a_qcs = None, None
if ru_calculator.is_custom_weights_config_applicable(ru_targets):
w_qcs = {n: n.final_weights_quantization_cfg for n in graph.nodes}
if ru_calculator.is_custom_activation_config_applicable(ru_targets):
a_qcs = {n: n.final_activation_quantization_cfg for n in graph.nodes}
final_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized,
BitwidthMode.QCustom,
act_qcs=a_qcs, w_qcs=w_qcs, ru_targets=ru_targets)
summary = final_ru.get_summary_str(restricted=True)
Logger.info(f'Resource utilization for quantized mixed-precision targets:\n {summary}.')
graph.user_info.final_resource_utilization = final_ru
graph.user_info.mixed_precision_cfg = final_bit_widths_config
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

default_ru = ResourceUtilization()
custom_ru = ResourceUtilization(1, 2, 3, 4)
mixed_ru = ResourceUtilization(activation_memory=5, bops=10)


class TestResourceUtilizationObject(unittest.TestCase):
Expand All @@ -38,15 +39,17 @@ def test_default(self):
self.assertTrue(custom_ru.bops, 4)

def test_representation(self):
self.assertEqual(repr(default_ru), f"Weights_memory: {np.inf}, "
f"Activation_memory: {np.inf}, "
f"Total_memory: {np.inf}, "
f"BOPS: {np.inf}")

self.assertEqual(repr(custom_ru), f"Weights_memory: {1}, "
f"Activation_memory: {2}, "
f"Total_memory: {3}, "
f"BOPS: {4}")
self.assertEqual(default_ru.get_summary_str(restricted=False), f"Weights memory: {np.inf}, "
f"Activation memory: {np.inf}, "
f"Total memory: {np.inf}, "
f"BOPS: {np.inf}")
self.assertEqual(default_ru.get_summary_str(restricted=True), "")

self.assertEqual(mixed_ru.get_summary_str(restricted=False), f"Weights memory: {np.inf}, "
"Activation memory: 5, "
f"Total memory: {np.inf}, "
"BOPS: 10")
self.assertEqual(mixed_ru.get_summary_str(restricted=True), "Activation memory: 5, BOPS: 10")

def test_ru_hold_constraints(self):
self.assertTrue(default_ru.is_satisfied_by(custom_ru))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,6 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
# test with its current setup (therefore, we don't check the input layer's bitwidth)
self.unit_test.assertTrue((activation_bits == [4, 8]))

# Verify final resource utilization
self.unit_test.assertTrue(
quantization_info.final_resource_utilization.total_memory ==
quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory,
"Running weights and activation mixed-precision, "
"final total memory should be equal to sum of weights and activation memory.")


class MixedPrecisionActivationSearch2BitsAvgTest(MixedPrecisionActivationBaseTest):
def __init__(self, unit_test):
Expand All @@ -206,13 +199,6 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
activation_layers_idx=self.activation_layers_idx,
unique_tensor_values=4)

# Verify final resource utilization
self.unit_test.assertTrue(
quantization_info.final_resource_utilization.total_memory ==
quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory,
"Running weights and activation mixed-precision, "
"final total memory should be equal to sum of weights and activation memory.")


class MixedPrecisionActivationDepthwiseTest(MixedPrecisionActivationBaseTest):
def __init__(self, unit_test):
Expand Down Expand Up @@ -484,13 +470,6 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info:
activation_layers_idx=self.activation_layers_idx,
unique_tensor_values=16)

# Verify final ResourceUtilization
self.unit_test.assertTrue(
quantization_info.final_resource_utilization.total_memory ==
quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory,
"Running weights and activation mixed-precision, "
"final total memory should be equal to sum of weights and activation memory.")


class MixedPrecisionMultipleResourcesTightUtilizationSearchTest(MixedPrecisionActivationBaseTest):
def __init__(self, unit_test):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,12 +341,6 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
self.unit_test.assertTrue(quantization_info.final_resource_utilization.activation_memory <=
self.target_total_ru.activation_memory)

self.unit_test.assertTrue(
quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory ==
quantization_info.final_resource_utilization.total_memory,
"Running weights mixed-precision with unconstrained Resource Utilization, "
"final weights and activation memory sum should be equal to total memory.")


class MixedPrecisionSearchTotalMemoryNonConfNodesTest(MixedPrecisionBaseTest):
def __init__(self, unit_test):
Expand All @@ -362,11 +356,6 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
# we're only interested in the ResourceUtilization
self.unit_test.assertTrue(
quantization_info.final_resource_utilization.total_memory <= self.target_total_ru.total_memory)
self.unit_test.assertTrue(
quantization_info.final_resource_utilization.weights_memory + quantization_info.final_resource_utilization.activation_memory ==
quantization_info.final_resource_utilization.total_memory,
"Running weights mixed-precision with unconstrained ResourceUtilization, "
"final weights and activation memory sum should be equal to total memory.")


class MixedPrecisionDepthwiseTest(MixedPrecisionBaseTest):
Expand Down

0 comments on commit 441bb66

Please sign in to comment.