Skip to content

Commit

Permalink
check bug effect
Browse files Browse the repository at this point in the history
  • Loading branch information
yardeny-sony committed Sep 25, 2024
1 parent 282edd5 commit c2ab94d
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions tests/pytorch_tests/model_tests/base_pytorch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,11 @@ def compare(self, quantized_models, float_model, input_x=None, quantization_info
# Check if we have a BatchNorm or MultiheadAttention layer in the model.
# If so, the outputs will not be the same, since the sqrt function in the
# Decomposition is not exactly like the sqrt in the C implementation of PyTorch.
float_model_operators = [type(module) for name, module in float_model.named_modules()]
if torch.nn.BatchNorm2d or torch.nn.MultiheadAttention in float_model_operators\
or self.use_fuzzy_validation: # todo: add flag to batch norm and MHA
self.unit_test.assertTrue(np.all(np.isclose(torch_tensor_to_numpy(f), torch_tensor_to_numpy(q),
# float_model_operators = [type(module) for name, module in float_model.named_modules()]
# if torch.nn.BatchNorm2d or torch.nn.MultiheadAttention in float_model_operators\
# or self.use_fuzzy_validation: # todo: add flag to batch norm and MHA
if torch.nn.BatchNorm2d or torch.nn.MultiheadAttention in [type(module) for name, module in float_model.named_modules()]:
self.unit_test.assertTrue(np.all(np.isclose(torch_tensor_to_numpy(f), torch_tensor_to_numpy(q),
atol=self.float_reconstruction_error)))
else:
self.unit_test.assertTrue(torch_tensor_to_numpy(torch.sum(torch.abs(f - q))) == 0,
Expand Down

0 comments on commit c2ab94d

Please sign in to comment.