You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
return {
f'{metric_name}': getattr(self, metric_name).compute()
for metric_name in self._metric_names
}
meaning that the second type parameter of the return type should be the return type of Metric.compute(). Metric.compute itself does not specify a return type, or even if it returns anything. The provided implementations (besides Multimetric itself!) return jax.Array and nnx.training.metrics.Statistics objects, suggesting (reasonably) that no restriction is placed on the return.
It looks like:
the return type of Metric.compute should be either Any or Optional[Any] depending on whether Metric.compute() has to return something (I would assume it does?).
the return type of MultiMetric.compute should be dict[str, Any] or dict[str, Optional[Any]).
System information
Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: 0.10.2
The text was updated successfully, but these errors were encountered:
The type signature of
Multimetric.compute
seems to be incorrect. It is currently:Its implementation is
meaning that the second type parameter of the return type should be the return type of
Metric.compute()
.Metric.compute
itself does not specify a return type, or even if it returns anything. The provided implementations (besidesMultimetric
itself!) returnjax.Array
andnnx.training.metrics.Statistics
objects, suggesting (reasonably) that no restriction is placed on the return.It looks like:
Metric.compute
should be eitherAny
orOptional[Any]
depending on whetherMetric.compute()
has to return something (I would assume it does?).MultiMetric.compute
should bedict[str, Any]
ordict[str, Optional[Any])
.System information
pip show flax jax jaxlib
: 0.10.2The text was updated successfully, but these errors were encountered: