Skip to content

Commit

Permalink
Add mask node & code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
yardeny-sony committed Sep 24, 2024
1 parent d04d84f commit aa1d03a
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNode, OutTensor
from model_compression_toolkit.core.pytorch.constants import DIM
from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device


class ScaledDotProductDecomposition(BaseSubstitution):
Expand All @@ -20,6 +21,7 @@ def __init__(self):
Matches MultiHeadAttention node.
"""
super().__init__(matcher_instance=NodeOperationMatcher(nn.functional.scaled_dot_product_attention))
self.device = get_working_device()

def _get_attention_input_nodes(self, graph: Graph, attention_node: BaseNode) -> dict:
q, k, v = 0, 1, 2
Expand All @@ -43,6 +45,20 @@ def _get_transpose_node(self, attention_node: BaseNode, node: BaseNode) -> BaseN
functional_op=torch.transpose)
return transpose_node

def _get_scale_node(self, attention_node, q_node, matmul_node):
scale_name = f'{attention_node.name}_scale'
q_embd_axis = -1
scale_factor = math.sqrt(q_node.output_shape[0][q_embd_axis]) # todo: validate the dimention is correct
scale_node = FunctionalNode(name=scale_name,
framework_attr={},
input_shape=(matmul_node.output_shape),
output_shape=matmul_node.output_shape,
weights={},
layer_class=torch.div,
op_call_args=[scale_factor],
op_call_kwargs={},
functional_op=torch.div)
return scale_node

def _get_matmul_node(self, attention_node: BaseNode, q_node: BaseNode, k_node: BaseNode) -> BaseNode:
q_batch_axis = 0
Expand All @@ -66,68 +82,106 @@ def _get_matmul_node(self, attention_node: BaseNode, q_node: BaseNode, k_node: B
functional_op=torch.matmul)
return matmul_node

def _get_mask_node(self, attention_node, q_node, k_node, scale_node):
attention_mask_tensor = self._get_attention_mask_tensor(attention_node, q_node, k_node)
if attention_mask_tensor is None:
return None
mask_node_name = f'{attention_node.name}_mask'
return FunctionalNode(name=mask_node_name,
framework_attr={},
input_shape=(scale_node.output_shape),
output_shape=scale_node.output_shape,
weights={},
layer_class=torch.add,
op_call_args=[],
op_call_kwargs={'other': attention_mask_tensor},
functional_op=torch.add)

def _get_softmax_node(self, attention_node, in_out_shape):
softmax_name = f'{attention_node.name}_softmax'
return BaseNode(name=softmax_name,
framework_attr={DIM: -1},
input_shape=in_out_shape,
output_shape=in_out_shape,
weights={},
layer_class=nn.Softmax)

def _get_matmul2_node(self, attention_node, softmax_node, v_node):
matmul2_output_shape = list(copy(softmax_node.output_shape))
matmul2_output_shape[-2] = softmax_node.output_shape[-2]
matmul2_output_shape[-1] = v_node.output_shape[0][-1]
matmul2_name = f'{attention_node.name}_matmul2'
return FunctionalNode(name=matmul2_name,
framework_attr={},
input_shape=(tuple(softmax_node.output_shape), tuple(v_node.output_shape[0])),
output_shape=tuple(matmul2_output_shape),
weights={},
layer_class=torch.matmul,
op_call_args=[],
op_call_kwargs={},
functional_op=torch.matmul)

def _get_attention_mask_tensor(self, attention_node, q_node, k_node):
target_seq_len = q_node.output_shape[0][-2]
source_seq_len = k_node.output_shape[0][-2]
attn_bias = torch.zeros(target_seq_len, source_seq_len).to(self.device) # todo: check if need to add dtype=query.dtype
# attn_mask = attention_node.op_call_kwargs['attn_mask']
is_causal = attention_node.op_call_kwargs.get('is_causal', False)
if is_causal:
raise NotImplementedError("scaled_dot_product_attention is_causal feature is not implemented.")
input_weights = list(attention_node.weights.values())
attn_mask = input_weights[0] if len(input_weights) > 0 else None
# if is_causal:
# assert attn_mask is None, "In case 'is_causal' is set to True, 'attn_mask' must be None"
# temp_mask = torch.ones(target_seq_len, source_seq_len, dtype=torch.bool).tril(diagonal=0).to(self.device)
# attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")).to(self.device)
# attn_bias.to(torch.float32)
#
# if attn_mask is not None:
# if attn_mask.dtype == torch.bool:
# attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
# else:
# attn_bias += torch.from_numpy(attn_mask).to(self.device) # todo: ask why attn_mask have changed to numpy ndarray
return attn_mask

def _get_dropout_node(self, attention_node, in_out_shape):
dropout_p = attention_node.op_call_kwargs['dropout_p']
if dropout_p == 0:
return None
dropout_name = f'{attention_node.name}_dropout'
return BaseNode(name=dropout_name,
framework_attr={"p": attention_node.op_call_kwargs['dropout_p']},
input_shape=in_out_shape,
output_shape=in_out_shape,
weights={},
layer_class=nn.Dropout)

def substitute(self,
graph: Graph,
attention_node: BaseNode) -> Graph:

input_nodes = self._get_attention_input_nodes(graph, attention_node)
q_node, k_node, v_node = input_nodes["q"], input_nodes["k"], input_nodes["v"]
transpose_k_node = self._get_transpose_node(attention_node, k_node)
graph.add_node_with_in_edges(transpose_k_node, [k_node])
matmul_node = self._get_matmul_node(attention_node, q_node, transpose_k_node)
graph.add_node_with_in_edges(matmul_node, [q_node, transpose_k_node])
scale_node = self._get_scale_node(attention_node, q_node, matmul_node)
mask_node = self._get_mask_node(attention_node, q_node, k_node, scale_node)
softmax_node = self._get_softmax_node(attention_node, matmul_node.output_shape)
dropout_node = self._get_dropout_node(attention_node, softmax_node.output_shape)
matmul2_node = self._get_matmul2_node(attention_node, softmax_node, v_node)

scale_name = f'{attention_node.name}_scale'
q_embd_axis = -1
scale_factor = math.sqrt(q_node.output_shape[0][q_embd_axis]) # todo: validate the dimention is correct
scale_node = FunctionalNode(name=scale_name,
framework_attr={},
input_shape=(matmul_node.output_shape,),
output_shape=matmul_node.output_shape,
weights={},
layer_class=torch.div,
op_call_args=[scale_factor],
op_call_kwargs={},
functional_op=torch.div)
graph.add_node_with_in_edges(transpose_k_node, [k_node])
graph.add_node_with_in_edges(matmul_node, [q_node, transpose_k_node])
graph.add_node_with_in_edges(scale_node, [matmul_node])
# todo: handle attn_mask

softmax_name = f'{attention_node.name}_softmax'
softmax_node = BaseNode(name=softmax_name,
framework_attr={DIM: -1},
input_shape=matmul_node.output_shape,
output_shape=matmul_node.output_shape,
weights={},
layer_class=nn.Softmax)
graph.add_node_with_in_edges(softmax_node, [scale_node])

dropout_name = f'{attention_node.name}_dropout'
dropout_node = BaseNode(name=dropout_name,
framework_attr={"p": attention_node.op_call_kwargs['dropout_p']},
input_shape=softmax_node.output_shape,
output_shape=softmax_node.output_shape,
weights={},
layer_class=nn.Dropout)
graph.add_node_with_in_edges(dropout_node, [softmax_node])

matmul2_output_shape = list(copy(dropout_node.output_shape))
matmul2_output_shape[-2] = dropout_node.output_shape[-2]
matmul2_output_shape[-1] = v_node.output_shape[0][-1]

matmul2_name = f'{attention_node.name}_matmul2'
matmul_node2 = FunctionalNode(name=matmul2_name,
framework_attr={},
input_shape=(tuple(dropout_node.output_shape), tuple(v_node.output_shape[0])),
output_shape=tuple(matmul2_output_shape),
weights={},
layer_class=torch.matmul,
op_call_args=[],
op_call_kwargs={},
functional_op=torch.matmul)
graph.add_node_with_in_edges(matmul_node2, [dropout_node, v_node])
if mask_node:
graph.add_node_with_in_edges(mask_node, [scale_node])
graph.add_node_with_in_edges(softmax_node, [mask_node if mask_node else scale_node])
if dropout_node:
graph.add_node_with_in_edges(dropout_node, [softmax_node])
graph.add_node_with_in_edges(matmul2_node, [dropout_node if dropout_node else softmax_node, v_node])

graph.remove_edge(q_node, attention_node)
graph.remove_edge(k_node, attention_node)
graph.remove_edge(v_node, attention_node)
graph.remove_node(attention_node, new_graph_outputs=[OutTensor(matmul_node2, 0)])
graph.remove_node(attention_node, new_graph_outputs=[OutTensor(matmul2_node, 0)])
return graph
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@


class ScaledDotProductAttentionNet(nn.Module):
def __init__(self, attn_mask=None, dropout_p=0.0):
def __init__(self, attn_mask=None, is_causal=False, dropout_p=0.0):
super().__init__()
self.dropout_p = dropout_p
self.attn_mask = attn_mask
self.is_causal = is_causal

def forward(self, q, k, v):
x = nn.functional.scaled_dot_product_attention(q, k, v,
dropout_p=self.dropout_p,
attn_mask=self.attn_mask
attn_mask=self.attn_mask,
is_causal=self.is_causal
)
return x

Expand All @@ -20,14 +22,15 @@ class ScaledDotProductAttentionTest(BasePytorchTest):
"""
This test checks the MultiHeadAttention as a single layer with add_bias_kv feature.
"""
def __init__(self, unit_test, attn_mask=None, dropout_p=0.0):
def __init__(self, unit_test, attn_mask=None, is_causal=False, dropout_p=0.0):
super().__init__(unit_test)
self.use_fuzzy_validation = True # because SDPA contains sqrt operation which leads to sightly different output values compared to original torch model
self.dropout_p = dropout_p
self.attn_mask = attn_mask
self.is_causal = is_causal

def create_feature_network(self, input_shape):
return ScaledDotProductAttentionNet(self.attn_mask, self.dropout_p)
return ScaledDotProductAttentionNet(self.attn_mask, self.is_causal, self.dropout_p)

def create_inputs_shape(self):
batch_size, q_and_k_embd_size, v_embd_size, source_seq_len, target_seq_len = 3, 8, 19, 21, 13
Expand All @@ -48,8 +51,8 @@ def _test_substitution_structure_output(self, post_substitution_nodes):
"matmul": 2,
"div": 1, # scale operator
"Softmax": 1,
"Dropout": 1,
"mask": 0 if self.attn_mask is None else 1
"Dropout": 0 if self.dropout_p == 0 else 1,
"add": 0 if self.attn_mask is None else 1 # mask operator
}

for node in post_substitution_nodes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -591,9 +591,15 @@ def test_scaled_dot_product_attention_layer(self):
"""
This test checks the ScaledDotProductDecomposition substitution feature.
"""
ScaledDotProductAttentionTest(self).run_test(seed=3)
# ScaledDotProductAttentionTest(self).run_test(seed=3)
ScaledDotProductAttentionTest(self, dropout_p=0.5).run_test(seed=3)

from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device
# attn_mask = torch.zeros(13, 21).to(get_working_device())
# ScaledDotProductAttentionTest(self, attn_mask=attn_mask, is_causal=False).run_test(seed=3)
# ScaledDotProductAttentionTest(self, attn_mask=None, is_causal=True).run_test(seed=3)
#todo: add scale argument

def test_gptq(self):
"""
This test checks the GPTQ feature.
Expand Down

0 comments on commit aa1d03a

Please sign in to comment.