Skip to content

Commit

Permalink
Increasing code coverage to 98% (#1140)
Browse files Browse the repository at this point in the history
Modified code and added tests to increase coverage to 98%.
Major changes include:

1. Remove unused functions (e.g., in base_graph, base_node, node_quantization_config...)
2. Extend mixed precision tests and fix Keras mixed precision tests that didn't run as expected.
3. Fix an issue with layers distance functions and axis retrieval for mixed precision metric computation.
4. Extend testing for const quantization, TPC, and LUT.
5. Exclude unnecessary lines from coverage.

---------

Co-authored-by: Ofir Gordon <[email protected]>
  • Loading branch information
ofirgo and Ofir Gordon authored Jul 31, 2024
1 parent ab9ef22 commit 9771975
Show file tree
Hide file tree
Showing 37 changed files with 576 additions and 333 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run_tests_suite_coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
runs-on: ubuntu-latest
timeout-minutes: 80
env:
COVERAGE_THRESHOLD: 97
COVERAGE_THRESHOLD: 98
steps:
- uses: actions/checkout@v2
- name: Install Python 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,24 +348,20 @@ def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s count_node_for_mixed_precision_interest_points method.') # pragma: no cover

def get_mp_node_distance_fn(self, layer_class: type,
framework_attrs: Dict[str, Any],
compute_distance_fn: Callable = None,
axis: int = None,
norm_mse: bool = False) -> Callable:
def get_mp_node_distance_fn(self, n: BaseNode,
compute_distance_fn: Callable = None,
norm_mse: bool = False) -> Tuple[Callable, int]:
"""
A mapping between layers' types and a distance function for computing the distance between
two tensors in mixed precision (for loss computation purposes). Returns a specific function if node of specific types is
given, or a default (normalized MSE) function otherwise.
Args:
layer_class: Class path of a model's layer.
framework_attrs: Framework attributes the layer had which the graph node holds.
n: Node to choose distance function for.
compute_distance_fn: An optional distance function to use globally for all nodes.
axis: The axis on which the operation is preformed (if specified).
norm_mse: whether to normalize mse distance function.
Returns: A distance function between two tensors.
Returns: A distance function between two tensors and a axis on which the distance is computed (if exists).
"""

raise NotImplemented(f'{self.__class__.__name__} have to implement the '
Expand Down
24 changes: 1 addition & 23 deletions model_compression_toolkit/core/common/graph/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def remove_node(self,

output_nodes = [ot.node for ot in self.get_outputs()] # get output nodes from namedtuples
if node_to_remove in output_nodes: # If node is in the graph's outputs, the outputs should be updated
if new_graph_outputs is None:
if new_graph_outputs is None: # pragma: no cover
Logger.critical(
f"{node_to_remove.name} is among the graph outputs; however, it cannot be removed without providing a new output.") # pragma: no cover
self.set_outputs(new_graph_outputs)
Expand Down Expand Up @@ -506,28 +506,6 @@ def out_edges(self,
output_edges.sort(key=lambda e: getattr(e, sort_by_attr))
return output_edges

def get_memory(self) -> float:
"""
Returns: Total memory consumption of the graph in bytes.
"""
memory = 0
for n in self.nodes:
memory += n.get_memory_bytes(self.fw_info)
return memory

def get_float_memory(self) -> float:
"""
Returns: Total memory consumption of the float graph in bytes.
"""
memory = 0
for n in self.nodes:
memory += n.get_float_memory_bytes(self.fw_info)
return memory

def get_configurable_sorted_nodes_names(self,
fw_info: FrameworkInfo,
include_reused_nodes: bool = False) -> List[str]:
Expand Down
30 changes: 2 additions & 28 deletions model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,18 +297,6 @@ def get_memory_bytes(self, fw_info) -> float:

return memory

def get_float_memory_bytes(self, fw_info) -> float:
"""
Compute the number of bytes the node's memory requires.
Args:
fw_info: Framework info to decide which attributes should be quantized.
Returns: Number of bytes the node's memory requires when in floating point (32 bit).
"""
q_params, f_params = self.get_num_parameters(fw_info)
return (f_params + q_params) * FP32_BYTES_PER_PARAMETER

def get_unified_weights_candidates_dict(self, fw_info) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -436,20 +424,6 @@ def get_total_output_params(self) -> float:

return sum([np.prod([x for x in output_shape if x is not None]) for output_shape in output_shapes])

def get_total_input_params(self) -> float:
"""
Calculates the total parameters in the node's input tensors.
Returns: Input size (i.e., total number of parameters).
"""

input_shapes = self.input_shape if isinstance(self.input_shape, List) else [self.input_shape]

# remove batch size (first element) from input shape
input_shapes = [s[1:] for s in input_shapes]

return sum([np.prod([x for x in input_shape if x is not None]) for input_shape in input_shapes])

def find_min_candidates_indices(self) -> List[int]:
"""
Returns a list with potential minimal candidates.
Expand Down Expand Up @@ -644,10 +618,10 @@ def get_simd(self) -> int:
if len(simd_list) > 1:
Logger.warning(f"More than one pruning SIMD option is available."
f" Min SIMD is used: {min(simd_list)}")
if len(simd_list) == 0:
if len(simd_list) == 0: # pragma: no cover
Logger.critical(f"No SIMD option is available for {self}")
_simd = min(simd_list)
if _simd <= 0 or int(_simd) != _simd:
if _simd <= 0 or int(_simd) != _simd: # pragma: no cover
Logger.critical(f"SIMD is expected to be a non-positive integer but found: {_simd}")
return _simd

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def __init__(self,
self.disable_activation_for_metric = disable_activation_for_metric
if self.quant_config.use_hessian_based_scores:
if not isinstance(hessian_info_service, HessianInfoService):
Logger.critical(f"When using Hessian-based approximations for sensitivity evaluation, a valid HessianInfoService object is required; found {type(hessian_info_service)}.")
Logger.critical(
f"When using Hessian-based approximations for sensitivity evaluation, a valid HessianInfoService object is required; found {type(hessian_info_service)}.")
self.hessian_info_service = hessian_info_service

self.sorted_configurable_nodes_names = graph.get_configurable_sorted_nodes_names(self.fw_info)
Expand All @@ -94,7 +95,8 @@ def __init__(self,
self.ips_distance_fns, self.ips_axis = self._init_metric_points_lists(self.interest_points, use_normalized_mse)

self.output_points = get_output_nodes_for_metric(graph)
self.out_ps_distance_fns, self.out_ps_axis = self._init_metric_points_lists(self.output_points, use_normalized_mse)
self.out_ps_distance_fns, self.out_ps_axis = self._init_metric_points_lists(self.output_points,
use_normalized_mse)

# Setting lists with relative position of the interest points
# and output points in the list of all mp model activation tensors
Expand Down Expand Up @@ -130,7 +132,8 @@ def __init__(self,
self.interest_points_hessians = self._compute_hessian_based_scores()
self.quant_config.distance_weighting_method = lambda d: self.interest_points_hessians

def _init_metric_points_lists(self, points: List[BaseNode], norm_mse: bool = False) -> Tuple[List[Callable], List[int]]:
def _init_metric_points_lists(self, points: List[BaseNode], norm_mse: bool = False) -> Tuple[
List[Callable], List[int]]:
"""
Initiates required lists for future use when computing the sensitivity metric.
Each point on which the metric is computed uses a dedicated distance function based on its type.
Expand All @@ -146,16 +149,12 @@ def _init_metric_points_lists(self, points: List[BaseNode], norm_mse: bool = Fal
distance_fns_list = []
axis_list = []
for n in points:
axis = n.framework_attr.get(AXIS) if not isinstance(n, FunctionalNode) else n.op_call_kwargs.get(AXIS)
distance_fn = self.fw_impl.get_mp_node_distance_fn(
layer_class=n.layer_class,
framework_attrs=n.framework_attr,
compute_distance_fn=self.quant_config.compute_distance_fn,
axis=axis,
norm_mse=norm_mse)
distance_fn, axis = self.fw_impl.get_mp_node_distance_fn(n,
compute_distance_fn=self.quant_config.compute_distance_fn,
norm_mse=norm_mse)
distance_fns_list.append(distance_fn)
# Axis is needed only for KL Divergence calculation, otherwise we use per-tensor computation
axis_list.append(axis if distance_fn==compute_kl_divergence else None)
axis_list.append(axis if distance_fn == compute_kl_divergence else None)
return distance_fns_list, axis_list

def compute_metric(self,
Expand Down Expand Up @@ -300,7 +299,8 @@ def _configure_node_bitwidth(self,
node_name = sorted_configurable_nodes_names[node_idx_to_configure]
layers_to_config = self.conf_node2layers.get(node_name, None)
if layers_to_config is None:
Logger.critical(f"Matching layers for node {node_name} not found in the mixed precision model configuration.") # pragma: no cover
Logger.critical(
f"Matching layers for node {node_name} not found in the mixed precision model configuration.") # pragma: no cover

for current_layer in layers_to_config:
self.set_layer_to_bitwidth(current_layer, mp_model_configuration[node_idx_to_configure])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self,
if activation_quantization_cfg is not None:
self.activation_quantization_cfg = activation_quantization_cfg
else:
if any(v is None for v in (qc, op_cfg, activation_quantization_fn, activation_quantization_params_fn)):
if any(v is None for v in (qc, op_cfg, activation_quantization_fn, activation_quantization_params_fn)): # pragma: no cover
Logger.critical(
"Missing required arguments to initialize a node activation quantization configuration. "
"Ensure QuantizationConfig, OpQuantizationConfig, activation quantization function, "
Expand All @@ -72,7 +72,7 @@ def __init__(self,
if weights_quantization_cfg is not None:
self.weights_quantization_cfg = weights_quantization_cfg
else:
if any(v is None for v in (qc, op_cfg, node_attrs_list)):
if any(v is None for v in (qc, op_cfg, node_attrs_list)): # pragma: no cover
Logger.critical("Missing required arguments to initialize a node weights quantization configuration. "
"Ensure QuantizationConfig, OpQuantizationConfig, weights quantization function, "
"parameters function, and weights attribute quantization config are provided.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,20 +178,6 @@ def set_activation_quantization_param(self,
for param_name, param_value in activation_params.items():
self.activation_quantization_params[param_name] = param_value

def has_activation_quantization_params(self) -> bool:
"""
Returns: Whether NodeQuantizationConfig has a activation quantization params or not.
"""
return len(self.activation_quantization_params) > 0

def no_quantization(self) -> bool:
"""
Returns: Whether NodeQuantizationConfig does not have activation params.
"""
return (not self.has_activation_quantization_params())

def __eq__(self, other: Any) -> bool:
"""
Compares the object to another object to find if they are equal.
Expand All @@ -203,7 +189,7 @@ def __eq__(self, other: Any) -> bool:
"""
if not isinstance(other, NodeActivationQuantizationConfig):
return False
return False # pragma: no cover

return self.activation_quantization_fn == other.activation_quantization_fn and \
self.activation_quantization_params_fn == other.activation_quantization_params_fn and \
Expand Down Expand Up @@ -340,14 +326,6 @@ def calculate_and_set_weights_params(self, tensor_data: np.ndarray, min_threshol
else:
self.set_weights_quantization_param({})

def has_weights_quantization_params(self) -> bool:
"""
Returns: Whether NodeQuantizationConfig has weights quantization params or not.
"""
return len(self.weights_quantization_params) > 0

def __eq__(self, other: Any) -> bool:
"""
Compares the object to another object to find if they are equal.
Expand All @@ -359,7 +337,7 @@ def __eq__(self, other: Any) -> bool:
"""
if not isinstance(other, WeightsAttrQuantizationConfig):
return False
return False # pragma: no cover

return self.weights_quantization_fn == other.weights_quantization_fn and \
self.weights_quantization_params_fn == other.weights_quantization_params_fn and \
Expand Down Expand Up @@ -419,11 +397,11 @@ def __init__(self, qc: QuantizationConfig,
# In Tensorflow, the attribute name is composed of the framework attribute name and the layer name,
# therefore, we need to look for the attribute in the op_cfg that is contained in the node attribute's name.
attrs_included_in_name = {k: v for k, v in op_cfg.attr_weights_configs_mapping.items() if k in attr}
if len(attrs_included_in_name) > 1:
Logger.error(f"Found multiple attribute in TPC OpConfig that are contained "
f"in the attribute name '{attr}'."
f"Please fix the TPC attribute names mapping such that each operator's attribute would "
f"have a unique matching name.")
if len(attrs_included_in_name) > 1: # pragma: no cover
Logger.critical(f"Found multiple attribute in TPC OpConfig that are contained "
f"in the attribute name '{attr}'."
f"Please fix the TPC attribute names mapping such that each operator's attribute would "
f"have a unique matching name.")
if len(attrs_included_in_name) == 0:
attr_cfg = op_cfg.default_weight_attr_config
else:
Expand All @@ -446,8 +424,8 @@ def get_attr_config(self, attr_name: Union[str, int]) -> WeightsAttrQuantization
Returns: An attribute quantization configuration.
"""
if attr_name is None:
Logger.error("Got 'None' attribute name for retrieving weights attribute quantization configuration.")
if attr_name is None: # pragma: no cover
Logger.critical("Got 'None' attribute name for retrieving weights attribute quantization configuration.")

if isinstance(attr_name, int):
# this is a positional attribute
Expand All @@ -463,8 +441,8 @@ def get_attr_config(self, attr_name: Union[str, int]) -> WeightsAttrQuantization
# If no attribute with the exact name then an error would be thrown
attr_cfg = self.attributes_config_mapping.get(attr_name)

if attr_cfg is None:
Logger.error(f"Weight attribute '{attr_name}' config could not be found.")
if attr_cfg is None: # pragma: no cover
Logger.critical(f"Weight attribute '{attr_name}' config could not be found.")

return attr_cfg

Expand Down Expand Up @@ -519,8 +497,8 @@ def _extract_config_for_attributes_with_name(self, attr_name) -> Dict[str, Weigh
f"{list(attrs_with_name.keys())}.")
return attrs_with_name

def set_quant_config_attr(self, config_parameter_name: str, config_parameter_value: Any, attr_name: str = None,
*args: List[Any], **kwargs: Dict[str, Any]):
def set_quant_config_attr(self, config_parameter_name: str, config_parameter_value: Any,
attr_name: Union[str, int] = None, *args: List[Any], **kwargs: Dict[str, Any]):
"""
This method overrides the parent class set_quant_config_attr to enable setting a specific weights
attribute config parameter.
Expand All @@ -546,8 +524,8 @@ def set_quant_config_attr(self, config_parameter_name: str, config_parameter_val
else:
Logger.warning(f"Parameter {config_parameter_name} could not be found in the node quantization config of "
f"weights attribute {attr_name} and was not updated!")
else:
Logger.error(f"Weights attribute {attr_name} could not be found to set parameter {config_parameter_name}.")
else: # pragma: no cover
Logger.critical(f"Weights attribute {attr_name} could not be found to set parameter {config_parameter_name}.")

def __eq__(self, other: Any) -> bool:
"""
Expand All @@ -560,7 +538,7 @@ def __eq__(self, other: Any) -> bool:
"""
if not isinstance(other, NodeWeightsQuantizationConfig):
return False
return False # pragma: no cover

return self.min_threshold == other.min_threshold and \
self.simd_size == other.simd_size and \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig, \
WeightsAttrQuantizationConfig

# If the quantization config does not contain kernel channel mapping or the weights
# quantization is not per-channel, we use a dummy channel mapping.
dummy_channel_mapping = DefaultDict(default_value=(None, None))
from model_compression_toolkit.logger import Logger


def get_weights_qparams(weights_attr_values: np.ndarray,
Expand Down Expand Up @@ -64,29 +61,10 @@ def get_weights_qparams(weights_attr_values: np.ndarray,
node=node,
hessian_info_service=hessian_info_service,
num_hessian_samples=num_hessian_samples)
else:
else: # pragma: no cover
Logger.error(f"Requested weights quantization parameters computation for node {node.name} without providing a "
f"weights_quantization_params_fn."
f"Returning an empty dictionary since no quantization parameters were computed.")
weights_params = {}

return weights_params, output_channels_axis


def _get_kernel_channels_mapping(fw_info:FrameworkInfo,
use_dummy: bool) -> DefaultDict:
"""
Get a kernel channel mapping from the framework info, or use dummy mapping (which returns a
tuple of Nones) if use_use_dummy is True.
Args:
fw_info: Framework info which contains a kernel channels mapping.
use_dummy: Whether to use a dummy mapping or not.
Returns:
Kernel channels mapping.
"""

# Set a kernel channels mapping
if use_dummy: # If kernel mapping is missing, we use a dummy channels mapping
kernel_channels_mapping = dummy_channel_mapping
else:
kernel_channels_mapping = fw_info.kernel_channels_mapping
return kernel_channels_mapping
Loading

0 comments on commit 9771975

Please sign in to comment.