Skip to content

Incorrect threshold indices in _binary_clf_curve due to F.pad on mpsΒ #3289

@ashwinvaidya17

Description

@ashwinvaidya17

πŸ› Bug

Hi, I am one of the maintainer of the Anomalib library.
We use torchmetric's BinaryPrecisionRecallCurve to compute F1AdaptiveThreshold.

We observe that F.pad produces corrupted/garbage values when padding int64 tensors on MPS device. This affects the _binary_clf_curve function

threshold_idxs = F.pad(distinct_value_indices, [0, 1], value=target.size(0) - 1)

It leads to corrupted threshold_idxs which leads tps having 0's at the end

For example

# After the corrupted padding is used for indexing:
tps = _cumsum(target * weight, dim=0)[threshold_idxs]
# Result: tensor([314511., 314511., 314511.,  ...,      0.,      0.,      0.], device='mps:0')

Which in turn leads to nan values in recall.

To Reproduce

Code sample
import torch
import torch.nn.functional as F

# Create int64 tensor on MPS
distinct_indices = torch.linspace(0, 5439486, 4921573, device="mps", dtype=torch.int64)
target = 5439487

# Attempt to pad
thresh_idxs_mps = F.pad(distinct_indices, [0, 1], value=target)
print("MPS result (last 10 values):", thresh_idxs_mps[-10:])

# Compare with CPU
thresh_idxs_cpu = F.pad(distinct_indices.to("cpu"), [0, 1], value=target)
print("CPU result (last 10 values):", thresh_idxs_cpu[-10:])
Environment
  • TorchMetrics version (if build from source, add commit SHA): 1.8.2
  • Python & PyTorch Version (e.g., 1.0): 2.8.0 (PyTorch) & 3.10.17 (Python)
  • Any other relevant information such as OS (e.g., Linux): macOS 26.0.1 (arm64)
  • Apple Silicon (M1/M2/M3) with MPS backend

Additional context

Expected Behavior

CPU (correct):

tensor([      0,       1,       2,  ..., 5439485, 5439486, 5439487])

The tensor is properly padded with the target value 5439487 at the end.

Actual Behavior

MPS (incorrect):

The behavior is inconsistent and produces corrupted values. Observed behaviors include:

Scenario 1:

tensor([5439487, 5439487, 5439487,  ..., 0, 0, 0], device='mps:0')

The entire tensor is corrupted - original values are replaced with the padding value or zeros.

Scenario 2:

tensor([5439487, 5439487, 5439487,  ..., 4594404260803942997, 4596536720656838091, 5439487],
       device='mps:0')

Some positions contain garbage/overflow values that are far outside the expected range.

Current workaround

Our current workaround is to move the predictions and targets to CPU before computing on Mac.
https://github.com/open-edge-platform/anomalib/blob/163f8bc653df1f0c893b62302688bf8e4fc00c12/src/anomalib/metrics/threshold/f1_adaptive_threshold.py#L51

Solution

This seems to be originating from Torch's MPS backend implementation however it is nice if the community is aware of this edge case.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bug / fixSomething isn't workinghelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions