Skip to content

Commit

Permalink
[nnx] fix MultiMetric typing
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jan 16, 2025
1 parent 1961c12 commit ca4dc8c
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion flax/nnx/training/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def update(self, **updates) -> None:
for metric_name in self._metric_names:
getattr(self, metric_name).update(**updates)

def compute(self) -> dict[str, Metric]:
def compute(self) -> dict[str, tp.Any]:
"""Compute and return the value of all underlying ``Metric``'s. This method
will return a dictionary, mapping strings (defined by the key-word arguments
``**metrics`` passed to the constructor) to the corresponding metric value.
Expand Down

0 comments on commit ca4dc8c

Please sign in to comment.