Skip to content

Commit

Permalink
transpose bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
yardeny-sony committed Sep 26, 2024
1 parent f6ada2f commit d885ccc
Showing 1 changed file with 9 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _get_transpose_k_node(self, attention_node_name: str, key_node: BaseNode) ->
output_shape=output_shape,
weights={},
layer_class=torch.transpose,
op_call_args=[1, 2], # axes to transpose
op_call_args=[-1, -2], # axes to transpose
op_call_kwargs={},
functional_op=torch.transpose)
return transpose_node
Expand Down Expand Up @@ -198,6 +198,7 @@ def substitute(self, graph: Graph, attention_node: FunctionalNode) -> Graph:
:param attention_node: the node to replace
:return: A graph after the substitution
"""
print("In scale_dot_product_attention substitution@@@@@@@@")
input_nodes = self._get_attention_input_nodes(graph, attention_node)
q_node, k_node, v_node = input_nodes["q"], input_nodes["k"], input_nodes["v"]
transpose_k_node = self._get_transpose_k_node(attention_node.name, k_node)
Expand All @@ -217,9 +218,14 @@ def substitute(self, graph: Graph, attention_node: FunctionalNode) -> Graph:
graph.add_node_with_in_edges(dropout_node, [softmax_node])
graph.add_node_with_in_edges(matmul2_node, [dropout_node if dropout_node else softmax_node, v_node])

graph_outputs = graph.get_outputs()
for i, g_out in enumerate(graph_outputs):
if g_out.node == attention_node:
graph_outputs[i] = OutTensor(node=matmul2_node, node_out_index=g_out.node_out_index)

graph.reconnect_out_edges(current_node=attention_node, new_node=matmul2_node)
graph.remove_edge(q_node, attention_node)
graph.remove_edge(k_node, attention_node)
graph.remove_edge(v_node, attention_node)
graph.remove_node(attention_node, new_graph_outputs=[OutTensor(matmul2_node, 0)])

graph.remove_node(attention_node, new_graph_outputs=graph_outputs)
return graph

0 comments on commit d885ccc

Please sign in to comment.