Skip to content

Commit

Permalink
reduce multi head attention test runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
irenaby committed Jan 8, 2025
1 parent 72cadff commit 1254e6f
Showing 1 changed file with 17 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model
from tests.pytorch_tests.model_tests.base_pytorch_test import BasePytorchTest

"""
Expand Down Expand Up @@ -55,6 +56,22 @@ def create_inputs_shape(self):
[self.val_batch_size] + list(self.key_input_shape),
[self.val_batch_size] + list(self.value_input_shape)]

def get_tpc(self):
tpc = {
'no_quantization': generate_test_tp_model({
'weights_n_bits': 32,
'activation_n_bits': 32,
'enable_weights_quantization': False,
'enable_activation_quantization': False
})
}
if self.num_heads < 5:
tpc['all_4bit'] = generate_test_tp_model({'weights_n_bits': 4,
'activation_n_bits': 4,
'enable_weights_quantization': True,
'enable_activation_quantization': True})
return tpc


class MHANet(nn.Module):
# This network based on single MHA layer
Expand Down

0 comments on commit 1254e6f

Please sign in to comment.