From ae13f235b6d66e4dd1b7e98acc9beac23a3f951d Mon Sep 17 00:00:00 2001 From: knikolaou Date: Fri, 10 May 2024 21:03:54 +0200 Subject: [PATCH] declare type hint of neural state --- papyrus/measurements/base_measurement.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/papyrus/measurements/base_measurement.py b/papyrus/measurements/base_measurement.py index 5630c30..e9bc8d4 100644 --- a/papyrus/measurements/base_measurement.py +++ b/papyrus/measurements/base_measurement.py @@ -23,6 +23,7 @@ from abc import ABC from inspect import signature +from typing import List import numpy as np @@ -81,7 +82,7 @@ def __init__(self, name: str, rank: int, public: bool = False): raise ValueError("Rank must be a positive integer.") # Get the neural state keys that the measurement takes as input - self.neural_state_keys = [] + self.neural_state_keys: List[str] = [] self.neural_state_keys.extend(signature(self.apply).parameters.keys()) def apply(self, *args: np.ndarray, **kwargs: np.ndarray) -> np.ndarray: