Skip to content

Commit

Permalink
update docstrong
Browse files Browse the repository at this point in the history
  • Loading branch information
irenaby committed Jan 7, 2025
1 parent a3073db commit 627fef0
Showing 1 changed file with 64 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from copy import deepcopy
from enum import Enum, auto
from functools import lru_cache
from typing import Dict, Any, NamedTuple, Callable, Optional, Tuple, List, Iterable, Union, Literal
from typing import Dict, NamedTuple, Optional, Tuple, List, Iterable, Union, Literal

from model_compression_toolkit.constants import FLOAT_BITWIDTH
from model_compression_toolkit.core import FrameworkInfo
Expand Down Expand Up @@ -164,8 +164,10 @@ def compute_resource_utilization(self,
Args:
target_criterion: criterion to include targets for computation (applies to weights, activation).
bitwidth_mode: bit-width mode for computation.
act_qcs: activation quantization candidates for custom bit-width mode.
w_qcs: weights quantization candidates for custom bit-width mode.
act_qcs: activation quantization candidates for custom bit-width mode. Must provide configuration for all
configurable activations.
w_qcs: weights quantization candidates for custom bit-width mode. Must provide configuration for all
configurable weights.
metrics: metrics to include for computation. If None, all metrics are calculated.
Returns:
Expand Down Expand Up @@ -211,7 +213,8 @@ def compute_weights_utilization(self,
Args:
target_criterion: criterion to include targets for computation.
bitwidth_mode: bit-width mode for computation.
w_qcs: weights quantization config per node for custom bit mode. Must contain all configurable weights.
w_qcs: weights quantization config per node for the custom bit mode. Must provide configuration for all
configurable weights.
Returns:
- Total weights utilization.
Expand Down Expand Up @@ -247,10 +250,11 @@ def compute_node_weights_utilization(self,
n: node.
target_criterion: criterion to include weights for computation.
bitwidth_mode: bit-width mode for the computation.
qc: weight quantization config for custom bit mode computation. Must contain all configurable weights.
qc: weight quantization config for the custom bit mode computation. Must provide configuration for all
configurable weights.
Returns:
- Total utilization
- Total utilization.
- Detailed per weight utilization.
"""
weight_attrs = self._get_target_weight_attrs(n, target_criterion)
Expand Down Expand Up @@ -286,10 +290,11 @@ def compute_cut_activation_utilization(self,
Args:
target_criterion: criterion to include weights for computation.
bitwidth_mode: bit-width mode for the computation.
act_qcs: custom configuration for BitwidthMode.MpCustom. Must contain all configurable nodes.
act_qcs: custom configuration for the custom bit mode. Must provide configuration for all configurable
activations.
Returns:
- Total utilization
- Total utilization.
- Total utilization per cut.
- Detailed utilization per cut per node.
"""
Expand Down Expand Up @@ -333,10 +338,12 @@ def compute_activation_tensors_utilization(self,
Args:
target_criterion: criterion to include weights for computation.
bitwidth_mode: bit-width mode for the computation.
act_qcs: custom configuration for BitwidthMode.MpCustom. Must contain all configurable nodes.
include_reused: whether to consider reused nodes.
act_qcs: custom configuration for the custom bit mode. Must provide configuration for all configurable
activations.
include_reused: whether to include reused nodes.
Returns:
Total activation utilization and a dict containing utilization per node.
- Total activation utilization.
- Detailed utilization per node. Dict keys are nodes in a topological order.
"""
nodes = self._get_target_activation_nodes(target_criterion, include_reused=include_reused)
Expand All @@ -362,7 +369,7 @@ def compute_node_activation_tensor_utilization(self,
n: node.
target_criterion: criterion to include nodes for computation. If None, will skip the check.
bitwidth_mode: bit-width mode for the computation.
qc: activation quantization config for custom bit mode. Must be passed for a configurable activation.
qc: activation quantization config for the custom bit mode. Must be provided for a configurable activation.
Returns:
Node's activation utilization.
Expand Down Expand Up @@ -391,8 +398,10 @@ def compute_bops(self,
Args:
target_criterion: criterion to include nodes for computation.
bitwidth_mode: bit-width mode for computation.
act_qcs: activation quantization candidates for custom bit-width mode.
w_qcs: weights quantization candidates for custom bit-width mode.
act_qcs: activation quantization candidates for custom bit-width mode. Must provide configuration for all
configurable activations.
w_qcs: weights quantization candidates for custom bit-width mode. Must provide configuration for all
configurable weights.
Returns:
- Total BOPS count.
Expand Down Expand Up @@ -423,8 +432,10 @@ def compute_node_bops(self,
Args:
n: node.
bitwidth_mode: bit-width mode for the computation.
act_qcs: nodes activation quantization configuration for custom bit mode. Must contain all configurable nodes.
w_qc: weights quantization config for the node for custom bit mode. Must be passed for configurable weights.
act_qcs: nodes activation quantization configuration for the custom bit mode. Must provide configuration for all
configurable activations.
w_qc: weights quantization config for the node for the custom bit mode. Must provide configuration for all
configurable weights.
Returns:
BOPS count.
Expand Down Expand Up @@ -500,14 +511,14 @@ def _get_target_weight_nodes(self,

def _get_target_weight_attrs(self, n: BaseNode, target_criterion: TargetInclusionCriterion) -> List[str]:
"""
Filter node's weight attributes per criterion.
Collect weight attributes of a node per criterion.
Args:
n: node.
target_criterion: selection criterion.
Returns:
A list of selected weight attributes names.
Selected weight attributes names.
"""
weight_attrs = n.get_node_weights_attributes()
if target_criterion == TargetInclusionCriterion.QConfigurable:
Expand Down Expand Up @@ -539,12 +550,12 @@ def _get_target_activation_nodes(self,
Collect nodes to include in activation utilization computation.
Args:
target_criterion: inclusion for computation criteria.
target_criterion: criterion to include activations for computation.
include_reused: whether to include reused nodes.
nodes: nodes to filter target nodes from. By default, uses the graph nodes.
Returns:
Target nodes.
Selected nodes.
"""
nodes = nodes or self.graph.nodes
if target_criterion == TargetInclusionCriterion.QConfigurable:
Expand All @@ -560,55 +571,69 @@ def _get_target_activation_nodes(self,
return nodes

@staticmethod
def _get_activation_nbits(n: BaseNode, mode: BitwidthMode, qc: Optional[NodeActivationQuantizationConfig]) -> int:
def _get_activation_nbits(n: BaseNode,
bitwidth_mode: BitwidthMode,
act_qc: Optional[NodeActivationQuantizationConfig]) -> int:
"""
Get activation bit-width for a node with accordance to bit-width mode.
Get activation bit-width for a node according to the requested bit-width mode.
Args:
n: node.
mode: bit-width mode for computation.
qc: quantization candidate for BitwidthMode.MpCustom mode. Can be skipped if the node has exactly one candidate.
bitwidth_mode: bit-width mode for computation.
act_qc: quantization candidate for the custom bit mode. Must be provided for a configurable activation.
Returns:
Activation bit-width.
"""
if mode == BitwidthMode.Float:
if bitwidth_mode == BitwidthMode.Float:
return FLOAT_BITWIDTH

if mode in _bitwidth_mode_fn:
if bitwidth_mode in _bitwidth_mode_fn:
candidates_nbits = [c.activation_quantization_cfg.activation_n_bits for c in n.candidates_quantization_cfg]
return _bitwidth_mode_fn[mode](candidates_nbits)
return _bitwidth_mode_fn[bitwidth_mode](candidates_nbits)

if mode == BitwidthMode.MpCustom and qc:
return qc.activation_n_bits
if bitwidth_mode == BitwidthMode.MpCustom and act_qc:
return act_qc.activation_n_bits

if mode in [BitwidthMode.MpCustom, BitwidthMode.SpDefault]:
if bitwidth_mode in [BitwidthMode.MpCustom, BitwidthMode.SpDefault]:
qcs = n.get_unique_activation_candidates()
if len(qcs) != 1:
raise ValueError(f'Could not retrieve the default activation quantization candidate for node {n.name} '
f'as it has {len(qcs)}!=1 unique candidates .')
return qcs[0].activation_quantization_cfg.activation_n_bits

raise ValueError(f'Unknown mode {mode}')
raise ValueError(f'Unknown mode {bitwidth_mode}')

@staticmethod
def _get_weight_nbits(n, attr: str, bitwidth_mode: BitwidthMode,
w_qc: Optional[NodeWeightsQuantizationConfig]) -> int:
if bitwidth_mode == BitwidthMode.Float or not n.is_weights_quantization_enabled(attr):
def _get_weight_nbits(n, w_attr: str, bitwidth_mode: BitwidthMode, w_qc: Optional[NodeWeightsQuantizationConfig]) -> int:
"""
Get the bit-width of a specific weight of a node according to the requested bit-width mode.
Args:
n: node.
w_attr: weight attribute.
bitwidth_mode: bit-width mode for the computation.
w_qc: weights quantization config for the node for the custom bit mode. Must provide configuration for all
configurable weights.
Returns:
Weight bit-width.
"""
if bitwidth_mode == BitwidthMode.Float or not n.is_weights_quantization_enabled(w_attr):
return FLOAT_BITWIDTH

if bitwidth_mode == BitwidthMode.MpCustom and w_qc and w_qc.has_attribute_config(attr):
return w_qc.get_attr_config(attr).weights_n_bits
if bitwidth_mode == BitwidthMode.MpCustom and w_qc and w_qc.has_attribute_config(w_attr):
return w_qc.get_attr_config(w_attr).weights_n_bits

node_qcs = n.get_unique_weights_candidates(attr)
w_qcs = [qc.weights_quantization_cfg.get_attr_config(attr) for qc in node_qcs]
node_qcs = n.get_unique_weights_candidates(w_attr)
w_qcs = [qc.weights_quantization_cfg.get_attr_config(w_attr) for qc in node_qcs]
if bitwidth_mode in _bitwidth_mode_fn:
return _bitwidth_mode_fn[bitwidth_mode]([qc.weights_n_bits for qc in w_qcs])

if bitwidth_mode in [BitwidthMode.MpCustom, BitwidthMode.SpDefault]:
# if configuration was not passed and the weight has only one candidate, use it
if len(w_qcs) != 1:
raise ValueError(f'Could not retrieve the quantization candidate for attr {attr} of node {n.name} '
raise ValueError(f'Could not retrieve the quantization candidate for attr {w_attr} of node {n.name} '
f'as it {len(w_qcs)}!=1 unique candidates.')
return w_qcs[0].weights_n_bits

Expand Down

0 comments on commit 627fef0

Please sign in to comment.