Skip to content

Commit

Permalink
bugfix - sort inputs by sink indices
Browse files Browse the repository at this point in the history
  • Loading branch information
yardeny-sony committed Sep 26, 2024
1 parent d885ccc commit 5d74d6a
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _get_input_by_name(self, attention_node: FunctionalNode, input_name: str,

def _get_attention_input_nodes(self, graph: Graph, attention_node: FunctionalNode) -> dict:
q, k, v = 0, 1, 2
prev_nodes = graph.get_prev_nodes(attention_node)
prev_nodes = graph.get_prev_nodes(attention_node, sink_index_sorted=True)
q_node, k_node, v_node = prev_nodes[q], prev_nodes[k], prev_nodes[v]
return {"q": q_node, "k": k_node, "v": v_node}

Expand Down Expand Up @@ -96,14 +96,14 @@ def _get_scale_node(self, attention_node: FunctionalNode, q_node: BaseNode, matm
functional_op=torch.mul)
return scale_node

def _get_matmul_node(self, attention_node_name: str, q_node: BaseNode, k_node: BaseNode) -> BaseNode:
def _get_matmul_node(self, attention_node_name: str, q_node: BaseNode, transposed_k_node: BaseNode) -> BaseNode:
matmul1_output_shape = copy(q_node.output_shape[0])
matmul1_output_shape[-2] = q_node.output_shape[0][-2]
matmul1_output_shape[-1] = k_node.output_shape[-1]
matmul1_output_shape[-1] = transposed_k_node.output_shape[-1]
matmul_name = f'{attention_node_name}_matmul1'
return FunctionalNode(name=matmul_name,
framework_attr={},
input_shape=(tuple(q_node.output_shape[0]), tuple(k_node.output_shape)),
input_shape=(tuple(q_node.output_shape[0]), tuple(transposed_k_node.output_shape)),
output_shape=tuple(matmul1_output_shape),
weights={},
layer_class=torch.matmul,
Expand Down

0 comments on commit 5d74d6a

Please sign in to comment.