Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reshape error fix #1086

Merged
merged 8 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def _run_operation(n: BaseNode,

op_call_args = n.op_call_args if isinstance(n, FunctionalNode) else []
functional_kwargs = n.op_call_kwargs if isinstance(n, FunctionalNode) else {}

if isinstance(n, FunctionalNode) and n.inputs_as_list:
out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs)
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from torch import reshape

from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
from model_compression_toolkit.core import common
from model_compression_toolkit.core.common.graph.base_graph import Graph
from model_compression_toolkit.core.common.graph.base_node import BaseNode


class ReshapeCallMethod(common.BaseSubstitution):
"""
Find "reshape" node to substitute new dimension argument if needed
"""

def __init__(self):
"""
Matches: 'reshape' node
"""
nodes = NodeOperationMatcher(reshape)
super().__init__(matcher_instance=nodes)

def substitute(self,
graph: Graph,
node: BaseNode) -> Graph:
"""
Wrap reshape args with tuple if it's missing

Args:
graph: Graph we apply the substitution on.
node: node that match the pattern in the substitution init.

Returns:
Graph after applying the substitution.
"""
if node.op_call_args and not isinstance(node.op_call_args[0], tuple):
if len(node.op_call_args)>1:
node.op_call_args = [node.op_call_args]
return graph
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
import MultiHeadAttentionDecomposition
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.permute_call_method import \
PermuteCallMethod
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.reshape_call_method import \
ReshapeCallMethod
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.const_holder_conv import \
FunctionalConvSubstitution
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.relu_bound_to_power_of_2 import \
Expand Down Expand Up @@ -236,7 +238,8 @@ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List
"""
return [ReshapeWithStaticShapes(),
MultiHeadAttentionDecomposition(),
PermuteCallMethod(),
PermuteCallMethod(),
ReshapeCallMethod(),
FunctionalConvSubstitution(fw_info),
FunctionalBatchNorm(),
FunctionalLayerNorm(),
Expand Down
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this file can be avoided. It's very similar to permute_substitution_test.py. Can we change it to transform_fn_call_substitution_test.py and test both reshape and permute?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guessing it's also good to keep the 2 tests.. so we know we check both cases.. but not too critical either way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel maybe best to keep them separate as we needed to make a change to the original file to allow it to work with both transform types? eg - if isinstance(node.op_call_args[0], (list, tuple)).

Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import torch
import torch.nn as nn
import model_compression_toolkit as mct
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, set_model
from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest
from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_pytorch_tpc


tp = mct.target_platform


class BaseReshapeSubstitutionTest(BasePytorchFeatureNetworkTest):

def __init__(self, unit_test):
super().__init__(unit_test=unit_test)

def get_tpc(self):
tp = generate_test_tp_model({'weights_n_bits': 32,
'activation_n_bits': 32,
'enable_weights_quantization': False,
'enable_activation_quantization': False})
return generate_pytorch_tpc(name="permute_substitution_test", tp_model=tp)

def get_quantization_config(self):
return mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.NOCLIPPING,
mct.core.QuantizationErrorMethod.NOCLIPPING, False, False)

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
in_torch_tensor = to_torch_tensor(input_x[0])
set_model(float_model)
y = float_model(in_torch_tensor)
y_hat = quantized_model(in_torch_tensor)
self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!')

class ReshapeSubstitutionTest(BaseReshapeSubstitutionTest):
def __init__(self, unit_test):
super().__init__(unit_test)

class ReshapeNet(nn.Module):
def __init__(self, ):
super().__init__()
self.gamma = nn.Parameter(1 * torch.ones((1, 3, 1, 1)))
def forward(self, x):
x=x.mul(self.gamma.reshape(1,-1,1,1))
return x
def create_networks(self):
return self.ReshapeNet()
12 changes: 10 additions & 2 deletions tests/pytorch_tests/model_tests/test_feature_models_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from tests.pytorch_tests.model_tests.feature_models.test_softmax_shift import SoftmaxLayerNetTest, \
SoftmaxFunctionNetTest
from tests.pytorch_tests.model_tests.feature_models.permute_substitution_test import PermuteSubstitutionTest
from tests.pytorch_tests.model_tests.feature_models.reshape_substitution_test import ReshapeSubstitutionTest
from tests.pytorch_tests.model_tests.feature_models.constant_conv_substitution_test import ConstantConvSubstitutionTest, \
ConstantConvReuseSubstitutionTest, ConstantConvTransposeSubstitutionTest
from tests.pytorch_tests.model_tests.feature_models.multi_head_attention_test import MHALayerNetTest, \
Expand Down Expand Up @@ -95,6 +96,7 @@


class FeatureModelsTestRunner(unittest.TestCase):

def test_remove_identity(self):
"""
This test checks that identity layers are removed from the model.
Expand Down Expand Up @@ -270,13 +272,19 @@ def test_const_representation(self):
enable_weights_quantization=enable_weights_quantization).run_test()
ConstRepresentationLinearLayerTest(self, func=nn.ConvTranspose2d(16, 16, 1),
const=c_img, enable_weights_quantization=enable_weights_quantization).run_test()

def test_permute_substitution(self):
"""
This test checks the permute substitution feature
"""
PermuteSubstitutionTest(self).run_test()


def test_reshape_substitution(self):
"""
This test checks the reshape substitution feature
"""
ReshapeSubstitutionTest(self).run_test()

def test_constant_conv_substitution(self):
"""
This test checks the constant conv substitution feature
Expand Down
Loading