Skip to content

Commit

Permalink
reshape() takes 2 positional arguments but 5 were given - fix (#1086)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuel-wj-chapman authored Jun 3, 2024
1 parent 3bf6e9a commit 46647d5
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def _run_operation(n: BaseNode,

op_call_args = n.op_call_args if isinstance(n, FunctionalNode) else []
functional_kwargs = n.op_call_kwargs if isinstance(n, FunctionalNode) else {}

if isinstance(n, FunctionalNode) and n.inputs_as_list:
out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from torch import permute
from torch import permute, reshape
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
from model_compression_toolkit.core import common
from model_compression_toolkit.core.common.graph.base_graph import Graph
from model_compression_toolkit.core.common.graph.base_node import BaseNode


class PermuteCallMethod(common.BaseSubstitution):
class TransformFunctionCallMethod(common.BaseSubstitution):
"""
Find "permute" node to substitute new dimension argument if needed
Find "permute or reshape" node to substitute new dimension argument if needed
"""

def __init__(self):
"""
Matches: 'permute' node
Matches: 'permute or reshape' node
"""
nodes = NodeOperationMatcher(permute)
nodes = NodeOperationMatcher(permute) | NodeOperationMatcher(reshape)
super().__init__(matcher_instance=nodes)

def substitute(self,
graph: Graph,
node: BaseNode) -> Graph:
"""
Wrap dimension of permute with tuple if it's missing
Wrap dimension of permute or reshape with tuple if it's missing
Args:
graph: Graph we apply the substitution on.
Expand All @@ -45,6 +45,6 @@ def substitute(self,
Graph after applying the substitution.
"""
# Check op_call_args is not empty and has its argument as a tuple
if node.op_call_args and not isinstance(node.op_call_args[0], tuple):
if node.op_call_args and not isinstance(node.op_call_args[0], (list,tuple)):
node.op_call_args = [node.op_call_args]
return graph
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@
pytorch_linear_collapsing
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.multi_head_attention_decomposition \
import MultiHeadAttentionDecomposition
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.permute_call_method import \
PermuteCallMethod
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.transform_function_call_method import \
TransformFunctionCallMethod
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.const_holder_conv import \
FunctionalConvSubstitution
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.relu_bound_to_power_of_2 import \
Expand Down Expand Up @@ -236,7 +236,7 @@ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List
"""
return [ReshapeWithStaticShapes(),
MultiHeadAttentionDecomposition(),
PermuteCallMethod(),
TransformFunctionCallMethod(),
FunctionalConvSubstitution(fw_info),
FunctionalBatchNorm(),
FunctionalLayerNorm(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import torch
import torch.nn as nn
import model_compression_toolkit as mct
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, set_model
from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest
from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_pytorch_tpc


tp = mct.target_platform


class BaseReshapeSubstitutionTest(BasePytorchFeatureNetworkTest):

def __init__(self, unit_test):
super().__init__(unit_test=unit_test)

def get_tpc(self):
tp = generate_test_tp_model({'weights_n_bits': 32,
'activation_n_bits': 32,
'enable_weights_quantization': False,
'enable_activation_quantization': False})
return generate_pytorch_tpc(name="permute_substitution_test", tp_model=tp)

def get_quantization_config(self):
return mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.NOCLIPPING,
mct.core.QuantizationErrorMethod.NOCLIPPING, False, False)

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
in_torch_tensor = to_torch_tensor(input_x[0])
set_model(float_model)
y = float_model(in_torch_tensor)
y_hat = quantized_model(in_torch_tensor)
self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!')

class ReshapeSubstitutionTest(BaseReshapeSubstitutionTest):
def __init__(self, unit_test):
super().__init__(unit_test)

class ReshapeNet(nn.Module):
def __init__(self, ):
super().__init__()
self.gamma = nn.Parameter(1 * torch.ones((1, 3, 1, 1)))
def forward(self, x):
x=x.mul(self.gamma.reshape(1,-1,1,1))
return x
def create_networks(self):
return self.ReshapeNet()
12 changes: 10 additions & 2 deletions tests/pytorch_tests/model_tests/test_feature_models_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from tests.pytorch_tests.model_tests.feature_models.test_softmax_shift import SoftmaxLayerNetTest, \
SoftmaxFunctionNetTest
from tests.pytorch_tests.model_tests.feature_models.permute_substitution_test import PermuteSubstitutionTest
from tests.pytorch_tests.model_tests.feature_models.reshape_substitution_test import ReshapeSubstitutionTest
from tests.pytorch_tests.model_tests.feature_models.constant_conv_substitution_test import ConstantConvSubstitutionTest, \
ConstantConvReuseSubstitutionTest, ConstantConvTransposeSubstitutionTest
from tests.pytorch_tests.model_tests.feature_models.multi_head_attention_test import MHALayerNetTest, \
Expand Down Expand Up @@ -97,6 +98,7 @@


class FeatureModelsTestRunner(unittest.TestCase):

def test_remove_identity(self):
"""
This test checks that identity layers are removed from the model.
Expand Down Expand Up @@ -272,13 +274,19 @@ def test_const_representation(self):
enable_weights_quantization=enable_weights_quantization).run_test()
ConstRepresentationLinearLayerTest(self, func=nn.ConvTranspose2d(16, 16, 1),
const=c_img, enable_weights_quantization=enable_weights_quantization).run_test()

def test_permute_substitution(self):
"""
This test checks the permute substitution feature
"""
PermuteSubstitutionTest(self).run_test()


def test_reshape_substitution(self):
"""
This test checks the reshape substitution feature
"""
ReshapeSubstitutionTest(self).run_test()

def test_constant_conv_substitution(self):
"""
This test checks the constant conv substitution feature
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"source": [
"# YOLOv8n Object Detection PyTorch Model - Quantization for IMX500\n",
"\n",
"[Run this tutorial in Google Colab](https://colab.research.google.com/github/sony/model_optimization/blob/main/tutorials/notebooks/pytorch/ptq/pytorch_yolov8n_seg_for_imx500.ipynb)\n",
"[Run this tutorial in Google Colab](https://colab.research.google.com/github/sony/model_optimization/blob/main/tutorials/notebooks/imx500_notebooks/pytorch/pytorch_yolov8n_seg_for_imx500.ipynb)\n",
"\n",
"## Overview\n",
"\n",
Expand Down

0 comments on commit 46647d5

Please sign in to comment.