-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Include correct license in pyproject * Update gitignore to ignore specific vscode settings. * Include header in existing files. * Create BaseMeasurement Class * Deactivate Black on PRs * - Write BaseMeasurement class work Create a test for - BaseMeasurement * remove dummy test * add temporal test notebook * Start writing measurements * Start writing matrix utils and according tests * remove file for entropy measurement * run black * Fix signs in testing eigenspace computation * try fixing numpy installation * Fix missing numpy installation in developer mode * fix base_measurement test * Write matrix utils * Implement analysis utils and tests * remove unnessecary inputs from tests * make kwargs of shannon and von neuman entropy similar * write default measurements * run isort * Write init for measurements * Introduce neural states as a state representation of an NN in form of a dict of np.arrays. * Start writing measurements example notebook. * run black and isort * remove unnecessary import * declare type hint of neural state * remove unused imports * Add test for neural state keys * Add ntk to measurements * Write - neural state - neural state creator * fix pytest file search * Fix import issues * remove print from tests * Remove logging of clipping in eigenvalue calculation * rename function input to loss and accuracy * remove unused parts from test_base_measurement * remove print statement from test_base_measurement * Add shape check to BaseMeasurement class * Add shape checking to all ntk measurements * Include NTK measurement in imports * write loss derivative measurement. * Remove the `public` kwarg from measurements * remove public kwarg from loss derivative * Update test_matrix_utils.py Removed print statement * Re-name analysis utils test module * rename grammian to gramian --------- Co-authored-by: knikolaou <[email protected]> Co-authored-by: Samuel Tovey <[email protected]>
- Loading branch information
1 parent
da48b24
commit e5b4af7
Showing
23 changed files
with
2,482 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
name: Check black coding style | ||
|
||
on: [push, pull_request] | ||
on: [push] | ||
|
||
jobs: | ||
lint: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
""" | ||
papyrus: a lightweight Python library to record neural learning. | ||
License | ||
------- | ||
This program and the accompanying materials are made available under the terms | ||
of the Eclipse Public License v2.0 which accompanies this distribution, and is | ||
available at https://www.eclipse.org/legal/epl-v20.html | ||
SPDX-License-Identifier: EPL-2.0 | ||
Copyright Contributors to the Zincwarecode Project. | ||
Contact Information | ||
------------------- | ||
email: [email protected] | ||
github: https://github.com/zincware | ||
web: https://zincwarecode.com/ | ||
Summary | ||
------- | ||
Test the base measurement class. | ||
""" | ||
|
||
from typing import Optional | ||
|
||
import numpy as np | ||
import pytest | ||
from numpy.testing import assert_raises | ||
|
||
from papyrus.measurements import BaseMeasurement | ||
|
||
|
||
class DummyMeasurement(BaseMeasurement): | ||
""" | ||
Dummy measurement class for testing. | ||
""" | ||
|
||
def __init__(self, name: str, rank: int): | ||
super().__init__(name, rank) | ||
|
||
def apply( | ||
self, a: np.ndarray, b: np.ndarray, c: Optional[np.ndarray] = None | ||
) -> np.ndarray: | ||
if c is not None: | ||
return a + b + c | ||
return a + b | ||
|
||
|
||
class TestBaseMeasurement: | ||
""" | ||
Test the base measurement class. | ||
""" | ||
|
||
def test_init(self): | ||
""" | ||
Test the constructor method of the BaseMeasurement class. | ||
""" | ||
# Test the constructor method | ||
name = "test" | ||
rank = 1 | ||
measurement = BaseMeasurement(name, rank) | ||
assert measurement.name == name | ||
assert measurement.rank == rank | ||
|
||
# Test the rank parameter | ||
with pytest.raises(ValueError): | ||
BaseMeasurement(name, -1) | ||
|
||
def test_call(self): | ||
""" | ||
Test the call method of the BaseMeasurement class. | ||
""" | ||
# Test the call method | ||
name = "test" | ||
rank = 1 | ||
measurement = BaseMeasurement(name, rank) | ||
|
||
# Test the apply method | ||
with pytest.raises(NotImplementedError): | ||
measurement.apply() | ||
|
||
# Test the apply method | ||
a = np.array([[1, 2], [3, 4]]) | ||
b = np.array([[5, 6], [7, 8]]) | ||
c = np.array([[9, 10], [11, 12]]) | ||
|
||
# Test the call method with only arguments | ||
measurement = DummyMeasurement(name, rank) | ||
result = measurement(a, b) | ||
assert np.allclose(result, a + b) | ||
result = measurement(a, b, c) | ||
assert np.allclose(result, a + b + c) | ||
|
||
# Test the call method with only keyword arguments | ||
result = measurement(a=a, b=b) | ||
assert np.allclose(result, a + b) | ||
result = measurement(a=a, b=b, c=c) | ||
assert np.allclose(result, a + b + c) | ||
|
||
# Test the call method with both arguments and keyword arguments | ||
result = measurement(a, b, c=c) | ||
assert np.allclose(result, a + b + c) | ||
result = measurement(a, b=b) | ||
assert np.allclose(result, a + b) | ||
|
||
# Test error handling for wrong size of arguments | ||
a = np.array([1, 2, 3]) | ||
b = np.array([[4, 5, 6], [7, 8, 9]]) | ||
c = np.array([[10, 11, 12], [13, 14, 15]]) | ||
with assert_raises(ValueError): | ||
measurement(a, b, c) | ||
with assert_raises(ValueError): | ||
measurement(a=a, b=b, c=c) | ||
with assert_raises(ValueError): | ||
measurement(a, b=b, c=c) | ||
with assert_raises(ValueError): | ||
measurement(a, b, c=c) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
""" | ||
papyrus: a lightweight Python library to record neural learning. | ||
License | ||
------- | ||
This program and the accompanying materials are made available under the terms | ||
of the Eclipse Public License v2.0 which accompanies this distribution, and is | ||
available at https://www.eclipse.org/legal/epl-v20.html | ||
SPDX-License-Identifier: EPL-2.0 | ||
Copyright Contributors to the Zincwarecode Project. | ||
Contact Information | ||
------------------- | ||
email: [email protected] | ||
github: https://github.com/zincware | ||
web: https://zincwarecode.com/ | ||
Summary | ||
------- | ||
""" | ||
|
||
from papyrus.neural_state import NeuralState | ||
|
||
|
||
class TestNeuralState: | ||
|
||
def test_init(self): | ||
|
||
neural_state = NeuralState() | ||
assert neural_state.loss is None | ||
assert neural_state.accuracy is None | ||
assert neural_state.predictions is None | ||
assert neural_state.targets is None | ||
assert neural_state.ntk is None | ||
|
||
neural_state = NeuralState( | ||
loss=[], | ||
accuracy=[], | ||
predictions=[], | ||
targets=[], | ||
ntk=[], | ||
) | ||
assert neural_state.loss == [] | ||
assert neural_state.accuracy == [] | ||
assert neural_state.predictions == [] | ||
assert neural_state.targets == [] | ||
assert neural_state.ntk == [] | ||
|
||
def test_get_dict(self): | ||
|
||
neural_state = NeuralState() | ||
assert neural_state.get_dict() == {} | ||
|
||
neural_state = NeuralState( | ||
loss=[], | ||
accuracy=[], | ||
predictions=[], | ||
targets=[], | ||
ntk=[], | ||
) | ||
assert neural_state.get_dict() == { | ||
"loss": [], | ||
"accuracy": [], | ||
"predictions": [], | ||
"targets": [], | ||
"ntk": [], | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
""" | ||
papyrus: a lightweight Python library to record neural learning. | ||
License | ||
------- | ||
This program and the accompanying materials are made available under the terms | ||
of the Eclipse Public License v2.0 which accompanies this distribution, and is | ||
available at https://www.eclipse.org/legal/epl-v20.html | ||
SPDX-License-Identifier: EPL-2.0 | ||
Copyright Contributors to the Zincwarecode Project. | ||
Contact Information | ||
------------------- | ||
email: [email protected] | ||
github: https://github.com/zincware | ||
web: https://zincwarecode.com/ | ||
Summary | ||
------- | ||
""" | ||
|
||
import numpy as np | ||
|
||
from papyrus.neural_state import NeuralStateCreator | ||
|
||
|
||
class TestNeuralStateCreator: | ||
def test_init(self): | ||
def network_apply_fn(params: dict, data: dict): | ||
return np.arange(10) | ||
|
||
def ntk_apply_fn(params: dict, data: dict): | ||
return np.arange(10) | ||
|
||
neural_state_creator = NeuralStateCreator( | ||
network_apply_fn=network_apply_fn, | ||
ntk_apply_fn=ntk_apply_fn, | ||
) | ||
assert neural_state_creator.apply_fns == { | ||
"predictions": network_apply_fn, | ||
"ntk": ntk_apply_fn, | ||
} | ||
|
||
def test_apply(self): | ||
def network_apply_fn(params: dict, data: dict): | ||
return np.arange(10) | ||
|
||
def ntk_apply_fn(params: dict, data: dict): | ||
return np.arange(10) | ||
|
||
neural_state_creator = NeuralStateCreator( | ||
network_apply_fn=network_apply_fn, | ||
ntk_apply_fn=ntk_apply_fn, | ||
) | ||
|
||
neural_state = neural_state_creator( | ||
params={}, | ||
data={}, | ||
loss=np.arange(5), | ||
) | ||
assert np.all(neural_state.predictions == np.arange(10)) | ||
assert np.all(neural_state.ntk == np.arange(10)) | ||
assert np.all(neural_state.loss == np.arange(5)) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
""" | ||
papyrus: a lightweight Python library to record neural learning. | ||
License | ||
------- | ||
This program and the accompanying materials are made available under the terms | ||
of the Eclipse Public License v2.0 which accompanies this distribution, and is | ||
available at https://www.eclipse.org/legal/epl-v20.html | ||
SPDX-License-Identifier: EPL-2.0 | ||
Copyright Contributors to the Zincwarecode Project. | ||
Contact Information | ||
------------------- | ||
email: [email protected] | ||
github: https://github.com/zincware | ||
web: https://zincwarecode.com/ | ||
Summary | ||
------- | ||
""" | ||
|
||
import numpy as np | ||
from numpy.testing import assert_almost_equal | ||
|
||
from papyrus.utils import ( | ||
compute_shannon_entropy, | ||
compute_trace, | ||
compute_von_neumann_entropy, | ||
) | ||
|
||
|
||
class TestAnalysisUtils: | ||
""" | ||
Test suite for the analysis utils. | ||
""" | ||
|
||
def test_compute_trace(self): | ||
""" | ||
Test the computation of the trace. | ||
""" | ||
vector = np.random.rand(10) | ||
matrix = np.diag(vector) | ||
|
||
# Test the trace without normalization | ||
trace = compute_trace(matrix, normalize=False) | ||
assert trace == np.sum(vector) | ||
|
||
# Test the trace with normalization | ||
trace = compute_trace(matrix, normalize=True) | ||
assert trace == np.sum(vector) / 10 | ||
|
||
def test_shannon_entropy(self): | ||
""" | ||
Test the Shannon entropy. | ||
""" | ||
dist = np.array([0.2, 0.2, 0.2, 0.2, 0.2]) | ||
assert_almost_equal(compute_shannon_entropy(dist), np.log(5)) | ||
assert_almost_equal(compute_shannon_entropy(dist, effective=True), 1.0) | ||
|
||
dist = np.array([0, 0, 0, 0, 1]) | ||
assert compute_shannon_entropy(dist) == 0 | ||
assert compute_shannon_entropy(dist, effective=True) == 0 | ||
|
||
dist = np.array([0, 0, 0, 0.5, 0.5]) | ||
assert compute_shannon_entropy(dist) == np.log(2) | ||
s = compute_shannon_entropy(dist, effective=True) | ||
assert s == np.log(2) / np.log(5) | ||
|
||
def test_compute_von_neumann_entropy(self): | ||
""" | ||
Test the computation of the von-Neumann entropy. | ||
""" | ||
matrix = np.eye(2) * 0.5 | ||
entropy = compute_von_neumann_entropy(matrix=matrix, effective=False) | ||
assert entropy == np.log(2) | ||
|
||
entropy = compute_von_neumann_entropy(matrix=matrix, effective=True) | ||
assert entropy == 1 |
Oops, something went wrong.