Skip to content

Commit

Permalink
Make Pytorch Hessian computation tests run on the advanced models (#1078
Browse files Browse the repository at this point in the history
)

Co-authored-by: Ofir Gordon <[email protected]>
  • Loading branch information
ofirgo and Ofir Gordon authored May 20, 2024
1 parent f7de840 commit 3236a53
Showing 1 changed file with 21 additions and 15 deletions.
36 changes: 21 additions & 15 deletions tests/pytorch_tests/function_tests/test_hessian_info_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(self):
self.relu1 = ReLU()
self.conv2 = Conv2d(3, 3, kernel_size=1, stride=1)
self.bn2 = BatchNorm2d(3)
self.relu2 = ReLU()
self.hswish = Hardswish()
self.dense = Linear(8, 7)

Expand All @@ -84,6 +85,7 @@ def forward(self, inp):
x1 = self.relu1(x)
x2 = self.conv2(x1)
x2 = self.bn2(x2)
x2 = self.relu2(x2)
x3 = self.hswish(x2)
x3 = self.dense(x3)
return x1, x2, x3
Expand Down Expand Up @@ -129,9 +131,10 @@ def get_expected_shape(t_shape, granularity):

class BaseHessianTraceBasicModelTest(BasePytorchTest):

def __init__(self, unit_test):
def __init__(self, unit_test, model):
super().__init__(unit_test)
self.val_batch_size = 1
self.model = model

def create_inputs_shape(self):
return [[self.val_batch_size, 3, 8, 8]]
Expand Down Expand Up @@ -165,7 +168,7 @@ def test_hessian_trace_approx(self,
f"Tensor shape is expected to be {expected_shape} but has shape {score.shape}")

def _setup(self):
model_float = basic_model()
model_float = self.model()
pytorch_impl = PytorchImplementation()
graph = prepare_graph_with_configs(model_float, PytorchImplementation(), DEFAULT_PYTORCH_INFO,
self.representative_data_gen, generate_pytorch_tpc)
Expand All @@ -175,7 +178,7 @@ def _setup(self):

class WeightsHessianTraceBasicModelTest(BaseHessianTraceBasicModelTest):
def __init__(self, unit_test):
super().__init__(unit_test)
super().__init__(unit_test, model=basic_model)
self.val_batch_size = 1

def run_test(self, seed=0):
Expand All @@ -201,7 +204,7 @@ def run_test(self, seed=0):

class WeightsHessianTraceAdvanceModelTest(BaseHessianTraceBasicModelTest):
def __init__(self, unit_test):
super().__init__(unit_test)
super().__init__(unit_test, model=advanced_model)
self.val_batch_size = 2

def run_test(self, seed=0):
Expand Down Expand Up @@ -230,7 +233,7 @@ def run_test(self, seed=0):

class WeightsHessianTraceMultipleOutputsModelTest(BaseHessianTraceBasicModelTest):
def __init__(self, unit_test):
super().__init__(unit_test)
super().__init__(unit_test, model=multiple_outputs_model)
self.val_batch_size = 1

def run_test(self, seed=0):
Expand Down Expand Up @@ -259,7 +262,7 @@ def run_test(self, seed=0):

class WeightsHessianTraceReuseModelTest(BaseHessianTraceBasicModelTest):
def __init__(self, unit_test):
super().__init__(unit_test)
super().__init__(unit_test, model=reused_model)
self.val_batch_size = 1

def run_test(self, seed=0):
Expand Down Expand Up @@ -288,7 +291,7 @@ def run_test(self, seed=0):

class ActivationHessianTraceBasicModelTest(BaseHessianTraceBasicModelTest):
def __init__(self, unit_test):
super().__init__(unit_test)
super().__init__(unit_test, model=basic_model)
self.val_batch_size = 1

def run_test(self, seed=0):
Expand All @@ -306,15 +309,17 @@ def run_test(self, seed=0):

class ActivationHessianTraceAdvanceModelTest(BaseHessianTraceBasicModelTest):
def __init__(self, unit_test):
super().__init__(unit_test)
super().__init__(unit_test, model=advanced_model)
self.val_batch_size = 2

def run_test(self, seed=0):
graph, pytorch_impl = self._setup()
hessian_service = hessian_common.HessianInfoService(graph=graph,
representative_dataset=self.representative_data_gen,
fw_impl=pytorch_impl)
ipts = [n for n in graph.get_topo_sorted_nodes() if len(n.weights) > 0]

# removing last layer cause we do not allow activation Hessian computation for the output layer
ipts = [n for n in graph.get_topo_sorted_nodes() if len(n.weights) > 0][:-1]
for ipt in ipts:
self.test_hessian_trace_approx(hessian_service,
interest_point=ipt,
Expand All @@ -325,15 +330,17 @@ def run_test(self, seed=0):

class ActivationHessianTraceMultipleOutputsModelTest(BaseHessianTraceBasicModelTest):
def __init__(self, unit_test):
super().__init__(unit_test)
super().__init__(unit_test, model=multiple_outputs_model)
self.val_batch_size = 1

def run_test(self, seed=0):
graph, pytorch_impl = self._setup()
hessian_service = hessian_common.HessianInfoService(graph=graph,
representative_dataset=self.representative_data_gen,
fw_impl=pytorch_impl)
ipts = [n for n in graph.get_topo_sorted_nodes() if len(n.weights) > 0]

# removing last layer cause we do not allow activation Hessian computation for the output layer
ipts = [n for n in graph.get_topo_sorted_nodes() if len(n.weights) > 0][:-1]
for ipt in ipts:
self.test_hessian_trace_approx(hessian_service,
interest_point=ipt,
Expand All @@ -344,14 +351,15 @@ def run_test(self, seed=0):

class ActivationHessianTraceReuseModelTest(BaseHessianTraceBasicModelTest):
def __init__(self, unit_test):
super().__init__(unit_test)
super().__init__(unit_test, model=reused_model)
self.val_batch_size = 1

def run_test(self, seed=0):
graph, pytorch_impl = self._setup()
hessian_service = hessian_common.HessianInfoService(graph=graph,
representative_dataset=self.representative_data_gen,
fw_impl=pytorch_impl)

ipts = [n for n in graph.get_topo_sorted_nodes() if len(n.weights) > 0]
for ipt in ipts:
self.test_hessian_trace_approx(hessian_service,
Expand All @@ -360,11 +368,9 @@ def run_test(self, seed=0):
granularity=hessian_common.HessianInfoGranularity.PER_TENSOR,
mode=hessian_common.HessianMode.ACTIVATION)



class ActivationHessianOutputExceptionTest(BaseHessianTraceBasicModelTest):
def __init__(self, unit_test):
super().__init__(unit_test)
super().__init__(unit_test, model=basic_model)
self.val_batch_size = 1

def run_test(self, seed=0):
Expand Down

0 comments on commit 3236a53

Please sign in to comment.