Skip to content

Conversation

@PavanTummala
Copy link

@PavanTummala PavanTummala commented Aug 19, 2025

This pull request introduces a new boolean parameter from_logits to Metrax classification metrics, enabling users to pass raw model logits directly without manually converting them to probabilities.
It directly resolves issue #105

Background:

Currently, users must apply softmax activation on logits before passing predictions to metrics, which adds boilerplate and can lead to errors.

What’s New:

Added from_logits flag to classification metrics: Precision, Recall, F1Score, FBetaScore, Accuracy.

When from_logits=True, Metrax automatically applies:

  • Softmax for multi-class logits

  • Backward compatibility preserved (from_logits=False by default).

  • Test suites updated to cover both activated and raw logits inputs.

Benefits:

  • Simplifies user workflow by removing manual activation step.

  • Reduces common user errors with logits processing.

  • Improves consistency and usability.

Tests:

  • Added parameterized tests for from_logits=True scenarios.

  • Verified numerical equivalence with ground truth metrics.

  • Passed tests across multiple dtypes and classification settings.

@google-cla
Copy link

google-cla bot commented Aug 19, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Copy link
Collaborator

@jeffcarp jeffcarp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution Pavan! Left some feedback.

"""

if from_logits:
predictions = jax.nn.softmax(predictions, axis=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is inconsistent with the other metrics where _convert_logits_to_probabilities is called

ValueError: If type of `labels` is wrong or the shapes of `predictions`
and `labels` are incompatible.
"""
predictions = _convert_logits_to_probabilities(predictions, from_logits)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update these so it only calls _convert_logits_to_probabilities if from_logits is true?

@classmethod
def from_model_output(
cls, predictions: jax.Array, labels: jax.Array, threshold: float = 0.5
cls, predictions: jax.Array, labels: jax.Array, threshold: float = 0.5, from_logits: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the docstrings (here and below)?

raise ValueError('The "Threshold" value must be between 0 and 1.')

# If the predictions are logits, convert them to probabilities

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove extra newline

('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, 0.5),
)
def test_precision(self, y_true, y_pred, threshold):
def test_precision(self, y_true, y_pred, threshold,from_logits=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix spacing after comma

),
)
def test_aucpr(self, inputs, dtype):
def test_aucpr(self, inputs, dtype, from_logits=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove default here as well

def test_aucpr(self, inputs, dtype, from_logits=False):
"""Test that `AUC-PR` Metric computes correct values."""
y_true, y_pred, sample_weights = inputs
y_true, y_pred, sample_weights, from_logits = inputs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is shadowing the from_logits variable - dedupe


keras_aucpr = keras.metrics.AUC(curve='PR')
if from_logits:
y_pred = jax.nn.softmax(y_pred, axis=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please match 2 spaces per indentation style here and everywhere else

def test_aucroc(self, inputs, dtype, from_logits=False):
"""Test that `AUC-ROC` Metric computes correct values."""
y_true, y_pred, sample_weights = inputs
y_true, y_pred, sample_weights,from_logits = inputs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix spacing

),
)
def test_aucroc(self, inputs, dtype):
def test_aucroc(self, inputs, dtype, from_logits=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove default arg

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants