diff --git a/flax/nnx/training/metrics.py b/flax/nnx/training/metrics.py index 4facf42787..1804490228 100644 --- a/flax/nnx/training/metrics.py +++ b/flax/nnx/training/metrics.py @@ -373,7 +373,7 @@ def update(self, **updates) -> None: for metric_name in self._metric_names: getattr(self, metric_name).update(**updates) - def compute(self) -> dict[str, Metric]: + def compute(self) -> dict[str, tp.Any]: """Compute and return the value of all underlying ``Metric``'s. This method will return a dictionary, mapping strings (defined by the key-word arguments ``**metrics`` passed to the constructor) to the corresponding metric value.