Skip to content

Commit

Permalink
Merge branch 'main' into Konsti_Recorders
Browse files Browse the repository at this point in the history
  • Loading branch information
knikolaou committed Jun 10, 2024
2 parents da20acf + e5b4af7 commit 35433db
Show file tree
Hide file tree
Showing 6 changed files with 343 additions and 0 deletions.
69 changes: 69 additions & 0 deletions CI/unit_tests/neural_state/test_neural_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""
papyrus: a lightweight Python library to record neural learning.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/
Summary
-------
"""

from papyrus.neural_state import NeuralState


class TestNeuralState:

def test_init(self):

neural_state = NeuralState()
assert neural_state.loss is None
assert neural_state.accuracy is None
assert neural_state.predictions is None
assert neural_state.targets is None
assert neural_state.ntk is None

neural_state = NeuralState(
loss=[],
accuracy=[],
predictions=[],
targets=[],
ntk=[],
)
assert neural_state.loss == []
assert neural_state.accuracy == []
assert neural_state.predictions == []
assert neural_state.targets == []
assert neural_state.ntk == []

def test_get_dict(self):

neural_state = NeuralState()
assert neural_state.get_dict() == {}

neural_state = NeuralState(
loss=[],
accuracy=[],
predictions=[],
targets=[],
ntk=[],
)
assert neural_state.get_dict() == {
"loss": [],
"accuracy": [],
"predictions": [],
"targets": [],
"ntk": [],
}
65 changes: 65 additions & 0 deletions CI/unit_tests/neural_state/test_neural_state_creator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
papyrus: a lightweight Python library to record neural learning.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/
Summary
-------
"""

import numpy as np

from papyrus.neural_state import NeuralStateCreator


class TestNeuralStateCreator:
def test_init(self):
def network_apply_fn(params: dict, data: dict):
return np.arange(10)

def ntk_apply_fn(params: dict, data: dict):
return np.arange(10)

neural_state_creator = NeuralStateCreator(
network_apply_fn=network_apply_fn,
ntk_apply_fn=ntk_apply_fn,
)
assert neural_state_creator.apply_fns == {
"predictions": network_apply_fn,
"ntk": ntk_apply_fn,
}

def test_apply(self):
def network_apply_fn(params: dict, data: dict):
return np.arange(10)

def ntk_apply_fn(params: dict, data: dict):
return np.arange(10)

neural_state_creator = NeuralStateCreator(
network_apply_fn=network_apply_fn,
ntk_apply_fn=ntk_apply_fn,
)

neural_state = neural_state_creator(
params={},
data={},
loss=np.arange(5),
)
assert np.all(neural_state.predictions == np.arange(10))
assert np.all(neural_state.ntk == np.arange(10))
assert np.all(neural_state.loss == np.arange(5))
18 changes: 18 additions & 0 deletions papyrus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,23 @@
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/
papyrus: a lightweight Python library to record neural learning.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: [email protected]
Expand All @@ -20,6 +37,7 @@
Summary
-------
papyrus measurements api.
papyrus measurements api.
"""

from papyrus import measurements, recorders, utils
Expand Down
30 changes: 30 additions & 0 deletions papyrus/neural_state/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""
papyrus: a lightweight Python library to record neural learning.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/
Summary
-------
"""

from papyrus.neural_state.neural_state import NeuralState
from papyrus.neural_state.neural_state_creator import NeuralStateCreator

__all__ = [
NeuralState.__name__,
NeuralStateCreator.__name__,
]
73 changes: 73 additions & 0 deletions papyrus/neural_state/neural_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
papyrus: a lightweight Python library to record neural learning.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/
Summary
-------
"""

from dataclasses import dataclass
from typing import List, Optional

import numpy as np


@dataclass
class NeuralState:
"""
Data class to represent the state of a neural network.
A neural network state can be represented in various ways. NeuralState offers a
structured solution to represent the state of a neural network in terms of different
properties.
If the default properties are not sufficient, the user can extend this class to
include more. In general, a property of a neural state can be any type of data, as
long as it is formatted as `List[Any]` or `np.array[Any]`.
Attributes
----------
loss: Optional[List[np.ndarray]]
The loss of a neural network.
accuracy: Optional[List[np.ndarray]]
The accuracy of a neural network.
predictions: Optional[List[np.ndarray]]
The predictions of a neural network.
targets: Optional[List[np.ndarray]]
The targets of a neural network.
ntk: Optional[List[np.ndarray]]
The neural tangent kernel of a neural network.
"""

loss: Optional[List[np.ndarray]] = None
accuracy: Optional[List[np.ndarray]] = None
predictions: Optional[List[np.ndarray]] = None
targets: Optional[List[np.ndarray]] = None
ntk: Optional[List[np.ndarray]] = None

def get_dict(self) -> dict:
"""
Get a dictionary representation of the neural state.
Only return the properties that are not None.
Returns
-------
dict
A dictionary representation of the neural state.
"""
return {k: v for k, v in self.__dict__.items() if v is not None}
88 changes: 88 additions & 0 deletions papyrus/neural_state/neural_state_creator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""
papyrus: a lightweight Python library to record neural learning.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/
Summary
-------
"""

from papyrus.neural_state.neural_state import NeuralState


class NeuralStateCreator:
"""
Class creating a neural state.
The NeuralStateCreator class serves as instance mapping data and parameter state to
a NeuralState instance using a set of apply functions. The apply functions.
These apply functions are e.g. the neural network forward pass or the neural tangent
kernel computation.
Attributes
----------
apply_fns : dict
A dictionary of apply functions that map the data and parameter state to a
NeuralState instance.
"""

def __init__(self, network_apply_fn: callable, ntk_apply_fn: callable):
"""
Initialize the NeuralStateCreator instance.
Parameters
----------
network_apply_fn : callable
The apply function that maps the data and parameter state to a
NeuralState instance.
ntk_apply_fn : callable
The apply function that maps the data and parameter state to a
NeuralState instance.
"""
self.apply_fns = {
"predictions": network_apply_fn,
"ntk": ntk_apply_fn,
}

def __call__(self, params: dict, data: dict, **kwargs) -> NeuralState:
"""
Call the NeuralStateCreator instance.
Parameters
----------
params : dict
A dictionary of parameters that are used in the apply functions.
data : dict
A dictionary of data that is used in the apply functions.
kwargs : Any
Additional keyword arguments that are directly added to the
neural state.
Returns
-------
NeuralState
The neural state that is created by the apply functions.
"""
neural_state = NeuralState()

for key, apply_fn in self.apply_fns.items():
neural_state.__setattr__(key, apply_fn(params, data))

for key, value in kwargs.items():
neural_state.__setattr__(key, value)

return neural_state

0 comments on commit 35433db

Please sign in to comment.