diff --git a/papyrus/measurements/base_measurement.py b/papyrus/measurements/base_measurement.py index a9af5d6..a9ed676 100644 --- a/papyrus/measurements/base_measurement.py +++ b/papyrus/measurements/base_measurement.py @@ -22,8 +22,8 @@ """ from abc import ABC -from typing import List from inspect import signature +from typing import List import numpy as np @@ -80,11 +80,10 @@ def __init__(self, name: str, rank: int, public: bool = False): if not isinstance(self.rank, int) or self.rank < 0: raise ValueError("Rank must be a positive integer.") - + # Get the neural state keys that the measurement takes as input self.neural_state_keys = [] self.neural_state_keys.extend(signature(self.apply).parameters.keys()) - def apply(self, *args: np.ndarray, **kwargs: np.ndarray) -> np.ndarray: """ diff --git a/papyrus/measurements/measurements.py b/papyrus/measurements/measurements.py index c4e26b5..4aa0128 100644 --- a/papyrus/measurements/measurements.py +++ b/papyrus/measurements/measurements.py @@ -22,8 +22,7 @@ Module containing default measurements for recording neural learning. """ -from typing import Callable, Optional -from typing import List +from typing import Callable, List, Optional import numpy as np @@ -46,13 +45,13 @@ class Loss(BaseMeasurement): Neural State Keys ----------------- predictions : np.ndarray - The predictions of the neural network. Required if the loss function is + The predictions of the neural network. Required if the loss function is provided. Needs to be combined with the targets key. - targets : np.ndarray - The target values of the neural network. Required if the loss function is + targets : np.ndarray + The target values of the neural network. Required if the loss function is provided. Needs to be combined with the predictions key. loss : float - The loss of the neural network. Required if the loss function is not + The loss of the neural network. Required if the loss function is not provided. Allows to measure precomputed loss values. """ @@ -69,17 +68,17 @@ def __init__( Parameters ---------- name : str (default="loss") - The name of the measurement, defining how the instance in the database + The name of the measurement, defining how the instance in the database will be identified. rank : int (default=0) - The rank of the measurement, defining the tensor order of the + The rank of the measurement, defining the tensor order of the measurement. public : bool (default=False) - Boolean flag to indicate whether the measurement resutls will be + Boolean flag to indicate whether the measurement resutls will be accessible via a public attribute of the recorder. loss_fn : Optional[Callable] (default=None) The loss function to be used to compute the loss of the neural network. - If the loss function is not provided, the apply method will assume that + If the loss function is not provided, the apply method will assume that the loss is used as the input. If the loss function is provided, the apply method will assume that the neural network outputs and the target values are used as inputs. @@ -136,13 +135,13 @@ class Accuracy(BaseMeasurement): Neural State Keys ----------------- predictions : np.ndarray - The predictions of the neural network. Required if the loss function is + The predictions of the neural network. Required if the loss function is provided. Needs to be combined with the targets key. - targets : np.ndarray - The target values of the neural network. Required if the loss function is + targets : np.ndarray + The target values of the neural network. Required if the loss function is provided. Needs to be combined with the predictions key. accuracy : float - The accuracy of the neural network. Required if the accuracy function is not + The accuracy of the neural network. Required if the accuracy function is not provided. Allows to measure precomputed loss values. """ @@ -246,10 +245,10 @@ def __init__( The name of the measurement, defining how the instance in the database will be identified. rank : int (default=1) - The rank of the measurement, defining the tensor order of the + The rank of the measurement, defining the tensor order of the measurement. public : bool (default=False) - Boolean flag to indicate whether the measurement resutls will be + Boolean flag to indicate whether the measurement resutls will be accessible via a public attribute of the recorder. normalize : bool (default=True) Boolean flag to indicate whether the trace of the NTK will be normalized @@ -382,7 +381,7 @@ def __init__( The name of the measurement, defining how the instance in the database will be identified. rank : int (default=1) - The rank of the measurement, defining the tensor order of the + The rank of the measurement, defining the tensor order of the measurement. public : bool (default=False) Boolean flag to indicate whether the measurement resutls will be @@ -448,7 +447,7 @@ def __init__( The name of the measurement, defining how the instance in the database will be identified. rank : int (default=1) - The rank of the measurement, defining the tensor order of the + The rank of the measurement, defining the tensor order of the measurement. public : bool (default=False) Boolean flag to indicate whether the measurement resutls will be @@ -509,7 +508,7 @@ def __init__( The name of the measurement, defining how the instance in the database will be identified. rank : int (default=1) - The rank of the measurement, defining the tensor order of the + The rank of the measurement, defining the tensor order of the measurement. public : bool (default=False) Boolean flag to indicate whether the measurement resutls will be