Skip to content

Commit

Permalink
first version
Browse files Browse the repository at this point in the history
  • Loading branch information
jpcbertoldo committed Nov 2, 2022
1 parent 8b02cd5 commit 48188d5
Showing 1 changed file with 104 additions and 1 deletion.
105 changes: 104 additions & 1 deletion src/torchmetrics/classification/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from enum import Enum
from typing import Any, List, Optional, Tuple, Union

import numpy as np
import scipy
import torch
from torch import Tensor
from typing_extensions import Literal
Expand Down Expand Up @@ -41,6 +45,59 @@
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import dim_zero_cat

DYNAMIC_THRESHOLDS_MODE_STR = "dynamic"
_DYNAMIC_THRESHOLDS_NBINS = 3 * 10**3
_DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN = 4 * 10**3
_DYNAMIC_THRESHOLDS_MIN_NSAMPLES = _DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN * _DYNAMIC_THRESHOLDS_NBINS
_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION = 12 * 10**6


def _validate_budget(budget: Optional[int]) -> int:
if budget is None:
raise ValueError("Budget must be specified when using dynamic thresholds mode.")

if budget <= 0:
raise ValueError("Budget must be larger than 0.")

if budget <= _DYNAMIC_THRESHOLDS_MIN_NSAMPLES:
warnings.warn(
f"Budget is small ({budget/10**6:.3f} million samples) to use dynamic thresholds mode. "
"This mode is recommended for a number of samples larger than "
f"{_DYNAMIC_THRESHOLDS_MIN_NSAMPLES/10**6:.3f} million samples."
)

if budget > _DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION:
warnings.warn(
f"Budget is {budget/10**6:.3f} million samples but the dynamic thresholds mode samples "
f"{_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION/10**6:.3f} million samples "
"to estimate the thresholds."
)

return budget


def _estimate_threhsholds(preds: Tensor) -> Tensor:
global _DYNAMIC_THRESHOLDS_NBINS, _DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN
npreds = preds.numel()

# sample from the predictions if there are too many (computation of mquantiles can be very slow)
if npreds > _DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION:
indices = torch.randperm(npreds, device=preds.device)
indices = indices[:_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION]
preds_q_estimation = preds[indices]
else:
preds_q_estimation = preds

preds_q_estimation = preds_q_estimation.cpu()

thresholds = scipy.stats.mstats.mquantiles(
preds_q_estimation, # it has to be on the CPU
prob=np.linspace(0, 1, _DYNAMIC_THRESHOLDS_NBINS),
)

# remove the min/max so lower/higher values will go to the first/last "bin"
return thresholds[1:-1]


class BinaryPrecisionRecallCurve(Metric):
r"""Computes the precision-recall curve for binary tasks. The curve consist of multiple pairs of precision and
Expand Down Expand Up @@ -104,11 +161,19 @@ class BinaryPrecisionRecallCurve(Metric):
higher_is_better: Optional[bool] = None
full_state_update: bool = False

class _DynamicModeState(Enum):
"""Internal state of the dynamic mode."""

NONE = "none"
NON_BINNED = "non_binned"
BINNED = "binned"

def __init__(
self,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
budget: Optional[int] = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -119,7 +184,31 @@ def __init__(
self.validate_args = validate_args

thresholds = _adjust_threshold_arg(thresholds)
if thresholds is None:

if isinstance(thresholds, str) and thresholds != DYNAMIC_THRESHOLDS_MODE_STR:
raise ValueError(f"Invalid thresholds mode '{thresholds}'.")

self._dynamic_mode_state = self._DynamicModeState.NONE

if thresholds == DYNAMIC_THRESHOLDS_MODE_STR:

self.budget = _validate_budget(budget)
self._dynamic_mode_state = self._DynamicModeState.NON_BINNED

# they are deleted after the switch to binned mode
self.preds = []
self.target = []

# used after the switch to binned mode
self.register_buffer("thresholds", None)
self.add_state(
# "-1" here compenstes the lack min/max removed (see comments in update() to understand)
"confmat",
default=torch.zeros(_DYNAMIC_THRESHOLDS_NBINS - 1, 2, 2, dtype=torch.long),
dist_reduce_fx="sum",
)

elif thresholds is None:
self.thresholds = thresholds
self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")
Expand All @@ -140,6 +229,20 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
self.preds.append(state[0])
self.target.append(state[1])

if self._dynamic_mode_state != self._DynamicModeState.NON_BINNED:
return

all_preds = dim_zero_cat(self.preds)

if all_preds.numel() < self.budget:
return

# switch to binned mode
self.thresholds = _estimate_threhsholds(all_preds)
self.confmat = _binary_precision_recall_curve_update(all_preds, dim_zero_cat(self.target), self.thresholds)
del self.preds, self.target
self._dynamic_mode_state = self._DynamicModeState.BINNED

def compute(self) -> Tuple[Tensor, Tensor, Tensor]:
if self.thresholds is None:
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)]
Expand Down

0 comments on commit 48188d5

Please sign in to comment.