Skip to content

Commit

Permalink
run black and isort
Browse files Browse the repository at this point in the history
  • Loading branch information
knikolaou committed May 10, 2024
1 parent 5d2b429 commit 06313df
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 22 deletions.
5 changes: 2 additions & 3 deletions papyrus/measurements/base_measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
"""

from abc import ABC
from typing import List
from inspect import signature
from typing import List

import numpy as np

Expand Down Expand Up @@ -80,11 +80,10 @@ def __init__(self, name: str, rank: int, public: bool = False):

if not isinstance(self.rank, int) or self.rank < 0:
raise ValueError("Rank must be a positive integer.")

# Get the neural state keys that the measurement takes as input
self.neural_state_keys = []
self.neural_state_keys.extend(signature(self.apply).parameters.keys())


def apply(self, *args: np.ndarray, **kwargs: np.ndarray) -> np.ndarray:
"""
Expand Down
37 changes: 18 additions & 19 deletions papyrus/measurements/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
Module containing default measurements for recording neural learning.
"""

from typing import Callable, Optional
from typing import List
from typing import Callable, List, Optional

import numpy as np

Expand All @@ -46,13 +45,13 @@ class Loss(BaseMeasurement):
Neural State Keys
-----------------
predictions : np.ndarray
The predictions of the neural network. Required if the loss function is
The predictions of the neural network. Required if the loss function is
provided. Needs to be combined with the targets key.
targets : np.ndarray
The target values of the neural network. Required if the loss function is
targets : np.ndarray
The target values of the neural network. Required if the loss function is
provided. Needs to be combined with the predictions key.
loss : float
The loss of the neural network. Required if the loss function is not
The loss of the neural network. Required if the loss function is not
provided. Allows to measure precomputed loss values.
"""

Expand All @@ -69,17 +68,17 @@ def __init__(
Parameters
----------
name : str (default="loss")
The name of the measurement, defining how the instance in the database
The name of the measurement, defining how the instance in the database
will be identified.
rank : int (default=0)
The rank of the measurement, defining the tensor order of the
The rank of the measurement, defining the tensor order of the
measurement.
public : bool (default=False)
Boolean flag to indicate whether the measurement resutls will be
Boolean flag to indicate whether the measurement resutls will be
accessible via a public attribute of the recorder.
loss_fn : Optional[Callable] (default=None)
The loss function to be used to compute the loss of the neural network.
If the loss function is not provided, the apply method will assume that
If the loss function is not provided, the apply method will assume that
the loss is used as the input.
If the loss function is provided, the apply method will assume that the
neural network outputs and the target values are used as inputs.
Expand Down Expand Up @@ -136,13 +135,13 @@ class Accuracy(BaseMeasurement):
Neural State Keys
-----------------
predictions : np.ndarray
The predictions of the neural network. Required if the loss function is
The predictions of the neural network. Required if the loss function is
provided. Needs to be combined with the targets key.
targets : np.ndarray
The target values of the neural network. Required if the loss function is
targets : np.ndarray
The target values of the neural network. Required if the loss function is
provided. Needs to be combined with the predictions key.
accuracy : float
The accuracy of the neural network. Required if the accuracy function is not
The accuracy of the neural network. Required if the accuracy function is not
provided. Allows to measure precomputed loss values.
"""

Expand Down Expand Up @@ -246,10 +245,10 @@ def __init__(
The name of the measurement, defining how the instance in the database
will be identified.
rank : int (default=1)
The rank of the measurement, defining the tensor order of the
The rank of the measurement, defining the tensor order of the
measurement.
public : bool (default=False)
Boolean flag to indicate whether the measurement resutls will be
Boolean flag to indicate whether the measurement resutls will be
accessible via a public attribute of the recorder.
normalize : bool (default=True)
Boolean flag to indicate whether the trace of the NTK will be normalized
Expand Down Expand Up @@ -382,7 +381,7 @@ def __init__(
The name of the measurement, defining how the instance in the database
will be identified.
rank : int (default=1)
The rank of the measurement, defining the tensor order of the
The rank of the measurement, defining the tensor order of the
measurement.
public : bool (default=False)
Boolean flag to indicate whether the measurement resutls will be
Expand Down Expand Up @@ -448,7 +447,7 @@ def __init__(
The name of the measurement, defining how the instance in the database
will be identified.
rank : int (default=1)
The rank of the measurement, defining the tensor order of the
The rank of the measurement, defining the tensor order of the
measurement.
public : bool (default=False)
Boolean flag to indicate whether the measurement resutls will be
Expand Down Expand Up @@ -509,7 +508,7 @@ def __init__(
The name of the measurement, defining how the instance in the database
will be identified.
rank : int (default=1)
The rank of the measurement, defining the tensor order of the
The rank of the measurement, defining the tensor order of the
measurement.
public : bool (default=False)
Boolean flag to indicate whether the measurement resutls will be
Expand Down

0 comments on commit 06313df

Please sign in to comment.