diff --git a/CI/unit_tests/measurements/test_base_measurement.py b/CI/unit_tests/measurements/test_base_measurement.py index 1004404..a051042 100644 --- a/CI/unit_tests/measurements/test_base_measurement.py +++ b/CI/unit_tests/measurements/test_base_measurement.py @@ -26,6 +26,7 @@ import numpy as np import pytest +from numpy.testing import assert_raises from papyrus.measurements import BaseMeasurement @@ -82,14 +83,7 @@ def test_call(self): with pytest.raises(NotImplementedError): measurement.apply() - # Set an exmaple apply method - def apply( - a: np.ndarray, b: np.ndarray, c: Optional[np.ndarray] = None - ) -> np.ndarray: - if c is not None: - return a + b + c - return a + b - + # Test the apply method a = np.array([[1, 2], [3, 4]]) b = np.array([[5, 6], [7, 8]]) c = np.array([[9, 10], [11, 12]]) @@ -111,24 +105,17 @@ def apply( result = measurement(a, b, c=c) assert np.allclose(result, a + b + c) result = measurement(a, b=b) - print(result) assert np.allclose(result, a + b) - def test_apply(self): - """ - Test the apply method of the BaseMeasurement class. - """ - # Create an example apply method and initialize the measurement - - def apply( - a: np.ndarray, b: np.ndarray, c: Optional[np.ndarray] = None - ) -> np.ndarray: - if c is not None: - return a + b + c - return a + b - - name = "test" - rank = 1 - public = False - measurement = DummyMeasurement(name, rank, public) - assert measurement.neural_state_keys == ["a", "b", "c"] + # Test error handling for wrong size of arguments + a = np.array([1, 2, 3]) + b = np.array([[4, 5, 6], [7, 8, 9]]) + c = np.array([[10, 11, 12], [13, 14, 15]]) + with assert_raises(ValueError): + measurement(a, b, c) + with assert_raises(ValueError): + measurement(a=a, b=b, c=c) + with assert_raises(ValueError): + measurement(a, b=b, c=c) + with assert_raises(ValueError): + measurement(a, b, c=c) diff --git a/papyrus/measurements/base_measurement.py b/papyrus/measurements/base_measurement.py index 4be24f6..fc1eda8 100644 --- a/papyrus/measurements/base_measurement.py +++ b/papyrus/measurements/base_measurement.py @@ -144,9 +144,21 @@ def __call__(self, *args: np.ndarray, **kwargs: np.ndarray) -> np.ndarray: """ # Get the number of arguments num_args = len(args) + # Get the keys and values of the keyword arguments if any keys = list(kwargs.keys()) vals = list(kwargs.values()) + + # Assert whether the length of dimension 0 of all inputs is the same + try: + inputs = args + tuple(vals) + assert all([len(i) == len(inputs[0]) for i in inputs]) + except AssertionError: + raise ValueError( + f"The first dimension of all inputs to the {self.name} measurement " + "must be the same." + ) + # Zip the arguments and values z = zip(*args, *vals)