Skip to content

Commit

Permalink
- Make reading neural state keys a method
Browse files Browse the repository at this point in the history
- Introduce boolean overwrite kwarg
  • Loading branch information
knikolaou committed May 15, 2024
1 parent 822724f commit 6d15dd9
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
18 changes: 17 additions & 1 deletion CI/unit_tests/recorders/test_base_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def test_init(self):
assert recorder.storage_path == storage_path
assert recorder.measurements == [self.measurement_1, self.measurement_2]
assert recorder.chunk_size == 10
assert recorder.overwrite is False

def test_neural_state_keys(self):
"""
Expand Down Expand Up @@ -181,7 +182,10 @@ def test_store(self):
name = "test"
storage_path = "temp/"
recorder = BaseRecorder(
name, storage_path, [self.measurement_1, self.measurement_2], 10
name,
storage_path,
[self.measurement_1, self.measurement_2],
10,
)

# Test storing
Expand All @@ -202,6 +206,18 @@ def test_store(self):
print(data["dummy_1"].shape)
assert_array_equal(data["dummy_1"], np.ones(shape=(2, 3, 10, 5)))
assert_array_equal(data["dummy_2"], 10 * np.ones(shape=(2, 3, 10)))
# _results should be empty after storing
assert recorder._results == {"dummy_1": [], "dummy_2": []}

# test overwriting
recorder.overwrite = True
recorder._measure(**self.neural_state)
recorder._store(20)
data = recorder.load()

assert set(data.keys()) == {"dummy_1", "dummy_2"}
assert_array_equal(data["dummy_1"], np.ones(shape=(1, 3, 10, 5)))
assert_array_equal(data["dummy_2"], 10 * np.ones(shape=(1, 3, 10)))

# Delete temporary directory
os.system("rm -r temp/")
28 changes: 22 additions & 6 deletions papyrus/recorders/base_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
storage_path: str,
measurements: List[BaseMeasurement],
chunk_size: int,
overwrite: bool = False,
):
"""
Constructor method of the BaseRecorder class.
Expand All @@ -78,21 +79,33 @@ def __init__(
The measurements that the recorder will apply.
chunk_size : int
The size of the chunks in which the data will be stored.
overwrite : bool (default=False)
Whether to overwrite the existing data in the database.
"""
self.name = name
self.storage_path = storage_path
self.measurements = measurements
self.chunk_size = chunk_size
self.overwrite = overwrite

# Read in neural state keys from measurements
self.neural_state_keys = []
for measurement in measurements:
self.neural_state_keys.extend(measurement.neural_state_keys)
self.neural_state_keys = list(set(self.neural_state_keys))
self._read_neural_state_keys()

# Temporary storage for results
self._init_results()

def _read_neural_state_keys(self):
"""
Read the neural state keys from the measurements.
Updates the neural_state_keys attribute of the recorder with the keys of the
neural state that the measurements take as input.
"""
self.neural_state_keys = []
for measurement in self.measurements:
self.neural_state_keys.extend(measurement.neural_state_keys)
self.neural_state_keys = list(set(self.neural_state_keys))

def _init_results(self):
"""
Initialize the temporary storage for the results.
Expand Down Expand Up @@ -169,8 +182,11 @@ def _store(self, epoch: int):
try:
data = self.load()
# Append the new data
for key in self._results.keys():
data[key] = np.append(data[key], self._results[key], axis=0)
if self.overwrite:
data = self._results
else:
for key in self._results.keys():
data[key] = np.append(data[key], self._results[key], axis=0)
# If the file does not exist, create a new one
except FileNotFoundError:
data = self._results
Expand Down

0 comments on commit 6d15dd9

Please sign in to comment.