Skip to content

Commit

Permalink
Merge with main
Browse files Browse the repository at this point in the history
  • Loading branch information
reuvenp committed Sep 23, 2024
2 parents 460b893 + 34d24a4 commit af16275
Show file tree
Hide file tree
Showing 69 changed files with 3,972 additions and 532 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/run_pytorch_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install torch==${{ inputs.torch-version }} torchvision onnx onnxruntime
pip install torch==${{ inputs.torch-version }} torchvision onnx onnxruntime onnxruntime-extensions
pip install pytest
- name: Run unittests
run: |
python -m unittest discover tests/pytorch_tests -v
pytest tests_pytest/pytorch
3 changes: 3 additions & 0 deletions .github/workflows/run_tests_suite_coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install coverage
pip install pytest
- name: Prepare TF env
run: pip install tensorflow==2.13.*
- name: Run tensorflow testsuite
Expand All @@ -32,6 +33,8 @@ jobs:
run: pip uninstall tensorflow -y && pip install torch==2.0.* torchvision onnx onnxruntime onnxruntime-extensions
- name: Run torch testsuite
run: coverage run --parallel-mode -m --omit "*__init__.py" --include "model_compression_toolkit/**/*.py" unittest tests/test_suite.py -v
- name: Run torch pytest
run: coverage run --parallel-mode -m --omit "*__init__.py" --include "model_compression_toolkit/**/*.py" pytest tests_pytest/pytorch
- name: Combine Multiple Coverage Files
run: coverage combine
- name: Run Coverage HTML
Expand Down
3 changes: 3 additions & 0 deletions model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self,
layer_class: type,
reuse: bool = False,
reuse_group: str = None,
inputs_as_list: bool = False,
quantization_attr: Dict[str, Any] = None,
has_activation: bool = True,
is_custom: bool = False
Expand All @@ -58,6 +59,7 @@ def __init__(self,
layer_class: Class path of the layer this node represents.
reuse: Whether this node was duplicated and represents a reused layer.
reuse_group: Name of group of nodes from the same reused layer.
inputs_as_list: Whether to pass the node its input tensors as a list or not when calling the layer.
quantization_attr: Attributes the node holds regarding how it should be quantized.
has_activation: Whether the node has activations that we might want to quantize.
is_custom: Whether the node is custom layer or not.
Expand All @@ -71,6 +73,7 @@ def __init__(self,
self.layer_class = layer_class
self.reuse = reuse
self.reuse_group = reuse_group
self.inputs_as_list = inputs_as_list
self.final_weights_quantization_cfg = None
self.final_activation_quantization_cfg = None
self.candidates_quantization_cfg = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ def __init__(self,
layer_class,
reuse,
reuse_group,
inputs_as_list,
quantization_attr,
has_activation=has_activation)

self.op_call_kwargs = op_call_kwargs
self.op_call_args = list(op_call_args)
self.functional_op = functional_op
self.inputs_as_list = inputs_as_list
self.tensor_input_allocs = [] if tensor_input_allocs is None else tensor_input_allocs

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,72 @@ def set_quantization_configuration_to_graph(graph: Graph,
return graph


def filter_node_qco_by_graph(node: BaseNode,
tpc: TargetPlatformCapabilities,
graph: Graph,
node_qc_options: QuantizationConfigOptions
) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
"""
Filter quantization config options that don't match the graph.
A node may have several quantization config options with 'activation_n_bits' values, and
the next nodes in the graph may support different bit-width as input activation. This function
filters out quantization config that don't comply to these attributes.
Args:
node: Node for filtering.
tpc: TPC to extract the QuantizationConfigOptions for the next nodes.
graph: Graph object.
node_qc_options: Node's QuantizationConfigOptions.
Returns:
A base config (OpQuantizationConfig) and a config options list (list of OpQuantizationConfig)
that are compatible with next nodes supported input bit-widths.
"""
# Filter quantization config options that don't match the graph.
_base_config = node_qc_options.base_config
_node_qc_options = node_qc_options.quantization_config_list

# Build next_nodes list by appending to the node's next nodes list all nodes that are quantization preserving.
_next_nodes = graph.get_next_nodes(node)
next_nodes = []
while len(_next_nodes):
n = _next_nodes.pop(0)
qco = n.get_qco(tpc)
qp = [qc.quantization_preserving for qc in qco.quantization_config_list]
if not all(qp) and any(qp):
Logger.error(f'Attribute "quantization_preserving" should be the same for all QuantizaionConfigOptions in {n}.')
if qp[0]:
_next_nodes.extend(graph.get_next_nodes(n))
next_nodes.append(n)

if len(next_nodes):
next_nodes_qc_options = [_node.get_qco(tpc) for _node in next_nodes]
next_nodes_supported_input_bitwidth = min([op_cfg.max_input_activation_n_bits
for qc_opts in next_nodes_qc_options
for op_cfg in qc_opts.quantization_config_list])

# Filter node's QC options that match next nodes input bit-width.
_node_qc_options = [_option for _option in _node_qc_options
if _option.activation_n_bits <= next_nodes_supported_input_bitwidth]
if len(_node_qc_options) == 0:
Logger.critical(f"Graph doesn't match TPC bit configurations: {node} -> {next_nodes}.")

# Verify base config match
if any([node_qc_options.base_config.activation_n_bits > qc_opt.base_config.max_input_activation_n_bits
for qc_opt in next_nodes_qc_options]):
# base_config activation bits doesn't match next node supported input bit-width -> replace with
# a qco from quantization_config_list with maximum activation bit-width.
if len(_node_qc_options) > 0:
output_act_bitwidth = {qco.activation_n_bits: i for i, qco in enumerate(_node_qc_options)}
_base_config = _node_qc_options[output_act_bitwidth[max(output_act_bitwidth)]]
Logger.warning(f"Node {node} base quantization config changed to match Graph and TPC configuration.\nCause: {node} -> {next_nodes}.")
else:
Logger.critical(f"Graph doesn't match TPC bit configurations: {node} -> {next_nodes}.") # pragma: no cover

return _base_config, _node_qc_options


def set_quantization_configs_to_node(node: BaseNode,
graph: Graph,
quant_config: QuantizationConfig,
Expand All @@ -99,7 +165,7 @@ def set_quantization_configs_to_node(node: BaseNode,
manual_bit_width_override (Optional[int]): Specifies a custom bit-width to override the node's activation bit-width. Defaults to None.
"""
node_qc_options = node.get_qco(tpc)
base_config, node_qc_options_list = node.filter_node_qco_by_graph(tpc, graph.get_next_nodes(node), node_qc_options)
base_config, node_qc_options_list = filter_node_qco_by_graph(node, tpc, graph, node_qc_options)

# If a manual_bit_width_override is given, filter node_qc_options_list to retain only the options with activation bits equal to manual_bit_width_override,
# and update base_config accordingly.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def _run_operation(self,
else:
# If operator expects a single input tensor, it cannot be a list as it should
# have a dtype field.
if len(input_tensors) == 1:
if len(input_tensors) == 1 and not n.inputs_as_list:
input_tensors = input_tensors[0]
out_tensors_of_n_float = op_func(input_tensors)

Expand Down
24 changes: 23 additions & 1 deletion model_compression_toolkit/core/keras/reader/node_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@
from keras.src.layers.core import TFOpLambda, SlicingOpLambda
from keras.src.engine.keras_tensor import KerasTensor
from keras.src.engine.node import Node as KerasNode
from keras.src.layers.merging.base_merge import _Merge
else:
from keras.layers.core import TFOpLambda, SlicingOpLambda
from keras.engine.keras_tensor import KerasTensor
from keras.engine.node import Node as KerasNode
from keras.layers.merging.base_merge import _Merge

from model_compression_toolkit.core.common.graph.base_node import BaseNode
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
Expand Down Expand Up @@ -287,6 +289,7 @@ def build_node(node: KerasNode,
for i, arg in enumerate(op_call_args[0]):
if is_const(arg):
weights.update({i: to_numpy(arg, is_single_tensor=True)})
inputs_as_list = __is_node_inputs_a_list(op_call_args, keras_layer)

node = BaseNode(node_name,
layer_config,
Expand All @@ -296,6 +299,7 @@ def build_node(node: KerasNode,
layer_class,
is_reused,
reuse_group,
inputs_as_list,
is_custom=is_keras_custom_layer(layer_class))

node_name_to_node[node_name] = node
Expand All @@ -316,6 +320,24 @@ def __is_functional_inputs_a_list(op_call_args: Any, keras_layer: Any) -> bool:
"""

return (keras_layer.symbol in
[TFOpLambda(tf.concat).symbol, TFOpLambda(tf.stack).symbol,TFOpLambda(tf.add_n).symbol] and
[TFOpLambda(tf.concat).symbol, TFOpLambda(tf.stack).symbol, TFOpLambda(tf.add_n).symbol] and
len(op_call_args) > 0 and
isinstance(op_call_args[0], list))


def __is_node_inputs_a_list(op_call_args: Any, keras_layer: Any) -> bool:
"""
Check whether the input tensors should be passed as a list or not. This is relevant
only for layers that inherit from _Merge such as Concatenate and Add.
Args:
op_call_args: Arguments list to check.
keras_layer: Keras layer.
Returns:
Whether the input tensors should be passed as a list or not.
"""

return (isinstance(keras_layer, _Merge) and
len(op_call_args) > 0 and
isinstance(op_call_args[0], (list, tuple)))
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,11 @@ def _run_operation(n: BaseNode,
_tensor_input_allocs = None

if isinstance(n, FunctionalNode) and n.inputs_as_list:
out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs)
if isinstance(op_func, PytorchQuantizationWrapper):
# in wrapped nodes, the op args & kwargs are already in the PytorchQuantizationWrapper.
out_tensors_of_n_float = op_func(*input_tensors)
else:
out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs)
else:
merged_inputs, functional_kwargs = _merge_inputs(n, input_tensors, op_call_args, functional_kwargs.copy(),
tensor_input_allocs=_tensor_input_allocs)
Expand Down
50 changes: 32 additions & 18 deletions model_compression_toolkit/core/pytorch/reader/graph_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,19 @@ def _build_input_alloc_and_call_args(n: Node, input_tensors_in_node_kwargs: Dict
tensor_input_alloc = []
op_call_args = list(n.args)
if inputs_as_list:
op_call_args.pop(0)
# input tensors are a list in the first argument -> remove from op_call_args and go over
# the tensors in that list.
_args = op_call_args.pop(0)
else:
for in_node in n.all_input_nodes:
# The extra for loop is used to tackle the case of the same input tensor for this node (e.g. torch.add(x, x)).
for i, arg in enumerate(n.args):
if arg == in_node:
tensor_input_alloc.append(i)
for k, arg in input_tensors_in_node_kwargs.items():
if arg == in_node:
tensor_input_alloc.append(k)
_args = n.args
for in_node in n.all_input_nodes:
# The extra for loop is used to tackle the case of the same input tensor for this node (e.g. torch.add(x, x)).
for i, arg in enumerate(_args):
if arg == in_node:
tensor_input_alloc.append(i)
for k, arg in input_tensors_in_node_kwargs.items():
if arg == in_node:
tensor_input_alloc.append(k)

return op_call_args, tensor_input_alloc

Expand Down Expand Up @@ -229,10 +232,19 @@ def nodes_builder(model: GraphModule,

# Add constants to weights dictionary.
if node.op != PLACEHOLDER:
for i, input_node in enumerate(node.all_input_nodes):
if input_node in consts_dict:
used_consts.add(input_node)
weights.update({i: consts_dict[input_node]})
if len(node.args) and isinstance(node.args[0], (list, tuple)):
# handle weights in nodes with list input. Especially when there's a duplicate of a tensor
# in the input list (e.g. torch.concat([const1, x, const2, x, const3], 1)).
for input_node in node.all_input_nodes:
for i, input_arg in enumerate(node.args[0]):
if input_node is input_arg and input_node in consts_dict:
used_consts.add(input_node)
weights.update({i: consts_dict[input_node]})
else:
for i, input_node in enumerate(node.all_input_nodes):
if input_node in consts_dict:
used_consts.add(input_node)
weights.update({i: consts_dict[input_node]})

# Extract input and output shapes of the node.
input_shape, output_shape = _extract_input_and_output_shapes(node)
Expand Down Expand Up @@ -278,11 +290,8 @@ def nodes_builder(model: GraphModule,
node_kwargs[k] = v

# Check if node's first input argument is a list of input fx nodes, such as torch.cat:
is_first_input_list_of_nodes = is_instance_first_arg(node, (list, tuple)) and all(
inputs_as_list = is_instance_first_arg(node, (list, tuple)) and all(
[isinstance(n, Node) for n in node.args[0]])
is_placeholder_a_list = is_instance_first_arg(node, Node) and \
node.args[0].op == PLACEHOLDER and node.args[0].meta[TYPE] in (list, tuple)
inputs_as_list = is_first_input_list_of_nodes or is_placeholder_a_list

# Build tensor_input_alloc required for the model builder. All input nodes are received as a list in the builder,
# so tensor_input_alloc is used to allocate each input tensor in the correct place in the node's args & kwargs.
Expand Down Expand Up @@ -360,7 +369,12 @@ def edges_builder(model: GraphModule,
if input_node in fx_node_2_graph_node:
# n_edges_for_input_node is for the case that the input node appears more than
# once as the input of the node, for example add(x, x)
n_edges_for_input_node = sum([1 for a in node.args if input_node == a])
if node in fx_node_2_graph_node and isinstance(fx_node_2_graph_node[node], FunctionalNode) and \
fx_node_2_graph_node[node].inputs_as_list:
_args = node.args[0]
else:
_args = node.args
n_edges_for_input_node = sum([1 for a in _args if input_node == a])
n_edges_for_input_node = max(n_edges_for_input_node, 1)

dst_index = node.all_input_nodes.index(input_node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@ def export(self) -> None:
else:
Logger.info(f"Exporting fake-quant onnx model: {self.save_model_path}")

model_input = to_torch_tensor(next(self.repr_dataset())[0])
model_input = to_torch_tensor(next(self.repr_dataset()))

if hasattr(self.model, 'metadata'):
onnx_bytes = BytesIO()
torch.onnx.export(self.model,
model_input,
tuple(model_input) if isinstance(model_input, list) else model_input,
onnx_bytes,
opset_version=self._onnx_opset_version,
verbose=False,
Expand All @@ -107,7 +107,7 @@ def export(self) -> None:
onnx.save_model(onnx_model, self.save_model_path)
else:
torch.onnx.export(self.model,
model_input,
tuple(model_input) if isinstance(model_input, list) else model_input,
self.save_model_path,
opset_version=self._onnx_opset_version,
verbose=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================

from typing import Tuple, Callable
from typing import Tuple, Callable, Union
from model_compression_toolkit.core import common
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.verify_packages import FOUND_TF
Expand All @@ -25,10 +25,12 @@
import tensorflow as tf
from tensorflow.keras.layers import Layer
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
from mct_quantizers import KerasQuantizationWrapper
from mct_quantizers import KerasActivationQuantizationHolder
from mct_quantizers.common.constants import OP_CALL_ARGS, OP_CALL_KWARGS

def _get_wrapper(node: common.BaseNode,
def _get_wrapper(node: Union[common.BaseNode, FunctionalNode],
layer: Layer,
fw_impl=None) -> Layer:
"""
Expand All @@ -45,9 +47,16 @@ def _get_wrapper(node: common.BaseNode,
# for positional weights we need to extract the weight's value.
weights_values = {attr: node.get_weights_by_keys(attr)
for attr in weights_quantizers if isinstance(attr, int)}
# When wrapping functional nodes, need to set call args\kwargs in wrapper, because they
# are used during wrapper call method.
func_node_kwargs = {OP_CALL_ARGS: node.op_call_args,
OP_CALL_KWARGS: node.op_call_kwargs
} if isinstance(node, FunctionalNode) else {}
return KerasQuantizationWrapper(layer,
weights_quantizers,
weights_values)
weights_values,
is_inputs_as_list=node.inputs_as_list,
**func_node_kwargs)
return layer


Expand Down
Loading

0 comments on commit af16275

Please sign in to comment.