Skip to content

Commit

Permalink
add simd padding to tpc
Browse files Browse the repository at this point in the history
  • Loading branch information
reuvenp committed Dec 3, 2023
1 parent 54e6fb3 commit 0e69771
Show file tree
Hide file tree
Showing 11 changed files with 91 additions and 32 deletions.
2 changes: 1 addition & 1 deletion model_compression_toolkit/core/common/graph/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ def get_pruning_sections(self, fw_info, fw_impl) -> List[PruningSection]:
input_section_node, intermediate_nodes, output_section_node = self.get_section_nodes(prunable_node, fw_impl)
pruning_sections.append(PruningSection(entry_node=input_section_node,
intermediate_nodes=intermediate_nodes,
exit_nodes=output_section_node))
exit_node=output_section_node))

return pruning_sections

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def __init__(self,
score_by_node: Dict[BaseNode, np.ndarray],
target_kpi: KPI,
graph,
fw_impl):
fw_impl,
tpc):
"""
Initializes the GreedyMaskCalculator with the required information.
Expand All @@ -36,6 +37,7 @@ def __init__(self,
self.target_kpi = target_kpi
self.graph = graph
self.fw_impl = fw_impl
self.tpc = tpc

# Initialize the SIMD group indices and scores dictionaries.
self.simd_groups_indices = {}
Expand Down Expand Up @@ -77,7 +79,8 @@ def _compute_mask(self):

# Iteratively prune the graph while monitoring the memory footprint.
current_memory = self.memory_calculator.get_pruned_graph_memory(masks=self.mask,
fw_impl=self.fw_impl)
fw_impl=self.fw_impl,
include_null_channels=self.tpc)
if current_memory > self.target_kpi.weights_memory:
Logger.error(f"Minimal required memory is {current_memory}, but target KPI is {self.target_kpi.weights_memory}")

Expand All @@ -86,7 +89,9 @@ def _compute_mask(self):
# Select the best SIMD group to add based on the scores.
node_to_remain, group_to_remain_idx = self._get_best_simd_group_candidate()
self._update_simd_mask(node=node_to_remain, group_index=group_to_remain_idx, value=1)
current_memory = self.memory_calculator.get_pruned_graph_memory(masks=self.mask, fw_impl=self.fw_impl)
current_memory = self.memory_calculator.get_pruned_graph_memory(masks=self.mask,
fw_impl=self.fw_impl,
include_null_channels=self.tpc.is_simd_padding())

# If the target memory is exceeded, revert the last addition.
if current_memory > self.target_kpi.weights_memory:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import numpy as np
from typing import List, Dict

from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
from model_compression_toolkit.core.common import BaseNode, Graph
from model_compression_toolkit.core.common.pruning.pruning_framework_implementation import \
PruningFrameworkImplementation
from model_compression_toolkit.core.common.pruning.pruning_section import PruningSection, PruningSectionMask


Expand All @@ -17,7 +18,7 @@ class MemoryCalculator:
def __init__(self,
graph: Graph,
fw_info: FrameworkInfo,
fw_impl: FrameworkImplementation):
fw_impl: PruningFrameworkImplementation):
"""
Initialize the MemoryCalculator.
Expand All @@ -33,7 +34,7 @@ def __init__(self,
def get_pruned_graph_memory(self,
masks: Dict[BaseNode, np.ndarray],
fw_impl,
include_null_channels: bool = True) -> float:
include_null_channels: bool) -> int:

# Total number of parameters after pruning
total_nparams = 0.0
Expand Down Expand Up @@ -131,7 +132,7 @@ def get_nparams_of_nonpruned_nodes(self, pruning_sections, include_null_channels
for n in self.graph.nodes:
if n not in nodes_to_prune:
node_nparams = sum(n.get_num_parameters(self.fw_info))
if include_null_channels:
if include_null_channels: # TODO: rename to simd channels padding
num_oc = n.output_shape[-1]
nparams_per_oc = node_nparams/num_oc
num_oc_include_null_channels = np.ceil(num_oc/n.get_simd())*n.get_simd()
Expand Down
19 changes: 10 additions & 9 deletions model_compression_toolkit/core/common/pruning/prune_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def build_pruned_graph(graph: Graph,
fw_info: FrameworkInfo,
fw_impl: FrameworkImplementation) -> Graph:
"""
Prunes the provided graph according to the given pruning masks.
Prunes the provided graph according to the given pruning output-channels masks.
Args:
graph: The original computational graph to be pruned.
Expand All @@ -29,19 +29,18 @@ def build_pruned_graph(graph: Graph,
# Create a deep copy of the graph to avoid modifying the original graph.
graph_to_prune = copy.deepcopy(graph)

# Get the prunable nodes and the corresponding pruning sections.
prunable_nodes = graph_to_prune.get_pruning_sections_entry_nodes(fw_info=fw_info, fw_impl=fw_impl)
# Get the pruning sections.
pruning_sections = graph_to_prune.get_pruning_sections(fw_info=fw_info, fw_impl=fw_impl)

# Check that each prunable node corresponds to a pruning section.
assert len(pruning_sections) == len(prunable_nodes)
# Check that each entry node corresponds to a pruning section has an output-channel mask.
assert len(pruning_sections) == len(masks)

# Apply the pruning masks to each pruning section.
for input_node, pruning_section in zip(prunable_nodes, pruning_sections):
for pruning_section in pruning_sections:

# Retrieve the corresponding mask using the node's name (since we use a graph's copy).
mask = [v for k, v in masks.items() if k.name == input_node.name]
assert len(mask) == 1, f"Expected to find a single node with name {input_node.name} in masks dictionary but found {len(mask)}"
mask = [v for k, v in masks.items() if k.name == pruning_section.entry_node.name]
assert len(mask) == 1, f"Expected to find a single node with name {pruning_section.entry_node.name} in masks dictionary but found {len(mask)}"
mask = mask[0]

# If the mask indicates that some channels are to be pruned, apply it.
Expand All @@ -50,7 +49,9 @@ def build_pruned_graph(graph: Graph,
entry_output_mask=mask,
exit_input_mask=mask,
exit_output_mask=None)
pruning_section.apply_inner_section_mask(section_mask, fw_impl, fw_info)
pruning_section.apply_inner_section_mask(section_mask,
fw_impl,
fw_info)

# Return the pruned graph.
return graph_to_prune
Expand Down
3 changes: 2 additions & 1 deletion model_compression_toolkit/core/common/pruning/pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def get_pruned_graph(self) -> Graph:
self.scores,
self.target_kpi,
self.float_graph,
self.fw_impl)
self.fw_impl,
self.target_platform_capabilities)

# Calculate the mask that will be used to prune the graph.
self.mask = mask_calculator.get_mask()
Expand Down
53 changes: 42 additions & 11 deletions model_compression_toolkit/core/common/pruning/pruning_section.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,71 @@

from model_compression_toolkit.core.common.graph.base_node import BaseNode


class PruningSectionMask:
"""
Represents the masks to be applied to a pruning section of a neural network.
This includes masks for both input and output channels at the entry and exit nodes of the section.
Attributes:
entry_input_mask (np.ndarray): Mask for input channels of the entry node.
entry_output_mask (np.ndarray): Mask for output channels of the entry node.
exit_input_mask (np.ndarray): Mask for input channels of the exit node.
exit_output_mask (np.ndarray): Mask for output channels of the exit node.
"""

def __init__(self,
entry_input_mask: np.ndarray,
entry_input_mask: np.ndarray, # TODO:entry_node_ic_mask
entry_output_mask: np.ndarray,
exit_input_mask: np.ndarray,
exit_output_mask: np.ndarray):

self.entry_input_mask = entry_input_mask
self.entry_output_mask = entry_output_mask
self.exit_input_mask = exit_input_mask
self.exit_output_mask = exit_output_mask


class PruningSection:
"""
Represents a section in a graph to be pruned, consisting of an entry node,
intermediate nodes, and an exit node.
Attributes:
entry_node (BaseNode): The first node in the pruning section.
intermediate_nodes (List[BaseNode]): List of nodes between the entry and exit nodes.
exit_node (BaseNode): The last node in the pruning section.
"""

def __init__(self,
entry_node:BaseNode,
entry_node: BaseNode,
intermediate_nodes: List[BaseNode],
exit_nodes: BaseNode):
exit_node: BaseNode):
self.entry_node = entry_node
self.intermediate_nodes = intermediate_nodes
self.exit_node = exit_nodes
self.exit_node = exit_node

def get_all_nodes(self) -> List[BaseNode]:
"""
Returns a list of all nodes in the pruning section, including the entry,
intermediate, and exit nodes.
def get_all_nodes(self):
nodes = [self.entry_node]
nodes.extend(self.intermediate_nodes)
nodes.append(self.exit_node)
Returns:
List[BaseNode]: List of all nodes in the pruning section.
"""
nodes = [self.entry_node] + self.intermediate_nodes + [self.exit_node]
return nodes

def apply_inner_section_mask(self,
pruning_section_mask: PruningSectionMask,
fw_impl,
fw_info):
"""
Apply the provided pruning section mask to all nodes within the pruning section.
Args:
pruning_section_mask (PruningSectionMask): The mask to be applied to the pruning section.
fw_impl: Framework-specific implementation for applying the mask.
fw_info: Framework-specific information needed to apply the mask.
"""
fw_impl.prune_entry_node(node=self.entry_node,
output_mask=pruning_section_mask.entry_output_mask,
fw_info=fw_info)
Expand All @@ -51,4 +83,3 @@ def apply_inner_section_mask(self,
input_mask=pruning_section_mask.exit_input_mask,
fw_info=fw_info)


Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO


# TODO: Rethink if it's should be common or fw-specific
def is_keras_entry_node(node: BaseNode):
"""
Expand All @@ -15,7 +16,7 @@ def is_keras_entry_node(node: BaseNode):
return _is_keras_node_pruning_section_edge(node)


def is_keras_exit_node(node: BaseNode, dual_entry_node: BaseNode):
def is_keras_exit_node(node: BaseNode, match_entry_node: BaseNode):
"""
Args:
Expand All @@ -24,7 +25,7 @@ def is_keras_exit_node(node: BaseNode, dual_entry_node: BaseNode):
Returns:
"""
return _is_keras_node_pruning_section_edge(node) and _is_same_channels(node, dual_entry_node)
return _is_keras_node_pruning_section_edge(node) and _is_same_channels(node, match_entry_node)


def _is_same_channels(exit_node: BaseNode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_keras_pruned_node_num_params(node: BaseNode,
# For non-kernel operations, apply the output mask to the last axis.
# This part assumes that for non-kernel ops, all weights output channel axis is -1.
for w_attr, w in node.weights.items():
pruned_w = np.take(w, np.where(output_mask)[0], axis=-1)
pruned_w = np.take(w, np.where(output_mask)[0], axis=-1) # TODO: get axis from fw-specific function
total_params += pruned_w.size

if include_null_channels:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(self,
f'Default QuantizationConfigOptions must contain only one option'
self.default_qco = default_qco
self.fusing_patterns = []
self.is_simd_padding=None

def get_config_options_by_operators_set(self,
operators_set_name: str) -> QuantizationConfigOptions:
Expand Down Expand Up @@ -232,3 +233,14 @@ def set_quantization_format(self,
quantization_format: A quantization format (fake-quant, int8 etc.) from enum QuantizationFormat.
"""
self.quantization_format = quantization_format

def set_simd_padding(self, is_simd_padding: bool):
"""
Args:
is_simd_padding:
Returns:
"""
self.is_simd_padding = is_simd_padding
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,6 @@ def raise_warnings(self):
"""
for op in self.__tp_model_opsets_not_used:
Logger.warning(f'{op} is defined in TargetPlatformModel, but is not used in TargetPlatformCapabilities.')

def is_simd_padding(self):
return self.tp_model.is_simd_padding
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza
fixed_zero_point=None,
weights_multiplier_nbits=None,
simd_size=32)
# TODO: split op pruning config from op q config
# TODO: create pruning candidate config (or at least remove/disable all quantization configration from nodes)

# To quantize a model using mixed-precision, create
# a list with more than one OpQuantizationConfig.
Expand Down Expand Up @@ -123,6 +125,8 @@ def generate_tp_model(default_config: OpQuantizationConfig,
# Set quantization format to fakely quant
generated_tpc.set_quantization_format(QuantizationFormat.FAKELY_QUANT)

generated_tpc.set_simd_padding(is_simd_padding=True)

# May suit for operations like: Dropout, Reshape, etc.
tp.OperatorsSet("NoQuantization",
tp.get_default_quantization_config_options().clone_and_edit(
Expand Down

0 comments on commit 0e69771

Please sign in to comment.