Accuracy doesn't work for simple binary classifier #4456
Labels
Priority: P2 - no schedule
Best effort response and resolution. We have no plan to work on this at the moment.
Status: pull requests welcome
We agree with the direction proposed, feel free to give it a shot and file a pull request
Hello,
This is quite a simple issue but the function
nnx.metric.Accuracy
only works for a network that returns more than one output, i.e.super().update(values=(logits.argmax(axis=-1) == labels))
. Not sure if this is by design.I was playing around with a simple binary classifier than only returned a single logit per observation. The Accuracy function gives nonsense for this case.
Anyway, might be worth noting in the doc string what the function expects.
The text was updated successfully, but these errors were encountered: