Skip to content

Commit

Permalink
PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
elad-c committed Sep 15, 2024
1 parent fb416db commit 2a7bfb7
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ def set_quantization_configuration_to_graph(graph: Graph,
return graph


def filter_node_qco_by_graph(node: BaseNode, tpc: TargetPlatformCapabilities,
graph: Graph, node_qc_options: QuantizationConfigOptions
def filter_node_qco_by_graph(node: BaseNode,
tpc: TargetPlatformCapabilities,
graph: Graph,
node_qc_options: QuantizationConfigOptions
) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
"""
Filter quantization config options that don't match the graph.
Expand All @@ -95,6 +97,8 @@ def filter_node_qco_by_graph(node: BaseNode, tpc: TargetPlatformCapabilities,
node_qc_options: Node's QuantizationConfigOptions.
Returns:
A base config (OpQuantizationConfig) and a config options list (list of OpQuantizationConfig)
that are compatible with next nodes supported input bit-widths.
"""
# Filter quantization config options that don't match the graph.
Expand Down Expand Up @@ -124,7 +128,7 @@ def filter_node_qco_by_graph(node: BaseNode, tpc: TargetPlatformCapabilities,
_node_qc_options = [_option for _option in _node_qc_options
if _option.activation_n_bits <= next_nodes_supported_input_bitwidth]
if len(_node_qc_options) == 0:
Logger.critical(f"Graph doesn't match TPC bit configurations: {node} -> {next_nodes}.") # pragma: no cover
Logger.critical(f"Graph doesn't match TPC bit configurations: {node} -> {next_nodes}.")

# Verify base config match
if any([node_qc_options.base_config.activation_n_bits > qc_opt.base_config.max_input_activation_n_bits
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ def generate_keras_tpc(name: str, tp_model: tp.TargetPlatformModel):
tf.compat.v1.gather,
tf.__operators__.getitem,
tf.strided_slice]
quantization_preserving_list_16bit_input = [Identity,
tf.identity,
Reshape,
quantization_preserving_list_16bit_input = [Reshape,
tf.reshape,
Permute,
tf.transpose,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ def generate_pytorch_tpc(name: str, tp_model: tp.TargetPlatformModel):
unbind,
gather,
MaxPool2d])
tp.OperationsSetToLayers("QuantizationPreserving16BitInout", [Identity,
Flatten,
tp.OperationsSetToLayers("QuantizationPreserving16BitInout", [Flatten,
flatten,
operator.getitem,
reshape,
Expand Down

0 comments on commit 2a7bfb7

Please sign in to comment.