Skip to content

Commit

Permalink
add bounded symmetric selection test and removed all symmetric select…
Browse files Browse the repository at this point in the history
…ion tests
  • Loading branch information
ofirgo committed Jan 28, 2025
1 parent 0ac2cd8 commit a29ca35
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 136 deletions.

This file was deleted.

16 changes: 0 additions & 16 deletions tests/keras_tests/feature_networks_tests/test_features_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@
from tests.keras_tests.feature_networks_tests.feature_networks.softmax_shift_test import SoftmaxShiftTest
from tests.keras_tests.feature_networks_tests.feature_networks.split_concatenate_test import SplitConcatenateTest
from tests.keras_tests.feature_networks_tests.feature_networks.split_conv_bug_test import SplitConvBugTest
from tests.keras_tests.feature_networks_tests.feature_networks.symmetric_threshold_selection_activation_test import \
SymmetricThresholdSelectionActivationTest, SymmetricThresholdSelectionBoundedActivationTest
from tests.keras_tests.feature_networks_tests.feature_networks.test_depthwise_conv2d_replacement import \
DwConv2dReplacementTest
from tests.keras_tests.feature_networks_tests.feature_networks.test_kmeans_quantizer import \
Expand Down Expand Up @@ -754,20 +752,6 @@ def test_gptq_conv_group_dilation(self):
def test_split_conv_bug(self):
SplitConvBugTest(self).run_test()

def test_symmetric_threshold_selection_activation(self):
SymmetricThresholdSelectionActivationTest(self, QuantizationErrorMethod.NOCLIPPING).run_test()
SymmetricThresholdSelectionActivationTest(self, QuantizationErrorMethod.MSE).run_test()
SymmetricThresholdSelectionActivationTest(self, QuantizationErrorMethod.MAE).run_test()
SymmetricThresholdSelectionActivationTest(self, QuantizationErrorMethod.LP).run_test()
SymmetricThresholdSelectionActivationTest(self, QuantizationErrorMethod.KL).run_test()

def test_symmetric_threshold_selection_softmax_activation(self):
SymmetricThresholdSelectionBoundedActivationTest(self, QuantizationErrorMethod.NOCLIPPING).run_test()
SymmetricThresholdSelectionBoundedActivationTest(self, QuantizationErrorMethod.MSE).run_test()
SymmetricThresholdSelectionBoundedActivationTest(self, QuantizationErrorMethod.MAE).run_test()
SymmetricThresholdSelectionBoundedActivationTest(self, QuantizationErrorMethod.LP).run_test()
SymmetricThresholdSelectionBoundedActivationTest(self, QuantizationErrorMethod.KL).run_test()

def test_uniform_range_selection_activation(self):
UniformRangeSelectionActivationTest(self, QuantizationErrorMethod.NOCLIPPING).run_test()
UniformRangeSelectionActivationTest(self, QuantizationErrorMethod.MSE).run_test()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from model_compression_toolkit.core.common.user_info import UserInformation
from model_compression_toolkit.core.keras.constants import KERNEL, DEPTHWISE_KERNEL
from model_compression_toolkit.ptq import keras_post_training_quantization
from model_compression_toolkit.target_platform_capabilities import AttributeQuantizationConfig, OpQuantizationConfig, \
Signedness
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
AttributeQuantizationConfig, Signedness
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR
from tests.common_tests.helpers.tpcs_for_tests.v4.tpc import generate_tpc

Expand Down Expand Up @@ -120,9 +120,10 @@ def _verify_weights_quantizer_params(quant_method, weights_quantizer, params_sha


class TestPostTrainingQuantizationApi:
# TODO:
# [a, w&a]
# extend to also test with different settings? (bc, snc, etc.)
# TODO: add tests for:
# 1) activation only, W&A, LUT quantizer (separate)
# 2) extend to also test with different settings features (bc, snc, etc.)
# 3) advanced models and operators


def _verify_quantized_model_structure(self, model, q_model, quantization_info):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,19 @@ def hist():
return count, bins


@pytest.fixture
def bounded_hist():
np.random.seed(42)
size = (32, 32, 3)
num_bins = 2048
x = np.random.uniform(-7, 7, size=size).flatten()
e_x = np.exp(x - np.max(x))
x = (e_x / e_x.sum()) + 1
count, bins = np.histogram(x, bins=num_bins)

return count, bins


err_methods_to_test = [e.name for e in QuantizationErrorMethod if e != QuantizationErrorMethod.HMSE]


Expand All @@ -48,4 +61,17 @@ def test_symmetric_threshold_selection(error_method, hist):
assert THRESHOLD in search_res
assert SIGNED in search_res
assert np.isclose(search_res[THRESHOLD], 7, atol=0.4)
assert search_res[SIGNED] is True
assert search_res[SIGNED]


@pytest.mark.parametrize("error_method", err_methods_to_test)
def test_symmetric_threshold_selection_bounded_activation(error_method, bounded_hist):
counts, bins = bounded_hist

search_res = symmetric_selection_histogram(bins, counts, 2, 8, Mock(), Mock(), Mock(), Mock(),
MIN_THRESHOLD, QuantizationErrorMethod[error_method], False)

assert THRESHOLD in search_res
assert SIGNED in search_res
assert np.isclose(search_res[THRESHOLD], 1, atol=0.4)
assert not search_res[SIGNED]

0 comments on commit a29ca35

Please sign in to comment.