Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix tf.nn.{conv2d,convolution} substitution #1275

Merged
merged 2 commits into from
Nov 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Optional, Tuple

import numpy as np
import tensorflow as tf
@@ -30,7 +31,7 @@
from model_compression_toolkit.constants import REUSE, REUSE_GROUP
from model_compression_toolkit.core.keras.constants import KERNEL, BIAS, USE_BIAS, FILTERS, PADDING, \
KERNEL_SIZE, DEPTH_MULTIPLIER, STRIDES, DILATIONS, DILATION_RATE, DEPTHWISE_KERNEL, RATE, \
ACTIVATION, LINEAR
ACTIVATION, LINEAR, DATA_FORMAT, GROUPS, CHANNELS_FORMAT_FIRST, CHANNELS_FORMAT_LAST


def extract_bias_node_data(_node: FunctionalNode, _graph: Graph) -> np.ndarray:
@@ -136,42 +137,67 @@ def substitute(self,
conv_fw_attr = {FILTERS: k.shape[3], KERNEL_SIZE: k.shape[:2], ACTIVATION: LINEAR}
if len(conv_func_node.op_call_args) > 0:
Logger.critical(f"node {conv_func_node.name} expected to have only kwargs but got args={conv_func_node.op_call_args}.") # pragma: no cover
if STRIDES in conv_func_node.op_call_kwargs:
strides = conv_func_node.op_call_kwargs[STRIDES]
if len(strides) == 4:
if strides[0] > 1 or strides[3] > 1:
# Non-standard strides -> skip substitution.
return graph # pragma: no cover
conv_fw_attr[STRIDES] = strides[1:3]
else:
conv_fw_attr[STRIDES] = strides
if PADDING in conv_func_node.op_call_kwargs:
padding = conv_func_node.op_call_kwargs[PADDING]
if not isinstance(padding, str):
# Non-standard padding, Layer only support either 'valid' or 'same' -> skip substitution.
return graph # pragma: no cover
conv_fw_attr[PADDING] = padding
if DILATIONS in conv_func_node.op_call_kwargs and conv_func_node.op_call_kwargs[DILATIONS] is not None:
dilations = conv_func_node.op_call_kwargs[DILATIONS]
if isinstance(dilations, (list, tuple)) and len(dilations) == 4:
if dilations[0] > 1 or dilations[3] > 1:
# Non-standard dilations -> skip substitution.
return graph # pragma: no cover
conv_fw_attr[DILATION_RATE] = dilations[1:3]
else:
conv_fw_attr[DILATION_RATE] = dilations

strides = self._parse_tf_stride_dilation(conv_func_node, STRIDES)
if strides is None:
# Non-standard strides -> skip substitution.
return graph
conv_fw_attr[STRIDES] = strides

padding = conv_func_node.op_call_kwargs.get(PADDING) or 'VALID'
if not isinstance(padding, str):
# Non-standard padding, Layer only support either 'valid' or 'same' -> skip substitution.
return graph # pragma: no cover
conv_fw_attr[PADDING] = padding

dilations = self._parse_tf_stride_dilation(conv_func_node, DILATIONS)
if dilations is None:
# Non-standard dilations -> skip substitution.
return graph
conv_fw_attr[DILATION_RATE] = dilations

if b is None:
conv_fw_attr[USE_BIAS] = False
else:
weights[BIAS] = b

data_format = conv_func_node.op_call_kwargs.get(DATA_FORMAT, 'NHWC')
conv_fw_attr[DATA_FORMAT] = {'NHWC': CHANNELS_FORMAT_LAST, 'NCHW': CHANNELS_FORMAT_FIRST}[data_format]

conv_fw_attr[GROUPS] = 1

_reuse_params = {REUSE: conv_func_node.reuse, REUSE_GROUP: conv_func_node.reuse_group}
conv_node = BaseNode(conv_func_node.name, conv_fw_attr, conv_func_node.input_shape, conv_func_node.output_shape,
weights, Conv2D, **_reuse_params)

replace_conv_node(graph, conv_node, conv_func_node, remove_add_node=b is not None)
return graph

def _parse_tf_stride_dilation(self, node, key) -> Optional[Tuple[int, int]]:
"""
Extract stride/dilation param from tf node and convert it to keras format (suitable for Conv2D).

Args:
node: node
key: param key

Returns:
Parsed value or None if non-standard.
"""
v = node.op_call_kwargs.get(key)
if v is None:
return 1, 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this way you assume the defaults. why not return None?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's intentional. None wouldn't do, we need to fill in an explicit default. This method is specific to tf stride & dilation

if isinstance(v, int):
return v, v
if len(v) == 1:
return v[0], v[0]
if len(v) == 4:
if v[0] > 1 and v[-1] > 1:
return None
else:
return v[1:3]
return tuple(v)


class DwConv2dFuncToDwConv2dLayer(common.BaseSubstitution):
"""
Original file line number Diff line number Diff line change
@@ -39,6 +39,9 @@

class ConvFuncSubstitutionsTest(BaseKerasFeatureNetworkTest):

def __init__(self, unit_test):
super().__init__(unit_test, input_shape=(32, 32, 3))

def get_tpc(self):
tp = generate_test_tp_model({'enable_weights_quantization': False,
'enable_activation_quantization': False})
@@ -67,6 +70,18 @@ def create_networks(self):
x = tf.nn.convolution(x, np.random.random((3, 3, 2, 4)).astype(np.float32),
[2, 1], padding='SAME')
x = tf.nn.bias_add(x, np.random.random((4,)).astype(np.float32))

# default values and various formats
x = tf.nn.conv2d(x, np.random.random((3, 3, 4, 8)), 1, 'VALID')
x = tf.nn.conv2d(x, np.random.random((3, 3, 8, 16)), strides=[1], padding='SAME', dilations=1)
x = tf.nn.conv2d(x, np.random.random((3, 3, 16, 8)), strides=[1, 1], padding='VALID', dilations=[1])
x = tf.nn.conv2d(x, filters=np.random.random((3, 3, 8, 4)), strides=[1, 1], padding='SAME', dilations=[1, 1])

x = tf.nn.convolution(x, np.random.random((3, 3, 4, 16)).astype(np.float32))
x = tf.nn.convolution(x, np.random.random((3, 3, 16, 32)).astype(np.float32), strides=[1], padding='SAME', dilations=1)
x = tf.nn.convolution(x, np.random.random((3, 3, 32, 8)).astype(np.float32), strides=[1, 1], padding='VALID', dilations=[1])
x = tf.nn.convolution(x, filters=np.random.random((3, 3, 8, 4)).astype(np.float32), strides=[1, 1], padding='VALID', dilations=[1, 1])

return tf.keras.Model(inputs=_in, outputs=x)

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
@@ -75,7 +90,7 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
cs = cosine_similarity(out_float.numpy(), out_quant.numpy())
self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check: {cs}')

self.unit_test.assertTrue(len(get_layers_from_model_by_type(quantized_model, Conv2D)) == 4,
self.unit_test.assertTrue(len(get_layers_from_model_by_type(quantized_model, Conv2D)) == 12,
"Not all conv functions were substituted.")
self.unit_test.assertTrue(len(get_layers_from_model_by_type(quantized_model, DepthwiseConv2D)) == 2,
"Not all dw-conv functions were substituted.")
Original file line number Diff line number Diff line change
@@ -54,6 +54,8 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
y = float_model.predict(input_x)
y_hat = quantized_model.predict(input_x)
self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!')
# FIXME this doesn't test anything, the number of quantized convs in the network is exactly 0. Even if it
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then why not remove it?

# looked at correct layers it hardly checks anything.
self.unit_test.assertTrue(len([l for l in quantized_model.layers if isinstance(l, KerasTrainableQuantizationWrapper) and isinstance(l.layer, layers.Conv2D)]) < len([l for l in float_model.layers if isinstance(l, layers.Conv2D)]), msg=f'fail number of layers should decrease!')
cs = cosine_similarity(y, y_hat)
self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check:{cs}')
@@ -75,6 +77,7 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
if type(layer) == layers.Conv2D:
self.unit_test.assertTrue(len(layer.weights) == 2, msg=f'fail Bias should appear in weights!!')


class ThreeConv2DCollapsingTest(BaseConv2DCollapsingTest):
def __init__(self, unit_test):
super().__init__(unit_test)
@@ -107,9 +110,35 @@ def create_networks(self):

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
super().compare(quantized_model, float_model, input_x, quantization_info)
for layer in quantized_model.layers:
if type(layer) == layers.Conv2D:
self.unit_test.assertTrue(len(layer.weights) == 2,msg=f'fail Bias should appear in weights!!')
convs = [l for l in quantized_model.layers if isinstance(l, layers.Conv2D)]
self.unit_test.assertTrue(len(convs) == 1)
for layer in convs:
self.unit_test.assertTrue(len(layer.weights) == 2,msg=f'fail Bias should appear in weights!!')


class FuncConvCollapsingTest(FourConv2DCollapsingTest):
def create_networks(self):
# Tests the combination of functional conv to Conv2D substitution with linear collapsing
# (in case of default values, tf layer doesn't contain these attributes, and they must be added explicitly
# to node's attributes dict, which is not covered by substitution test)
h, w, c = self.get_input_shapes()[0][1:]
inputs = layers.Input(shape=(h, w, c))
x = tf.nn.conv2d(inputs, tf.random.uniform((3, 3, c, 16)), 1, 'SAME')
x = tf.nn.convolution(x, tf.random.uniform((1, 1, 16, 8)))
x = tf.nn.relu(x)
x = tf.nn.convolution(x, tf.random.uniform((3, 3, 8, 32)))
y = tf.nn.conv2d(x, tf.random.uniform((1, 1, 32, 4)), 1, 'VALID')
return tf.keras.models.Model(inputs=inputs, outputs=y)

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
convs = [l for l in quantized_model.layers if isinstance(l, layers.Conv2D)]
self.unit_test.assertTrue(len(convs) == 2)

y = float_model.predict(input_x)
y_hat = quantized_model.predict(input_x)
self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!')
cs = cosine_similarity(y, y_hat)
self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check:{cs}')


class SixConv2DCollapsingTest(BaseConv2DCollapsingTest):
Original file line number Diff line number Diff line change
@@ -56,7 +56,8 @@
from tests.keras_tests.feature_networks_tests.feature_networks.input_scaling_test import InputScalingDenseTest, \
InputScalingConvTest, InputScalingDWTest, InputScalingZeroPadTest
from tests.keras_tests.feature_networks_tests.feature_networks.linear_collapsing_test import TwoConv2DCollapsingTest, \
ThreeConv2DCollapsingTest, FourConv2DCollapsingTest, SixConv2DCollapsingTest, Op2DAddConstCollapsingTest
ThreeConv2DCollapsingTest, FourConv2DCollapsingTest, SixConv2DCollapsingTest, Op2DAddConstCollapsingTest, \
FuncConvCollapsingTest
from tests.keras_tests.feature_networks_tests.feature_networks.lut_quantizer import LUTWeightsQuantizerTest, \
LUTActivationQuantizerTest
from tests.keras_tests.feature_networks_tests.feature_networks.manual_bit_selection import ManualBitWidthSelectionTest, \
@@ -605,6 +606,7 @@ def test_linear_collapsing(self):
FourConv2DCollapsingTest(self).run_test()
SixConv2DCollapsingTest(self).run_test()
Op2DAddConstCollapsingTest(self).run_test()
FuncConvCollapsingTest(self).run_test()

def test_const_quantization(self):
c = (np.ones((32, 32, 16)) + np.random.random((32, 32, 16))).astype(np.float32)