-
Notifications
You must be signed in to change notification settings - Fork 466
Description
π Bug
The retrieval metrics (RetrievalMAP
, RetrievalRecall
, etc.) crash or allocate excessive memory when the indexes
tensor contains sparse or high-valued integers, even if the number of unique queries is small.
This is because torchmetrics.utilities.data._bincount()
relies on index.max()
to determine the size of internal tensors. When deterministic mode is enabled (or on XLA/MPS), the fallback implementation can allocate massive [len(indexes), index.max()]
tensors, leading to out-of-memory (OOM) errors.
β Expected behavior
The metrics should group predictions by query regardless of the numerical values in indexes
. The actual values shouldn't impact performance or memory.
To Reproduce
Steps:
- Simulate a retrieval task with a few queries using high-value IDs
- Use
RetrievalMAP()
or similar - Call
.update()
and.compute()
Code sample
import torch
from torchmetrics.retrieval import RetrievalMAP
# Simulate predictions and labels for 3 queries with sparse/high index values
preds = torch.tensor([0.2, 0.8, 0.4, 0.9, 0.1, 0.3])
target = torch.tensor([0, 1, 0, 1, 1, 0])
indexes = torch.tensor([1000, 1000, 50000, 50000, 90000000, 90000000]) # only 3 unique queries
# Enable deterministic mode (triggers fallback path)
torch.use_deterministic_algorithms(True)
metric = RetrievalMAP()
metric.update(preds, target, indexes)
# This line will likely cause a crash or massive memory use due to high index values
result = metric.compute()
print(result)
π₯ What happens
With torch.use_deterministic_algorithms(True)
enabled:
- The
_bincount()
fallback tries to allocate a tensor of shape[len(indexes), index.max()] = [6, 90000001]
- This results in >67 GB memory allocation and often crashes
Without deterministic mode, torch.bincount()
is used directly, which also scales poorly if index.max()
is large.
Environment
- TorchMetrics version: 1.8.1 (reproduced on latest)
- Python version: 3.12.10
- PyTorch version: 2.8.0
- OS: Ubuntu 22.04 / Windows 11
- Device: CPU and CUDA (but also applies to MPS/XLA)
Additional context
This is especially common in real-world retrieval problems where indexes
come from:
- Row numbers or IDs in large datasets
- Sparse query IDs (e.g., from database keys)
Since the metric only uses indexes
to group elements, their actual values are irrelevant β only equality matters.