-
Notifications
You must be signed in to change notification settings - Fork 10
Added from_logits flag #109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution Pavan! Left some feedback.
| """ | ||
|
|
||
| if from_logits: | ||
| predictions = jax.nn.softmax(predictions, axis=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is inconsistent with the other metrics where _convert_logits_to_probabilities is called
| ValueError: If type of `labels` is wrong or the shapes of `predictions` | ||
| and `labels` are incompatible. | ||
| """ | ||
| predictions = _convert_logits_to_probabilities(predictions, from_logits) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you update these so it only calls _convert_logits_to_probabilities if from_logits is true?
| @classmethod | ||
| def from_model_output( | ||
| cls, predictions: jax.Array, labels: jax.Array, threshold: float = 0.5 | ||
| cls, predictions: jax.Array, labels: jax.Array, threshold: float = 0.5, from_logits: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you update the docstrings (here and below)?
| raise ValueError('The "Threshold" value must be between 0 and 1.') | ||
|
|
||
| # If the predictions are logits, convert them to probabilities | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove extra newline
| ('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, 0.5), | ||
| ) | ||
| def test_precision(self, y_true, y_pred, threshold): | ||
| def test_precision(self, y_true, y_pred, threshold,from_logits=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix spacing after comma
| ), | ||
| ) | ||
| def test_aucpr(self, inputs, dtype): | ||
| def test_aucpr(self, inputs, dtype, from_logits=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove default here as well
| def test_aucpr(self, inputs, dtype, from_logits=False): | ||
| """Test that `AUC-PR` Metric computes correct values.""" | ||
| y_true, y_pred, sample_weights = inputs | ||
| y_true, y_pred, sample_weights, from_logits = inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is shadowing the from_logits variable - dedupe
|
|
||
| keras_aucpr = keras.metrics.AUC(curve='PR') | ||
| if from_logits: | ||
| y_pred = jax.nn.softmax(y_pred, axis=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please match 2 spaces per indentation style here and everywhere else
| def test_aucroc(self, inputs, dtype, from_logits=False): | ||
| """Test that `AUC-ROC` Metric computes correct values.""" | ||
| y_true, y_pred, sample_weights = inputs | ||
| y_true, y_pred, sample_weights,from_logits = inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix spacing
| ), | ||
| ) | ||
| def test_aucroc(self, inputs, dtype): | ||
| def test_aucroc(self, inputs, dtype, from_logits=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove default arg
This pull request introduces a new boolean parameter from_logits to Metrax classification metrics, enabling users to pass raw model logits directly without manually converting them to probabilities.
It directly resolves issue #105
Background:
Currently, users must apply softmax activation on logits before passing predictions to metrics, which adds boilerplate and can lead to errors.
What’s New:
Added from_logits flag to classification metrics: Precision, Recall, F1Score, FBetaScore, Accuracy.
When from_logits=True, Metrax automatically applies:
Softmax for multi-class logits
Backward compatibility preserved (from_logits=False by default).
Test suites updated to cover both activated and raw logits inputs.
Benefits:
Simplifies user workflow by removing manual activation step.
Reduces common user errors with logits processing.
Improves consistency and usability.
Tests:
Added parameterized tests for from_logits=True scenarios.
Verified numerical equivalence with ground truth metrics.
Passed tests across multiple dtypes and classification settings.