Skip to content

Commit

Permalink
Improve parsing of functional layers arguments. Support nn.functional…
Browse files Browse the repository at this point in the history
….layer_norm and substract
  • Loading branch information
Idan-BenAmi committed Jan 16, 2024
1 parent 28d412d commit aad03b4
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def __init__(self,
quantization_attr: Dict[str, Any] = None,
functional_op: Any = None,
inputs_as_list: bool = False,
op_call_args_order: List[Any] = None,
has_activation: bool = True):
"""
Init a FunctionalNode object.
Expand All @@ -43,7 +42,7 @@ def __init__(self,
functional_op: The op the node implements.
inputs_as_list: Whether to pass the node its input tensors as a list or not when calling the layer.
has_activation: Whether the node has activations that we might want to quantize.
op_call_args_order: List of indexes to sort the inputs + arguments when using the functional operation.
"""

super().__init__(name,
Expand All @@ -61,7 +60,6 @@ def __init__(self,
self.op_call_args = op_call_args
self.functional_op = functional_op
self.inputs_as_list = inputs_as_list
self.op_call_args_order = op_call_args_order

@property
def type(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,8 @@ def _run_operation(n: BaseNode,

op_call_args = n.op_call_args if isinstance(n, FunctionalNode) else []
functional_kwargs = n.op_call_kwargs if isinstance(n, FunctionalNode) else {}

# If function inputs and arguments have the following pattern: op_func([inputs list], arg1, arg2,...)
if isinstance(n, FunctionalNode) and n.inputs_as_list:
out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs)

# If function inputs and arguments requires specific order, sort it before the call
elif isinstance(n, FunctionalNode) and n.op_call_args_order and not (len(n.op_call_args_order) > len(input_tensors + op_call_args)):
args = [(input_tensors + op_call_args)[ind] for ind in n.op_call_args_order]
out_tensors_of_n_float = op_func(*args, **functional_kwargs)

# Assume the rest of the cases correspond the following pattern: op_func(input1, input2, arg1, arg2,...)
else:
out_tensors_of_n_float = op_func(*input_tensors + op_call_args, **functional_kwargs)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import torch
from torch import nn
import torch.nn.functional as F

from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
from model_compression_toolkit.core import common
from model_compression_toolkit.core.common import BaseNode, Graph
from model_compression_toolkit.core.pytorch.constants import *
from model_compression_toolkit.logger import Logger


class FunctionalLayerNorm(common.BaseSubstitution):
"""
Replace functional batch_norm with BatchNorm2d.
"""

def __init__(self):
"""
Matches: functional batch_norm
"""
bn_node = NodeOperationMatcher(F.layer_norm)
super().__init__(matcher_instance=bn_node)

def get_attributes_from_inputs(self, graph: Graph, node: BaseNode, normalized_shape) -> dict:
input_nodes = graph.get_prev_nodes(node, sink_index_sorted=True)

if len(input_nodes) == 3:
return {
GAMMA: list(input_nodes[1].weights.values())[0],
BETA: list(input_nodes[2].weights.values())[0]
}
else:
Logger.warning(f'functional batch_norm is only folded in the 5 inputs case (input, mean, var, gamma, beta),'
f'got {len(input_nodes)}')
return {
GAMMA: nn.Parameter(torch.ones(normalized_shape)),
BETA: nn.Parameter(torch.ones(normalized_shape))
}

def substitute(self,
graph: Graph,
node: BaseNode) -> Graph:
"""
Substitute functional.batch_norm and its inputs with BatchNorm2d.
Args:
graph: Graph we apply the substitution on.
node: node that match the pattern in the substitution init.
Returns:
Graph after applying the substitution.
"""
# if the input is not a 4D tensor, we can't substitute it with BatchNorm2d
# if len(node.input_shape[0]) != 4:
# return graph
normalized_shape = node.input_shape[0][-1]

bn_node_weights = self.get_attributes_from_inputs(graph, node, normalized_shape)
if not bn_node_weights:
return graph
new_layernorm = BaseNode(name=node.name + '_into_LayerNorm2d',
framework_attr={'normalized_shape': normalized_shape,
EPSILON: EPSILON_VAL,
'elementwise_affine': True,
},
input_shape=node.output_shape,
output_shape=node.output_shape,
weights=bn_node_weights,
layer_class=nn.LayerNorm)

num_nodes_before_substitution = len(graph.nodes)
num_edges_before_substitution = len(graph.edges)

batch_norm_consts = graph.get_prev_nodes(node)[1:]
for const in batch_norm_consts:
graph.remove_edge(const, node)
graph.remove_node(const)

graph.replace_node(node, new_layernorm)

assert num_nodes_before_substitution - len(graph.nodes) == len(batch_norm_consts)
assert num_edges_before_substitution - len(graph.edges) == len(batch_norm_consts)

return graph
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
pytorch_batchnorm_refusing
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.functional_batch_norm import \
FunctionalBatchNorm
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.functional_layer_norm import \
FunctionalLayerNorm
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.linear_collapsing import \
pytorch_linear_collapsing
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.multi_head_attention_decomposition \
Expand Down Expand Up @@ -246,7 +248,8 @@ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List
MultiHeadAttentionDecomposition(),
PermuteCallMethod(),
ConstantHolderConv(fw_info),
FunctionalBatchNorm()]
FunctionalBatchNorm(),
FunctionalLayerNorm()]

def get_substitutions_pre_statistics_collection(self,
quant_config: QuantizationConfig
Expand Down
35 changes: 5 additions & 30 deletions model_compression_toolkit/core/pytorch/reader/graph_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
from model_compression_toolkit.core.common.graph.edge import Edge
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
from model_compression_toolkit.core.pytorch.constants import OUTPUT, PLACEHOLDER, TENSOR_META, CALL_FUNCTION, TYPE, \
CALL_METHOD, BIAS, FUNCTIONAL_OP, OP_CALL_KWARGS, OP_CALL_ARGS, INPUTS_AS_LIST, GET_ATTR, CONSTANT, BUFFER, \
OP_CALL_ARGS_ORDER
CALL_METHOD, BIAS, FUNCTIONAL_OP, OP_CALL_KWARGS, OP_CALL_ARGS, INPUTS_AS_LIST, GET_ATTR, CONSTANT, BUFFER
from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder, ConstantHolder, BufferHolder
from model_compression_toolkit.logger import Logger

Expand Down Expand Up @@ -177,43 +176,20 @@ def nodes_builder(model: GraphModule,
# initiate graph nodes
if node.op in [CALL_METHOD, CALL_FUNCTION]:
graph_node_type = FunctionalNode

# Check if function inputs and arguments have the following pattern: op_func([inputs list], arg1, arg2,...)
inputs_as_list1 = len(node.args) > 0 and isinstance(node.args[0], (list, tuple)) and all(
[isinstance(n, torch.fx.node.Node) for n in node.args[0]])
inputs_as_list = inputs_as_list1 or \
(len(node.args) > 0 and hasattr(node.args[0], 'op') and node.args[0].op == PLACEHOLDER and node.args[0].meta[TYPE] in (list, tuple))

# If the pattern as above, list the arguments in op_call_args = [arg1, arg2, ...]
op_call_args_order = []
(len(node.args) > 0 and node.args[0].op == PLACEHOLDER and node.args[0].meta[TYPE] in (list, tuple))
if inputs_as_list:
num_inputs = 1
op_call_args = list(node.args[num_inputs:])
else:
# Otherwise, we need to check which of the "node.args" are inputs and which are arguments.
input_counter = 0
for in_node in node.all_input_nodes:
for arg in node.args:
if arg == in_node:
input_counter += 1
num_inputs = max(len(node.all_input_nodes), input_counter)

# If there are arguments in addition to the inputs, set op_call_args = [arg1, arg2, ...]
if num_inputs != len(node.args):
op_call_args = [arg for arg in node.args if arg not in node.all_input_nodes]
else:
op_call_args = []

# Next, we keep the correct order of the inputs and arguments in op_call_args_order to be used later in "model_builder"
# We keep op_call_args_order empty in case it's standard pattern op_func(input1, input2, arg1, arg2,...)
sorted_args = list(node.args) + node.all_input_nodes[input_counter:]
# for arg in node.args:
# if isinstance(arg, (List, Tuple)):
# sorted_args = list(node.args)
all_args = node.all_input_nodes + op_call_args
if op_call_args and not sorted_args==all_args:
for a in sorted_args:
op_call_args_order.append(all_args.index(a))
op_call_args = list(node.args[num_inputs:])

# remove torch.fx.node.Node from inputs to graph_node_type
for arg in op_call_args:
Expand All @@ -223,8 +199,7 @@ def nodes_builder(model: GraphModule,
kwargs = {FUNCTIONAL_OP: node_type,
OP_CALL_ARGS: op_call_args,
OP_CALL_KWARGS: node_kwargs,
INPUTS_AS_LIST: inputs_as_list,
OP_CALL_ARGS_ORDER: op_call_args_order}
INPUTS_AS_LIST: inputs_as_list}
else:
graph_node_type = BaseNode
kwargs = {}
Expand Down Expand Up @@ -289,4 +264,4 @@ def edges_builder(model: GraphModule,
edges.append(
Edge(fx_node_2_graph_node[node], fx_node_2_graph_node[out_node], src_index, dst_index))

return edges
return edges

0 comments on commit aad03b4

Please sign in to comment.