From 1736b7f1d00902bb3f6181658bef7d2c34d96c46 Mon Sep 17 00:00:00 2001 From: knikolaou Date: Wed, 15 May 2024 21:55:12 +0200 Subject: [PATCH 1/3] remove unused parts from test_base_measurement --- .../measurements/test_base_measurement.py | 28 +------------------ 1 file changed, 1 insertion(+), 27 deletions(-) diff --git a/CI/unit_tests/measurements/test_base_measurement.py b/CI/unit_tests/measurements/test_base_measurement.py index 1004404..2b24e9f 100644 --- a/CI/unit_tests/measurements/test_base_measurement.py +++ b/CI/unit_tests/measurements/test_base_measurement.py @@ -82,14 +82,7 @@ def test_call(self): with pytest.raises(NotImplementedError): measurement.apply() - # Set an exmaple apply method - def apply( - a: np.ndarray, b: np.ndarray, c: Optional[np.ndarray] = None - ) -> np.ndarray: - if c is not None: - return a + b + c - return a + b - + # Test the apply method a = np.array([[1, 2], [3, 4]]) b = np.array([[5, 6], [7, 8]]) c = np.array([[9, 10], [11, 12]]) @@ -113,22 +106,3 @@ def apply( result = measurement(a, b=b) print(result) assert np.allclose(result, a + b) - - def test_apply(self): - """ - Test the apply method of the BaseMeasurement class. - """ - # Create an example apply method and initialize the measurement - - def apply( - a: np.ndarray, b: np.ndarray, c: Optional[np.ndarray] = None - ) -> np.ndarray: - if c is not None: - return a + b + c - return a + b - - name = "test" - rank = 1 - public = False - measurement = DummyMeasurement(name, rank, public) - assert measurement.neural_state_keys == ["a", "b", "c"] From 658fe50756d63fc577187a558019c297908c5e23 Mon Sep 17 00:00:00 2001 From: knikolaou Date: Wed, 15 May 2024 21:55:38 +0200 Subject: [PATCH 2/3] remove print statement from test_base_measurement --- CI/unit_tests/measurements/test_base_measurement.py | 1 - 1 file changed, 1 deletion(-) diff --git a/CI/unit_tests/measurements/test_base_measurement.py b/CI/unit_tests/measurements/test_base_measurement.py index 2b24e9f..abb64f8 100644 --- a/CI/unit_tests/measurements/test_base_measurement.py +++ b/CI/unit_tests/measurements/test_base_measurement.py @@ -104,5 +104,4 @@ def test_call(self): result = measurement(a, b, c=c) assert np.allclose(result, a + b + c) result = measurement(a, b=b) - print(result) assert np.allclose(result, a + b) From e387616fa6396d144cf80ae31c1338d18904d6ea Mon Sep 17 00:00:00 2001 From: knikolaou Date: Wed, 15 May 2024 22:20:15 +0200 Subject: [PATCH 3/3] Add shape check to BaseMeasurement class --- .../measurements/test_base_measurement.py | 14 ++++++++++++++ papyrus/measurements/base_measurement.py | 12 ++++++++++++ 2 files changed, 26 insertions(+) diff --git a/CI/unit_tests/measurements/test_base_measurement.py b/CI/unit_tests/measurements/test_base_measurement.py index abb64f8..a051042 100644 --- a/CI/unit_tests/measurements/test_base_measurement.py +++ b/CI/unit_tests/measurements/test_base_measurement.py @@ -26,6 +26,7 @@ import numpy as np import pytest +from numpy.testing import assert_raises from papyrus.measurements import BaseMeasurement @@ -105,3 +106,16 @@ def test_call(self): assert np.allclose(result, a + b + c) result = measurement(a, b=b) assert np.allclose(result, a + b) + + # Test error handling for wrong size of arguments + a = np.array([1, 2, 3]) + b = np.array([[4, 5, 6], [7, 8, 9]]) + c = np.array([[10, 11, 12], [13, 14, 15]]) + with assert_raises(ValueError): + measurement(a, b, c) + with assert_raises(ValueError): + measurement(a=a, b=b, c=c) + with assert_raises(ValueError): + measurement(a, b=b, c=c) + with assert_raises(ValueError): + measurement(a, b, c=c) diff --git a/papyrus/measurements/base_measurement.py b/papyrus/measurements/base_measurement.py index e9bc8d4..aa14f35 100644 --- a/papyrus/measurements/base_measurement.py +++ b/papyrus/measurements/base_measurement.py @@ -130,9 +130,21 @@ def __call__(self, *args: np.ndarray, **kwargs: np.ndarray) -> np.ndarray: """ # Get the number of arguments num_args = len(args) + # Get the keys and values of the keyword arguments if any keys = list(kwargs.keys()) vals = list(kwargs.values()) + + # Assert whether the length of dimension 0 of all inputs is the same + try: + inputs = args + tuple(vals) + assert all([len(i) == len(inputs[0]) for i in inputs]) + except AssertionError: + raise ValueError( + f"The first dimension of all inputs to the {self.name} measurement " + "must be the same." + ) + # Zip the arguments and values z = zip(*args, *vals)