Skip to content

Commit

Permalink
add missing returning types
Browse files Browse the repository at this point in the history
  • Loading branch information
yardeny-sony committed Sep 26, 2024
1 parent 89bd8d8 commit f6ada2f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def __init__(self):
"""
super().__init__(matcher_instance=NodeOperationMatcher(nn.functional.scaled_dot_product_attention))

def _get_input_by_name(self, attention_node: FunctionalNode, input_name: str, input_index: int, default_value: any):
def _get_input_by_name(self, attention_node: FunctionalNode, input_name: str,
input_index: int, default_value: any) -> any:
"""
Search for attention_node input value in op_call_kwargs (using input_name) and op_call_args (using input_index).
In case the input is not given, returns its default_value.
Expand Down Expand Up @@ -76,7 +77,7 @@ def _get_transpose_k_node(self, attention_node_name: str, key_node: BaseNode) ->
functional_op=torch.transpose)
return transpose_node

def _get_scale_node(self, attention_node: FunctionalNode, q_node: BaseNode, matmul_node: BaseNode):
def _get_scale_node(self, attention_node: FunctionalNode, q_node: BaseNode, matmul_node: BaseNode) -> FunctionalNode:
"""
:return: multiplication node that represents multiplication by the scale factor
"""
Expand Down Expand Up @@ -110,7 +111,7 @@ def _get_matmul_node(self, attention_node_name: str, q_node: BaseNode, k_node: B
op_call_kwargs={},
functional_op=torch.matmul)

def _get_mask_node(self, attention_node: FunctionalNode, scale_node: FunctionalNode):
def _get_mask_node(self, attention_node: FunctionalNode, scale_node: FunctionalNode) -> FunctionalNode:
"""
:return: Add operator node with the mask tensor as input. In case there is no mask tensor, returns None.
"""
Expand All @@ -128,7 +129,7 @@ def _get_mask_node(self, attention_node: FunctionalNode, scale_node: FunctionalN
op_call_kwargs={'other': attention_mask_tensor},
functional_op=torch.add)

def _get_softmax_node(self, attention_node_name: str, in_out_shape: tuple):
def _get_softmax_node(self, attention_node_name: str, in_out_shape: tuple) -> BaseNode:
softmax_name = f'{attention_node_name}_softmax'
return BaseNode(name=softmax_name,
framework_attr={DIM: -1},
Expand All @@ -137,7 +138,7 @@ def _get_softmax_node(self, attention_node_name: str, in_out_shape: tuple):
weights={},
layer_class=nn.Softmax)

def _get_matmul2_node(self, attention_node_name: str, softmax_node: BaseNode, v_node: BaseNode):
def _get_matmul2_node(self, attention_node_name: str, softmax_node: BaseNode, v_node: BaseNode) -> FunctionalNode:
matmul2_output_shape = list(copy(softmax_node.output_shape))
matmul2_output_shape[-2] = softmax_node.output_shape[-2]
matmul2_output_shape[-1] = v_node.output_shape[0][-1]
Expand All @@ -152,7 +153,7 @@ def _get_matmul2_node(self, attention_node_name: str, softmax_node: BaseNode, v_
op_call_kwargs={},
functional_op=torch.matmul)

def _get_attention_mask_tensor(self, attention_node: FunctionalNode):
def _get_attention_mask_tensor(self, attention_node: FunctionalNode) -> torch.Tensor:
"""
:return: mask tensor given as part of attention node input.
Since MCT doesn't support infinite values, we don't support is_causal (torch.nn.scale_dot_product_attention
Expand All @@ -172,7 +173,7 @@ def _get_attention_mask_tensor(self, attention_node: FunctionalNode):
"scaled_dot_product_attention attn_mask contains infinite value, which is not supported.")
return torch.from_numpy(attn_mask).to(device) if attn_mask is not None else None

def _get_dropout_node(self, attention_node: FunctionalNode, in_out_shape: tuple):
def _get_dropout_node(self, attention_node: FunctionalNode, in_out_shape: tuple) -> BaseNode:
dropout_p = attention_node.op_call_kwargs.get('dropout_p', 0)
dropout_name = f'{attention_node.name}_dropout'
return BaseNode(name=dropout_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, unit_test, batch_size: int, q_and_k_embd_size: int, v_embd_si
self.attn_mask = attn_mask
self.is_causal = is_causal

def create_feature_network(self, input_shape):
def create_feature_network(self, input_shape) -> nn.Module:

if version.parse(torch.__version__) >= version.parse("2.1"):
return ScaledDotProductAttentionNet(dropout_p=self.dropout_p,
Expand All @@ -69,13 +69,13 @@ def create_feature_network(self, input_shape):
attn_mask=self.attn_mask,
is_causal=self.is_causal)

def create_inputs_shape(self):
def create_inputs_shape(self) -> list:
q_shape = [self.batch_size, self.target_seq_len, self.q_and_k_embd_size]
k_shape = [self.batch_size, self.source_seq_len, self.q_and_k_embd_size]
v_shape = [self.batch_size, self.source_seq_len, self.v_embd_size]
return [q_shape, k_shape, v_shape]

def _test_substitution_structure_output(self, post_substitution_nodes):
def _test_substitution_structure_output(self, post_substitution_nodes) -> None:
"""
:param post_substitution_nodes: The graph nodes after the SDPA substitution
raise Exception if case the post_substitution_nodes doesn't match the expected_nodes_counter
Expand All @@ -100,7 +100,7 @@ def _test_substitution_structure_output(self, post_substitution_nodes):
if not (len(counter_results) == 1 and 0 in counter_results): # validate that all values are zeros
raise Exception(f"Post substitution graph contains unexpected nodes: {[k for k, v in expected_nodes_counter.items() if v != 0]}")

def compare(self, quantized_models, float_model, input_x=None, quantization_info=None):
def compare(self, quantized_models, float_model, input_x=None, quantization_info=None) -> None:
super().compare(quantized_models, float_model, input_x, quantization_info)
post_substitution_nodes = quantized_models['no_quantization'].node_sort
self._test_substitution_structure_output(post_substitution_nodes)

0 comments on commit f6ada2f

Please sign in to comment.