Skip to content

Commit

Permalink
PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
elad-c committed Jun 23, 2024
1 parent d35c5da commit c15341f
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self,
functional_op: The op the node implements.
inputs_as_list: Whether to pass the node its input tensors as a list or not when calling the layer.
has_activation: Whether the node has activations that we might want to quantize.
tensor_input_allocs: A list of indices and strings for allocatins input tensors in the node's args and kwargs
tensor_input_allocs: A list of indices and strings for allocations input tensors in the node's args and kwargs.
"""

super().__init__(name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _run_operation(n: BaseNode,
input_tensors: List,
op_func: Any,
quantize_node_activation_fn,
use_activation_quantization: bool) -> Tuple[Union[List, torch.Tensor], Union[List, torch.Tensor]]:
use_activation_quantization: bool) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""
Applying the layer (op_func) to the input tensors (input_tensors).
If quantized is set to True, and the layer's corresponding node (n) has quantization
Expand All @@ -126,11 +126,11 @@ def _run_operation(n: BaseNode,
op_call_args = n.op_call_args if isinstance(n, FunctionalNode) else []
functional_kwargs = n.op_call_kwargs if isinstance(n, FunctionalNode) else {}

# Insert positional weights only when not a quantized functional node, because quantized functional nodes
# insert the quantized weights in the wrapper.
if isinstance(n, FunctionalNode) and isinstance(op_func, PytorchQuantizationWrapper):
_tensor_input_allocs = [i for i in n.tensor_input_allocs if i not in n.weights]
else:
# Insert positional weights only when not a quantized functional node, because quantized functional nodes
# insert the quantized weights in the wrapper.
input_tensors = n.insert_positional_weights_to_input_list(input_tensors)
# convert inputs from positional weights (numpy arrays) to tensors. Must handle each element in the
# list separately, because in FX the tensors are FX objects and fail to_torch_tensor
Expand All @@ -152,6 +152,8 @@ def _run_operation(n: BaseNode,
out_tensors_of_n_float = torch.cat(out_tensors_of_n_float, dim=0)
out_tensors_of_n = quantize_node_activation_fn(out_tensors_of_n_float)

if not isinstance(out_tensors_of_n, list):
out_tensors_of_n, out_tensors_of_n_float = [out_tensors_of_n], [out_tensors_of_n_float]
return out_tensors_of_n, out_tensors_of_n_float


Expand Down Expand Up @@ -318,12 +320,8 @@ def forward(self,
quantize_node_activation_fn=activation_quantization_fn,
use_activation_quantization=use_activation_quantization)

if isinstance(out_tensors_of_n, list):
node_to_output_tensors_dict.update({node: out_tensors_of_n})
node_to_output_tensors_dict_float.update({node: out_tensors_of_n_float})
else:
node_to_output_tensors_dict.update({node: [out_tensors_of_n]})
node_to_output_tensors_dict_float.update({node: [out_tensors_of_n_float]})
node_to_output_tensors_dict.update({node: out_tensors_of_n})
node_to_output_tensors_dict_float.update({node: out_tensors_of_n_float})

if self.append2output:
outputs = _generate_outputs(self.append2output,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,12 @@ def substitute(self,
# Create new node of layer convolution
if 1 not in func_node.weights:
Logger.critical(f'Weight input missing for node {func_node.name}.') # pragma: no cover
# Extract index of kernel and bias according to tensor_input_allocs if they were input as kwargs. If
# they were input as args, use their fixed positions.
weight_index = func_node.tensor_input_allocs.index(KERNEL) if KERNEL in func_node.tensor_input_allocs else 1
bias_index = func_node.tensor_input_allocs.index(BIAS) if BIAS in func_node.tensor_input_allocs else 2
if weight_index not in func_node.weights:
Logger.critical(f'Mismatch between tensor_input_allocs and weight index in node {func_node.name}.') # pragma: no cover
weight = func_node.weights[weight_index]
bias = func_node.weights.get(bias_index)
framework_attr = func_node.op_call_kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def __init__(self):
@staticmethod
def get_attributes_from_weights(node: FunctionalNode) -> Dict:
"""
convert functional batch_norm positional weights to BatchNorm2d weights
Convert functional batch_norm positional weights to BatchNorm2d weights. Extract indices of gamma
and beta according to tensor_input_allocs if they were input as kwargs. If they were input as args,
use their fixed positions.
Args:
node: functional batch_norm node.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def __init__(self):
@staticmethod
def get_attributes_from_weights(node: FunctionalNode, normalized_shape: [Tuple, List, int]) -> Dict:
"""
Parse layer_norm(input, normalized_shape, weight=None, bias=None)
Convert functional layer_norm positional weights to LayerNorm weights. Extract indices of gamma
and beta according to tensor_input_allocs if they were input as kwargs. If they were input as args,
use their fixed positions.
Args:
node: Node that match the pattern in the substitution init.
normalized_shape: nn.LayerNorm "normalized_shape" argument
Expand Down
106 changes: 76 additions & 30 deletions model_compression_toolkit/core/pytorch/reader/graph_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,24 +47,75 @@ def _extract_parameters_and_buffers(module: Union[torch.nn.Module, GraphModule],
return {**named_parameters, **named_buffers}


def is_instance_first_arg(n: Node, expected_type: Any) -> bool:
"""
Check whether first argument of the node is the expected type
Args:
n: fx node.
expected_type: Expected 1st argument type.
Returns:
"""
return len(n.args) > 0 and isinstance(n.args[0], expected_type)


def _build_input_alloc_and_call_args(n: Node, input_tensors_in_node_kwargs: Dict,
inputs_as_list: bool) -> Tuple[List, List]:
"""
Build the tensor inputs list and op_call_args of the functional node.
Args:
n: fx node.
input_tensors_in_node_kwargs: A dictionary of node kwarg name and input fx node.
inputs_as_list: Is node's inputs are a list.
Returns:
A list of updated op_call args.
A list of tensor allocations in node's inputs.
"""

tensor_input_alloc = []
op_call_args = list(n.args)
if inputs_as_list:
op_call_args.pop(0)
else:
for in_node in n.all_input_nodes:
# The extra for loop is used to tackle the case of the same input tensor for this node (e.g. torch.add(x, x)).
for i, arg in enumerate(n.args):
if arg == in_node:
tensor_input_alloc.append(i)
for k, arg in input_tensors_in_node_kwargs.items():
if arg == in_node:
tensor_input_alloc.append(k)

return op_call_args, tensor_input_alloc


def _extract_torch_layer_data(node_module: torch.nn.Module,
to_numpy: Callable) -> Tuple[Any, Dict[str, np.ndarray], Dict]:
"""
Extract required data from the node to rebuild the PyTorch layer.
Extract required data from a non-functional node to rebuild the PyTorch layer.
Args:
node_module: Torch layer, such as nn.Conv2d, nn.Linear, etc.
to_numpy: Function to convert framework's tensor to a Numpy array.
Returns:
Node layer class.
PyTorch layer weights and named buffers to a weights dictionary.
framework_attr dictionary required to rebuild the layer with the layer class.
A mapping between the layer's named parameters and buffers to their tensor values.
A framework_attr dictionary required to instantiate the node with the layer class.
"""
node_type = type(node_module)
framework_attr = node_module.__dict__
if not isinstance(node_module, torch.nn.Module):
Logger.error(f"Expected an instance of torch.nn.Module for node {node_module.name}, but got {node_type}")
# Extract the instance framework_attr (i.e. the arguments the class instance was initialized with). "fullargspec"
# is a list of the layer's attribute names, that will be used as keys of the framework_attr dictionary. We the
# values from the layer instance.
fullargspec = inspect.getfullargspec(node_type.__init__).args
framework_attr = {k: v for k, v in framework_attr.items() if k in fullargspec}
framework_attr = {k: v for k, v in node_module.__dict__.items() if k in fullargspec}
# The "bias" argument doesn't appear in the node_module.__dict__, so we add it manually.
if hasattr(node_module, BIAS) and BIAS in fullargspec:
framework_attr[BIAS] = False if node_module.bias is None else True

Expand Down Expand Up @@ -126,7 +177,7 @@ def nodes_builder(model: GraphModule,
consts_dict = {}
used_consts = set()

# Init parameters & buffers dictionary
# Init parameters & buffers dictionary of the entire model. We later extract the constants values from this dictionary.
model_parameters_and_buffers = _extract_parameters_and_buffers(model, to_numpy)

for node in model.graph.nodes:
Expand Down Expand Up @@ -189,37 +240,32 @@ def nodes_builder(model: GraphModule,
if node.op in [CALL_METHOD, CALL_FUNCTION]:
graph_node_type = FunctionalNode

# Filter nodes from node_kwargs, we replace these attributes with nx graph nodes
node_kwargs, nodes_in_kwargs = {}, {}
# Filter FX nodes from node_kwargs. These FX nodes are tensor inputs to the node that are part of the
# model's graph. We remove them because the node_kwargs should not include input tensors of the node.
# These input tensors will be inserted in the kwargs according to the tensor_input_alloc which is used
# to convert the input_tensors list in the builder to the node's args & kwargs.
node_kwargs, input_tensors_in_node_kwargs = {}, {}
for k, v in node.kwargs.items():
if isinstance(v, Node):
nodes_in_kwargs[k] = v
input_tensors_in_node_kwargs[k] = v
else:
node_kwargs[k] = v

# Check if node inputs is a list of nodes, such as torch.cat.
inputs_as_list1 = len(node.args) > 0 and isinstance(node.args[0], (list, tuple)) and all(
# Check if node's first input argument is a list of input fx nodes, such as torch.cat:
is_first_input_list_of_nodes = is_instance_first_arg(node, (list, tuple)) and all(
[isinstance(n, Node) for n in node.args[0]])
inputs_as_list = inputs_as_list1 or (len(node.args) > 0 and isinstance(node.args[0], Node) and
node.args[0].op == PLACEHOLDER and node.args[0].meta[TYPE] in (list, tuple))
is_placeholder_a_list = is_instance_first_arg(node, Node) and \
node.args[0].op == PLACEHOLDER and node.args[0].meta[TYPE] in (list, tuple)
inputs_as_list = is_first_input_list_of_nodes or is_placeholder_a_list

# Build tensor_input_alloc required for the model builder. All input nodes are received as a list in the builder,
# so tensor_input_alloc is used to allocate each input tensor in the correct place in the node's args & kwargs.
tensor_input_alloc = []
op_call_args = list(node.args)
if inputs_as_list:
op_call_args.pop(0)
else:
for in_node in node.all_input_nodes:
# The extra for loop is used to tackle the case of the same input tensor for this node (e.g. torch.add(x, x)).
for i, arg in enumerate(node.args):
if arg == in_node:
tensor_input_alloc.append(i)
for k, arg in nodes_in_kwargs.items():
if arg == in_node:
tensor_input_alloc.append(k)

# Remove torch.fx.node.Node from inputs to the functional node. FX nodes are input tensors in the builder.
op_call_args, tensor_input_alloc = _build_input_alloc_and_call_args(node, input_tensors_in_node_kwargs,
inputs_as_list)

# Remove torch.fx.node.Node from inputs to the functional node. FX nodes are input tensors in the builder,
# so they are remove from the op_call_args (same as op_call_kwargs) and are inserted back according to the
# tensor_input_alloc list.
op_call_args = [arg for arg in op_call_args if not isinstance(arg, Node)]
# Convert torch.fx.immutable_collections.immutable_list to tuple.
op_call_args = [tuple(arg) if isinstance(arg, torch.fx.immutable_collections.immutable_list) else arg
Expand All @@ -232,7 +278,7 @@ def nodes_builder(model: GraphModule,
TENSOR_INPUT_ALLOCS: tensor_input_alloc}
else:
if not all([not isinstance(v, Node) for v in framework_attr.values()]):
Logger.error(f'Found FX nodes in framework attributes of {node.name}. This node type should not contain any.')
Logger.critical(f'Found FX nodes in framework attributes of {node.name}. This node type should not contain any.') # pragma: no cover

graph_node_type = BaseNode
kwargs = {}
Expand All @@ -257,7 +303,7 @@ def nodes_builder(model: GraphModule,
# Check whether all extracted constants were used in the graph.
not_connected_consts = [c for c in consts_dict if c not in used_consts]
if not_connected_consts:
Logger.warning(f'Error reading graph: These constants are not connected in the graph: {not_connected_consts}.')
Logger.critical(f'Error reading graph: These constants are not connected in the graph: {not_connected_consts}.') # pragma: no cover

# Generate graph outputs list.
for node in output_nodes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def compare(self, quantized_models, float_model, input_x=None, quantization_info
self.unit_test.assertTrue(self.fold_applied is not is_bn_in_model)

# TODO: remove this and set atol=1e-5 for all tests after fixing the issue with ConvTranspose2d.
# https://github.com/sony/model_optimization/issues/1115
is_convtranspose2d_in_model = nn.ConvTranspose2d in [type(module) for name, module in quant_model.named_modules()]
atol = 1e-3 if is_convtranspose2d_in_model else 1e-5
self.unit_test.assertTrue(np.isclose(out_quant, out_float, atol=atol, rtol=1e-4).all())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ def __init__(self, input_shape):
self.bn = nn.BatchNorm2d(16)
self.register_buffer('sub_const', 10 * torch.rand((1, 16, 64)))

# Register unused buffer and parameter to test that FX doesn't include them in the FX graph.
self.register_buffer('unused_buffer', 10 * torch.rand((1, 16, 64)))
self.register_parameter('unused_parameter', nn.Parameter(10 * torch.rand((1, 16, 64))))

def forward(self, x):
_shape = x.shape[2:]
x = self.conv2d(input=x)
Expand Down

0 comments on commit c15341f

Please sign in to comment.