Skip to content

Commit

Permalink
Merge branch 'Konsti_Measurements' into Konsti_Recorders
Browse files Browse the repository at this point in the history
  • Loading branch information
knikolaou committed May 15, 2024
2 parents a4f1c86 + e387616 commit 17d9d1b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 27 deletions.
41 changes: 14 additions & 27 deletions CI/unit_tests/measurements/test_base_measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import numpy as np
import pytest
from numpy.testing import assert_raises

from papyrus.measurements import BaseMeasurement

Expand Down Expand Up @@ -82,14 +83,7 @@ def test_call(self):
with pytest.raises(NotImplementedError):
measurement.apply()

# Set an exmaple apply method
def apply(
a: np.ndarray, b: np.ndarray, c: Optional[np.ndarray] = None
) -> np.ndarray:
if c is not None:
return a + b + c
return a + b

# Test the apply method
a = np.array([[1, 2], [3, 4]])
b = np.array([[5, 6], [7, 8]])
c = np.array([[9, 10], [11, 12]])
Expand All @@ -111,24 +105,17 @@ def apply(
result = measurement(a, b, c=c)
assert np.allclose(result, a + b + c)
result = measurement(a, b=b)
print(result)
assert np.allclose(result, a + b)

def test_apply(self):
"""
Test the apply method of the BaseMeasurement class.
"""
# Create an example apply method and initialize the measurement

def apply(
a: np.ndarray, b: np.ndarray, c: Optional[np.ndarray] = None
) -> np.ndarray:
if c is not None:
return a + b + c
return a + b

name = "test"
rank = 1
public = False
measurement = DummyMeasurement(name, rank, public)
assert measurement.neural_state_keys == ["a", "b", "c"]
# Test error handling for wrong size of arguments
a = np.array([1, 2, 3])
b = np.array([[4, 5, 6], [7, 8, 9]])
c = np.array([[10, 11, 12], [13, 14, 15]])
with assert_raises(ValueError):
measurement(a, b, c)
with assert_raises(ValueError):
measurement(a=a, b=b, c=c)
with assert_raises(ValueError):
measurement(a, b=b, c=c)
with assert_raises(ValueError):
measurement(a, b, c=c)
12 changes: 12 additions & 0 deletions papyrus/measurements/base_measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,21 @@ def __call__(self, *args: np.ndarray, **kwargs: np.ndarray) -> np.ndarray:
"""
# Get the number of arguments
num_args = len(args)

# Get the keys and values of the keyword arguments if any
keys = list(kwargs.keys())
vals = list(kwargs.values())

# Assert whether the length of dimension 0 of all inputs is the same
try:
inputs = args + tuple(vals)
assert all([len(i) == len(inputs[0]) for i in inputs])
except AssertionError:
raise ValueError(
f"The first dimension of all inputs to the {self.name} measurement "
"must be the same."
)

# Zip the arguments and values
z = zip(*args, *vals)

Expand Down

0 comments on commit 17d9d1b

Please sign in to comment.