Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Target Platform Capabilities - Phase 3 #1297

Merged
merged 7 commits into from
Dec 22, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion model_compression_toolkit/core/common/graph/base_graph.py
Original file line number Diff line number Diff line change
@@ -105,7 +105,7 @@ def set_tpc(self,
Logger.critical(f'MCT does not support optimizing Keras custom layers. Found a layer of type {n.type}. '
' Please add the custom layer to Target Platform Capabilities (TPC), or file a feature '
'request or an issue if you believe this should be supported.') # pragma: no cover
if any([qc.default_weight_attr_config.enable_weights_quantization for qc in n.get_qco(tpc).quantization_config_list]):
if any([qc.default_weight_attr_config.enable_weights_quantization for qc in n.get_qco(tpc).quantization_configurations]):
Logger.critical(f'Layer identified: {n.type}. MCT does not support weight quantization for Keras custom layers.') # pragma: no cover

self.tpc = tpc
6 changes: 3 additions & 3 deletions model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
@@ -582,12 +582,12 @@ def filter_node_qco_by_graph(self, tpc: TargetPlatformCapabilities,
"""
# Filter quantization config options that don't match the graph.
_base_config = node_qc_options.base_config
_node_qc_options = node_qc_options.quantization_config_list
_node_qc_options = node_qc_options.quantization_configurations
if len(next_nodes):
next_nodes_qc_options = [_node.get_qco(tpc) for _node in next_nodes]
next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg)
for qc_opts in next_nodes_qc_options
for op_cfg in qc_opts.quantization_config_list])
for op_cfg in qc_opts.quantization_configurations])

# Filter node's QC options that match next nodes input bit-width.
_node_qc_options = [_option for _option in _node_qc_options
@@ -599,7 +599,7 @@ def filter_node_qco_by_graph(self, tpc: TargetPlatformCapabilities,
if any([node_qc_options.base_config.activation_n_bits > max_input_activation_n_bits(qc_opt.base_config)
for qc_opt in next_nodes_qc_options]):
# base_config activation bits doesn't match next node supported input bit-width -> replace with
# a qco from quantization_config_list with maximum activation bit-width.
# a qco from quantization_configurations with maximum activation bit-width.
if len(_node_qc_options) > 0:
output_act_bitwidth = {qco.activation_n_bits: i for i, qco in enumerate(_node_qc_options)}
_base_config = _node_qc_options[output_act_bitwidth[max(output_act_bitwidth)]]
Original file line number Diff line number Diff line change
@@ -101,15 +101,15 @@ def filter_node_qco_by_graph(node: BaseNode,
"""
# Filter quantization config options that don't match the graph.
_base_config = node_qc_options.base_config
_node_qc_options = node_qc_options.quantization_config_list
_node_qc_options = node_qc_options.quantization_configurations

# Build next_nodes list by appending to the node's next nodes list all nodes that are quantization preserving.
_next_nodes = graph.get_next_nodes(node)
next_nodes = []
while len(_next_nodes):
n = _next_nodes.pop(0)
qco = n.get_qco(tpc)
qp = [qc.quantization_preserving for qc in qco.quantization_config_list]
qp = [qc.quantization_preserving for qc in qco.quantization_configurations]
if not all(qp) and any(qp):
Logger.error(f'Attribute "quantization_preserving" should be the same for all QuantizaionConfigOptions in {n}.')
if qp[0]:
@@ -120,7 +120,7 @@ def filter_node_qco_by_graph(node: BaseNode,
next_nodes_qc_options = [_node.get_qco(tpc) for _node in next_nodes]
next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg)
for qc_opts in next_nodes_qc_options
for op_cfg in qc_opts.quantization_config_list])
for op_cfg in qc_opts.quantization_configurations])

# Filter node's QC options that match next nodes input bit-width.
_node_qc_options = [_option for _option in _node_qc_options
@@ -132,7 +132,7 @@ def filter_node_qco_by_graph(node: BaseNode,
if any([node_qc_options.base_config.activation_n_bits > max_input_activation_n_bits(qc_opt.base_config)
for qc_opt in next_nodes_qc_options]):
# base_config activation bits doesn't match next node supported input bit-width -> replace with
# a qco from quantization_config_list with maximum activation bit-width.
# a qco from quantization_configurations with maximum activation bit-width.
if len(_node_qc_options) > 0:
output_act_bitwidth = {qco.activation_n_bits: i for i, qco in enumerate(_node_qc_options)}
_base_config = _node_qc_options[output_act_bitwidth[max(output_act_bitwidth)]]
Original file line number Diff line number Diff line change
@@ -392,7 +392,7 @@ def shift_negative_function(graph: Graph,
bypass_candidate_qc.activation_quantization_cfg.activation_quantization_params[SIGNED] = False
graph.shift_stats_collector(bypass_node, np.array(shift_value))

add_node_qco = add_node.get_qco(graph.tpc).quantization_config_list
add_node_qco = add_node.get_qco(graph.tpc).quantization_configurations
for op_qc_idx, candidate_qc in enumerate(add_node.candidates_quantization_cfg):
for attr in add_node.get_node_weights_attributes():
candidate_qc.weights_quantization_cfg.get_attr_config(attr).enable_weights_quantization = False
@@ -535,7 +535,7 @@ def apply_shift_negative_correction(graph: Graph,
# Skip substitution if QuantizationMethod is uniform.
node_qco = n.get_qco(graph.tpc)
if any([op_qc.activation_quantization_method is QuantizationMethod.UNIFORM
for op_qc in node_qco.quantization_config_list]):
for op_qc in node_qco.quantization_configurations]):
continue

if snc_node_types.apply(n):
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import model_compression_toolkit.target_platform_capabilities.schema.v1 as schema

OperatorSetNames = schema.OperatorSetNames
Signedness = schema.Signedness
AttributeQuantizationConfig = schema.AttributeQuantizationConfig
OpQuantizationConfig = schema.OpQuantizationConfig
Original file line number Diff line number Diff line change
@@ -64,10 +64,10 @@ def get_default_op_quantization_config(tp_model: TargetPlatformModel) -> OpQuant
Raises:
AssertionError: If the default quantization configuration list contains more than one configuration option.
"""
assert len(tp_model.default_qco.quantization_config_list) == 1, \
assert len(tp_model.default_qco.quantization_configurations) == 1, \
f"Default quantization configuration options must contain only one option, " \
f"but found {len(tp_model.default_qco.quantization_config_list)} configurations." # pragma: no cover
return tp_model.default_qco.quantization_config_list[0]
f"but found {len(tp_model.default_qco.quantization_configurations)} configurations." # pragma: no cover
return tp_model.default_qco.quantization_configurations[0]


def is_opset_in_model(tp_model: TargetPlatformModel, opset_name: str) -> bool:
@@ -82,7 +82,10 @@ def is_opset_in_model(tp_model: TargetPlatformModel, opset_name: str) -> bool:
bool: True if an OperatorsSet with the given name exists in the target platform model,
otherwise False.
"""
return opset_name in [x.name for x in tp_model.operator_set]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return tp_model.operator_set is not None and opset_name in [x.name for x in tp_model.operator_set]

if tp_model.operator_set is None:
return False
else:
return opset_name in [x.name for x in tp_model.operator_set]


def get_opset_by_name(tp_model: TargetPlatformModel, opset_name: str) -> Optional[OperatorsSetBase]:
Loading