Skip to content

Commit

Permalink
changes according to PR - part 1
Browse files Browse the repository at this point in the history
  • Loading branch information
yardeny-sony committed Sep 25, 2024
1 parent 671eb4b commit cf1e84f
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import torch.nn as nn
import torch
import math
Expand Down Expand Up @@ -45,14 +60,12 @@ def _get_attention_input_nodes(self, graph: Graph, attention_node: BaseNode) ->
q, k, v = 0, 1, 2
prev_nodes = graph.get_prev_nodes(attention_node)
q_node, k_node, v_node = prev_nodes[q], prev_nodes[k], prev_nodes[v]
assert q_node.name == "q" and k_node.name == "k" and v_node.name == "v", \
"Bad assumption on attention input nodes order"
return {"q": q_node, "k": k_node, "v": v_node}

def _get_transpose_node(self, attention_node: BaseNode, node: BaseNode) -> BaseNode:
input_shape, output_shape = copy(node.output_shape[0]), copy(node.output_shape[0])
def _get_transpose_k_node(self, attention_node: BaseNode, key_node: BaseNode) -> BaseNode:
input_shape, output_shape = copy(key_node.output_shape[0]), copy(key_node.output_shape[0])
output_shape[-2], output_shape[-1] = input_shape[-1], input_shape[-2]
transpose_node = FunctionalNode(name=f"{attention_node.name}_{node.name}_transpose",
transpose_node = FunctionalNode(name=f"{attention_node.name}_{key_node.name}_transpose",
framework_attr={},
input_shape=input_shape,
output_shape=output_shape,
Expand Down Expand Up @@ -118,11 +131,11 @@ def _get_softmax_node(self, attention_node, in_out_shape):
weights={},
layer_class=nn.Softmax)

def _get_matmul2_node(self, attention_node, softmax_node, v_node):
def _get_matmul2_node(self, attention_node_name, softmax_node, v_node):
matmul2_output_shape = list(copy(softmax_node.output_shape))
matmul2_output_shape[-2] = softmax_node.output_shape[-2]
matmul2_output_shape[-1] = v_node.output_shape[0][-1]
matmul2_name = f'{attention_node.name}_matmul2'
matmul2_name = f'{attention_node_name}_matmul2'
return FunctionalNode(name=matmul2_name,
framework_attr={},
input_shape=(tuple(softmax_node.output_shape), tuple(v_node.output_shape[0])),
Expand Down Expand Up @@ -162,13 +175,13 @@ def substitute(self, graph: Graph, attention_node: BaseNode) -> Graph:

input_nodes = self._get_attention_input_nodes(graph, attention_node)
q_node, k_node, v_node = input_nodes["q"], input_nodes["k"], input_nodes["v"]
transpose_k_node = self._get_transpose_node(attention_node, k_node)
transpose_k_node = self._get_transpose_k_node(attention_node, k_node)
matmul_node = self._get_matmul_node(attention_node, q_node, transpose_k_node)
scale_node = self._get_scale_node(attention_node, q_node, matmul_node)
mask_node = self._get_mask_node(attention_node, scale_node)
softmax_node = self._get_softmax_node(attention_node, matmul_node.output_shape)
dropout_node = self._get_dropout_node(attention_node, softmax_node.output_shape)
matmul2_node = self._get_matmul2_node(attention_node, softmax_node, v_node)
matmul2_node = self._get_matmul2_node(attention_node.name, softmax_node, v_node)

graph.add_node_with_in_edges(transpose_k_node, [k_node])
graph.add_node_with_in_edges(matmul_node, [q_node, transpose_k_node])
Expand Down
6 changes: 3 additions & 3 deletions tests/common_tests/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ def __init__(self, unit_test,
num_calibration_iter=1,
num_of_inputs=1,
input_shape=(8, 8, 3),
use_fuzzy_validation=False
use_is_close_validation=False
):
"""
:param use_fuzzy_validation: Allow similar (instead of exact) outputs when comparing the original float
:param use_is_close_validation: Allow similar (instead of exact) outputs when comparing the original float
model output against the no_quantization model output.
"""

Expand All @@ -25,7 +25,7 @@ def __init__(self, unit_test,
self.num_calibration_iter = num_calibration_iter
self.num_of_inputs = num_of_inputs
self.input_shape = (val_batch_size,) + input_shape
self.use_fuzzy_validation = use_fuzzy_validation
self.use_is_close_validation = use_is_close_validation

def generate_inputs(self):
return [np.random.randn(*in_shape) for in_shape in self.get_input_shapes()]
Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch_tests/model_tests/base_pytorch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def compare(self, quantized_models, float_model, input_x=None, quantization_info
# Decomposition is not exactly like the sqrt in the C implementation of PyTorch.
float_model_operators = [type(module) for name, module in float_model.named_modules()]
if (torch.nn.BatchNorm2d in float_model_operators or
torch.nn.MultiheadAttention in float_model_operators or self.use_fuzzy_validation):
torch.nn.MultiheadAttention in float_model_operators or self.use_is_close_validation):
self.unit_test.assertTrue(np.all(np.isclose(torch_tensor_to_numpy(f), torch_tensor_to_numpy(q),
atol=self.float_reconstruction_error)))
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class BNFNetTest(BasePytorchTest):

def __init__(self, unit_test):
super().__init__(unit_test)
self.use_fuzzy_validation = True # because the net contains BN layer
self.use_is_close_validation = True # because the net contains BN layer

def create_inputs_shape(self):
return [[self.val_batch_size, 3, 32, 32], [self.val_batch_size, 3, 32, 32]]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from tests.pytorch_tests.model_tests.base_pytorch_test import BasePytorchTest
from torch import nn

Expand All @@ -24,9 +39,18 @@ class ScaledDotProductAttentionTest(BasePytorchTest):
"""
This test checks the scaled_dot_product_attention (SDPA) substitution using a single SDPA layer.
"""
def __init__(self, unit_test, dropout_p=0.0, scale=None, attn_mask=None, is_causal=False):

def __init__(self, unit_test, batch_size: int, q_and_k_embd_size: int, v_embd_size: int, source_seq_len: int,
target_seq_len: int, dropout_p: float = 0.0, scale: float = None, attn_mask: float = None,
is_causal: bool = False):

super().__init__(unit_test)
self.use_fuzzy_validation = True # because SDPA contains sqrt operation which leads to sightly different output values compared to original torch model
self.batch_size = batch_size
self.q_and_k_embd_size = q_and_k_embd_size
self.v_embd_size = v_embd_size
self.source_seq_len = source_seq_len
self.target_seq_len = target_seq_len
self.use_is_close_validation = True # because SDPA contains sqrt operation which leads to sightly different output values compared to original torch model
self.dropout_p = dropout_p
self.scale = scale
self.attn_mask = attn_mask
Expand All @@ -39,10 +63,9 @@ def create_feature_network(self, input_shape):
scale=self.scale)

def create_inputs_shape(self):
batch_size, q_and_k_embd_size, v_embd_size, source_seq_len, target_seq_len = 3, 8, 19, 21, 13
q_shape = [batch_size, target_seq_len, q_and_k_embd_size]
k_shape = [batch_size, source_seq_len, q_and_k_embd_size]
v_shape = [batch_size, source_seq_len, v_embd_size]
q_shape = [self.batch_size, self.target_seq_len, self.q_and_k_embd_size]
k_shape = [self.batch_size, self.source_seq_len, self.q_and_k_embd_size]
v_shape = [self.batch_size, self.source_seq_len, self.v_embd_size]
return [q_shape, k_shape, v_shape]

def _test_substitution_structure_output(self, post_substitution_nodes):
Expand Down
21 changes: 16 additions & 5 deletions tests/pytorch_tests/model_tests/test_feature_models_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def test_parameter_net(self):
"""
ParameterNetTest(self).run_test()

def test_reuse_layer_net(self): # yoyo
def test_reuse_layer_net(self):
"""
This test checks:
The reuse of a layer in a model.
Expand Down Expand Up @@ -595,10 +595,21 @@ def test_scaled_dot_product_attention_layer(self):
"""
This test checks the ScaledDotProductDecomposition substitution feature.
"""
ScaledDotProductAttentionTest(self).run_test(seed=3)
ScaledDotProductAttentionTest(self, dropout_p=0.0, scale=5).run_test(seed=3)
attn_mask = torch.ones(13, 21).to(get_working_device())
ScaledDotProductAttentionTest(self, attn_mask=attn_mask).run_test(seed=3)

batch_size = [3, 1, 5]
q_and_k_embd_size = [8, 9, 3]
v_embd_size = [19, 2, 6]
source_seq_len = [21, 4, 15]
target_seq_len = [13, 12, 9]
for i in range(len(batch_size)):
ScaledDotProductAttentionTest(self, batch_size[i], q_and_k_embd_size[i], v_embd_size[i], source_seq_len[i],
target_seq_len[i]).run_test(seed=3)
ScaledDotProductAttentionTest(self, batch_size[i], q_and_k_embd_size[i], v_embd_size[i], source_seq_len[i],
target_seq_len[i], dropout_p=0.0, scale=5).run_test(seed=3)
attn_mask = torch.ones(target_seq_len[i], source_seq_len[i]).to(get_working_device())
ScaledDotProductAttentionTest(self, batch_size[i], q_and_k_embd_size[i], v_embd_size[i], source_seq_len[i],
target_seq_len[i], attn_mask=attn_mask).run_test(seed=3)


def test_gptq(self):
"""
Expand Down

0 comments on commit cf1e84f

Please sign in to comment.