Skip to content

Commit

Permalink
unite small functions in graph_builders
Browse files Browse the repository at this point in the history
  • Loading branch information
reuvenp committed Sep 24, 2024
1 parent af16275 commit 5ca8673
Showing 1 changed file with 9 additions and 48 deletions.
57 changes: 9 additions & 48 deletions model_compression_toolkit/core/pytorch/reader/graph_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def nodes_builder(model: GraphModule,
# Check if this node's target has been seen before
reuse = False
reuse_group = None
node_group_key = create_reuse_group(node.target, generate_weight_signature(weights))
node_group_key = create_reuse_group(node.target, weights)
# We mark nodes as reused only if there are multiple nodes in the graph with same
# 'target' and it has some weights.
if node_group_key in seen_targets and len(weights) > 0:
Expand Down Expand Up @@ -393,63 +393,24 @@ def edges_builder(model: GraphModule,

return edges

from typing import Dict, Tuple, Optional, Any, Set

def generate_weight_signature(weights: Dict[str, Any]) -> Optional[Tuple[int, ...]]:
def create_reuse_group(target: Any, weights: Dict[str, Any]) -> str:
"""
Create a unique signature for the weights based on their instance IDs.
This function generates a tuple of sorted weight IDs, which serves as a unique
signature for a set of weights. This is useful for identifying identical weight
instances, particularly in the case of reused functional layers.
Args:
weights (Dict[str, Any]): A dictionary of weight names to weight values.
The values can be any type (typically tensors or arrays).
Returns:
Optional[Tuple[int, ...]]: A tuple of sorted weight IDs if weights are present,
or None if the weights dictionary is empty.
"""
if not weights:
return None
weight_ids = tuple(sorted(id(weight) for weight in weights.values()))
return weight_ids

def create_reuse_group(target: Any, weight_signature: Optional[Tuple[int, ...]]) -> str:
"""
Combine target and weight signature to create a unique reuse group identifier.
Combine target and weights to create a unique reuse group identifier.
This function creates a unique string identifier for a reuse group by combining
the target (typically a layer or operation name) with the weight signature.
the target (typically a layer or operation name) with the weights IDs.
Args:
target (Any): The target of the node, typically a string or callable representing
a layer or operation.
weight_signature (Optional[Tuple[int, ...]]): The weight signature generated by
generate_weight_signature function.
weights (Dict[str, Any]): A dictionary of weight names to weight values.
The values can be any type (typically tensors or arrays).
Returns:
str: A unique string identifier for the reuse group.
"""
if weight_signature is None:
if not weights:
return str(target)
return f"{target}_{weight_signature}"

def is_node_reusable(reuse_group: str, seen_targets: Set[str]) -> bool:
"""
Determine if the node is reusable based on its reuse group.
This function checks if a node's reuse group has been seen before, indicating
that the node is potentially reusable.
Args:
reuse_group (str): The reuse group identifier for the node, created by
the create_reuse_group function.
seen_targets (Set[str]): A set of previously seen reuse group identifiers.
Returns:
bool: True if the node is reusable (its reuse group has been seen before),
False otherwise.
"""
return reuse_group in seen_targets
weight_ids = tuple(sorted(id(weight) for weight in weights.values()))
return f"{target}_{weight_ids}"

0 comments on commit 5ca8673

Please sign in to comment.