diff --git a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py index 1c51315ed..ab6cc020c 100644 --- a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py +++ b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py @@ -2,6 +2,7 @@ import torch import math from copy import copy +import numpy as np from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode from model_compression_toolkit.core.common import BaseSubstitution from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher @@ -23,6 +24,13 @@ def __init__(self): super().__init__(matcher_instance=NodeOperationMatcher(nn.functional.scaled_dot_product_attention)) self.device = get_working_device() + def _get_input_by_name(self, attention_node, input_name, input_index, default_value): + if input_name in attention_node.op_call_kwargs: + return attention_node.op_call_kwargs[input_name] + elif len(attention_node.op_call_args) > input_index: # input order: [attn_mask, dropout_p, is_causal, scale] + return attention_node.op_call_args[input_index] + return default_value + def _get_attention_input_nodes(self, graph: Graph, attention_node: BaseNode) -> dict: q, k, v = 0, 1, 2 prev_nodes = graph.get_prev_nodes(attention_node) @@ -48,30 +56,25 @@ def _get_transpose_node(self, attention_node: BaseNode, node: BaseNode) -> BaseN def _get_scale_node(self, attention_node, q_node, matmul_node): scale_name = f'{attention_node.name}_scale' q_embd_axis = -1 - scale_factor = math.sqrt(q_node.output_shape[0][q_embd_axis]) # todo: validate the dimention is correct + input_scale = self._get_input_by_name(attention_node, "scale", 3, None) + scale_factor = input_scale if input_scale else (1 / math.sqrt(q_node.output_shape[0][q_embd_axis])) scale_node = FunctionalNode(name=scale_name, framework_attr={}, input_shape=(matmul_node.output_shape), output_shape=matmul_node.output_shape, weights={}, - layer_class=torch.div, + layer_class=torch.mul, op_call_args=[scale_factor], op_call_kwargs={}, - functional_op=torch.div) + functional_op=torch.mul) return scale_node def _get_matmul_node(self, attention_node: BaseNode, q_node: BaseNode, k_node: BaseNode) -> BaseNode: - q_batch_axis = 0 - q_and_k_embd_axis = -1 # d_k == d - k_seq_axis = -2 - q_seq_axis = -2 - matmul1_output_shape = copy(q_node.output_shape[0]) - matmul1_output_shape[-2] = q_node.output_shape[0][q_seq_axis] + matmul1_output_shape[-2] = q_node.output_shape[0][-2] matmul1_output_shape[-1] = k_node.output_shape[-1] - matmul_name = f'{attention_node.name}_matmul1' - matmul_node = FunctionalNode(name=matmul_name, + return FunctionalNode(name=matmul_name, framework_attr={}, input_shape=(tuple(q_node.output_shape[0]), tuple(k_node.output_shape)), output_shape=tuple(matmul1_output_shape), @@ -80,10 +83,9 @@ def _get_matmul_node(self, attention_node: BaseNode, q_node: BaseNode, k_node: B op_call_args=[], op_call_kwargs={}, functional_op=torch.matmul) - return matmul_node - def _get_mask_node(self, attention_node, q_node, k_node, scale_node): - attention_mask_tensor = self._get_attention_mask_tensor(attention_node, q_node, k_node) + def _get_mask_node(self, attention_node, scale_node): + attention_mask_tensor = self._get_attention_mask_tensor(attention_node) if attention_mask_tensor is None: return None mask_node_name = f'{attention_node.name}_mask' @@ -121,51 +123,36 @@ def _get_matmul2_node(self, attention_node, softmax_node, v_node): op_call_kwargs={}, functional_op=torch.matmul) - def _get_attention_mask_tensor(self, attention_node, q_node, k_node): - target_seq_len = q_node.output_shape[0][-2] - source_seq_len = k_node.output_shape[0][-2] - attn_bias = torch.zeros(target_seq_len, source_seq_len).to(self.device) # todo: check if need to add dtype=query.dtype - # attn_mask = attention_node.op_call_kwargs['attn_mask'] - is_causal = attention_node.op_call_kwargs.get('is_causal', False) + def _get_attention_mask_tensor(self, attention_node): + is_causal = self._get_input_by_name(attention_node, "is_causal", 2, False) if is_causal: raise NotImplementedError("scaled_dot_product_attention is_causal feature is not implemented.") input_weights = list(attention_node.weights.values()) attn_mask = input_weights[0] if len(input_weights) > 0 else None - # if is_causal: - # assert attn_mask is None, "In case 'is_causal' is set to True, 'attn_mask' must be None" - # temp_mask = torch.ones(target_seq_len, source_seq_len, dtype=torch.bool).tril(diagonal=0).to(self.device) - # attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")).to(self.device) - # attn_bias.to(torch.float32) - # - # if attn_mask is not None: - # if attn_mask.dtype == torch.bool: - # attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) - # else: - # attn_bias += torch.from_numpy(attn_mask).to(self.device) # todo: ask why attn_mask have changed to numpy ndarray - return attn_mask + if attn_mask is not None and (attn_mask.dtype == "bool"): + raise NotImplementedError("scaled_dot_product_attention attn_mask is of type boolean, which is not supported.") + if attn_mask is not None and (not np.isfinite(attn_mask).all()): + raise NotImplementedError("scaled_dot_product_attention attn_mask contains infinite value, which is not supported.") + return torch.from_numpy(attn_mask).to(self.device) if attn_mask is not None else None def _get_dropout_node(self, attention_node, in_out_shape): - dropout_p = attention_node.op_call_kwargs['dropout_p'] - if dropout_p == 0: - return None + dropout_p = attention_node.op_call_kwargs.get('dropout_p', 0) dropout_name = f'{attention_node.name}_dropout' return BaseNode(name=dropout_name, - framework_attr={"p": attention_node.op_call_kwargs['dropout_p']}, - input_shape=in_out_shape, - output_shape=in_out_shape, - weights={}, - layer_class=nn.Dropout) + framework_attr={"p": dropout_p}, + input_shape=in_out_shape, + output_shape=in_out_shape, + weights={}, + layer_class=nn.Dropout) - def substitute(self, - graph: Graph, - attention_node: BaseNode) -> Graph: + def substitute(self, graph: Graph, attention_node: BaseNode) -> Graph: input_nodes = self._get_attention_input_nodes(graph, attention_node) q_node, k_node, v_node = input_nodes["q"], input_nodes["k"], input_nodes["v"] transpose_k_node = self._get_transpose_node(attention_node, k_node) matmul_node = self._get_matmul_node(attention_node, q_node, transpose_k_node) scale_node = self._get_scale_node(attention_node, q_node, matmul_node) - mask_node = self._get_mask_node(attention_node, q_node, k_node, scale_node) + mask_node = self._get_mask_node(attention_node, scale_node) softmax_node = self._get_softmax_node(attention_node, matmul_node.output_shape) dropout_node = self._get_dropout_node(attention_node, softmax_node.output_shape) matmul2_node = self._get_matmul2_node(attention_node, softmax_node, v_node) @@ -176,12 +163,12 @@ def substitute(self, if mask_node: graph.add_node_with_in_edges(mask_node, [scale_node]) graph.add_node_with_in_edges(softmax_node, [mask_node if mask_node else scale_node]) - if dropout_node: - graph.add_node_with_in_edges(dropout_node, [softmax_node]) + graph.add_node_with_in_edges(dropout_node, [softmax_node]) graph.add_node_with_in_edges(matmul2_node, [dropout_node if dropout_node else softmax_node, v_node]) graph.remove_edge(q_node, attention_node) graph.remove_edge(k_node, attention_node) graph.remove_edge(v_node, attention_node) graph.remove_node(attention_node, new_graph_outputs=[OutTensor(matmul2_node, 0)]) - return graph + + return graph \ No newline at end of file diff --git a/tests/pytorch_tests/model_tests/feature_models/scaled_dot_product_attention_test.py b/tests/pytorch_tests/model_tests/feature_models/scaled_dot_product_attention_test.py index 5a10e49c2..144ec1076 100644 --- a/tests/pytorch_tests/model_tests/feature_models/scaled_dot_product_attention_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/scaled_dot_product_attention_test.py @@ -3,17 +3,19 @@ class ScaledDotProductAttentionNet(nn.Module): - def __init__(self, attn_mask=None, is_causal=False, dropout_p=0.0): + def __init__(self, dropout_p=0.0, scale=None, attn_mask=None, is_causal=False): super().__init__() self.dropout_p = dropout_p + self.scale = scale self.attn_mask = attn_mask self.is_causal = is_causal def forward(self, q, k, v): x = nn.functional.scaled_dot_product_attention(q, k, v, - dropout_p=self.dropout_p, attn_mask=self.attn_mask, - is_causal=self.is_causal + dropout_p=self.dropout_p, + is_causal=self.is_causal, + scale=self.scale ) return x @@ -22,22 +24,26 @@ class ScaledDotProductAttentionTest(BasePytorchTest): """ This test checks the MultiHeadAttention as a single layer with add_bias_kv feature. """ - def __init__(self, unit_test, attn_mask=None, is_causal=False, dropout_p=0.0): + def __init__(self, unit_test, dropout_p=0.0, scale=None, attn_mask=None, is_causal=False): super().__init__(unit_test) self.use_fuzzy_validation = True # because SDPA contains sqrt operation which leads to sightly different output values compared to original torch model self.dropout_p = dropout_p + self.scale = scale self.attn_mask = attn_mask self.is_causal = is_causal def create_feature_network(self, input_shape): - return ScaledDotProductAttentionNet(self.attn_mask, self.is_causal, self.dropout_p) + return ScaledDotProductAttentionNet(dropout_p=self.dropout_p, + attn_mask=self.attn_mask, + is_causal=self.is_causal, + scale=self.scale) def create_inputs_shape(self): batch_size, q_and_k_embd_size, v_embd_size, source_seq_len, target_seq_len = 3, 8, 19, 21, 13 q_shape = [batch_size, target_seq_len, q_and_k_embd_size] k_shape = [batch_size, source_seq_len, q_and_k_embd_size] v_shape = [batch_size, source_seq_len, v_embd_size] - return [q_shape, k_shape, v_shape] # expected output shape: (batch_size, target_seq_len, v_embd_size) == (3, 13, 19) + return [q_shape, k_shape, v_shape] def _test_substitution_structure_output(self, post_substitution_nodes): """ @@ -49,9 +55,9 @@ def _test_substitution_structure_output(self, post_substitution_nodes): 'DummyPlaceHolder': 3, "transpose": 1, "matmul": 2, - "div": 1, # scale operator + "mul": 1, # scale operator "Softmax": 1, - "Dropout": 0 if self.dropout_p == 0 else 1, + "Dropout": 1, "add": 0 if self.attn_mask is None else 1 # mask operator } diff --git a/tests/pytorch_tests/model_tests/test_feature_models_runner.py b/tests/pytorch_tests/model_tests/test_feature_models_runner.py index 82ed04bd4..95ead6a01 100644 --- a/tests/pytorch_tests/model_tests/test_feature_models_runner.py +++ b/tests/pytorch_tests/model_tests/test_feature_models_runner.py @@ -108,6 +108,7 @@ from tests.pytorch_tests.model_tests.feature_models.remove_identity_test import RemoveIdentityTest from tests.pytorch_tests.model_tests.feature_models.activation_16bit_test import Activation16BitTest, \ Activation16BitMixedPrecisionTest +from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device class FeatureModelsTestRunner(unittest.TestCase): @@ -591,14 +592,10 @@ def test_scaled_dot_product_attention_layer(self): """ This test checks the ScaledDotProductDecomposition substitution feature. """ - # ScaledDotProductAttentionTest(self).run_test(seed=3) - ScaledDotProductAttentionTest(self, dropout_p=0.5).run_test(seed=3) - - from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device - # attn_mask = torch.zeros(13, 21).to(get_working_device()) - # ScaledDotProductAttentionTest(self, attn_mask=attn_mask, is_causal=False).run_test(seed=3) - # ScaledDotProductAttentionTest(self, attn_mask=None, is_causal=True).run_test(seed=3) - #todo: add scale argument + ScaledDotProductAttentionTest(self).run_test(seed=3) + ScaledDotProductAttentionTest(self, dropout_p=0.0, scale=5).run_test(seed=3) + attn_mask = torch.zeros(13, 21).to(get_working_device()) + ScaledDotProductAttentionTest(self, attn_mask=attn_mask).run_test(seed=3) def test_gptq(self): """