Skip to content

Commit

Permalink
fixes per cide review
Browse files Browse the repository at this point in the history
  • Loading branch information
irenaby committed Jan 12, 2025
1 parent 1036c10 commit ae3136b
Show file tree
Hide file tree
Showing 13 changed files with 227 additions and 196 deletions.
19 changes: 16 additions & 3 deletions model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,24 @@ def is_weights_quantization_enabled(self, attr_name: str) -> bool:
return False

def is_configurable_weight(self, attr_name: str) -> bool:
""" Checks whether the specific weight has a configurable quantization. """
"""
Checks whether the specific weight attribute has a configurable quantization.
Args:
attr_name: weight attribute name.
Returns:
Whether the weight attribute is configurable.
"""
return self.is_weights_quantization_enabled(attr_name) and not self.is_all_weights_candidates_equal(attr_name)

def has_configurable_activation(self):
""" Checks whether the activation has a configurable quantization. """
def has_configurable_activation(self) -> bool:
"""
Checks whether the activation has a configurable quantization.
Returns:
Whether the activation has a configurable quantization.
"""
return self.is_activation_quantization_enabled() and not self.is_all_activation_candidates_equal()

def __repr__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \
ResourceUtilizationCalculator, TargetInclusionCriterion, BitwidthMode
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_methods import \
MixPrecisionRUHelper
MixedPrecisionRUHelper
from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
from model_compression_toolkit.logger import Logger

Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(self,
self._cuts = None

self.ru_metrics = target_resource_utilization.get_restricted_metrics()
self.ru_helper = MixPrecisionRUHelper(graph, fw_info, fw_impl)
self.ru_helper = MixedPrecisionRUHelper(graph, fw_info, fw_impl)
self.target_resource_utilization = target_resource_utilization
self.min_ru_config = self.graph.get_min_candidates_config(fw_info)
self.max_ru_config = self.graph.get_max_candidates_config(fw_info)
Expand Down Expand Up @@ -207,10 +207,9 @@ def compute_resource_utilization_for_config(self, config: List[int]) -> Resource
"""
act_qcs, w_qcs = self.ru_helper.get_configurable_qcs(config)
ru = self.ru_helper.ru_calculator.compute_resource_utilization(target_criterion=TargetInclusionCriterion.AnyQuantized,
bitwidth_mode=BitwidthMode.MpCustom,
act_qcs=act_qcs,
w_qcs=w_qcs)
ru = self.ru_helper.ru_calculator.compute_resource_utilization(
target_criterion=TargetInclusionCriterion.AnyQuantized, bitwidth_mode=BitwidthMode.QCustom, act_qcs=act_qcs,
w_qcs=w_qcs)
return ru

def finalize_distance_metric(self, layer_to_metrics_mapping: Dict[int, Dict[int, float]]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ class ResourceUtilization:
total_memory: The sum of model's activation and weights memory in bytes.
bops: The total bit-operations in the model.
"""
# TODO the user facade actually computes size, not memory. Do we want to change fields names?
weights_memory: float = np.inf
activation_memory: float = np.inf
total_memory: float = np.inf
Expand Down Expand Up @@ -93,9 +92,3 @@ def get_restricted_metrics(self) -> Set[RUTarget]:

def is_any_restricted(self) -> bool:
return bool(self.get_restricted_metrics())

def __repr__(self):
return f"Weights_memory: {self.weights_memory}, " \
f"Activation_memory: {self.activation_memory}, " \
f"Total_memory: {self.total_memory}, " \
f"BOPS: {self.bops}"

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def compute_resource_utilization_data(in_model: Any,
fw_info: FrameworkInfo,
fw_impl: FrameworkImplementation,
transformed_graph: Graph = None,
mixed_precision_enabled: bool = True) -> ResourceUtilization:
mixed_precision_enable: bool = True) -> ResourceUtilization:
"""
Compute Resource Utilization information that can be relevant for defining target ResourceUtilization for mixed precision search.
Calculates maximal activation tensor size, the sum of the model's weight parameters and the total memory combining both weights
Expand All @@ -49,7 +49,7 @@ def compute_resource_utilization_data(in_model: Any,
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
transformed_graph: An internal graph representation of the input model. Defaults to None.
If no graph is provided, a graph will be constructed using the specified model.
mixed_precision_enabled: Indicates if mixed precision is enabled, defaults to True.
mixed_precision_enable: Indicates if mixed precision is enabled, defaults to True.
If disabled, computes resource utilization using base quantization
configurations across all layers.
Expand All @@ -68,13 +68,12 @@ def compute_resource_utilization_data(in_model: Any,
fw_impl,
tpc,
bit_width_config=core_config.bit_width_config,
mixed_precision_enable=mixed_precision_enabled,
mixed_precision_enable=mixed_precision_enable,
running_gptq=False)

ru_calculator = ResourceUtilizationCalculator(transformed_graph, fw_impl, fw_info)
ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized,
BitwidthMode.Size,
metrics=set(RUTarget) - {RUTarget.BOPS})
ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized, BitwidthMode.Q8Bit,
ru_targets=set(RUTarget) - {RUTarget.BOPS})
ru.bops, _ = ru_calculator.compute_bops(TargetInclusionCriterion.AnyQuantized, BitwidthMode.Float)
return ru

Expand Down Expand Up @@ -118,9 +117,8 @@ def requires_mixed_precision(in_model: Any,
running_gptq=False)

ru_calculator = ResourceUtilizationCalculator(transformed_graph, fw_impl, fw_info)
max_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized,
BitwidthMode.MpMax,
metrics=target_resource_utilization.get_restricted_metrics())
max_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized, BitwidthMode.QMaxBit,
ru_targets=target_resource_utilization.get_restricted_metrics())
return not target_resource_utilization.is_satisfied_by(max_ru)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
from model_compression_toolkit.core.common.graph.memory_graph.cut import Cut
from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode, \
VirtualSplitWeightsNode, VirtualSplitActivationNode
from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
RUTarget
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \
Expand All @@ -33,7 +32,9 @@
# TODO take into account Virtual nodes. Are candidates defined with respect to virtual or original nodes?
# Can we use the virtual graph only for bops and the original graph for everything else?

class MixPrecisionRUHelper:
class MixedPrecisionRUHelper:
""" Helper class for resource utilization computations for mixed precision optimization. """

def __init__(self, graph: Graph, fw_info: FrameworkInfo, fw_impl: FrameworkImplementation):
self.graph = graph
self.fw_info = fw_info
Expand All @@ -42,7 +43,10 @@ def __init__(self, graph: Graph, fw_info: FrameworkInfo, fw_impl: FrameworkImple

def compute_utilization(self, ru_targets: Set[RUTarget], mp_cfg: Optional[List[int]]) -> Dict[RUTarget, np.ndarray]:
"""
Compute utilization of requested targets for a specific configuration
Compute utilization of requested targets for a specific configuration in the format expected by LP problem
formulation, namely an array of ru values corresponding to graph's configurable nodes in the topological order.
For activation target, the array contains values for activation cuts in unspecified order (as long as it is
consistent between configurations).
Args:
ru_targets: resource utilization targets to compute.
Expand Down Expand Up @@ -112,10 +116,10 @@ def _weights_utilization(self, w_qcs: Optional[Dict[BaseNode, NodeWeightsQuantiz
"""
if w_qcs:
target_criterion = TargetInclusionCriterion.QConfigurable
bitwidth_mode = BitwidthMode.MpCustom
bitwidth_mode = BitwidthMode.QCustom
else:
target_criterion = TargetInclusionCriterion.QNonConfigurable
bitwidth_mode = BitwidthMode.SpDefault
bitwidth_mode = BitwidthMode.QDefaultSP

_, nodes_util, _ = self.ru_calculator.compute_weights_utilization(target_criterion=target_criterion,
bitwidth_mode=bitwidth_mode,
Expand All @@ -136,7 +140,7 @@ def _activation_maxcut_utilization(self, act_qcs: Optional[Dict[BaseNode, NodeAc
"""
if act_qcs:
_, cuts_util, _ = self.ru_calculator.compute_cut_activation_utilization(TargetInclusionCriterion.AnyQuantized,
bitwidth_mode=BitwidthMode.MpCustom,
bitwidth_mode=BitwidthMode.QCustom,
act_qcs=act_qcs)
cuts_util = {c: u.bytes for c, u in cuts_util.items()}
return cuts_util
Expand All @@ -158,10 +162,10 @@ def _activation_tensor_utilization(self, act_qcs: Optional[Dict[BaseNode, NodeAc
"""
if act_qcs:
target_criterion = TargetInclusionCriterion.QConfigurable
bitwidth_mode = BitwidthMode.MpCustom
bitwidth_mode = BitwidthMode.QCustom
else:
target_criterion = TargetInclusionCriterion.QNonConfigurable
bitwidth_mode = BitwidthMode.SpDefault
bitwidth_mode = BitwidthMode.QDefaultSP

_, nodes_util = self.ru_calculator.compute_activation_tensors_utilization(target_criterion=target_criterion,
bitwidth_mode=bitwidth_mode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
# ==============================================================================

import numpy as np
import pulp
from pulp import *
from tqdm import tqdm
from typing import Dict, List, Tuple, Callable
from typing import Dict, Tuple

from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \
ru_target_aggregation_fn, AggregationMethod
Expand Down Expand Up @@ -236,13 +235,23 @@ def _add_set_of_ru_constraints(search_manager: MixedPrecisionSearchManager,
lp_problem += v <= target_resource_utilization_value


def _aggregate_for_lp(ru_vec, target) -> list:
def _aggregate_for_lp(ru_vec, target: RUTarget) -> list:
"""
Aggregate resource utilization values for the LP.
Args:
ru_vec: a vector of resource utilization values.
target: resource utilization target.
Returns:
Aggregated resource utilization.
"""
if target == RUTarget.TOTAL:
w = pulp.lpSum(v[0] for v in ru_vec)
w = lpSum(v[0] for v in ru_vec)
return [w + v[1] for v in ru_vec]

if ru_target_aggregation_fn[target] == AggregationMethod.SUM:
return [pulp.lpSum(ru_vec)]
return [lpSum(ru_vec)]

if ru_target_aggregation_fn[target] == AggregationMethod.MAX:
return list(ru_vec)
Expand Down
8 changes: 4 additions & 4 deletions model_compression_toolkit/core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,14 +224,14 @@ def _set_final_resource_utilization(graph: Graph,
w_qcs = {n: n.final_weights_quantization_cfg for n in graph.nodes}
a_qcs = {n: n.final_activation_quantization_cfg for n in graph.nodes}
ru_calculator = ResourceUtilizationCalculator(graph, fw_impl, fw_info)
final_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized, BitwidthMode.MpCustom,
final_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized, BitwidthMode.QCustom,
act_qcs=a_qcs, w_qcs=w_qcs)

for ru_target, ru in final_ru.get_resource_utilization_dict().items():
if ru == 0:
Logger.warning(f"No relevant quantized layers for the ru target {ru_target} were found, the recorded "
f"final ru for this target would be 0.")
Logger.warning(f"No relevant quantized layers for the resource utilization target {ru_target} were found, "
f"the recorded final ru for this target would be 0.")

print(final_ru)
Logger.info(f'Resource utilization (of quantized targets):\n {str(final_ru)}.')
graph.user_info.final_resource_utilization = final_ru
graph.user_info.mixed_precision_cfg = final_bit_widths_config
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,10 @@ def get_max_resources_for_model(self, model):
attach2keras = AttachTpcToKeras()
tpc = attach2keras.attach(tpc, cc.quantization_config.custom_tpc_opset_to_layer)

return compute_resource_utilization_data(in_model=model,
representative_data_gen=self.representative_data_gen(),
core_config=cc,
tpc=tpc,
fw_info=DEFAULT_KERAS_INFO,
fw_impl=KerasImplementation(),
transformed_graph=None,
mixed_precision_enabled=False)
return compute_resource_utilization_data(in_model=model, representative_data_gen=self.representative_data_gen(),
core_config=cc, tpc=tpc, fw_info=DEFAULT_KERAS_INFO,
fw_impl=KerasImplementation(), transformed_graph=None,
mixed_precision_enable=False)

def get_quantization_config(self):
return mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.MSE,
Expand Down
2 changes: 1 addition & 1 deletion tests_pytest/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion tests_pytest/core/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion tests_pytest/core/common/mixed_precision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down

0 comments on commit ae3136b

Please sign in to comment.