Skip to content

Commit

Permalink
fix intermittently failing test
Browse files Browse the repository at this point in the history
  • Loading branch information
irenaby committed Oct 15, 2024
1 parent 898f499 commit 097f7f1
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions tests/pytorch_tests/function_tests/test_hessian_info_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
class basic_model(torch.nn.Module):
def __init__(self):
super(basic_model, self).__init__()
self.conv1 = Conv2d(3, 3, kernel_size=1, stride=1)
self.conv1 = Conv2d(3, 3, kernel_size=3, stride=1, padding=1)
self.bn = BatchNorm2d(3)
self.relu = ReLU()

Expand All @@ -50,10 +50,10 @@ def forward(self, inp):
class advanced_model(torch.nn.Module):
def __init__(self):
super(advanced_model, self).__init__()
self.conv1 = Conv2d(3, 3, kernel_size=1, stride=1)
self.conv1 = Conv2d(3, 3, kernel_size=3, stride=1, padding=1)
self.bn1 = BatchNorm2d(3)
self.relu1 = ReLU()
self.conv2 = Conv2d(3, 3, kernel_size=1, stride=1)
self.conv2 = Conv2d(3, 3, kernel_size=3, stride=1, padding=1)
self.bn2 = BatchNorm2d(3)
self.relu2 = ReLU()
self.dense = Linear(8, 7)
Expand All @@ -72,10 +72,10 @@ def forward(self, inp):
class multiple_outputs_model(torch.nn.Module):
def __init__(self):
super(multiple_outputs_model, self).__init__()
self.conv1 = Conv2d(3, 3, kernel_size=1, stride=1)
self.conv1 = Conv2d(3, 3, kernel_size=3, stride=1, padding=1)
self.bn1 = BatchNorm2d(3)
self.relu1 = ReLU()
self.conv2 = Conv2d(3, 3, kernel_size=1, stride=1)
self.conv2 = Conv2d(3, 3, kernel_size=3, stride=1, padding=1)
self.bn2 = BatchNorm2d(3)
self.relu2 = ReLU()
self.hswish = Hardswish()
Expand All @@ -96,8 +96,8 @@ def forward(self, inp):
class multiple_inputs_model(torch.nn.Module):
def __init__(self):
super(multiple_inputs_model, self).__init__()
self.conv1 = Conv2d(3, 3, kernel_size=1, stride=1)
self.conv2 = Conv2d(3, 3, kernel_size=1, stride=1)
self.conv1 = Conv2d(3, 3, kernel_size=3, stride=1, padding=1)
self.conv2 = Conv2d(3, 3, kernel_size=3, stride=1, padding=1)

def forward(self, inp1, inp2):
x1 = self.conv1(inp1)
Expand All @@ -108,7 +108,7 @@ def forward(self, inp1, inp2):
class reused_model(torch.nn.Module):
def __init__(self):
super(reused_model, self).__init__()
self.conv1 = Conv2d(3, 3, kernel_size=1, stride=1)
self.conv1 = Conv2d(3, 3, kernel_size=3, stride=1, padding=1)
self.bn1 = BatchNorm2d(3)
self.relu = ReLU()

Expand Down Expand Up @@ -138,7 +138,7 @@ def get_expected_shape(t_shape, granularity, n_samples):

class BaseHessianTraceBasicModelTest(BasePytorchTest):

def __init__(self, unit_test, model, n_iters=2):
def __init__(self, unit_test, model, n_iters=10):
super().__init__(unit_test)
self.val_batch_size = 1
self.model = model
Expand Down Expand Up @@ -241,7 +241,7 @@ def run_test(self, seed=0):

class WeightsHessianTraceAdvanceModelTest(BaseHessianTraceBasicModelTest):
def __init__(self, unit_test):
super().__init__(unit_test, model=advanced_model, n_iters=3)
super().__init__(unit_test, model=advanced_model)
self.val_batch_size = 2

def run_test(self, seed=0):
Expand All @@ -265,7 +265,7 @@ def run_test(self, seed=0):

class WeightsHessianTraceMultipleOutputsModelTest(BaseHessianTraceBasicModelTest):
def __init__(self, unit_test):
super().__init__(unit_test, model=multiple_outputs_model, n_iters=3)
super().__init__(unit_test, model=multiple_outputs_model)
self.val_batch_size = 1

def run_test(self, seed=0):
Expand All @@ -289,7 +289,7 @@ def run_test(self, seed=0):

class WeightsHessianTraceReuseModelTest(BaseHessianTraceBasicModelTest):
def __init__(self, unit_test):
super().__init__(unit_test, model=reused_model, n_iters=3)
super().__init__(unit_test, model=reused_model)
self.val_batch_size = 1

def run_test(self, seed=0):
Expand Down

0 comments on commit 097f7f1

Please sign in to comment.