Skip to content

Commit

Permalink
all tests pass :)
Browse files Browse the repository at this point in the history
  • Loading branch information
yardeny-sony committed Sep 24, 2024
1 parent aa1d03a commit 466f592
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import math
from copy import copy
import numpy as np
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
from model_compression_toolkit.core.common import BaseSubstitution
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
Expand All @@ -23,6 +24,13 @@ def __init__(self):
super().__init__(matcher_instance=NodeOperationMatcher(nn.functional.scaled_dot_product_attention))
self.device = get_working_device()

def _get_input_by_name(self, attention_node, input_name, input_index, default_value):
if input_name in attention_node.op_call_kwargs:
return attention_node.op_call_kwargs[input_name]
elif len(attention_node.op_call_args) > input_index: # input order: [attn_mask, dropout_p, is_causal, scale]
return attention_node.op_call_args[input_index]
return default_value

def _get_attention_input_nodes(self, graph: Graph, attention_node: BaseNode) -> dict:
q, k, v = 0, 1, 2
prev_nodes = graph.get_prev_nodes(attention_node)
Expand All @@ -48,30 +56,25 @@ def _get_transpose_node(self, attention_node: BaseNode, node: BaseNode) -> BaseN
def _get_scale_node(self, attention_node, q_node, matmul_node):
scale_name = f'{attention_node.name}_scale'
q_embd_axis = -1
scale_factor = math.sqrt(q_node.output_shape[0][q_embd_axis]) # todo: validate the dimention is correct
input_scale = self._get_input_by_name(attention_node, "scale", 3, None)
scale_factor = input_scale if input_scale else (1 / math.sqrt(q_node.output_shape[0][q_embd_axis]))
scale_node = FunctionalNode(name=scale_name,
framework_attr={},
input_shape=(matmul_node.output_shape),
output_shape=matmul_node.output_shape,
weights={},
layer_class=torch.div,
layer_class=torch.mul,
op_call_args=[scale_factor],
op_call_kwargs={},
functional_op=torch.div)
functional_op=torch.mul)
return scale_node

def _get_matmul_node(self, attention_node: BaseNode, q_node: BaseNode, k_node: BaseNode) -> BaseNode:
q_batch_axis = 0
q_and_k_embd_axis = -1 # d_k == d
k_seq_axis = -2
q_seq_axis = -2

matmul1_output_shape = copy(q_node.output_shape[0])
matmul1_output_shape[-2] = q_node.output_shape[0][q_seq_axis]
matmul1_output_shape[-2] = q_node.output_shape[0][-2]
matmul1_output_shape[-1] = k_node.output_shape[-1]

matmul_name = f'{attention_node.name}_matmul1'
matmul_node = FunctionalNode(name=matmul_name,
return FunctionalNode(name=matmul_name,
framework_attr={},
input_shape=(tuple(q_node.output_shape[0]), tuple(k_node.output_shape)),
output_shape=tuple(matmul1_output_shape),
Expand All @@ -80,10 +83,9 @@ def _get_matmul_node(self, attention_node: BaseNode, q_node: BaseNode, k_node: B
op_call_args=[],
op_call_kwargs={},
functional_op=torch.matmul)
return matmul_node

def _get_mask_node(self, attention_node, q_node, k_node, scale_node):
attention_mask_tensor = self._get_attention_mask_tensor(attention_node, q_node, k_node)
def _get_mask_node(self, attention_node, scale_node):
attention_mask_tensor = self._get_attention_mask_tensor(attention_node)
if attention_mask_tensor is None:
return None
mask_node_name = f'{attention_node.name}_mask'
Expand Down Expand Up @@ -121,51 +123,36 @@ def _get_matmul2_node(self, attention_node, softmax_node, v_node):
op_call_kwargs={},
functional_op=torch.matmul)

def _get_attention_mask_tensor(self, attention_node, q_node, k_node):
target_seq_len = q_node.output_shape[0][-2]
source_seq_len = k_node.output_shape[0][-2]
attn_bias = torch.zeros(target_seq_len, source_seq_len).to(self.device) # todo: check if need to add dtype=query.dtype
# attn_mask = attention_node.op_call_kwargs['attn_mask']
is_causal = attention_node.op_call_kwargs.get('is_causal', False)
def _get_attention_mask_tensor(self, attention_node):
is_causal = self._get_input_by_name(attention_node, "is_causal", 2, False)
if is_causal:
raise NotImplementedError("scaled_dot_product_attention is_causal feature is not implemented.")
input_weights = list(attention_node.weights.values())
attn_mask = input_weights[0] if len(input_weights) > 0 else None
# if is_causal:
# assert attn_mask is None, "In case 'is_causal' is set to True, 'attn_mask' must be None"
# temp_mask = torch.ones(target_seq_len, source_seq_len, dtype=torch.bool).tril(diagonal=0).to(self.device)
# attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")).to(self.device)
# attn_bias.to(torch.float32)
#
# if attn_mask is not None:
# if attn_mask.dtype == torch.bool:
# attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
# else:
# attn_bias += torch.from_numpy(attn_mask).to(self.device) # todo: ask why attn_mask have changed to numpy ndarray
return attn_mask
if attn_mask is not None and (attn_mask.dtype == "bool"):
raise NotImplementedError("scaled_dot_product_attention attn_mask is of type boolean, which is not supported.")
if attn_mask is not None and (not np.isfinite(attn_mask).all()):
raise NotImplementedError("scaled_dot_product_attention attn_mask contains infinite value, which is not supported.")
return torch.from_numpy(attn_mask).to(self.device) if attn_mask is not None else None

def _get_dropout_node(self, attention_node, in_out_shape):
dropout_p = attention_node.op_call_kwargs['dropout_p']
if dropout_p == 0:
return None
dropout_p = attention_node.op_call_kwargs.get('dropout_p', 0)
dropout_name = f'{attention_node.name}_dropout'
return BaseNode(name=dropout_name,
framework_attr={"p": attention_node.op_call_kwargs['dropout_p']},
input_shape=in_out_shape,
output_shape=in_out_shape,
weights={},
layer_class=nn.Dropout)
framework_attr={"p": dropout_p},
input_shape=in_out_shape,
output_shape=in_out_shape,
weights={},
layer_class=nn.Dropout)

def substitute(self,
graph: Graph,
attention_node: BaseNode) -> Graph:
def substitute(self, graph: Graph, attention_node: BaseNode) -> Graph:

input_nodes = self._get_attention_input_nodes(graph, attention_node)
q_node, k_node, v_node = input_nodes["q"], input_nodes["k"], input_nodes["v"]
transpose_k_node = self._get_transpose_node(attention_node, k_node)
matmul_node = self._get_matmul_node(attention_node, q_node, transpose_k_node)
scale_node = self._get_scale_node(attention_node, q_node, matmul_node)
mask_node = self._get_mask_node(attention_node, q_node, k_node, scale_node)
mask_node = self._get_mask_node(attention_node, scale_node)
softmax_node = self._get_softmax_node(attention_node, matmul_node.output_shape)
dropout_node = self._get_dropout_node(attention_node, softmax_node.output_shape)
matmul2_node = self._get_matmul2_node(attention_node, softmax_node, v_node)
Expand All @@ -176,12 +163,12 @@ def substitute(self,
if mask_node:
graph.add_node_with_in_edges(mask_node, [scale_node])
graph.add_node_with_in_edges(softmax_node, [mask_node if mask_node else scale_node])
if dropout_node:
graph.add_node_with_in_edges(dropout_node, [softmax_node])
graph.add_node_with_in_edges(dropout_node, [softmax_node])
graph.add_node_with_in_edges(matmul2_node, [dropout_node if dropout_node else softmax_node, v_node])

graph.remove_edge(q_node, attention_node)
graph.remove_edge(k_node, attention_node)
graph.remove_edge(v_node, attention_node)
graph.remove_node(attention_node, new_graph_outputs=[OutTensor(matmul2_node, 0)])
return graph

return graph
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@


class ScaledDotProductAttentionNet(nn.Module):
def __init__(self, attn_mask=None, is_causal=False, dropout_p=0.0):
def __init__(self, dropout_p=0.0, scale=None, attn_mask=None, is_causal=False):
super().__init__()
self.dropout_p = dropout_p
self.scale = scale
self.attn_mask = attn_mask
self.is_causal = is_causal

def forward(self, q, k, v):
x = nn.functional.scaled_dot_product_attention(q, k, v,
dropout_p=self.dropout_p,
attn_mask=self.attn_mask,
is_causal=self.is_causal
dropout_p=self.dropout_p,
is_causal=self.is_causal,
scale=self.scale
)
return x

Expand All @@ -22,22 +24,26 @@ class ScaledDotProductAttentionTest(BasePytorchTest):
"""
This test checks the MultiHeadAttention as a single layer with add_bias_kv feature.
"""
def __init__(self, unit_test, attn_mask=None, is_causal=False, dropout_p=0.0):
def __init__(self, unit_test, dropout_p=0.0, scale=None, attn_mask=None, is_causal=False):
super().__init__(unit_test)
self.use_fuzzy_validation = True # because SDPA contains sqrt operation which leads to sightly different output values compared to original torch model
self.dropout_p = dropout_p
self.scale = scale
self.attn_mask = attn_mask
self.is_causal = is_causal

def create_feature_network(self, input_shape):
return ScaledDotProductAttentionNet(self.attn_mask, self.is_causal, self.dropout_p)
return ScaledDotProductAttentionNet(dropout_p=self.dropout_p,
attn_mask=self.attn_mask,
is_causal=self.is_causal,
scale=self.scale)

def create_inputs_shape(self):
batch_size, q_and_k_embd_size, v_embd_size, source_seq_len, target_seq_len = 3, 8, 19, 21, 13
q_shape = [batch_size, target_seq_len, q_and_k_embd_size]
k_shape = [batch_size, source_seq_len, q_and_k_embd_size]
v_shape = [batch_size, source_seq_len, v_embd_size]
return [q_shape, k_shape, v_shape] # expected output shape: (batch_size, target_seq_len, v_embd_size) == (3, 13, 19)
return [q_shape, k_shape, v_shape]

def _test_substitution_structure_output(self, post_substitution_nodes):
"""
Expand All @@ -49,9 +55,9 @@ def _test_substitution_structure_output(self, post_substitution_nodes):
'DummyPlaceHolder': 3,
"transpose": 1,
"matmul": 2,
"div": 1, # scale operator
"mul": 1, # scale operator
"Softmax": 1,
"Dropout": 0 if self.dropout_p == 0 else 1,
"Dropout": 1,
"add": 0 if self.attn_mask is None else 1 # mask operator
}

Expand Down
13 changes: 5 additions & 8 deletions tests/pytorch_tests/model_tests/test_feature_models_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
from tests.pytorch_tests.model_tests.feature_models.remove_identity_test import RemoveIdentityTest
from tests.pytorch_tests.model_tests.feature_models.activation_16bit_test import Activation16BitTest, \
Activation16BitMixedPrecisionTest
from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device


class FeatureModelsTestRunner(unittest.TestCase):
Expand Down Expand Up @@ -591,14 +592,10 @@ def test_scaled_dot_product_attention_layer(self):
"""
This test checks the ScaledDotProductDecomposition substitution feature.
"""
# ScaledDotProductAttentionTest(self).run_test(seed=3)
ScaledDotProductAttentionTest(self, dropout_p=0.5).run_test(seed=3)

from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device
# attn_mask = torch.zeros(13, 21).to(get_working_device())
# ScaledDotProductAttentionTest(self, attn_mask=attn_mask, is_causal=False).run_test(seed=3)
# ScaledDotProductAttentionTest(self, attn_mask=None, is_causal=True).run_test(seed=3)
#todo: add scale argument
ScaledDotProductAttentionTest(self).run_test(seed=3)
ScaledDotProductAttentionTest(self, dropout_p=0.0, scale=5).run_test(seed=3)
attn_mask = torch.zeros(13, 21).to(get_working_device())
ScaledDotProductAttentionTest(self, attn_mask=attn_mask).run_test(seed=3)

def test_gptq(self):
"""
Expand Down

0 comments on commit 466f592

Please sign in to comment.