Skip to content

Commit

Permalink
Hessian-based score w.r.t weights (Pytorch) (#836)
Browse files Browse the repository at this point in the history
* add computation of hessian based score w.r.t node weights (Pytorch)
  • Loading branch information
eladc-git authored Oct 30, 2023
1 parent df2e499 commit d6db398
Show file tree
Hide file tree
Showing 9 changed files with 419 additions and 31 deletions.
3 changes: 2 additions & 1 deletion model_compression_toolkit/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,5 @@

# Hessian configuration default constants
HESSIAN_OUTPUT_ALPHA = 0.3
HESSIAN_NUM_ITERATIONS = 50
HESSIAN_NUM_ITERATIONS = 50
HESSIAN_EPS = 1e-6
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,14 @@ def compute(self) -> List[float]:
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement compute method.') # pragma: no cover

def _unfold_tensors_list(self, tensors_to_unfold: Any) -> List[Any]:
@staticmethod
def unfold_tensors_list(tensors_to_unfold: Any) -> List[Any]:
"""
Unfold (flatten) a nested tensors list.
Given a mixed list of single tensors and nested tensor lists,
this method returns a flattened list where nested lists are expanded.
Args:
tensors_to_unfold: Tensors to unfold.
Returns:
A flattened list of tensors.
"""
Expand All @@ -95,4 +93,4 @@ def _unfold_tensors_list(self, tensors_to_unfold: Any) -> List[Any]:
unfold_tensors += tensor
else:
unfold_tensors.append(tensor)
return unfold_tensors
return unfold_tensors
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _concat_tensors(self, tensors_to_concate: Union[tf.Tensor, List[tf.Tensor]])
tf.Tensor of the concatenation of the tensors.
"""
_unfold_tensors = self._unfold_tensors_list(tensors_to_concate)
_unfold_tensors = self.unfold_tensors_list(tensors_to_concate)
_r_tensors = [tf.reshape(tensor, shape=[tensor.shape[0], -1]) for tensor in _unfold_tensors]

# Ensure all tensors have the same shape for concatenation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import tensorflow as tf
from typing import List

from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS, MIN_JACOBIANS_ITER, JACOBIANS_COMP_TOLERANCE
from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS, MIN_JACOBIANS_ITER, JACOBIANS_COMP_TOLERANCE, HESSIAN_EPS
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.hessian import TraceHessianRequest, HessianInfoGranularity
from model_compression_toolkit.core.keras.back2framework.float_model_builder import FloatKerasModelBuilder
Expand Down Expand Up @@ -124,7 +124,7 @@ def compute(self) -> np.ndarray:
# Compute new means and deltas
new_mean = tf.reduce_mean(tf.stack(approximation_per_iteration + approx), axis=0)
delta = new_mean - tf.reduce_mean(tf.stack(approximation_per_iteration), axis=0)
is_converged = np.all(np.abs(delta) / (np.abs(new_mean) + 1e-6) < JACOBIANS_COMP_TOLERANCE)
is_converged = np.all(np.abs(delta) / (np.abs(new_mean) + HESSIAN_EPS) < JACOBIANS_COMP_TOLERANCE)
if is_converged:
approximation_per_iteration.append(approx)
break
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
# limitations under the License.
# ==============================================================================

from typing import List, Tuple, Dict, Any
from typing import List

from torch import autograd
from tqdm import tqdm
import numpy as np

from model_compression_toolkit.constants import MIN_JACOBIANS_ITER, JACOBIANS_COMP_TOLERANCE, HESSIAN_NUM_ITERATIONS
from model_compression_toolkit.core.common import Graph
Expand Down Expand Up @@ -76,7 +75,7 @@ def compute(self) -> List[float]:

# Concat outputs
# First, we need to unfold all outputs that are given as list, to extract the actual output tensors
output = self._concat_tensors(output_tensors)
output = self.concat_tensors(output_tensors)

ipts_jac_trace_approx = []
for ipt in tqdm(model_grads_net.interest_points_tensors): # Per Interest point activation tensor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,14 @@
# limitations under the License.
# ==============================================================================

from typing import Dict, List, Union
from typing import Union, List

from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.hessian import TraceHessianRequest
from model_compression_toolkit.core.common.hessian.trace_hessian_calculator import TraceHessianCalculator

import torch

from model_compression_toolkit.logger import Logger
import torch


class TraceHessianCalculatorPytorch(TraceHessianCalculator):
Expand Down Expand Up @@ -52,18 +50,16 @@ def __init__(self,
trace_hessian_request=trace_hessian_request,
num_iterations_for_approximation=num_iterations_for_approximation)

def _concat_tensors(self, tensors_to_concate: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor:

def concat_tensors(self, tensors_to_concate: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor:
"""
Concatenate model tensors into a single tensor.
Args:
tensors_to_concate: Tensors to concatenate.
Returns:
torch.Tensor of the concatenation of tensors.
"""
_unfold_tensors = self._unfold_tensors_list(tensors_to_concate)
_unfold_tensors = self.unfold_tensors_list(tensors_to_concate)
_r_tensors = [torch.reshape(tensor, shape=[tensor.shape[0], -1]) for tensor in _unfold_tensors]

concat_axis_dim = [o.shape[0] for o in _r_tensors]
Expand All @@ -73,4 +69,3 @@ def _concat_tensors(self, tensors_to_concate: Union[torch.Tensor, List[torch.Ten
"is not equal in all outputs.")

return torch.concat(_r_tensors, dim=1)

Original file line number Diff line number Diff line change
Expand Up @@ -14,37 +14,119 @@
# ==============================================================================

from typing import List


import torch
from torch import autograd
import numpy as np
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.hessian import TraceHessianRequest
from model_compression_toolkit.core.common.hessian import TraceHessianRequest, HessianInfoGranularity
from model_compression_toolkit.core.pytorch.hessian.trace_hessian_calculator_pytorch import \
TraceHessianCalculatorPytorch
from model_compression_toolkit.logger import Logger
import torch
from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS, MIN_JACOBIANS_ITER, JACOBIANS_COMP_TOLERANCE, HESSIAN_EPS


class WeightsTraceHessianCalculatorPytorch(TraceHessianCalculatorPytorch):
"""
Pytorch-specific implementation of the Trace Hessian approximation computation w.r.t to a node's weights.
Pytorch-specific implementation of the Trace Hessian approximation computation w.r.t node's weights.
"""

def __init__(self,
graph: Graph,
input_images: List[torch.Tensor],
fw_impl,
trace_hessian_request: TraceHessianRequest):
trace_hessian_request: TraceHessianRequest,
num_iterations_for_approximation: int = HESSIAN_NUM_ITERATIONS):
"""
Args:
graph: Computational graph for the float model.
input_images: List of input images for the computation.
fw_impl: Framework-specific implementation for trace Hessian computation.
trace_hessian_request: Configuration request for which to compute the trace Hessian approximation.
num_iterations_for_approximation: Number of iterations to use when approximating the Hessian trace.
"""
super(WeightsTraceHessianCalculatorPytorch, self).__init__(graph=graph,
input_images=input_images,
fw_impl=fw_impl,
trace_hessian_request=trace_hessian_request)
trace_hessian_request=trace_hessian_request,
num_iterations_for_approximation=num_iterations_for_approximation)


def compute(self) -> np.ndarray:
"""
Compute the Hessian-based scores w.r.t target node's weights.
The computed scores are returned in a numpy array. The shape of the result differs
according to the requested granularity. If for example the node is Conv2D with a kernel
shape of (2, 3, 3, 3) (namely, 3 input channels, 2 output channels and kernel size of 3x3)
and the required granularity is HessianInfoGranularity.PER_TENSOR the result shape will be (1,),
for HessianInfoGranularity.PER_OUTPUT_CHANNEL the shape will be (2,) and for
HessianInfoGranularity.PER_ELEMENT a shape of (2, 3, 3, 3).
Returns:
The computed scores as numpy ndarray for target node's weights.
"""

# Check if the target node's layer type is supported
if not DEFAULT_PYTORCH_INFO.is_kernel_op(self.hessian_request.target_node.type):
Logger.error(f"{self.hessian_request.target_node.type} is not supported for Hessian info w.r.t weights.") # pragma: no cover

# Float model
model, _ = FloatPyTorchModelBuilder(graph=self.graph).build_model()

# Get the weight attributes for the target node type
weights_attributes = DEFAULT_PYTORCH_INFO.get_kernel_op_attributes(self.hessian_request.target_node.type)

# Get the weight tensor for the target node
if len(weights_attributes) != 1:
Logger.error(f"Hessian scores w.r.t weights is supported, for now, for a single-weight node. Found {len(weights_attributes)}")

weights_tensor = getattr(getattr(model,self.hessian_request.target_node.name),weights_attributes[0])

# Get the output channel index
output_channel_axis, _ = DEFAULT_PYTORCH_INFO.kernel_channels_mapping.get(self.hessian_request.target_node.type)
shape_channel_axis = [i for i in range(len(weights_tensor.shape))]
if self.hessian_request.granularity == HessianInfoGranularity.PER_OUTPUT_CHANNEL:
shape_channel_axis.remove(output_channel_axis)
elif self.hessian_request.granularity == HessianInfoGranularity.PER_ELEMENT:
shape_channel_axis = ()

# Run model inference
outputs = model(self.input_images)
output_tensor = self.concat_tensors(outputs)
device = output_tensor.device

approximation_per_iteration = []
for j in range(self.num_iterations_for_approximation):
# Getting a random vector with normal distribution and the same shape as the model output
v = torch.randn_like(output_tensor, device=device)
f_v = torch.mean(torch.sum(v * output_tensor, dim=-1))
# Compute gradients of f_v with respect to the weights
f_v_grad = autograd.grad(outputs=f_v,
inputs=weights_tensor,
retain_graph=True)[0]

# Trace{A^T * A} = sum of all squares values of A
approx = f_v_grad ** 2
if len(shape_channel_axis) > 0:
approx = torch.sum(approx, dim=shape_channel_axis)

if j > MIN_JACOBIANS_ITER:
new_mean = (torch.sum(torch.stack(approximation_per_iteration), dim=0) + approx)/(j+1)
delta = new_mean - torch.mean(torch.stack(approximation_per_iteration), dim=0)
converged_tensor = torch.abs(delta) / (torch.abs(new_mean) + HESSIAN_EPS) < JACOBIANS_COMP_TOLERANCE
if torch.all(converged_tensor):
break

approximation_per_iteration.append(approx)

# Compute the mean of the approximations
final_approx = torch.mean(torch.stack(approximation_per_iteration), dim=0)

# Make sure all final shape are tensors and not scalar
if self.hessian_request.granularity == HessianInfoGranularity.PER_TENSOR:
final_approx = final_approx.reshape(1)

return final_approx.detach().cpu().numpy()

def compute(self):
Logger.error(f"Hessian trace approx w.r.t weights is not supported for now")
11 changes: 11 additions & 0 deletions tests/pytorch_tests/function_tests/test_function_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
TestSetLayerToBitwidthActivation
from tests.pytorch_tests.function_tests.test_sensitivity_eval_output_replacement import \
TestSensitivityEvalWithArgmaxOutputReplacementNodes, TestSensitivityEvalWithSoftmaxOutputReplacementNodes
from tests.pytorch_tests.function_tests.test_hessian_info_weights import WeightsHessianTraceBasicModelTest, WeightsHessianTraceAdvanceModelTest, \
WeightsHessianTraceMultipleOutputsModelTest, WeightsHessianTraceReuseModelTest


class FunctionTestRunner(unittest.TestCase):
Expand Down Expand Up @@ -112,6 +114,15 @@ def test_model_gradients(self):
ModelGradientsNonDifferentiableNodeModelTest(self).run_test()
ModelGradientsSinglePointTest(self).run_test()

def test_weights_hessian_trace(self):
"""
This test checks the weighes hessian trace approximation in Pytorch.
"""
WeightsHessianTraceBasicModelTest(self).run_test()
WeightsHessianTraceAdvanceModelTest(self).run_test()
WeightsHessianTraceMultipleOutputsModelTest(self).run_test()
WeightsHessianTraceReuseModelTest(self).run_test()

def test_layer_fusing(self):
"""
This test checks the Fusion mechanism in Pytorch.
Expand Down
Loading

0 comments on commit d6db398

Please sign in to comment.