From 510ffb3de0e53071e29e0f2d430f976678fce60b Mon Sep 17 00:00:00 2001 From: knikolaou Date: Fri, 10 May 2024 21:30:49 +0200 Subject: [PATCH] Add test for neural state keys --- .../measurements/test_base_measurement.py | 37 ++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/CI/unit_tests/measurements/test_base_measurement.py b/CI/unit_tests/measurements/test_base_measurement.py index f6c7866..48cbc07 100644 --- a/CI/unit_tests/measurements/test_base_measurement.py +++ b/CI/unit_tests/measurements/test_base_measurement.py @@ -30,6 +30,22 @@ from papyrus.measurements.base_measurement import BaseMeasurement +class DummyMeasurement(BaseMeasurement): + """ + Dummy measurement class for testing. + """ + + def __init__(self, name: str, rank: int, public: bool = False): + super().__init__(name, rank, public) + + def apply( + self, a: np.ndarray, b: np.ndarray, c: Optional[np.ndarray] = None + ) -> np.ndarray: + if c is not None: + return a + b + c + return a + b + + class TestBaseMeasurement: """ Test the base measurement class. @@ -79,7 +95,7 @@ def apply( c = np.array([[9, 10], [11, 12]]) # Test the call method with only arguments - measurement.apply = apply + measurement = DummyMeasurement(name, rank, public) result = measurement(a, b) assert np.allclose(result, a + b) result = measurement(a, b, c) @@ -97,3 +113,22 @@ def apply( result = measurement(a, b=b) print(result) assert np.allclose(result, a + b) + + def test_apply(self): + """ + Test the apply method of the BaseMeasurement class. + """ + # Create an example apply method and initialize the measurement + + def apply( + a: np.ndarray, b: np.ndarray, c: Optional[np.ndarray] = None + ) -> np.ndarray: + if c is not None: + return a + b + c + return a + b + + name = "test" + rank = 1 + public = False + measurement = DummyMeasurement(name, rank, public) + assert measurement.neural_state_keys == ["a", "b", "c"]