diff --git a/papyrus/measurements/base_measurement.py b/papyrus/measurements/base_measurement.py index 5630c30..e9bc8d4 100644 --- a/papyrus/measurements/base_measurement.py +++ b/papyrus/measurements/base_measurement.py @@ -23,6 +23,7 @@ from abc import ABC from inspect import signature +from typing import List import numpy as np @@ -81,7 +82,7 @@ def __init__(self, name: str, rank: int, public: bool = False): raise ValueError("Rank must be a positive integer.") # Get the neural state keys that the measurement takes as input - self.neural_state_keys = [] + self.neural_state_keys: List[str] = [] self.neural_state_keys.extend(signature(self.apply).parameters.keys()) def apply(self, *args: np.ndarray, **kwargs: np.ndarray) -> np.ndarray: