Skip to content

Commit

Permalink
Add test for neural state keys
Browse files Browse the repository at this point in the history
  • Loading branch information
knikolaou committed May 10, 2024
1 parent 4b63f0b commit 510ffb3
Showing 1 changed file with 36 additions and 1 deletion.
37 changes: 36 additions & 1 deletion CI/unit_tests/measurements/test_base_measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,22 @@
from papyrus.measurements.base_measurement import BaseMeasurement


class DummyMeasurement(BaseMeasurement):
"""
Dummy measurement class for testing.
"""

def __init__(self, name: str, rank: int, public: bool = False):
super().__init__(name, rank, public)

def apply(
self, a: np.ndarray, b: np.ndarray, c: Optional[np.ndarray] = None
) -> np.ndarray:
if c is not None:
return a + b + c
return a + b


class TestBaseMeasurement:
"""
Test the base measurement class.
Expand Down Expand Up @@ -79,7 +95,7 @@ def apply(
c = np.array([[9, 10], [11, 12]])

# Test the call method with only arguments
measurement.apply = apply
measurement = DummyMeasurement(name, rank, public)
result = measurement(a, b)
assert np.allclose(result, a + b)
result = measurement(a, b, c)
Expand All @@ -97,3 +113,22 @@ def apply(
result = measurement(a, b=b)
print(result)
assert np.allclose(result, a + b)

def test_apply(self):
"""
Test the apply method of the BaseMeasurement class.
"""
# Create an example apply method and initialize the measurement

def apply(
a: np.ndarray, b: np.ndarray, c: Optional[np.ndarray] = None
) -> np.ndarray:
if c is not None:
return a + b + c
return a + b

name = "test"
rank = 1
public = False
measurement = DummyMeasurement(name, rank, public)
assert measurement.neural_state_keys == ["a", "b", "c"]

0 comments on commit 510ffb3

Please sign in to comment.