Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Op2d->add_const collapse substitution #878

Merged
merged 3 commits into from
Dec 5, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -235,6 +235,14 @@ def get_linear_collapsing_substitution(self) -> common.BaseSubstitution:
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_linear_collapsing_substitution method.') # pragma: no cover

@abstractmethod
def get_op2d_add_const_collapsing_substitution(self) -> common.BaseSubstitution:
"""
Returns: conv2d add const collapsing substitution
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_op2d_add_const_collapsing_substitution method.') # pragma: no cover

@abstractmethod
def get_substitutions_statistics_correction(self, quant_config: QuantizationConfig) -> \
List[common.BaseSubstitution]:
11 changes: 10 additions & 1 deletion model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
@@ -79,7 +79,8 @@ def __init__(self,
def type(self):
"""
A function to get the node's layer_class op for convenient comparison
:return: the node's layer_class
Returns:
the node's layer_class
"""
return self.layer_class

@@ -130,6 +131,14 @@ def __repr__(self):
"""
return f'{self.type.__name__}:{self.name}'

def is_reused(self) -> bool:
"""
Check whether the node is reused or not
Returns:
True if node is reused, else False
"""
return self.reuse or self.reuse_group is not None

def get_weights_by_keys(self, name: str) -> np.ndarray:
"""
Get a node's weight by its name.
Original file line number Diff line number Diff line change
@@ -93,7 +93,7 @@ def substitute(self,

# If the linear operator is part of a reused group (it is the "base" node, or a reused node),
# we should skip the substitution.
if conv_node.reuse or conv_node.reuse_group is not None:
if conv_node.is_reused():
return graph

bn_node = edge_nodes[1]
@@ -230,7 +230,7 @@ def substitute(self,

# If the linear operator is part of a reused group (it is the "base" node, or a reused node),
# we should skip the substitution.
if conv_node.reuse or conv_node.reuse_group is not None or bn_node.reuse or bn_node.reuse_group is not None:
if conv_node.is_reused() or bn_node.is_reused():
return graph

if len(graph.get_next_nodes(bn_node)) > 1 or len(graph.get_prev_nodes(conv_node)) > 1:
Original file line number Diff line number Diff line change
@@ -79,7 +79,7 @@ def substitute(self,

# If the linear operator is part of a reused group (it is the "base" node, or a reused node),
# we should skip the substitution.
if source_node.reuse or source_node.reuse_group is not None:
if source_node.is_reused():
for qc in source_node.candidates_quantization_cfg:
qc.weights_quantization_cfg.weights_second_moment_correction = False
return graph
Original file line number Diff line number Diff line change
@@ -102,7 +102,7 @@ def substitute(self,

# If the linear operator is part of a reused group (it is the "base" node, or a reused node),
# we should skip the substitution.
if source_node.reuse or source_node.reuse_group is not None:
if source_node.is_reused():
Logger.exception("If the linear operator is part of a reused group we should skip the the BN folding "
"substitution and SMC feature") # pragma: no cover

Original file line number Diff line number Diff line change
@@ -91,14 +91,11 @@ def substitute(self,
Graph after applying the substitution.
"""

first_node = edge_nodes[0]
second_node = edge_nodes[1]
first_node, second_node, _ = edge_nodes

# If the linear operator is part of a reused group (it is the "base" node, or a reused node),
# we should skip the substitution.
if first_node.reuse or first_node.reuse_group is not None:
return graph
if second_node.reuse or second_node.reuse_group is not None:
if first_node.is_reused() or second_node.is_reused():
return graph

# If there is an extra connection between these two nodes skip the substitution
@@ -182,3 +179,83 @@ def substitute(self,
assert num_edges_before_substition - len(graph.edges) == 1

return graph


class Op2DAddConstCollapsing(common.BaseSubstitution):
"""
Collapse Add-const into preceding Op2D (Not non-linear activation between them)
"""
def __init__(self,
first_node: NodeOperationMatcher,
second_node: NodeOperationMatcher,
op2d_collapsing_fn: Callable,
bias_str: str,
use_bias_str: str,
layer_name_str: str = None):
"""
Collapsing Add-const node (2nd node) to Op2D node (first node).
Args:
first_node: Node matcher for Op2d type nodes.
second_node: Node matcher for add type nodes.
op2d_collapsing_fn: Function for updating the convolution kernel and bias
bias_str: The framework specific attribute name of the convolution layer's bias.
use_bias_str: The framework specific attribute name of the convolution layer's bias flag.
layer_name_str: The framework specific attribute name of layer's name.
"""
super().__init__(matcher_instance=EdgeMatcher(first_node, second_node))
self.op2d_collapsing_fn = op2d_collapsing_fn
self.bias_str = bias_str
self.use_bias_str = use_bias_str
self.layer_name_str = layer_name_str

def substitute(self,
graph: Graph,
edge_nodes: Tuple[BaseNode, BaseNode]) -> Graph:
"""
Collapse linear layer into preceding linear layers.
Convolution condition:
|-------------------------| |------|
| Op2D | ---> | Add-const | -> | Op2D |
|-------------------------| |------|
Args:
graph: Graph we apply the substitution on.
edge_nodes: Tuple of linear node and add nodes
Returns:
Graph after applying the substitution.
"""

first_node, second_node, _ = edge_nodes

# If the linear operator is part of a reused group (it is the "base" node, or a reused node),
# we should skip the substitution.
if first_node.is_reused() or second_node.is_reused():
return graph

# If there is an extra connection between these two nodes skip the substitution
if len(graph.get_next_nodes(first_node)) > 1 or len(graph.get_prev_nodes(second_node)) > 1:
return graph

# New collapsed bias
bias = self.op2d_collapsing_fn(first_node, second_node, self.bias_str)

# New collapsed node
op2d_collapsed = copy.deepcopy(first_node)
op2d_collapsed_name = first_node.name + '_collapsed'
op2d_collapsed.name = op2d_collapsed_name
op2d_collapsed.framework_attr[self.use_bias_str] = True
op2d_collapsed.set_weights_by_keys(self.bias_str, bias)

if self.layer_name_str is not None:
op2d_collapsed.framework_attr[self.layer_name_str] = op2d_collapsed_name

# Update graph
graph.add_node(op2d_collapsed)
graph.reconnect_out_edges(current_node=second_node, new_node=op2d_collapsed)
graph.reconnect_in_edges(current_node=first_node, new_node=op2d_collapsed)
graph.replace_output_node(current_node=second_node, new_node=op2d_collapsed)

graph.remove_edge(first_node, second_node)
graph.remove_node(first_node)
graph.remove_node(second_node)

return graph
Original file line number Diff line number Diff line change
@@ -30,6 +30,9 @@ def linear_collapsing_substitute(graph: common.Graph,
Returns:
Transformed graph after applying all linear collapsing substitutions.
"""
# TODO: remove this if after adding Op2d-add_const collapse substitution in PyTorch
if linear_collapsing_substitution is None:
return graph
matched_nodes = graph.filter(linear_collapsing_substitution.matcher_instance)
matched_nodes_list = []
match_indicator = True
Original file line number Diff line number Diff line change
@@ -63,9 +63,7 @@ def substitute(self,

# If the linear operator is part of a reused group (it is the "base" node, or a reused node),
# we should skip the substitution.
if first_node.reuse or first_node.reuse_group is not None:
return graph
if second_node.reuse or second_node.reuse_group is not None:
if first_node.is_reused() or second_node.is_reused():
return graph

# Check if convolution and residual satisfy the collapsing conditions, otherwise skip substitution
1 change: 1 addition & 0 deletions model_compression_toolkit/core/graph_prep_runner.py
Original file line number Diff line number Diff line change
@@ -129,6 +129,7 @@ def get_finalized_graph(initial_graph: Graph,
transformed_graph = substitute(graph, fw_impl.get_substitutions_pre_statistics_collection(quant_config))
if quant_config.linear_collapsing:
transformed_graph = linear_collapsing_substitute(transformed_graph, fw_impl.get_linear_collapsing_substitution())
transformed_graph = linear_collapsing_substitute(transformed_graph, fw_impl.get_op2d_add_const_collapsing_substitution())
if quant_config.residual_collapsing:
transformed_graph = substitute(transformed_graph, fw_impl.get_residual_collapsing_substitution())

Original file line number Diff line number Diff line change
@@ -15,10 +15,14 @@
from typing import Tuple
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Conv2D
if tf.__version__ < "2.6":
from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, Conv2DTranspose, Dense
else:
from keras.layers import Conv2D, DepthwiseConv2D, Conv2DTranspose, Dense

from model_compression_toolkit.core.common import BaseNode
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher, NodeFrameworkAttrMatcher
from model_compression_toolkit.core.common.substitutions.linear_collapsing import Conv2DCollapsing
from model_compression_toolkit.core.common.substitutions.linear_collapsing import Conv2DCollapsing, Op2DAddConstCollapsing
from model_compression_toolkit.core.keras.constants import KERNEL, KERNEL_SIZE, STRIDES, DILATIONS, LINEAR, \
ACTIVATION, BIAS, USE_BIAS, LAYER_NAME, FILTERS, PADDING, GROUPS, DATA_FORMAT
from model_compression_toolkit.logger import Logger
@@ -123,3 +127,69 @@ def keras_linear_collapsing() -> Conv2DCollapsing:
FILTERS,
data_format_str=DATA_FORMAT,
layer_name_str=LAYER_NAME)


def op2d_add_const_collapsing_node_matchers() -> Tuple[NodeOperationMatcher, NodeOperationMatcher]:
"""
Function generates matchers for matching:
(Op2D, Add(const)) -> Op2D. (Op2D is one of [DepthwiseConv2D, Conv2D, Conv2DTranspose, Dense)
Returns:
Matcher for Op2D followed by Add const
"""
first_node = NodeOperationMatcher(DepthwiseConv2D) | \
NodeOperationMatcher(Conv2D) | \
NodeOperationMatcher(Conv2DTranspose) | \
NodeOperationMatcher(Dense)
second_node = NodeOperationMatcher(tf.math.add)
return first_node, second_node


def op2d_add_const_collapsing_fn(op2d_node: BaseNode,
add_node: BaseNode,
bias_str: str) -> np.ndarray:
"""
Collapsing Add-Const to previous node's bias
Args:
op2d_node: Op2d layer node
add_node: Add layer to collapse
bias_str: The framework specific attribute name of the convolution layer's bias.
Returns:
The modified conv layer node's bias
"""
bias = op2d_node.get_weights_by_keys(bias_str)

# read constant from add node
if len(add_node.op_call_args) > 0:
const = add_node.op_call_args[0]
elif 'y' in add_node.op_call_kwargs:
const = add_node.op_call_kwargs['y']
else:
Logger.error(f'Unable to read constant from add node: {add_node.name}') # pragma: no cover

# convert constant to numpy array
if isinstance(const, tf.Tensor):
const = const.numpy()
elif isinstance(const, list):
const = np.array(const)
else:
Logger.error(f'Unable to convert constant to numpy array: {add_node.name}') # pragma: no cover

# return new bias
if bias is None:
return const
else:
return const + bias


def keras_op2d_add_const_collapsing() -> Op2DAddConstCollapsing:
"""
Returns:
An Op2DCollapsing initialized for Keras models.
"""
first_node, second_node = op2d_add_const_collapsing_node_matchers()
return Op2DAddConstCollapsing(first_node,
second_node,
op2d_add_const_collapsing_fn,
BIAS,
USE_BIAS,
layer_name_str=LAYER_NAME)
8 changes: 7 additions & 1 deletion model_compression_toolkit/core/keras/keras_implementation.py
Original file line number Diff line number Diff line change
@@ -75,7 +75,7 @@
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.batchnorm_refusing import \
keras_batchnorm_refusing
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.linear_collapsing import \
keras_linear_collapsing
keras_linear_collapsing, keras_op2d_add_const_collapsing
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.residual_collapsing import \
keras_residual_collapsing
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.input_scaling import InputScaling, \
@@ -311,6 +311,12 @@ def get_linear_collapsing_substitution(self) -> common.BaseSubstitution:
"""
return keras_linear_collapsing()

def get_op2d_add_const_collapsing_substitution(self) -> common.BaseSubstitution:
"""
Returns: Op2d add-const collapsing substitution
"""
return keras_op2d_add_const_collapsing()

def get_substitutions_post_statistics_collection(self, quant_config: QuantizationConfig) \
-> List[common.BaseSubstitution]:
"""
Original file line number Diff line number Diff line change
@@ -289,6 +289,12 @@ def get_linear_collapsing_substitution(self) -> common.BaseSubstitution:
"""
return pytorch_linear_collapsing()

def get_op2d_add_const_collapsing_substitution(self) -> common.BaseSubstitution:
"""
Returns: None, as Op2d add-const substitution is not supported in torch yet
"""
return None

def get_substitutions_post_statistics_collection(self,
quant_config: QuantizationConfig) -> List[common.BaseSubstitution]:
"""
Original file line number Diff line number Diff line change
@@ -14,9 +14,13 @@
# ==============================================================================

from abc import ABC
from packaging import version
import model_compression_toolkit as mct
import tensorflow as tf
from tensorflow.keras.layers import Conv2D
if version.parse(tf.__version__) >= version.parse("2.13"):
from keras.src.layers.core import TFOpLambda
else:
from keras.layers.core import TFOpLambda

from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model
@@ -51,7 +55,7 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
y = float_model.predict(input_x)
y_hat = quantized_model.predict(input_x)
self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!')
self.unit_test.assertTrue(len([l for l in quantized_model.layers if isinstance(l, KerasTrainableQuantizationWrapper) and isinstance(l.layer, Conv2D)]) < len([l for l in float_model.layers if isinstance(l, Conv2D)]), msg=f'fail number of layers should decrease!')
self.unit_test.assertTrue(len([l for l in quantized_model.layers if isinstance(l, KerasTrainableQuantizationWrapper) and isinstance(l.layer, layers.Conv2D)]) < len([l for l in float_model.layers if isinstance(l, layers.Conv2D)]), msg=f'fail number of layers should decrease!')
cs = cosine_similarity(y, y_hat)
self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check:{cs}')

@@ -69,7 +73,7 @@ def create_networks(self):
def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
super().compare(quantized_model, float_model, input_x, quantization_info)
for layer in quantized_model.layers:
if type(layer) == Conv2D:
if type(layer) == layers.Conv2D:
self.unit_test.assertTrue(len(layer.weights) == 2, msg=f'fail Bias should appear in weights!!')

class ThreeConv2DCollapsingTest(BaseConv2DCollapsingTest):
@@ -86,7 +90,7 @@ def create_networks(self):
def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
super().compare(quantized_model, float_model, input_x, quantization_info)
for layer in quantized_model.layers:
if type(layer) == Conv2D:
if type(layer) == layers.Conv2D:
self.unit_test.assertTrue(len(layer.weights) == 1,msg=f'fail Bias should not appear in weights!!')


@@ -105,9 +109,10 @@ def create_networks(self):
def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
super().compare(quantized_model, float_model, input_x, quantization_info)
for layer in quantized_model.layers:
if type(layer) == Conv2D:
if type(layer) == layers.Conv2D:
self.unit_test.assertTrue(len(layer.weights) == 2,msg=f'fail Bias should appear in weights!!')


class SixConv2DCollapsingTest(BaseConv2DCollapsingTest):
def __init__(self, unit_test):
super().__init__(unit_test)
@@ -125,5 +130,113 @@ def create_networks(self):
def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
super().compare(quantized_model, float_model, input_x, quantization_info)
for layer in quantized_model.layers:
if type(layer) == Conv2D:
self.unit_test.assertTrue(len(layer.weights) == 2,msg=f'fail Bias should appear in weights!!')
if type(layer) == layers.Conv2D:
self.unit_test.assertTrue(len(layer.weights) == 2,msg=f'fail Bias should appear in weights!!')


class Op2DAddConstCollapsingTest(BaseConv2DCollapsingTest):
def __init__(self, unit_test):
super().__init__(unit_test)

def create_networks(self):
inputs = layers.Input(shape=self.get_input_shapes()[0][1:])
# ########
# Cond2D #
# ########
# Collapse Conv2D with bias
x = layers.Conv2D(filters=7, kernel_size=(5, 5), strides=(1, 1), padding='same',
use_bias=True, bias_initializer='glorot_uniform')(inputs)
x = tf.math.add(x, tf.constant(np.random.normal(size=x.shape[-1]), dtype=x.dtype))
x = layers.ReLU()(x)

# Collapse Conv2D without bias, const first argument of tf.math.add
x = layers.Conv2D(filters=5, kernel_size=(5, 5), strides=(1, 1), padding='same',
use_bias=False)(x)
x = tf.math.add(tf.constant(np.random.normal(size=x.shape[-1]), dtype=x.dtype), x)
x = layers.ReLU()(x)

# Collapse + operator to Conv2D without bias
# TODO: replace add with + (currently using tf.math.add because below TF 2.14 creates TFOpLambda which fails ths node matcher)
x = layers.Conv2D(filters=9, kernel_size=(5, 5), strides=(1, 1), padding='same',
use_bias=False)(x)
x = tf.math.add(x, tf.constant(np.random.normal(size=x.shape[-1]), dtype=x.dtype))

# #################
# DepthwiseConv2D #
# #################
# Collapse DepthwiseConv2D with bias
x = layers.DepthwiseConv2D(kernel_size=(5, 5), strides=(1, 1), padding='same',
use_bias=True, bias_initializer='glorot_uniform')(x)
x = tf.math.add(x, tf.constant(np.random.normal(size=x.shape[-1]), dtype=x.dtype))
x = layers.ReLU()(x)

# Collapse DepthwiseConv2D without bias, const first argument of tf.math.add
x = layers.DepthwiseConv2D(kernel_size=(5, 5), strides=(1, 1), padding='same',
use_bias=False)(x)
x = tf.math.add(tf.constant(np.random.normal(size=x.shape[-1]), dtype=x.dtype), x)
x = layers.ReLU()(x)

# Collapse + operator to DepthwiseConv2D without bias
# TODO: replace add with + (currently using tf.math.add because below TF 2.14 creates TFOpLambda which fails ths node matcher)
x = layers.DepthwiseConv2D(kernel_size=(5, 5), strides=(1, 1), padding='same',
use_bias=False)(x)
x = tf.math.add(x, tf.constant(np.random.normal(size=x.shape[-1]), dtype=x.dtype))

# #################
# Conv2DTranspose #
# #################
# Collapse Conv2DTranspose with bias
x = layers.Conv2DTranspose(filters=9, kernel_size=(5, 5), strides=(1, 1), padding='same',
use_bias=True, bias_initializer='glorot_uniform')(x)
x = tf.math.add(x, tf.constant(np.random.normal(size=x.shape[-1]), dtype=x.dtype))
x = layers.ReLU()(x)

# Collapse Conv2DTranspose without bias, const first argument of tf.math.add
x = layers.Conv2DTranspose(filters=9, kernel_size=(5, 5), strides=(1, 1), padding='same',
use_bias=False)(x)
x = tf.math.add(tf.constant(np.random.normal(size=x.shape[-1]), dtype=x.dtype), x)
x = layers.ReLU()(x)

# Collapse + operator to Conv2DTranspose without bias
# TODO: replace add with + (currently using tf.math.add because below TF 2.14 creates TFOpLambda which fails ths node matcher)
x = layers.Conv2DTranspose(filters=9, kernel_size=(5, 5), strides=(1, 1), padding='same',
use_bias=False)(x)
x = tf.math.add(x, tf.constant(np.random.normal(size=x.shape[-1]), dtype=x.dtype))

# #######
# Dense #
# #######
x = layers.Reshape((-1,))(x)
# Collapse Dense with bias
x = layers.Dense(9, use_bias=True, bias_initializer='glorot_uniform')(x)
x = tf.math.add(x, tf.constant(np.random.normal(size=x.shape[-1]), dtype=x.dtype))
x = layers.ReLU()(x)

# Collapse Dense without bias, const first argument of tf.math.add
x = layers.Dense(9, use_bias=False)(x)
x = tf.math.add(tf.constant(np.random.normal(size=x.shape[-1]), dtype=x.dtype), x)
x = layers.ReLU()(x)

# Collapse + operator to Conv2DTranspose without bias
# TODO: replace add with + (currently using tf.math.add because below TF 2.14 creates TFOpLambda which fails ths node matcher)
x = layers.Dense(9, use_bias=False)(x)
x = tf.math.add(x, tf.constant(np.random.normal(size=x.shape[-1]), dtype=x.dtype))

# Don't collapse
x2 = layers.Dense(9, use_bias=True, bias_initializer='glorot_uniform')(x)
x = tf.math.add(x2, x)
y = layers.ReLU()(x)

return tf.keras.models.Model(inputs=inputs, outputs=y)

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
super().compare(quantized_model, float_model, input_x, quantization_info)
num_adds = 0
for layer in quantized_model.layers:
if type(layer) in [layers.Conv2D, layers.DepthwiseConv2D, layers.Conv2DTranspose, layers.Dense]:
self.unit_test.assertTrue(len(layer.weights) == 2, msg=f'fail Bias should appear in weights!!')
elif isinstance(layer, TFOpLambda) and layer.function is tf.add:
num_adds += 1

# check all "add"s were folded except the one with 2 tensor inputs
self.unit_test.assertTrue(num_adds == 1, msg=f'Only one add should remain in the quantized model')
Original file line number Diff line number Diff line change
@@ -45,7 +45,7 @@
from tests.keras_tests.feature_networks_tests.feature_networks.input_scaling_test import InputScalingDenseTest, \
InputScalingConvTest, InputScalingDWTest, InputScalingZeroPadTest
from tests.keras_tests.feature_networks_tests.feature_networks.linear_collapsing_test import TwoConv2DCollapsingTest, \
ThreeConv2DCollapsingTest, FourConv2DCollapsingTest, SixConv2DCollapsingTest
ThreeConv2DCollapsingTest, FourConv2DCollapsingTest, SixConv2DCollapsingTest, Op2DAddConstCollapsingTest
from tests.keras_tests.feature_networks_tests.feature_networks.lut_quantizer import LUTWeightsQuantizerTest, \
LUTActivationQuantizerTest
from tests.keras_tests.feature_networks_tests.feature_networks.mixed_precision_bops_test import \
@@ -531,6 +531,7 @@ def test_linear_collapsing(self):
ThreeConv2DCollapsingTest(self).run_test()
FourConv2DCollapsingTest(self).run_test()
SixConv2DCollapsingTest(self).run_test()
Op2DAddConstCollapsingTest(self).run_test()

def test_second_moment(self):
DepthwiseConv2DSecondMomentTest(self).run_test()