Skip to content

Commit

Permalink
Add documentation & types
Browse files Browse the repository at this point in the history
  • Loading branch information
yardeny-sony committed Sep 25, 2024
1 parent 9116ae7 commit c6771fc
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 28 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,14 +28,14 @@

class ScaledDotProductDecomposition(BaseSubstitution):
"""
Removes a scaled_dot_product_attention node from the graph, and replaces it with a compatible graph that
consists of:
MatMul(over q and transposed k)
Mul (for scaling)
Add (for masking. optional operation, used in cases that attn_mask ig given)
Dropout
Softmax
Matmul.
Decompose torch.nn.scale_dot_product into its base operators:
Transpose (over k)
MatMul(over q and transposed k)
Mul (for scaling)
Add (for masking. optional operation, used in cases that attn_mask ig given)
Dropout
Softmax
Matmul.
"""

def __init__(self):
Expand All @@ -44,7 +44,7 @@ def __init__(self):
"""
super().__init__(matcher_instance=NodeOperationMatcher(nn.functional.scaled_dot_product_attention))

def _get_input_by_name(self, attention_node, input_name, input_index, default_value):
def _get_input_by_name(self, attention_node: FunctionalNode, input_name: str, input_index: int, default_value: any):
"""
Search for attention_node input value in op_call_kwargs (using input_name) and op_call_args (using input_index).
In case the input is not given, returns its default_value.
Expand All @@ -56,16 +56,16 @@ def _get_input_by_name(self, attention_node, input_name, input_index, default_va
return attention_node.op_call_args[input_index]
return default_value

def _get_attention_input_nodes(self, graph: Graph, attention_node: BaseNode) -> dict:
def _get_attention_input_nodes(self, graph: Graph, attention_node: FunctionalNode) -> dict:
q, k, v = 0, 1, 2
prev_nodes = graph.get_prev_nodes(attention_node)
q_node, k_node, v_node = prev_nodes[q], prev_nodes[k], prev_nodes[v]
return {"q": q_node, "k": k_node, "v": v_node}

def _get_transpose_k_node(self, attention_node: BaseNode, key_node: BaseNode) -> BaseNode:
def _get_transpose_k_node(self, attention_node_name: str, key_node: BaseNode) -> BaseNode:
input_shape, output_shape = copy(key_node.output_shape[0]), copy(key_node.output_shape[0])
output_shape[-2], output_shape[-1] = input_shape[-1], input_shape[-2]
transpose_node = FunctionalNode(name=f"{attention_node.name}_{key_node.name}_transpose",
transpose_node = FunctionalNode(name=f"{attention_node_name}_{key_node.name}_transpose",
framework_attr={},
input_shape=input_shape,
output_shape=output_shape,
Expand All @@ -76,7 +76,10 @@ def _get_transpose_k_node(self, attention_node: BaseNode, key_node: BaseNode) ->
functional_op=torch.transpose)
return transpose_node

def _get_scale_node(self, attention_node, q_node, matmul_node):
def _get_scale_node(self, attention_node: FunctionalNode, q_node: BaseNode, matmul_node: BaseNode):
"""
:return: multiplication node that represents multiplication by the scale factor
"""
scale_name = f'{attention_node.name}_scale'
q_embd_axis = -1
input_scale = self._get_input_by_name(attention_node, "scale", 3, None)
Expand All @@ -92,11 +95,11 @@ def _get_scale_node(self, attention_node, q_node, matmul_node):
functional_op=torch.mul)
return scale_node

def _get_matmul_node(self, attention_node: BaseNode, q_node: BaseNode, k_node: BaseNode) -> BaseNode:
def _get_matmul_node(self, attention_node_name: str, q_node: BaseNode, k_node: BaseNode) -> BaseNode:
matmul1_output_shape = copy(q_node.output_shape[0])
matmul1_output_shape[-2] = q_node.output_shape[0][-2]
matmul1_output_shape[-1] = k_node.output_shape[-1]
matmul_name = f'{attention_node.name}_matmul1'
matmul_name = f'{attention_node_name}_matmul1'
return FunctionalNode(name=matmul_name,
framework_attr={},
input_shape=(tuple(q_node.output_shape[0]), tuple(k_node.output_shape)),
Expand All @@ -107,7 +110,10 @@ def _get_matmul_node(self, attention_node: BaseNode, q_node: BaseNode, k_node: B
op_call_kwargs={},
functional_op=torch.matmul)

def _get_mask_node(self, attention_node, scale_node):
def _get_mask_node(self, attention_node: FunctionalNode, scale_node: FunctionalNode):
"""
:return: Add operator node with the mask tensor as input. In case there is no mask tensor, returns None.
"""
attention_mask_tensor = self._get_attention_mask_tensor(attention_node)
if attention_mask_tensor is None:
return None
Expand All @@ -122,16 +128,16 @@ def _get_mask_node(self, attention_node, scale_node):
op_call_kwargs={'other': attention_mask_tensor},
functional_op=torch.add)

def _get_softmax_node(self, attention_node, in_out_shape):
softmax_name = f'{attention_node.name}_softmax'
def _get_softmax_node(self, attention_node_name: str, in_out_shape: tuple):
softmax_name = f'{attention_node_name}_softmax'
return BaseNode(name=softmax_name,
framework_attr={DIM: -1},
input_shape=in_out_shape,
output_shape=in_out_shape,
weights={},
layer_class=nn.Softmax)

def _get_matmul2_node(self, attention_node_name, softmax_node, v_node):
def _get_matmul2_node(self, attention_node_name: str, softmax_node: BaseNode, v_node: BaseNode):
matmul2_output_shape = list(copy(softmax_node.output_shape))
matmul2_output_shape[-2] = softmax_node.output_shape[-2]
matmul2_output_shape[-1] = v_node.output_shape[0][-1]
Expand All @@ -146,7 +152,12 @@ def _get_matmul2_node(self, attention_node_name, softmax_node, v_node):
op_call_kwargs={},
functional_op=torch.matmul)

def _get_attention_mask_tensor(self, attention_node):
def _get_attention_mask_tensor(self, attention_node: FunctionalNode):
"""
:return: mask tensor given as part of attention node input.
Since MCT doesn't support infinite values, we don't support is_causal (torch.nn.scale_dot_product_attention
argument) and boolean mask tensor, as they both require -inf values.
"""
device = get_working_device()
is_causal = self._get_input_by_name(attention_node, "is_causal", 2, False)
if is_causal:
Expand All @@ -161,7 +172,7 @@ def _get_attention_mask_tensor(self, attention_node):
"scaled_dot_product_attention attn_mask contains infinite value, which is not supported.")
return torch.from_numpy(attn_mask).to(device) if attn_mask is not None else None

def _get_dropout_node(self, attention_node, in_out_shape):
def _get_dropout_node(self, attention_node: FunctionalNode, in_out_shape: tuple):
dropout_p = attention_node.op_call_kwargs.get('dropout_p', 0)
dropout_name = f'{attention_node.name}_dropout'
return BaseNode(name=dropout_name,
Expand All @@ -171,15 +182,28 @@ def _get_dropout_node(self, attention_node, in_out_shape):
weights={},
layer_class=nn.Dropout)

def substitute(self, graph: Graph, attention_node: BaseNode) -> Graph:

def substitute(self, graph: Graph, attention_node: FunctionalNode) -> Graph:
"""
Removes a scaled_dot_product_attention node from the graph, and replaces it with a compatible graph that
consists of:
Transpose (over k)
MatMul(over q and transposed k)
Mul (for scaling)
Add (for masking. optional operation, used in cases that attn_mask ig given)
Dropout
Softmax
Matmul.
:param graph: A Graph to apply substitution on
:param attention_node: the node to replace
:return: A graph after the substitution
"""
input_nodes = self._get_attention_input_nodes(graph, attention_node)
q_node, k_node, v_node = input_nodes["q"], input_nodes["k"], input_nodes["v"]
transpose_k_node = self._get_transpose_k_node(attention_node, k_node)
matmul_node = self._get_matmul_node(attention_node, q_node, transpose_k_node)
transpose_k_node = self._get_transpose_k_node(attention_node.name, k_node)
matmul_node = self._get_matmul_node(attention_node.name, q_node, transpose_k_node)
scale_node = self._get_scale_node(attention_node, q_node, matmul_node)
mask_node = self._get_mask_node(attention_node, scale_node)
softmax_node = self._get_softmax_node(attention_node, matmul_node.output_shape)
softmax_node = self._get_softmax_node(attention_node.name, matmul_node.output_shape)
dropout_node = self._get_dropout_node(attention_node, softmax_node.output_shape)
matmul2_node = self._get_matmul2_node(attention_node.name, softmax_node, v_node)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down

0 comments on commit c6771fc

Please sign in to comment.