Skip to content

Commit

Permalink
Include hdf5 data storage [unfinished]
Browse files Browse the repository at this point in the history
  • Loading branch information
knikolaou committed Jun 7, 2024
1 parent 0fa1a4d commit 1662a66
Show file tree
Hide file tree
Showing 7 changed files with 268 additions and 47 deletions.
56 changes: 32 additions & 24 deletions CI/unit_tests/recorders/test_base_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def test_init(self):
assert recorder.measurements == [self.measurement_1, self.measurement_2]
assert recorder.chunk_size == 10
assert recorder.overwrite is False
assert recorder._data_storage.database_path == f"{storage_path}{name}.h5"

def test_neural_state_keys(self):
"""
Expand Down Expand Up @@ -149,29 +150,29 @@ def test_measure(self):
)
assert_array_equal(recorder._results["dummy_2"], 10 * np.ones(shape=(2, 3, 10)))

def test_write_read(self):
"""
Test the write and read methods of the BaseRecorder class.
"""
# Create a temporary directory
temp_dir = tempfile.TemporaryDirectory()
name = "test"
storage_path = temp_dir.name
recorder = BaseRecorder(
name, storage_path, [self.measurement_1, self.measurement_2], 10
)

# Test writing and reading
recorder._measure(**self.neural_state)
recorder._write(recorder._results)
data = recorder.load()

assert set(data.keys()) == {"dummy_1", "dummy_2"}
assert_array_equal(data["dummy_1"], np.ones(shape=(1, 3, 10, 5)))
assert_array_equal(data["dummy_2"], 10 * np.ones(shape=(1, 3, 10)))

# Delete temporary directory
temp_dir.cleanup()
# def test_write_read(self):
# """
# Test the write and read methods of the BaseRecorder class.
# """
# # Create a temporary directory
# temp_dir = tempfile.TemporaryDirectory()
# name = "test"
# storage_path = temp_dir.name
# recorder = BaseRecorder(
# name, storage_path, [self.measurement_1, self.measurement_2], 10
# )

# # Test writing and reading
# recorder._measure(**self.neural_state)
# recorder._write(recorder._results)
# data = recorder.load()

# assert set(data.keys()) == {"dummy_1", "dummy_2"}
# assert_array_equal(data["dummy_1"], np.ones(shape=(1, 3, 10, 5)))
# assert_array_equal(data["dummy_2"], 10 * np.ones(shape=(1, 3, 10)))

# # Delete temporary directory
# temp_dir.cleanup()

def test_store(self):
"""
Expand Down Expand Up @@ -342,7 +343,7 @@ def test_overwrite(self):

# Measure and save data
recorder._measure(**self.neural_state)
recorder._write(recorder._results)
recorder._data_storage.write(recorder._results)
data = recorder.load()
assert set(data.keys()) == {"dummy_1", "dummy_2"}
assert_array_equal(data["dummy_1"], np.ones(shape=(1, 3, 10, 5)))
Expand All @@ -356,6 +357,7 @@ def test_overwrite(self):
overwrite=True,
)
data = recorder.load()
print(data)
assert set(data.keys()) == {"dummy_1", "dummy_2"}
assert_array_equal(data["dummy_1"], [])
assert_array_equal(data["dummy_2"], [])
Expand All @@ -372,3 +374,9 @@ def test_overwrite(self):
assert set(data.keys()) == {"dummy_1", "dummy_2"}
assert_array_equal(data["dummy_1"], np.ones(shape=(2, 3, 10, 5)))
assert_array_equal(data["dummy_2"], 10 * np.ones(shape=(2, 3, 10)))

def test_recoding_order(self):
"""
Test the order of the recordings.
"""
pass
94 changes: 94 additions & 0 deletions CI/unit_tests/recorders/test_data_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""
papyrus: a lightweight Python library to record neural learning.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/
Summary
-------
"""

import tempfile
from os import path
from pathlib import Path

import h5py as hf
import numpy as onp
from numpy import testing

from papyrus.recorders.data_storage import DataStorage


class TestDataStorage:
"""
Test suite for the storage module.
"""

@classmethod
def setup_class(cls):
"""
Set up the test.
"""
cls.vector_data = onp.random.uniform(size=(100,))
cls.tensor_data = onp.random.uniform(size=(100, 10, 10))

cls.data = {"vector_data": cls.vector_data, "tensor_data": cls.tensor_data}

def test_database_construction(self):
"""
Test that database groups are built properly.
"""
# Create temporary directory for safe testing.
with tempfile.TemporaryDirectory() as directory:
database_path = path.join(directory, "test_creation")
data_storage = DataStorage(Path(database_path))
data_storage.write(self.data) # write some data to empty DB.

with hf.File(data_storage.database_path, "r") as db:
# Test correct dataset creation.
keys = list(db.keys())
testing.assert_equal(keys, ["tensor_data", "vector_data"])
vector_data = onp.array(db["vector_data"])
tensor_data = onp.array(db["tensor_data"])

# Check data structure within the db.
assert vector_data.shape == (100,)
assert vector_data.sum() != 0.0

assert tensor_data.shape == (100, 10, 10)
assert tensor_data.sum() != 0.0

def test_resize_dataset_standard(self):
"""
Test if the datasets are resized properly.
"""
with tempfile.TemporaryDirectory() as directory:
database_path = path.join(directory, "test_resize")
data_storage = DataStorage(Path(database_path))
data_storage.write(self.data) # write some data to empty DB.
data_storage.write(self.data) # force resize.

with hf.File(data_storage.database_path, "r") as db:
# Test correct dataset creation.
vector_data = onp.array(db["vector_data"])
tensor_data = onp.array(db["tensor_data"])

# Check data structure within the db.
assert vector_data.shape == (200,)
assert vector_data[100:].sum() != 0.0

assert tensor_data.shape == (200, 10, 10)
assert tensor_data[100:].sum() != 0.0
2 changes: 1 addition & 1 deletion papyrus/measurements/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
NTK,
Accuracy,
Loss,
LossDerivative,
NTKEigenvalues,
NTKEntropy,
NTKMagnitudeDistribution,
NTKSelfEntropy,
NTKTrace,
LossDerivative,
)

__all__ = [
Expand Down
3 changes: 2 additions & 1 deletion papyrus/recorders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@
"""

from papyrus.recorders.base_recorder import BaseRecorder
from papyrus.recorders.data_storage import DataStorage

__all__ = [BaseRecorder.__name__]
__all__ = [BaseRecorder.__name__, DataStorage.__name__]
29 changes: 9 additions & 20 deletions papyrus/recorders/base_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import numpy as np

from papyrus.measurements.base_measurement import BaseMeasurement
from papyrus.recorders.data_storage import DataStorage


class BaseRecorder(ABC):
Expand Down Expand Up @@ -92,6 +93,9 @@ def __init__(
self.chunk_size = chunk_size
self.overwrite = overwrite

# Initialize the data storage
self._data_storage = DataStorage(self.storage_path + self.name)

# Read in neural state keys from measurements
self.neural_state_keys = self._read_neural_state_keys()

Expand All @@ -103,7 +107,7 @@ def __init__(
try:
self.load()
# If overwrite is True, delete the existing data
self._write(self._results)
self._data_storage.write(self._results)
except FileNotFoundError:
pass

Expand All @@ -127,19 +131,6 @@ def _init_internals(self):
# Reset the counter
self._counter = 0

def _write(self, data: dict):
"""
Write data to the database using np.savez
TODO: Change this method to use another type of storage.
Parameters
----------
data : dict
The data to be written to the database.
"""
np.savez(self.storage_path + self.name, **data)

def load(self):
"""
Load the data from the database using np.load.
Expand All @@ -151,9 +142,7 @@ def load(self):
data : dict
The data loaded from the database.
"""
# By combining storage path and name, we can load the data
data = np.load(self.storage_path + self.name + ".npz")
return dict(data)
return self._data_storage.load(self._results.keys())

def _measure(self, **neural_state):
"""
Expand Down Expand Up @@ -190,17 +179,17 @@ def store(self, ignore_chunk_size=True):
TODO: Change this method to use another type of storage.
"""
if self._counter % self.chunk_size == 0 or ignore_chunk_size:
# Gather the data
data = self.gather()
# Write the data back to the database
self._write(data)
self._data_storage.write(self._results)
# Reinitialize the temporary storage
self._init_internals()

def gather(self):
"""
Gather the results from the temporary storage and the database.
TODO: Change this method to use another type of storage.
Returns
-------
data : dict
Expand Down
Loading

0 comments on commit 1662a66

Please sign in to comment.