-
Notifications
You must be signed in to change notification settings - Fork 466
Description
π Bug
Hi, I am one of the maintainer of the Anomalib library.
We use torchmetric's BinaryPrecisionRecallCurve
to compute F1AdaptiveThreshold.
We observe that F.pad
produces corrupted/garbage values when padding int64 tensors on MPS device. This affects the _binary_clf_curve
function
torchmetrics/src/torchmetrics/functional/classification/precision_recall_curve.py
Line 71 in 122a852
threshold_idxs = F.pad(distinct_value_indices, [0, 1], value=target.size(0) - 1) |
It leads to corrupted threshold_idxs
which leads tps
having 0's at the end
For example
# After the corrupted padding is used for indexing:
tps = _cumsum(target * weight, dim=0)[threshold_idxs]
# Result: tensor([314511., 314511., 314511., ..., 0., 0., 0.], device='mps:0')
Which in turn leads to nan
values in recall.
torchmetrics/src/torchmetrics/functional/classification/precision_recall_curve.py
Line 278 in 122a852
recall = tps / tps[-1] |
To Reproduce
Code sample
import torch
import torch.nn.functional as F
# Create int64 tensor on MPS
distinct_indices = torch.linspace(0, 5439486, 4921573, device="mps", dtype=torch.int64)
target = 5439487
# Attempt to pad
thresh_idxs_mps = F.pad(distinct_indices, [0, 1], value=target)
print("MPS result (last 10 values):", thresh_idxs_mps[-10:])
# Compare with CPU
thresh_idxs_cpu = F.pad(distinct_indices.to("cpu"), [0, 1], value=target)
print("CPU result (last 10 values):", thresh_idxs_cpu[-10:])
Environment
- TorchMetrics version (if build from source, add commit SHA): 1.8.2
- Python & PyTorch Version (e.g., 1.0): 2.8.0 (PyTorch) & 3.10.17 (Python)
- Any other relevant information such as OS (e.g., Linux): macOS 26.0.1 (arm64)
- Apple Silicon (M1/M2/M3) with MPS backend
Additional context
Expected Behavior
CPU (correct):
tensor([ 0, 1, 2, ..., 5439485, 5439486, 5439487])
The tensor is properly padded with the target value 5439487
at the end.
Actual Behavior
MPS (incorrect):
The behavior is inconsistent and produces corrupted values. Observed behaviors include:
Scenario 1:
tensor([5439487, 5439487, 5439487, ..., 0, 0, 0], device='mps:0')
The entire tensor is corrupted - original values are replaced with the padding value or zeros.
Scenario 2:
tensor([5439487, 5439487, 5439487, ..., 4594404260803942997, 4596536720656838091, 5439487],
device='mps:0')
Some positions contain garbage/overflow values that are far outside the expected range.
Current workaround
Our current workaround is to move the predictions and targets to CPU before computing on Mac.
https://github.com/open-edge-platform/anomalib/blob/163f8bc653df1f0c893b62302688bf8e4fc00c12/src/anomalib/metrics/threshold/f1_adaptive_threshold.py#L51
Solution
This seems to be originating from Torch's MPS backend implementation however it is nice if the community is aware of this edge case.