Skip to content

Commit

Permalink
add documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
yardeny-sony committed Sep 25, 2024
1 parent 0f8aab7 commit 0648665
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,32 @@

class ScaledDotProductDecomposition(BaseSubstitution):
"""
Removes a MultiHeadAttention node from the graph,
and replaces it with a compatible graph that consists of Conv, MatMul, Softmax and stacked layers.
Removes a scaled_dot_product_attention node from the graph, and replaces it with a compatible graph that
consists of:
MatMul(over q and transposed k)
Mul (for scaling)
Add (for masking. optional operation, used in cases that attn_mask ig given)
Dropout
Softmax
Matmul.
"""

def __init__(self):
"""
Matches MultiHeadAttention node.
Matches scaled_dot_product_attention node.
"""
super().__init__(matcher_instance=NodeOperationMatcher(nn.functional.scaled_dot_product_attention))
self.device = get_working_device()

def _get_input_by_name(self, attention_node, input_name, input_index, default_value):
"""
Search for attention_node input value in op_call_kwargs (using input_name) and op_call_args (using input_index).
In case the input is not given, returns its default_value.
"""
if input_name in attention_node.op_call_kwargs:
return attention_node.op_call_kwargs[input_name]
elif len(attention_node.op_call_args) > input_index: # input order: [attn_mask, dropout_p, is_causal, scale]
elif len(attention_node.op_call_args) > input_index: # input order: [attn_mask, dropout_p, is_causal]
return attention_node.op_call_args[input_index]
return default_value

Expand All @@ -43,14 +54,14 @@ def _get_transpose_node(self, attention_node: BaseNode, node: BaseNode) -> BaseN
input_shape, output_shape = copy(node.output_shape[0]), copy(node.output_shape[0])
output_shape[-2], output_shape[-1] = input_shape[-1], input_shape[-2]
transpose_node = FunctionalNode(name=f"{attention_node.name}_{node.name}_transpose",
framework_attr={},
input_shape=input_shape,
output_shape=output_shape,
weights={},
layer_class=torch.transpose,
op_call_args=[1, 2],
op_call_kwargs={},
functional_op=torch.transpose)
framework_attr={},
input_shape=input_shape,
output_shape=output_shape,
weights={},
layer_class=torch.transpose,
op_call_args=[1, 2], # axes to transpose
op_call_kwargs={},
functional_op=torch.transpose)
return transpose_node

def _get_scale_node(self, attention_node, q_node, matmul_node):
Expand All @@ -59,14 +70,14 @@ def _get_scale_node(self, attention_node, q_node, matmul_node):
input_scale = self._get_input_by_name(attention_node, "scale", 3, None)
scale_factor = input_scale if input_scale else (1 / math.sqrt(q_node.output_shape[0][q_embd_axis]))
scale_node = FunctionalNode(name=scale_name,
framework_attr={},
input_shape=(matmul_node.output_shape),
output_shape=matmul_node.output_shape,
weights={},
layer_class=torch.mul,
op_call_args=[scale_factor],
op_call_kwargs={},
functional_op=torch.mul)
framework_attr={},
input_shape=(matmul_node.output_shape),
output_shape=matmul_node.output_shape,
weights={},
layer_class=torch.mul,
op_call_args=[scale_factor],
op_call_kwargs={},
functional_op=torch.mul)
return scale_node

def _get_matmul_node(self, attention_node: BaseNode, q_node: BaseNode, k_node: BaseNode) -> BaseNode:
Expand All @@ -75,14 +86,14 @@ def _get_matmul_node(self, attention_node: BaseNode, q_node: BaseNode, k_node: B
matmul1_output_shape[-1] = k_node.output_shape[-1]
matmul_name = f'{attention_node.name}_matmul1'
return FunctionalNode(name=matmul_name,
framework_attr={},
input_shape=(tuple(q_node.output_shape[0]), tuple(k_node.output_shape)),
output_shape=tuple(matmul1_output_shape),
weights={},
layer_class=torch.matmul,
op_call_args=[],
op_call_kwargs={},
functional_op=torch.matmul)
framework_attr={},
input_shape=(tuple(q_node.output_shape[0]), tuple(k_node.output_shape)),
output_shape=tuple(matmul1_output_shape),
weights={},
layer_class=torch.matmul,
op_call_args=[],
op_call_kwargs={},
functional_op=torch.matmul)

def _get_mask_node(self, attention_node, scale_node):
attention_mask_tensor = self._get_attention_mask_tensor(attention_node)
Expand Down Expand Up @@ -114,14 +125,14 @@ def _get_matmul2_node(self, attention_node, softmax_node, v_node):
matmul2_output_shape[-1] = v_node.output_shape[0][-1]
matmul2_name = f'{attention_node.name}_matmul2'
return FunctionalNode(name=matmul2_name,
framework_attr={},
input_shape=(tuple(softmax_node.output_shape), tuple(v_node.output_shape[0])),
output_shape=tuple(matmul2_output_shape),
weights={},
layer_class=torch.matmul,
op_call_args=[],
op_call_kwargs={},
functional_op=torch.matmul)
framework_attr={},
input_shape=(tuple(softmax_node.output_shape), tuple(v_node.output_shape[0])),
output_shape=tuple(matmul2_output_shape),
weights={},
layer_class=torch.matmul,
op_call_args=[],
op_call_kwargs={},
functional_op=torch.matmul)

def _get_attention_mask_tensor(self, attention_node):
is_causal = self._get_input_by_name(attention_node, "is_causal", 2, False)
Expand All @@ -130,20 +141,22 @@ def _get_attention_mask_tensor(self, attention_node):
input_weights = list(attention_node.weights.values())
attn_mask = input_weights[0] if len(input_weights) > 0 else None
if attn_mask is not None and (attn_mask.dtype == "bool"):
raise NotImplementedError("scaled_dot_product_attention attn_mask is of type boolean, which is not supported.")
raise NotImplementedError(
"scaled_dot_product_attention attn_mask is of type boolean, which is not supported.")
if attn_mask is not None and (not np.isfinite(attn_mask).all()):
raise NotImplementedError("scaled_dot_product_attention attn_mask contains infinite value, which is not supported.")
raise NotImplementedError(
"scaled_dot_product_attention attn_mask contains infinite value, which is not supported.")
return torch.from_numpy(attn_mask).to(self.device) if attn_mask is not None else None

def _get_dropout_node(self, attention_node, in_out_shape):
dropout_p = attention_node.op_call_kwargs.get('dropout_p', 0)
dropout_name = f'{attention_node.name}_dropout'
return BaseNode(name=dropout_name,
framework_attr={"p": dropout_p},
input_shape=in_out_shape,
output_shape=in_out_shape,
weights={},
layer_class=nn.Dropout)
framework_attr={"p": dropout_p},
input_shape=in_out_shape,
output_shape=in_out_shape,
weights={},
layer_class=nn.Dropout)

def substitute(self, graph: Graph, attention_node: BaseNode) -> Graph:

Expand Down Expand Up @@ -171,4 +184,4 @@ def substitute(self, graph: Graph, attention_node: BaseNode) -> Graph:
graph.remove_edge(v_node, attention_node)
graph.remove_node(attention_node, new_graph_outputs=[OutTensor(matmul2_node, 0)])

return graph
return graph
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def test_scaled_dot_product_attention_layer(self):
"""
ScaledDotProductAttentionTest(self).run_test(seed=3)
ScaledDotProductAttentionTest(self, dropout_p=0.0, scale=5).run_test(seed=3)
attn_mask = torch.zeros(13, 21).to(get_working_device())
attn_mask = torch.ones(13, 21).to(get_working_device())
ScaledDotProductAttentionTest(self, attn_mask=attn_mask).run_test(seed=3)

def test_gptq(self):
Expand Down

0 comments on commit 0648665

Please sign in to comment.