Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect type annotation on Multimetric.compute() #4481

Closed
gabbard opened this issue Jan 13, 2025 · 1 comment · Fixed by #4485
Closed

Incorrect type annotation on Multimetric.compute() #4481

gabbard opened this issue Jan 13, 2025 · 1 comment · Fixed by #4485

Comments

@gabbard
Copy link

gabbard commented Jan 13, 2025

The type signature of Multimetric.compute seems to be incorrect. It is currently:

def compute(self) -> dict[str, Metric]:

Its implementation is

return {
        f'{metric_name}': getattr(self, metric_name).compute()
        for metric_name in self._metric_names
    }

meaning that the second type parameter of the return type should be the return type of Metric.compute(). Metric.compute itself does not specify a return type, or even if it returns anything. The provided implementations (besides Multimetric itself!) return jax.Array and nnx.training.metrics.Statistics objects, suggesting (reasonably) that no restriction is placed on the return.

It looks like:

  • the return type of Metric.compute should be either Any or Optional[Any] depending on whether Metric.compute() has to return something (I would assume it does?).
  • the return type of MultiMetric.compute should be dict[str, Any] or dict[str, Optional[Any]).

System information

  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: 0.10.2
@cgarciae
Copy link
Collaborator

Thanks @gabbard, I think it should be dict[str, Any]. Created #4485 to fix this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants