Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace max tensor with max cut #1295

Merged
merged 15 commits into from
Dec 25, 2024
Prev Previous commit
Next Next commit
Fix MaxCutAStar.
Fix PR comments.
elad-c committed Dec 23, 2024
commit 96b82ac26dcbbe28a2efaa4ab55ebec0b79a4387
Original file line number Diff line number Diff line change
@@ -223,8 +223,7 @@ def _get_cut_to_expand(self, open_list: List[Cut], costs: Dict[Cut, float], rout

"""
ordered_cuts_list = sorted(open_list,
key=lambda c: (self.accumulate(costs[c], self.estimate(c, estimate_factor)), len(routes[c])),
reverse=False)
key=lambda c: (self.accumulate(costs[c], self.estimate(c, estimate_factor)), -len(routes[c])))

assert len(ordered_cuts_list) > 0
return ordered_cuts_list[0]
Original file line number Diff line number Diff line change
@@ -66,11 +66,11 @@ def __init__(self,
self.sensitivity_evaluator = sensitivity_evaluator
self.layer_to_bitwidth_mapping = self.get_search_space()
self.compute_metric_fn = self.get_sensitivity_metric()
self._cuts = None

self.cuts = calc_graph_cuts(self.original_graph)

ru_types = [k for k, v in target_resource_utilization.get_resource_utilization_dict().items() if v < np.inf]
self.compute_ru_functions = {k: v for k, v in ru_functions.items() if k in ru_types}
ru_types = [ru_target for ru_target, ru_value in
target_resource_utilization.get_resource_utilization_dict().items() if ru_value < np.inf]
self.compute_ru_functions = {ru_target: ru_fn for ru_target, ru_fn in ru_functions.items() if ru_target in ru_types}
self.target_resource_utilization = target_resource_utilization
self.min_ru_config = self.graph.get_min_candidates_config(fw_info)
self.max_ru_config = self.graph.get_max_candidates_config(fw_info)
@@ -80,6 +80,17 @@ def __init__(self,
self.config_reconstruction_helper = ConfigReconstructionHelper(virtual_graph=self.graph,
original_graph=self.original_graph)

@property
def cuts(self):
"""
Calcualtes graph cuts. Written as property so it will only be calculkated once and
only if cuts are needed.

"""
if self._cuts is None:
self._cuts = calc_graph_cuts(self.original_graph)
return self._cuts

def get_search_space(self) -> Dict[int, List[int]]:
"""
The search space is a mapping from a node's index to a list of integers (possible bitwidths candidates indeces
@@ -110,6 +121,21 @@ def get_sensitivity_metric(self) -> Callable:

return self.sensitivity_evaluator.compute_metric

def _calc_ru_fn(self, ru_target, ru_fn, mp_cfg):
"""
Computes a resource utilization for a certain mp configuration
The method computes a resource utilization vector for specific target resource utilization.

Returns: resource utilization value.

"""
# ru_fn is a pair of resource utilization computation method and
# resource utilization aggregation method (in this method we only need the first one)
if ru_target is RUTarget.ACTIVATION:
return ru_fn.metric_fn(mp_cfg, self.graph, self.fw_info, self.fw_impl, self.cuts)
else:
return ru_fn.metric_fn(mp_cfg, self.graph, self.fw_info, self.fw_impl)

def compute_min_ru(self) -> Dict[RUTarget, np.ndarray]:
"""
Computes a resource utilization vector with the values matching to the minimal mp configuration
@@ -122,13 +148,10 @@ def compute_min_ru(self) -> Dict[RUTarget, np.ndarray]:

"""
min_ru = {}
for ru_target, ru_fns in self.compute_ru_functions.items():
# ru_fns is a pair of resource utilization computation method and
for ru_target, ru_fn in self.compute_ru_functions.items():
# ru_fns is a pair of resource utilization computation method and
# resource utilization aggregation method (in this method we only need the first one)
if ru_target is RUTarget.ACTIVATION:
min_ru[ru_target] = ru_fns.metric_fn(self.min_ru_config, self.graph, self.fw_info, self.fw_impl, self.cuts)
else:
min_ru[ru_target] = ru_fns.metric_fn(self.min_ru_config, self.graph, self.fw_info, self.fw_impl)
min_ru[ru_target] = self._calc_ru_fn(ru_target, ru_fn, self.min_ru_config)

return min_ru

@@ -219,10 +242,7 @@ def compute_node_ru_for_candidate(self, conf_node_idx: int, candidate_idx: int,

"""
cfg = self.replace_config_in_index(self.min_ru_config, conf_node_idx, candidate_idx)
if target == RUTarget.ACTIVATION:
return self.compute_ru_functions[target].metric_fn(cfg, self.graph, self.fw_info, self.fw_impl, self.cuts)
else:
return self.compute_ru_functions[target].metric_fn(cfg, self.graph, self.fw_info, self.fw_impl)
return self._calc_ru_fn(target, self.compute_ru_functions[target], cfg)

@staticmethod
def replace_config_in_index(mp_cfg: List[int], idx: int, value: int) -> List[int]:
Original file line number Diff line number Diff line change
@@ -13,10 +13,12 @@
# limitations under the License.
# ==============================================================================
import copy
from collections import defaultdict

import numpy as np
from typing import Callable, Any, Dict, Tuple

from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import FLOAT_BITWIDTH, BITS_TO_BYTES
from model_compression_toolkit.core import FrameworkInfo, ResourceUtilization, CoreConfig, QuantizationErrorMethod
from model_compression_toolkit.core.common import Graph
@@ -152,22 +154,19 @@ def compute_activation_output_maxcut_sizes(graph: Graph) -> Tuple[np.ndarray, np
cuts = calc_graph_cuts(graph)

# map nodes to cuts.
node_to_cat_mapping = {}
node_to_cat_mapping = defaultdict(list)
for i, cut in enumerate(cuts):
mem_element_names = [m.node_name for m in cut.mem_elements.elements]
for m_name in mem_element_names:
if len(graph.find_node_by_name(m_name)) > 0:
if m_name in node_to_cat_mapping:
node_to_cat_mapping[m_name].append(i)
else:
node_to_cat_mapping[m_name] = [i]
node_to_cat_mapping[m_name].append(i)
else:
raise Exception("Missing node")
Logger.critical(f"Missing node: {m_name}")

activation_outputs = np.zeros(len(cuts))
activation_outputs_bytes = np.zeros(len(cuts))
for n in graph.nodes:
# Go over all nodes that have configurable activation.
# Go over all nodes that have activation quantization enabled.
if n.has_activation_quantization_enabled_candidate():
# Fetch maximum bits required for activations quantization.
max_activation_bits = max([qc.activation_quantization_cfg.activation_n_bits for qc in n.candidates_quantization_cfg])
@@ -177,7 +176,6 @@ def compute_activation_output_maxcut_sizes(graph: Graph) -> Tuple[np.ndarray, np
# Calculate activation size in bytes and append to list
activation_outputs_bytes[cut_index] += node_output_size * max_activation_bits / BITS_TO_BYTES

del cuts
return activation_outputs_bytes, activation_outputs


@@ -285,16 +283,17 @@ def requires_mixed_precision(in_model: Any,
total_weights_memory_bytes = 0 if len(weights_memory_by_layer_bytes) == 0 else sum(weights_memory_by_layer_bytes)

# Compute max activation tensor in bytes
activation_output_sizes_bytes, _ = compute_activation_output_maxcut_sizes(transformed_graph)
max_activation_tensor_size_bytes = 0 if len(activation_output_sizes_bytes) == 0 else max(activation_output_sizes_bytes)
activation_memory_estimation_bytes, _ = compute_activation_output_maxcut_sizes(transformed_graph)
max_activation_memory_estimation_bytes = 0 if len(activation_memory_estimation_bytes) == 0 \
else max(activation_memory_estimation_bytes)

# Compute BOPS utilization - total count of bit-operations for all configurable layers with kernel
bops_count = compute_total_bops(graph=transformed_graph, fw_info=fw_info, fw_impl=fw_impl)
bops_count = np.inf if len(bops_count) == 0 else sum(bops_count)

is_mixed_precision |= target_resource_utilization.weights_memory < total_weights_memory_bytes
is_mixed_precision |= target_resource_utilization.activation_memory < max_activation_tensor_size_bytes
is_mixed_precision |= target_resource_utilization.total_memory < total_weights_memory_bytes + max_activation_tensor_size_bytes
is_mixed_precision |= target_resource_utilization.activation_memory < max_activation_memory_estimation_bytes
is_mixed_precision |= target_resource_utilization.total_memory < total_weights_memory_bytes + max_activation_memory_estimation_bytes
is_mixed_precision |= target_resource_utilization.bops < bops_count
return is_mixed_precision

Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
# ==============================================================================
from enum import Enum
from functools import partial
from typing import List
from typing import List, Optional
from copy import deepcopy

import numpy as np
@@ -104,23 +104,23 @@ def calc_graph_cuts(graph: Graph) -> List[Cut]:
_, _, cuts = compute_graph_max_cut(memory_graph)

if cuts is None:
return None
else:
# filter empty cuts and cuta that contain nodes with activation quantization disabled.
filtered_cuts = []
for cut in cuts:
if len(cut.mem_elements.elements) > 0 and any(
[graph.find_node_by_name(e.node_name)[0].has_activation_quantization_enabled_candidate()
for e in cut.mem_elements.elements]):
filtered_cuts.append(cut)
return filtered_cuts
Logger.critical("Failed to calculate activation memory cuts for graph.")
# filter empty cuts and cuts that contain only nodes with activation quantization disabled.
filtered_cuts = []
for cut in cuts:
cut_has_no_act_quant_nodes = any(
[graph.find_node_by_name(e.node_name)[0].has_activation_quantization_enabled_candidate()
for e in cut.mem_elements.elements])
if len(cut.mem_elements.elements) > 0 and cut_has_no_act_quant_nodes:
filtered_cuts.append(cut)
return filtered_cuts


def activation_maxcut_size_utilization(mp_cfg: List[int],
graph: Graph,
fw_info: FrameworkInfo,
fw_impl: FrameworkImplementation,
cuts: List[Cut] = None) -> np.ndarray:
cuts: Optional[List[Cut]] = None) -> np.ndarray:
"""
Computes a resource utilization vector with the respective output memory max-cut size for activation
nodes, according to the given mixed-precision configuration.
@@ -156,10 +156,7 @@ def activation_maxcut_size_utilization(mp_cfg: List[int],
cuts = calc_graph_cuts(graph)

for i, cut in enumerate(cuts):
mem_elements = []
for m in cut.mem_elements.elements:
mem_elements.append(m.node_name)

mem_elements = [m.node_name for m in cut.mem_elements.elements]
mem = 0
for op_name in mem_elements:
n = graph.find_node_by_name(op_name)[0]
Original file line number Diff line number Diff line change
@@ -21,7 +21,9 @@
from model_compression_toolkit.constants import TENSORFLOW
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.constants import IMX500_TP_MODEL
from mct_quantizers.keras.activation_quantization_holder import KerasActivationQuantizationHolder
from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest
from tests.keras_tests.utils import get_layers_from_model_by_type

keras = tf.keras
layers = keras.layers
@@ -54,8 +56,8 @@ def create_networks(self):
return keras.Model(inputs=inputs, outputs=outputs)

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
mul1_act_quant = quantized_model.layers[3]
mul2_act_quant = quantized_model.layers[11]
act_quant_layers = get_layers_from_model_by_type(quantized_model, KerasActivationQuantizationHolder)
mul1_act_quant, mul2_act_quant = act_quant_layers[1], act_quant_layers[5]
self.unit_test.assertTrue(mul1_act_quant.activation_holder_quantizer.num_bits == 16,
"1st mul activation bits should be 16 bits because of following concat node.")
self.unit_test.assertTrue(mul1_act_quant.activation_holder_quantizer.signed == True,
@@ -97,8 +99,8 @@ def create_networks(self):
return keras.Model(inputs=inputs, outputs=outputs)

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
mul1_act_quant = quantized_model.layers[3]
mul2_act_quant = quantized_model.layers[10]
act_quant_layers = get_layers_from_model_by_type(quantized_model, KerasActivationQuantizationHolder)
mul1_act_quant, mul2_act_quant = act_quant_layers[1], act_quant_layers[4]
self.unit_test.assertTrue(mul1_act_quant.activation_holder_quantizer.num_bits == 8,
"1st mul activation bits should be 8 bits because of RU.")
self.unit_test.assertTrue(mul1_act_quant.activation_holder_quantizer.signed == False,
Original file line number Diff line number Diff line change
@@ -322,6 +322,7 @@ def test_mixed_precision_bops_utilization(self):
MixedPrecisionBopsAllWeightsLayersTest(self).run_test()
MixedPrecisionWeightsOnlyBopsTest(self).run_test()
MixedPrecisionActivationOnlyBopsTest(self).run_test()
# TODO: uncomment these tests when the issue of combined BOPs and other RU metrics is solved.
# MixedPrecisionBopsAndWeightsUtilizationTest(self).run_test()
# MixedPrecisionBopsAndActivationUtilizationTest(self).run_test()
# MixedPrecisionBopsAndTotalUtilizationTest(self).run_test()
2 changes: 1 addition & 1 deletion tests/keras_tests/utils.py
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@
from keras.layers import TFOpLambda


def get_layers_from_model_by_type(model:keras.Model,
def get_layers_from_model_by_type(model: keras.Model,
layer_type: type,
include_wrapped_layers: bool = True):
"""
Original file line number Diff line number Diff line change
@@ -605,6 +605,7 @@ def test_mixed_precision_bops_utilization(self):
MixedPrecisionBopsAllWeightsLayersTest(self).run_test()
MixedPrecisionWeightsOnlyBopsTest(self).run_test()
MixedPrecisionActivationOnlyBopsTest(self).run_test()
# TODO: uncomment these tests when the issue of combined BOPs and other RU metrics is solved.
# MixedPrecisionBopsAndWeightsMemoryUtilizationTest(self).run_test()
# MixedPrecisionBopsAndActivationMemoryUtilizationTest(self).run_test()
# MixedPrecisionBopsAndTotalMemoryUtilizationTest(self).run_test()