From 5d74d6a9abb612be81a8eee4d0f2cda760f97571 Mon Sep 17 00:00:00 2001 From: yardeny-sony Date: Thu, 26 Sep 2024 19:04:14 +0300 Subject: [PATCH] bugfix - sort inputs by sink indices --- .../substitutions/scaled_dot_product_attention.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py index 7d9dff097..ed4b9ec5c 100644 --- a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py +++ b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py @@ -59,7 +59,7 @@ def _get_input_by_name(self, attention_node: FunctionalNode, input_name: str, def _get_attention_input_nodes(self, graph: Graph, attention_node: FunctionalNode) -> dict: q, k, v = 0, 1, 2 - prev_nodes = graph.get_prev_nodes(attention_node) + prev_nodes = graph.get_prev_nodes(attention_node, sink_index_sorted=True) q_node, k_node, v_node = prev_nodes[q], prev_nodes[k], prev_nodes[v] return {"q": q_node, "k": k_node, "v": v_node} @@ -96,14 +96,14 @@ def _get_scale_node(self, attention_node: FunctionalNode, q_node: BaseNode, matm functional_op=torch.mul) return scale_node - def _get_matmul_node(self, attention_node_name: str, q_node: BaseNode, k_node: BaseNode) -> BaseNode: + def _get_matmul_node(self, attention_node_name: str, q_node: BaseNode, transposed_k_node: BaseNode) -> BaseNode: matmul1_output_shape = copy(q_node.output_shape[0]) matmul1_output_shape[-2] = q_node.output_shape[0][-2] - matmul1_output_shape[-1] = k_node.output_shape[-1] + matmul1_output_shape[-1] = transposed_k_node.output_shape[-1] matmul_name = f'{attention_node_name}_matmul1' return FunctionalNode(name=matmul_name, framework_attr={}, - input_shape=(tuple(q_node.output_shape[0]), tuple(k_node.output_shape)), + input_shape=(tuple(q_node.output_shape[0]), tuple(transposed_k_node.output_shape)), output_shape=tuple(matmul1_output_shape), weights={}, layer_class=torch.matmul,