Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BinaryPrecisionRecallCurve for large datasets (>100 million samples) #1309

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
Next Next commit
first version
jpcbertoldo committed Nov 2, 2022
commit 48188d568281250dda63b1e4655a34094d175b3a
105 changes: 104 additions & 1 deletion src/torchmetrics/classification/precision_recall_curve.py
Original file line number Diff line number Diff line change
@@ -11,8 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from enum import Enum
from typing import Any, List, Optional, Tuple, Union

import numpy as np
import scipy
import torch
from torch import Tensor
from typing_extensions import Literal
@@ -41,6 +45,59 @@
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import dim_zero_cat

DYNAMIC_THRESHOLDS_MODE_STR = "dynamic"
_DYNAMIC_THRESHOLDS_NBINS = 3 * 10**3
_DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN = 4 * 10**3
_DYNAMIC_THRESHOLDS_MIN_NSAMPLES = _DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN * _DYNAMIC_THRESHOLDS_NBINS
_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION = 12 * 10**6


def _validate_budget(budget: Optional[int]) -> int:
if budget is None:
raise ValueError("Budget must be specified when using dynamic thresholds mode.")

if budget <= 0:
raise ValueError("Budget must be larger than 0.")

if budget <= _DYNAMIC_THRESHOLDS_MIN_NSAMPLES:
warnings.warn(
f"Budget is small ({budget/10**6:.3f} million samples) to use dynamic thresholds mode. "
"This mode is recommended for a number of samples larger than "
f"{_DYNAMIC_THRESHOLDS_MIN_NSAMPLES/10**6:.3f} million samples."
)

if budget > _DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION:
warnings.warn(
f"Budget is {budget/10**6:.3f} million samples but the dynamic thresholds mode samples "
f"{_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION/10**6:.3f} million samples "
"to estimate the thresholds."
)

return budget


def _estimate_threhsholds(preds: Tensor) -> Tensor:
global _DYNAMIC_THRESHOLDS_NBINS, _DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN
npreds = preds.numel()

# sample from the predictions if there are too many (computation of mquantiles can be very slow)
if npreds > _DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION:
indices = torch.randperm(npreds, device=preds.device)
indices = indices[:_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION]
preds_q_estimation = preds[indices]
else:
preds_q_estimation = preds

preds_q_estimation = preds_q_estimation.cpu()

thresholds = scipy.stats.mstats.mquantiles(
preds_q_estimation, # it has to be on the CPU
prob=np.linspace(0, 1, _DYNAMIC_THRESHOLDS_NBINS),
)

# remove the min/max so lower/higher values will go to the first/last "bin"
return thresholds[1:-1]


class BinaryPrecisionRecallCurve(Metric):
r"""Computes the precision-recall curve for binary tasks. The curve consist of multiple pairs of precision and
@@ -104,11 +161,19 @@ class BinaryPrecisionRecallCurve(Metric):
higher_is_better: Optional[bool] = None
full_state_update: bool = False

class _DynamicModeState(Enum):
"""Internal state of the dynamic mode."""

NONE = "none"
NON_BINNED = "non_binned"
BINNED = "binned"

def __init__(
self,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
budget: Optional[int] = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
@@ -119,7 +184,31 @@ def __init__(
self.validate_args = validate_args

thresholds = _adjust_threshold_arg(thresholds)
if thresholds is None:

if isinstance(thresholds, str) and thresholds != DYNAMIC_THRESHOLDS_MODE_STR:
raise ValueError(f"Invalid thresholds mode '{thresholds}'.")

self._dynamic_mode_state = self._DynamicModeState.NONE

if thresholds == DYNAMIC_THRESHOLDS_MODE_STR:

self.budget = _validate_budget(budget)
self._dynamic_mode_state = self._DynamicModeState.NON_BINNED

# they are deleted after the switch to binned mode
self.preds = []
self.target = []

# used after the switch to binned mode
self.register_buffer("thresholds", None)
self.add_state(
# "-1" here compenstes the lack min/max removed (see comments in update() to understand)
"confmat",
default=torch.zeros(_DYNAMIC_THRESHOLDS_NBINS - 1, 2, 2, dtype=torch.long),
dist_reduce_fx="sum",
)

elif thresholds is None:
self.thresholds = thresholds
self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")
@@ -140,6 +229,20 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
self.preds.append(state[0])
self.target.append(state[1])

if self._dynamic_mode_state != self._DynamicModeState.NON_BINNED:
return

all_preds = dim_zero_cat(self.preds)

if all_preds.numel() < self.budget:
return

# switch to binned mode
self.thresholds = _estimate_threhsholds(all_preds)
self.confmat = _binary_precision_recall_curve_update(all_preds, dim_zero_cat(self.target), self.thresholds)
del self.preds, self.target
self._dynamic_mode_state = self._DynamicModeState.BINNED

def compute(self) -> Tuple[Tensor, Tensor, Tensor]:
if self.thresholds is None:
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)]