diff --git a/CI/unit_tests/recorders/test_base_recorder.py b/CI/unit_tests/recorders/test_base_recorder.py index eff66ad..4708859 100644 --- a/CI/unit_tests/recorders/test_base_recorder.py +++ b/CI/unit_tests/recorders/test_base_recorder.py @@ -109,6 +109,7 @@ def test_init(self): assert recorder.storage_path == storage_path assert recorder.measurements == [self.measurement_1, self.measurement_2] assert recorder.chunk_size == 10 + assert recorder.overwrite is False def test_neural_state_keys(self): """ @@ -181,7 +182,10 @@ def test_store(self): name = "test" storage_path = "temp/" recorder = BaseRecorder( - name, storage_path, [self.measurement_1, self.measurement_2], 10 + name, + storage_path, + [self.measurement_1, self.measurement_2], + 10, ) # Test storing @@ -202,6 +206,18 @@ def test_store(self): print(data["dummy_1"].shape) assert_array_equal(data["dummy_1"], np.ones(shape=(2, 3, 10, 5))) assert_array_equal(data["dummy_2"], 10 * np.ones(shape=(2, 3, 10))) + # _results should be empty after storing + assert recorder._results == {"dummy_1": [], "dummy_2": []} + + # test overwriting + recorder.overwrite = True + recorder._measure(**self.neural_state) + recorder._store(20) + data = recorder.load() + + assert set(data.keys()) == {"dummy_1", "dummy_2"} + assert_array_equal(data["dummy_1"], np.ones(shape=(1, 3, 10, 5))) + assert_array_equal(data["dummy_2"], 10 * np.ones(shape=(1, 3, 10))) # Delete temporary directory os.system("rm -r temp/") diff --git a/papyrus/recorders/base_recorder.py b/papyrus/recorders/base_recorder.py index 2d93e98..7712974 100644 --- a/papyrus/recorders/base_recorder.py +++ b/papyrus/recorders/base_recorder.py @@ -63,6 +63,7 @@ def __init__( storage_path: str, measurements: List[BaseMeasurement], chunk_size: int, + overwrite: bool = False, ): """ Constructor method of the BaseRecorder class. @@ -78,21 +79,33 @@ def __init__( The measurements that the recorder will apply. chunk_size : int The size of the chunks in which the data will be stored. + overwrite : bool (default=False) + Whether to overwrite the existing data in the database. """ self.name = name self.storage_path = storage_path self.measurements = measurements self.chunk_size = chunk_size + self.overwrite = overwrite # Read in neural state keys from measurements - self.neural_state_keys = [] - for measurement in measurements: - self.neural_state_keys.extend(measurement.neural_state_keys) - self.neural_state_keys = list(set(self.neural_state_keys)) + self._read_neural_state_keys() # Temporary storage for results self._init_results() + def _read_neural_state_keys(self): + """ + Read the neural state keys from the measurements. + + Updates the neural_state_keys attribute of the recorder with the keys of the + neural state that the measurements take as input. + """ + self.neural_state_keys = [] + for measurement in self.measurements: + self.neural_state_keys.extend(measurement.neural_state_keys) + self.neural_state_keys = list(set(self.neural_state_keys)) + def _init_results(self): """ Initialize the temporary storage for the results. @@ -169,8 +182,11 @@ def _store(self, epoch: int): try: data = self.load() # Append the new data - for key in self._results.keys(): - data[key] = np.append(data[key], self._results[key], axis=0) + if self.overwrite: + data = self._results + else: + for key in self._results.keys(): + data[key] = np.append(data[key], self._results[key], axis=0) # If the file does not exist, create a new one except FileNotFoundError: data = self._results