Skip to content

Retrieval metrics crash or use excessive memory with high-valued query indexesΒ #3290

@ramon-adalia-lmd

Description

@ramon-adalia-lmd

πŸ› Bug

The retrieval metrics (RetrievalMAP, RetrievalRecall, etc.) crash or allocate excessive memory when the indexes tensor contains sparse or high-valued integers, even if the number of unique queries is small.

This is because torchmetrics.utilities.data._bincount() relies on index.max() to determine the size of internal tensors. When deterministic mode is enabled (or on XLA/MPS), the fallback implementation can allocate massive [len(indexes), index.max()] tensors, leading to out-of-memory (OOM) errors.

βœ… Expected behavior

The metrics should group predictions by query regardless of the numerical values in indexes. The actual values shouldn't impact performance or memory.


To Reproduce

Steps:

  1. Simulate a retrieval task with a few queries using high-value IDs
  2. Use RetrievalMAP() or similar
  3. Call .update() and .compute()
Code sample
import torch
from torchmetrics.retrieval import RetrievalMAP

# Simulate predictions and labels for 3 queries with sparse/high index values
preds = torch.tensor([0.2, 0.8, 0.4, 0.9, 0.1, 0.3])
target = torch.tensor([0, 1, 0, 1, 1, 0])
indexes = torch.tensor([1000, 1000, 50000, 50000, 90000000, 90000000])  # only 3 unique queries

# Enable deterministic mode (triggers fallback path)
torch.use_deterministic_algorithms(True)

metric = RetrievalMAP()
metric.update(preds, target, indexes)

# This line will likely cause a crash or massive memory use due to high index values
result = metric.compute()
print(result)

πŸ”₯ What happens

With torch.use_deterministic_algorithms(True) enabled:

  • The _bincount() fallback tries to allocate a tensor of shape [len(indexes), index.max()] = [6, 90000001]
  • This results in >67 GB memory allocation and often crashes

Without deterministic mode, torch.bincount() is used directly, which also scales poorly if index.max() is large.


Environment
  • TorchMetrics version: 1.8.1 (reproduced on latest)
  • Python version: 3.12.10
  • PyTorch version: 2.8.0
  • OS: Ubuntu 22.04 / Windows 11
  • Device: CPU and CUDA (but also applies to MPS/XLA)

Additional context

This is especially common in real-world retrieval problems where indexes come from:

  • Row numbers or IDs in large datasets
  • Sparse query IDs (e.g., from database keys)

Since the metric only uses indexes to group elements, their actual values are irrelevant β€” only equality matters.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bug / fixSomething isn't workinghelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions