Skip to content

Commit

Permalink
Merge branch 'refs/heads/main' into substitution/scaled_dot_product_a…
Browse files Browse the repository at this point in the history
…ttention
  • Loading branch information
yardeny-sony committed Sep 25, 2024
2 parents e2e3e6f + a0aa989 commit 671eb4b
Show file tree
Hide file tree
Showing 30 changed files with 843 additions and 469 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run_pytorch_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install torch==${{ inputs.torch-version }} torchvision onnx onnxruntime
pip install torch==${{ inputs.torch-version }} torchvision onnx onnxruntime onnxruntime-extensions
pip install pytest
- name: Run unittests
run: |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,10 @@ def fetch_hessian(self,
f"{hessian_scores_request.target_nodes}.")

# Replace node in reused target nodes with a representing node from the 'reuse group'.
for n in hessian_scores_request.target_nodes:
if n.reuse_group:
rep_node = self._get_representing_of_reuse_group(n)
hessian_scores_request.target_nodes.remove(n)
if rep_node not in hessian_scores_request.target_nodes:
hessian_scores_request.target_nodes.append(rep_node)
hessian_scores_request.target_nodes = [
self._get_representing_of_reuse_group(node) if node.reuse else node
for node in hessian_scores_request.target_nodes
]

# Ensure the saved info has the required number of approximations
self._populate_saved_info_to_size(hessian_scores_request, required_size, batch_size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,14 @@ def _run_operation(n: BaseNode,
else:
out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs)
else:
merged_inputs, functional_kwargs = _merge_inputs(n, input_tensors, op_call_args, functional_kwargs.copy(),
tensor_input_allocs=_tensor_input_allocs)
out_tensors_of_n_float = op_func(*merged_inputs, **functional_kwargs)
if isinstance(op_func, PytorchQuantizationWrapper) and isinstance(n, FunctionalNode) and n.functional_op is not torch.gather:
# in wrapped nodes, the op args & kwargs are already in the PytorchQuantizationWrapper.
# Temporary patch: for torch.gather this is not the case, so need to merge inputs.
out_tensors_of_n_float = op_func(*input_tensors)
else:
merged_inputs, functional_kwargs = _merge_inputs(n, input_tensors, op_call_args, functional_kwargs.copy(),
tensor_input_allocs=_tensor_input_allocs)
out_tensors_of_n_float = op_func(*merged_inputs, **functional_kwargs)

# Add a fake quant node if the node has an activation threshold.
out_tensors_of_n = out_tensors_of_n_float
Expand Down Expand Up @@ -226,6 +231,7 @@ def __init__(self,
self.return_float_outputs = return_float_outputs
self.wrapper = wrapper
self.get_activation_quantizer_holder = get_activation_quantizer_holder_fn
self.reuse_groups = {}
self._add_modules()

# todo: Move to parent class BaseModelBuilder
Expand Down Expand Up @@ -283,7 +289,19 @@ def _add_modules(self):
Build and add the modules and functional nodes from node_sort list as attributes to PytorchModel
"""
for node in self.node_sort:
node_op = self.wrap(node)
if node.reuse:
# If the node is reused, retrieve the original module
if node.reuse_group not in self.reuse_groups:
Logger.critical(f"Reuse group {node.reuse_group} not found for node {node.name}")

node_op = self.reuse_groups[node.reuse_group]
else:
# If it's not reused, create a new module
node_op = self.wrap(node)
if node.reuse_group:
# Store the module for future reuse
self.reuse_groups[node.reuse_group] = node_op

if isinstance(node, FunctionalNode):
# for functional layers
setattr(self, node.name, node_op)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def substitute(self,
output_shape=func_node.output_shape,
weights={KERNEL: weight} if bias is None else {KERNEL: weight, BIAS: bias},
layer_class=new_layer,
has_activation=func_node.has_activation)
has_activation=func_node.has_activation,
reuse=func_node.reuse,
reuse_group=func_node.reuse_group)
graph.add_node(new_node)
graph.reconnect_out_edges(current_node=func_node, new_node=new_node)
graph.reconnect_in_edges(current_node=func_node, new_node=new_node)
Expand Down
72 changes: 62 additions & 10 deletions model_compression_toolkit/core/pytorch/reader/graph_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
from model_compression_toolkit.logger import Logger


def _extract_parameters_and_buffers(module: Union[torch.nn.Module, GraphModule],
to_numpy: Callable) -> Dict[str, np.ndarray]:
def _extract_parameters_and_buffers(module: Union[torch.nn.Module, GraphModule]) -> Dict[str, np.ndarray]:
"""
Extract parameters & buffers from input module to a dictionary.
Args:
Expand All @@ -41,8 +40,8 @@ def _extract_parameters_and_buffers(module: Union[torch.nn.Module, GraphModule],
Dictionary containing module parameters and buffers by name.
"""

named_parameters = {name: to_numpy(parameter) for name, parameter in module.named_parameters()}
named_buffers = {name: to_numpy(buffer) for name, buffer in module.named_buffers()}
named_parameters = {name: parameter for name, parameter in module.named_parameters()}
named_buffers = {name: buffer for name, buffer in module.named_buffers()}

return {**named_parameters, **named_buffers}

Expand Down Expand Up @@ -97,14 +96,12 @@ def _build_input_alloc_and_call_args(n: Node, input_tensors_in_node_kwargs: Dict
return op_call_args, tensor_input_alloc


def _extract_torch_layer_data(node_module: torch.nn.Module,
to_numpy: Callable) -> Tuple[Any, Dict[str, np.ndarray], Dict]:
def _extract_torch_layer_data(node_module: torch.nn.Module) -> Tuple[Any, Dict[str, np.ndarray], Dict]:
"""
Extract required data from a non-functional node to rebuild the PyTorch layer.
Args:
node_module: Torch layer, such as nn.Conv2d, nn.Linear, etc.
to_numpy: Function to convert framework's tensor to a Numpy array.
Returns:
Node layer class.
Expand All @@ -124,7 +121,7 @@ def _extract_torch_layer_data(node_module: torch.nn.Module,
framework_attr[BIAS] = False if node_module.bias is None else True

# Extract layer weights and named buffers.
weights = {n: w for n, w in _extract_parameters_and_buffers(node_module, to_numpy).items() if len(w.shape) > 0}
weights = {n: w for n, w in _extract_parameters_and_buffers(node_module).items() if len(w.shape) > 0}
return node_type, weights, framework_attr


Expand Down Expand Up @@ -181,8 +178,11 @@ def nodes_builder(model: GraphModule,
consts_dict = {}
used_consts = set()

# Dictionary to track seen targets and their corresponding nodes to mark reused nodes
seen_targets = {}

# Init parameters & buffers dictionary of the entire model. We later extract the constants values from this dictionary.
model_parameters_and_buffers = _extract_parameters_and_buffers(model, to_numpy)
model_parameters_and_buffers = _extract_parameters_and_buffers(model)

for node in model.graph.nodes:

Expand All @@ -195,7 +195,7 @@ def nodes_builder(model: GraphModule,

if node.target in module_dict.keys():
# PyTorch module node, such as nn.Conv2d or nn.Linear.
node_type, weights, framework_attr = _extract_torch_layer_data(module_dict[node.target], to_numpy)
node_type, weights, framework_attr = _extract_torch_layer_data(module_dict[node.target])

elif node.op == CALL_FUNCTION:
# Node is a function that handle a parameter\buffer in the model.
Expand Down Expand Up @@ -249,6 +249,31 @@ def nodes_builder(model: GraphModule,
# Extract input and output shapes of the node.
input_shape, output_shape = _extract_input_and_output_shapes(node)

# Check if this node's target has been seen before
reuse = False
reuse_group = None
node_group_key = create_reuse_group(node.target, weights)
# We mark nodes as reused only if there are multiple nodes in the graph with same
# 'target' and it has some weights.
if node_group_key in seen_targets and len(weights) > 0:
reuse = True
reuse_group = node_group_key
# Update the 'base/main' node with the reuse group as all other nodes in its group.
fx_node_2_graph_node[seen_targets[node_group_key]].reuse_group = reuse_group
else:
seen_targets[node_group_key] = node

# Convert weights to numpy arrays after reuse marking
# We delay this conversion to preserve the original tensor instances during the reuse identification process.
# This is crucial for correctly identifying identical weight instances in reused functional layers.
# By keeping the original PyTorch tensors until this point, we ensure that:
# 1. Reused layers with the same weight instances are correctly marked as reused.
# 2. The instance-based weight signature generation works as intended, using the memory
# addresses of the original tensors.
# Only after all reuse marking is complete do we convert to numpy arrays.
for weight_name, weight_value in weights.items():
weights[weight_name] = to_numpy(weight_value)

# Initiate graph nodes.
if node.op in [CALL_METHOD, CALL_FUNCTION]:
graph_node_type = FunctionalNode
Expand Down Expand Up @@ -300,6 +325,8 @@ def nodes_builder(model: GraphModule,
weights=weights,
layer_class=node_type,
has_activation=node_has_activation,
reuse=reuse,
reuse_group=reuse_group,
**kwargs)

# Generate graph inputs list.
Expand Down Expand Up @@ -365,3 +392,28 @@ def edges_builder(model: GraphModule,
Edge(fx_node_2_graph_node[node], fx_node_2_graph_node[out_node], src_index, dst_index))

return edges


def create_reuse_group(target: Any, weights: Dict[str, Any]) -> str:
"""
Combine target and weights to create a unique reuse group identifier.
We consider the weights as part of the group identifier because they are not part of
the module in functional layers, but if a functional layer is using the same weights multiple
times it is considered to be reused.
This function creates a unique string identifier for a reuse group by combining
the target (typically a layer or operation name) with the weights IDs.
Args:
target (Any): The target of the node, typically a string or callable representing
a layer or operation.
weights (Dict[str, Any]): A dictionary of weight names to weight values.
The values can be any type (typically tensors or arrays).
Returns:
str: A unique string identifier for the reuse group.
"""
if not weights:
return str(target)
weight_ids = tuple(sorted(id(weight) for weight in weights.values()))
return f"{target}_{weight_ids}"
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@ def export(self) -> None:
else:
Logger.info(f"Exporting fake-quant onnx model: {self.save_model_path}")

model_input = to_torch_tensor(next(self.repr_dataset())[0])
model_input = to_torch_tensor(next(self.repr_dataset()))

if hasattr(self.model, 'metadata'):
onnx_bytes = BytesIO()
torch.onnx.export(self.model,
model_input,
tuple(model_input) if isinstance(model_input, list) else model_input,
onnx_bytes,
opset_version=self._onnx_opset_version,
verbose=False,
Expand All @@ -107,7 +107,7 @@ def export(self) -> None:
onnx.save_model(onnx_model, self.save_model_path)
else:
torch.onnx.export(self.model,
model_input,
tuple(model_input) if isinstance(model_input, list) else model_input,
self.save_model_path,
opset_version=self._onnx_opset_version,
verbose=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ def fully_quantized_wrapper(node: common.BaseNode,
for attr in weight_quantizers if isinstance(attr, int)}
# When wrapping functional nodes, need to set call args\kwargs in wrapper, because they
# are used during wrapper call method.
# Temporary patch: for torch.gather this is not the case, so args & kwargs shouldn't be
# saved in the warpper.
func_node_kwargs = {OP_CALL_ARGS: node.op_call_args,
OP_CALL_KWARGS: node.op_call_kwargs
} if isinstance(node, FunctionalNode) else {}
} if isinstance(node, FunctionalNode) and not node.functional_op is torch.gather else {}
return PytorchQuantizationWrapper(module, weight_quantizers, weights_values,
is_inputs_as_list=node.inputs_as_list,
**func_node_kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,23 @@

tp = mct.target_platform

OPSET_NO_QUANTIZATION = "NoQuantization"
OPSET_QUANTIZATION_PRESERVING = "QuantizationPreserving"
OPSET_DIMENSION_MANIPULATION_OPS_WITH_WEIGHTS = "DimensionManipulationOpsWithWeights"
OPSET_DIMENSION_MANIPULATION_OPS = "DimensionManipulationOps"
OPSET_MERGE_OPS = "MergeOps"
OPSET_CONV = "Conv"
OPSET_FULLY_CONNECTED = "FullyConnected"
OPSET_ANY_RELU = "AnyReLU"
OPSET_ADD = "Add"
OPSET_SUB = "Sub"
OPSET_MUL = "Mul"
OPSET_DIV = "Div"
OPSET_PRELU = "PReLU"
OPSET_SWISH = "Swish"
OPSET_SIGMOID = "Sigmoid"
OPSET_TANH = "Tanh"


def get_tp_model() -> TargetPlatformModel:
"""
Expand Down Expand Up @@ -189,6 +206,10 @@ def generate_tp_model(default_config: OpQuantizationConfig,
const_config_input16_per_tensor],
base_config=const_config_input16_per_tensor)

qpreserving_const_config = const_config.clone_and_edit(enable_activation_quantization=False,
quantization_preserving=True)
qpreserving_const_config_options = tp.QuantizationConfigOptions([qpreserving_const_config])

# Create a TargetPlatformModel and set its default quantization config.
# This default configuration will be used for all operations
# unless specified otherwise (see OperatorsSet, for example):
Expand All @@ -207,39 +228,40 @@ def generate_tp_model(default_config: OpQuantizationConfig,

# May suit for operations like: Dropout, Reshape, etc.
default_qco = tp.get_default_quantization_config_options()
tp.OperatorsSet("NoQuantization",
tp.OperatorsSet(OPSET_NO_QUANTIZATION,
default_qco.clone_and_edit(enable_activation_quantization=False)
.clone_and_edit_weight_attribute(enable_weights_quantization=False))
tp.OperatorsSet("QuantizationPreserving",
tp.OperatorsSet(OPSET_QUANTIZATION_PRESERVING,
default_qco.clone_and_edit(enable_activation_quantization=False,
quantization_preserving=True)
.clone_and_edit_weight_attribute(enable_weights_quantization=False))
tp.OperatorsSet("DimensionManipulationOps",
tp.OperatorsSet(OPSET_DIMENSION_MANIPULATION_OPS_WITH_WEIGHTS, qpreserving_const_config_options)
tp.OperatorsSet(OPSET_DIMENSION_MANIPULATION_OPS,
default_qco.clone_and_edit(enable_activation_quantization=False,
quantization_preserving=True,
supported_input_activation_n_bits=(8, 16))
.clone_and_edit_weight_attribute(enable_weights_quantization=False))
tp.OperatorsSet("MergeOps", const_configuration_options_inout16_per_tensor)
tp.OperatorsSet(OPSET_MERGE_OPS, const_configuration_options_inout16_per_tensor)

# Create Mixed-Precision quantization configuration options from the given list of OpQuantizationConfig objects
mixed_precision_configuration_options = tp.QuantizationConfigOptions(mixed_precision_cfg_list,
base_config=base_config)

# Define operator sets that use mixed_precision_configuration_options:
conv = tp.OperatorsSet("Conv", mixed_precision_configuration_options)
fc = tp.OperatorsSet("FullyConnected", mixed_precision_configuration_options)
conv = tp.OperatorsSet(OPSET_CONV, mixed_precision_configuration_options)
fc = tp.OperatorsSet(OPSET_FULLY_CONNECTED, mixed_precision_configuration_options)

# Define operations sets without quantization configuration
# options (useful for creating fusing patterns, for example):
any_relu = tp.OperatorsSet("AnyReLU")
add = tp.OperatorsSet("Add", const_configuration_options_inout16)
sub = tp.OperatorsSet("Sub", const_configuration_options_inout16)
mul = tp.OperatorsSet("Mul", const_configuration_options_inout16)
div = tp.OperatorsSet("Div", const_configuration_options)
prelu = tp.OperatorsSet("PReLU")
swish = tp.OperatorsSet("Swish")
sigmoid = tp.OperatorsSet("Sigmoid")
tanh = tp.OperatorsSet("Tanh")
any_relu = tp.OperatorsSet(OPSET_ANY_RELU)
add = tp.OperatorsSet(OPSET_ADD, const_configuration_options_inout16)
sub = tp.OperatorsSet(OPSET_SUB, const_configuration_options_inout16)
mul = tp.OperatorsSet(OPSET_MUL, const_configuration_options_inout16)
div = tp.OperatorsSet(OPSET_DIV, const_configuration_options)
prelu = tp.OperatorsSet(OPSET_PRELU)
swish = tp.OperatorsSet(OPSET_SWISH)
sigmoid = tp.OperatorsSet(OPSET_SIGMOID)
tanh = tp.OperatorsSet(OPSET_TANH)

# Combine multiple operators into a single operator to avoid quantization between
# them. To do this we define fusing patterns using the OperatorsSets that were created.
Expand Down
Loading

0 comments on commit 671eb4b

Please sign in to comment.