From 48188d568281250dda63b1e4655a34094d175b3a Mon Sep 17 00:00:00 2001 From: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> Date: Thu, 3 Nov 2022 00:02:42 +0100 Subject: [PATCH 1/8] first version --- .../classification/precision_recall_curve.py | 105 +++++++++++++++++- 1 file changed, 104 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index 8615d4cf51a..2aacff9f772 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -11,8 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings +from enum import Enum from typing import Any, List, Optional, Tuple, Union +import numpy as np +import scipy import torch from torch import Tensor from typing_extensions import Literal @@ -41,6 +45,59 @@ from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat +DYNAMIC_THRESHOLDS_MODE_STR = "dynamic" +_DYNAMIC_THRESHOLDS_NBINS = 3 * 10**3 +_DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN = 4 * 10**3 +_DYNAMIC_THRESHOLDS_MIN_NSAMPLES = _DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN * _DYNAMIC_THRESHOLDS_NBINS +_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION = 12 * 10**6 + + +def _validate_budget(budget: Optional[int]) -> int: + if budget is None: + raise ValueError("Budget must be specified when using dynamic thresholds mode.") + + if budget <= 0: + raise ValueError("Budget must be larger than 0.") + + if budget <= _DYNAMIC_THRESHOLDS_MIN_NSAMPLES: + warnings.warn( + f"Budget is small ({budget/10**6:.3f} million samples) to use dynamic thresholds mode. " + "This mode is recommended for a number of samples larger than " + f"{_DYNAMIC_THRESHOLDS_MIN_NSAMPLES/10**6:.3f} million samples." + ) + + if budget > _DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION: + warnings.warn( + f"Budget is {budget/10**6:.3f} million samples but the dynamic thresholds mode samples " + f"{_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION/10**6:.3f} million samples " + "to estimate the thresholds." + ) + + return budget + + +def _estimate_threhsholds(preds: Tensor) -> Tensor: + global _DYNAMIC_THRESHOLDS_NBINS, _DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN + npreds = preds.numel() + + # sample from the predictions if there are too many (computation of mquantiles can be very slow) + if npreds > _DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION: + indices = torch.randperm(npreds, device=preds.device) + indices = indices[:_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION] + preds_q_estimation = preds[indices] + else: + preds_q_estimation = preds + + preds_q_estimation = preds_q_estimation.cpu() + + thresholds = scipy.stats.mstats.mquantiles( + preds_q_estimation, # it has to be on the CPU + prob=np.linspace(0, 1, _DYNAMIC_THRESHOLDS_NBINS), + ) + + # remove the min/max so lower/higher values will go to the first/last "bin" + return thresholds[1:-1] + class BinaryPrecisionRecallCurve(Metric): r"""Computes the precision-recall curve for binary tasks. The curve consist of multiple pairs of precision and @@ -104,11 +161,19 @@ class BinaryPrecisionRecallCurve(Metric): higher_is_better: Optional[bool] = None full_state_update: bool = False + class _DynamicModeState(Enum): + """Internal state of the dynamic mode.""" + + NONE = "none" + NON_BINNED = "non_binned" + BINNED = "binned" + def __init__( self, thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + budget: Optional[int] = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -119,7 +184,31 @@ def __init__( self.validate_args = validate_args thresholds = _adjust_threshold_arg(thresholds) - if thresholds is None: + + if isinstance(thresholds, str) and thresholds != DYNAMIC_THRESHOLDS_MODE_STR: + raise ValueError(f"Invalid thresholds mode '{thresholds}'.") + + self._dynamic_mode_state = self._DynamicModeState.NONE + + if thresholds == DYNAMIC_THRESHOLDS_MODE_STR: + + self.budget = _validate_budget(budget) + self._dynamic_mode_state = self._DynamicModeState.NON_BINNED + + # they are deleted after the switch to binned mode + self.preds = [] + self.target = [] + + # used after the switch to binned mode + self.register_buffer("thresholds", None) + self.add_state( + # "-1" here compenstes the lack min/max removed (see comments in update() to understand) + "confmat", + default=torch.zeros(_DYNAMIC_THRESHOLDS_NBINS - 1, 2, 2, dtype=torch.long), + dist_reduce_fx="sum", + ) + + elif thresholds is None: self.thresholds = thresholds self.add_state("preds", default=[], dist_reduce_fx="cat") self.add_state("target", default=[], dist_reduce_fx="cat") @@ -140,6 +229,20 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.preds.append(state[0]) self.target.append(state[1]) + if self._dynamic_mode_state != self._DynamicModeState.NON_BINNED: + return + + all_preds = dim_zero_cat(self.preds) + + if all_preds.numel() < self.budget: + return + + # switch to binned mode + self.thresholds = _estimate_threhsholds(all_preds) + self.confmat = _binary_precision_recall_curve_update(all_preds, dim_zero_cat(self.target), self.thresholds) + del self.preds, self.target + self._dynamic_mode_state = self._DynamicModeState.BINNED + def compute(self) -> Tuple[Tensor, Tensor, Tensor]: if self.thresholds is None: state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] From 1e5f130c22c09f63d0cb07e9513c407a27168445 Mon Sep 17 00:00:00 2001 From: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> Date: Thu, 3 Nov 2022 01:02:40 +0100 Subject: [PATCH 2/8] update consts --- src/torchmetrics/classification/precision_recall_curve.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index 2aacff9f772..0093f44cd70 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -46,10 +46,10 @@ from torchmetrics.utilities.data import dim_zero_cat DYNAMIC_THRESHOLDS_MODE_STR = "dynamic" -_DYNAMIC_THRESHOLDS_NBINS = 3 * 10**3 -_DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN = 4 * 10**3 +_DYNAMIC_THRESHOLDS_NBINS = 10**4 +_DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN = 10**4 _DYNAMIC_THRESHOLDS_MIN_NSAMPLES = _DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN * _DYNAMIC_THRESHOLDS_NBINS -_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION = 12 * 10**6 +_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION = 10**7 def _validate_budget(budget: Optional[int]) -> int: From b98edf7c32543a3969e1a1e2c5e0d11804e2375f Mon Sep 17 00:00:00 2001 From: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> Date: Thu, 3 Nov 2022 14:29:32 +0100 Subject: [PATCH 3/8] simplify --- requirements/classification.txt | 1 + .../classification/precision_recall_curve.py | 110 +++++++----------- .../classification/precision_recall_curve.py | 14 ++- 3 files changed, 52 insertions(+), 73 deletions(-) create mode 100644 requirements/classification.txt diff --git a/requirements/classification.txt b/requirements/classification.txt new file mode 100644 index 00000000000..f5368c4974d --- /dev/null +++ b/requirements/classification.txt @@ -0,0 +1 @@ +humanfriendly diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index 0093f44cd70..5a0d78d6244 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -15,14 +15,14 @@ from enum import Enum from typing import Any, List, Optional, Tuple, Union -import numpy as np -import scipy +import humanfriendly import torch from torch import Tensor from typing_extensions import Literal from torchmetrics.functional.classification.precision_recall_curve import ( _adjust_threshold_arg, + _binary_clf_curve, _binary_precision_recall_curve_arg_validation, _binary_precision_recall_curve_compute, _binary_precision_recall_curve_format, @@ -45,60 +45,27 @@ from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat -DYNAMIC_THRESHOLDS_MODE_STR = "dynamic" -_DYNAMIC_THRESHOLDS_NBINS = 10**4 -_DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN = 10**4 -_DYNAMIC_THRESHOLDS_MIN_NSAMPLES = _DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN * _DYNAMIC_THRESHOLDS_NBINS -_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION = 10**7 +_DYNAMIC_THRESHOLDS_MIN_NSAMPLES = 1024**2 # 1MiB,TODO: find a better way to estimate a reasonable minimum -def _validate_budget(budget: Optional[int]) -> int: - if budget is None: - raise ValueError("Budget must be specified when using dynamic thresholds mode.") +def _budget_bytes_to_nsamples(budget_bytes: int): + # assume that both preds and target ("* 2") will be of size (N, 1) and of type float32 (4 bytes) + return budget_bytes / (2 * 4) + +def _validate_memory_budget(budget: int): if budget <= 0: raise ValueError("Budget must be larger than 0.") - if budget <= _DYNAMIC_THRESHOLDS_MIN_NSAMPLES: - warnings.warn( - f"Budget is small ({budget/10**6:.3f} million samples) to use dynamic thresholds mode. " - "This mode is recommended for a number of samples larger than " - f"{_DYNAMIC_THRESHOLDS_MIN_NSAMPLES/10**6:.3f} million samples." - ) - - if budget > _DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION: + if _budget_bytes_to_nsamples(budget) <= _DYNAMIC_THRESHOLDS_MIN_NSAMPLES: warnings.warn( - f"Budget is {budget/10**6:.3f} million samples but the dynamic thresholds mode samples " - f"{_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION/10**6:.3f} million samples " - "to estimate the thresholds." + f"Budget is relatively small ({humanfriendly.format_size(budget, binary=True)}). " + "The dynamic mode is recommended for bigger samples." ) return budget -def _estimate_threhsholds(preds: Tensor) -> Tensor: - global _DYNAMIC_THRESHOLDS_NBINS, _DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN - npreds = preds.numel() - - # sample from the predictions if there are too many (computation of mquantiles can be very slow) - if npreds > _DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION: - indices = torch.randperm(npreds, device=preds.device) - indices = indices[:_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION] - preds_q_estimation = preds[indices] - else: - preds_q_estimation = preds - - preds_q_estimation = preds_q_estimation.cpu() - - thresholds = scipy.stats.mstats.mquantiles( - preds_q_estimation, # it has to be on the CPU - prob=np.linspace(0, 1, _DYNAMIC_THRESHOLDS_NBINS), - ) - - # remove the min/max so lower/higher values will go to the first/last "bin" - return thresholds[1:-1] - - class BinaryPrecisionRecallCurve(Metric): r"""Computes the precision-recall curve for binary tasks. The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen. @@ -130,6 +97,7 @@ class BinaryPrecisionRecallCurve(Metric): - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as bins for the calculation. + - If set to a `str`, the value is interpreted as a memory budget and the dynamic mode approach is used. validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. @@ -161,19 +129,27 @@ class BinaryPrecisionRecallCurve(Metric): higher_is_better: Optional[bool] = None full_state_update: bool = False - class _DynamicModeState(Enum): + class _ComputationMode(Enum): """Internal state of the dynamic mode.""" - NONE = "none" - NON_BINNED = "non_binned" BINNED = "binned" + NON_BINNED = "non-binned" + NON_BINNED_DYNAMIC = "non-binned-dynamic" + + @staticmethod + def _deduce_computation_mode(thresholds: Optional[Union[int, List[float], Tensor, str]]) -> _ComputationMode: + if isinstance(thresholds, str): + return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED_DYNAMIC + elif thresholds is None: + return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED + else: + return BinaryPrecisionRecallCurve._ComputationMode.BINNED def __init__( self, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, List[float], Tensor, str]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, - budget: Optional[int] = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -183,30 +159,21 @@ def __init__( self.ignore_index = ignore_index self.validate_args = validate_args + self._computation_mode = self._deduce_computation_mode(thresholds) thresholds = _adjust_threshold_arg(thresholds) - if isinstance(thresholds, str) and thresholds != DYNAMIC_THRESHOLDS_MODE_STR: - raise ValueError(f"Invalid thresholds mode '{thresholds}'.") - - self._dynamic_mode_state = self._DynamicModeState.NONE - - if thresholds == DYNAMIC_THRESHOLDS_MODE_STR: - - self.budget = _validate_budget(budget) - self._dynamic_mode_state = self._DynamicModeState.NON_BINNED - - # they are deleted after the switch to binned mode - self.preds = [] - self.target = [] - + if self._computation_mode == self._ComputationMode.NON_BINNED_DYNAMIC: + self._memory_budget_bytes = _validate_memory_budget(thresholds) # used after the switch to binned mode self.register_buffer("thresholds", None) self.add_state( - # "-1" here compenstes the lack min/max removed (see comments in update() to understand) "confmat", - default=torch.zeros(_DYNAMIC_THRESHOLDS_NBINS - 1, 2, 2, dtype=torch.long), + default=torch.zeros(_budget_bytes_to_nsamples(self._memory_budget_bytes), 2, 2, dtype=torch.long), dist_reduce_fx="sum", ) + # they are deleted after the switch to binned mode + self.preds = [] + self.target = [] elif thresholds is None: self.thresholds = thresholds @@ -229,19 +196,22 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.preds.append(state[0]) self.target.append(state[1]) - if self._dynamic_mode_state != self._DynamicModeState.NON_BINNED: + if self._computation_mode != self._ComputationMode.NON_BINNED_DYNAMIC: return all_preds = dim_zero_cat(self.preds) + mem_used = all_preds.element_size() * all_preds.nelement() * 2 # 2 accounts for the target - if all_preds.numel() < self.budget: + if mem_used < self._memory_budget_bytes: return # switch to binned mode - self.thresholds = _estimate_threhsholds(all_preds) - self.confmat = _binary_precision_recall_curve_update(all_preds, dim_zero_cat(self.target), self.thresholds) + self.preds, self.target = all_preds, dim_zero_cat(self.target) + _, _, self.thresholds = _binary_clf_curve(self.preds, self.target) + # if the number of thr + self.confmat = _binary_precision_recall_curve_update(self.preds, self.target, self.thresholds) del self.preds, self.target - self._dynamic_mode_state = self._DynamicModeState.BINNED + self._computation_mode = self._ComputationMode.BINNED def compute(self) -> Tuple[Tensor, Tensor, Tensor]: if self.thresholds is None: diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index a505898f040..792ab5c306d 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -14,6 +14,7 @@ from typing import List, Optional, Sequence, Tuple, Union +import humandfriendly import torch from torch import Tensor, tensor from torch.nn import functional as F @@ -81,13 +82,20 @@ def _binary_clf_curve( def _adjust_threshold_arg( - thresholds: Optional[Union[int, List[float], Tensor]] = None, device: Optional[torch.device] = None -) -> Optional[Tensor]: - """Utility function for converting the threshold arg for list and int to tensor format.""" + thresholds: Optional[Union[int, List[float], Tensor, str]] = None, device: Optional[torch.device] = None +) -> Optional[Union[Tensor, int]]: + """Utility function for converting the threshold arg. + + - list and int -> tensor + - None -> None + - str -> int (memory budget) in Mb + """ if isinstance(thresholds, int): thresholds = torch.linspace(0, 1, thresholds, device=device) if isinstance(thresholds, list): thresholds = torch.tensor(thresholds, device=device) + if isinstance(thresholds, str): + thresholds = humandfriendly.parse_size(thresholds, binary=True) return thresholds From ca963bbe74e7c968cd4051f509309c20cbf96b41 Mon Sep 17 00:00:00 2001 From: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> Date: Thu, 3 Nov 2022 00:02:42 +0100 Subject: [PATCH 4/8] first version --- .../classification/precision_recall_curve.py | 105 +++++++++++++++++- 1 file changed, 104 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index 8615d4cf51a..2aacff9f772 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -11,8 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings +from enum import Enum from typing import Any, List, Optional, Tuple, Union +import numpy as np +import scipy import torch from torch import Tensor from typing_extensions import Literal @@ -41,6 +45,59 @@ from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat +DYNAMIC_THRESHOLDS_MODE_STR = "dynamic" +_DYNAMIC_THRESHOLDS_NBINS = 3 * 10**3 +_DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN = 4 * 10**3 +_DYNAMIC_THRESHOLDS_MIN_NSAMPLES = _DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN * _DYNAMIC_THRESHOLDS_NBINS +_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION = 12 * 10**6 + + +def _validate_budget(budget: Optional[int]) -> int: + if budget is None: + raise ValueError("Budget must be specified when using dynamic thresholds mode.") + + if budget <= 0: + raise ValueError("Budget must be larger than 0.") + + if budget <= _DYNAMIC_THRESHOLDS_MIN_NSAMPLES: + warnings.warn( + f"Budget is small ({budget/10**6:.3f} million samples) to use dynamic thresholds mode. " + "This mode is recommended for a number of samples larger than " + f"{_DYNAMIC_THRESHOLDS_MIN_NSAMPLES/10**6:.3f} million samples." + ) + + if budget > _DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION: + warnings.warn( + f"Budget is {budget/10**6:.3f} million samples but the dynamic thresholds mode samples " + f"{_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION/10**6:.3f} million samples " + "to estimate the thresholds." + ) + + return budget + + +def _estimate_threhsholds(preds: Tensor) -> Tensor: + global _DYNAMIC_THRESHOLDS_NBINS, _DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN + npreds = preds.numel() + + # sample from the predictions if there are too many (computation of mquantiles can be very slow) + if npreds > _DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION: + indices = torch.randperm(npreds, device=preds.device) + indices = indices[:_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION] + preds_q_estimation = preds[indices] + else: + preds_q_estimation = preds + + preds_q_estimation = preds_q_estimation.cpu() + + thresholds = scipy.stats.mstats.mquantiles( + preds_q_estimation, # it has to be on the CPU + prob=np.linspace(0, 1, _DYNAMIC_THRESHOLDS_NBINS), + ) + + # remove the min/max so lower/higher values will go to the first/last "bin" + return thresholds[1:-1] + class BinaryPrecisionRecallCurve(Metric): r"""Computes the precision-recall curve for binary tasks. The curve consist of multiple pairs of precision and @@ -104,11 +161,19 @@ class BinaryPrecisionRecallCurve(Metric): higher_is_better: Optional[bool] = None full_state_update: bool = False + class _DynamicModeState(Enum): + """Internal state of the dynamic mode.""" + + NONE = "none" + NON_BINNED = "non_binned" + BINNED = "binned" + def __init__( self, thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + budget: Optional[int] = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -119,7 +184,31 @@ def __init__( self.validate_args = validate_args thresholds = _adjust_threshold_arg(thresholds) - if thresholds is None: + + if isinstance(thresholds, str) and thresholds != DYNAMIC_THRESHOLDS_MODE_STR: + raise ValueError(f"Invalid thresholds mode '{thresholds}'.") + + self._dynamic_mode_state = self._DynamicModeState.NONE + + if thresholds == DYNAMIC_THRESHOLDS_MODE_STR: + + self.budget = _validate_budget(budget) + self._dynamic_mode_state = self._DynamicModeState.NON_BINNED + + # they are deleted after the switch to binned mode + self.preds = [] + self.target = [] + + # used after the switch to binned mode + self.register_buffer("thresholds", None) + self.add_state( + # "-1" here compenstes the lack min/max removed (see comments in update() to understand) + "confmat", + default=torch.zeros(_DYNAMIC_THRESHOLDS_NBINS - 1, 2, 2, dtype=torch.long), + dist_reduce_fx="sum", + ) + + elif thresholds is None: self.thresholds = thresholds self.add_state("preds", default=[], dist_reduce_fx="cat") self.add_state("target", default=[], dist_reduce_fx="cat") @@ -140,6 +229,20 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.preds.append(state[0]) self.target.append(state[1]) + if self._dynamic_mode_state != self._DynamicModeState.NON_BINNED: + return + + all_preds = dim_zero_cat(self.preds) + + if all_preds.numel() < self.budget: + return + + # switch to binned mode + self.thresholds = _estimate_threhsholds(all_preds) + self.confmat = _binary_precision_recall_curve_update(all_preds, dim_zero_cat(self.target), self.thresholds) + del self.preds, self.target + self._dynamic_mode_state = self._DynamicModeState.BINNED + def compute(self) -> Tuple[Tensor, Tensor, Tensor]: if self.thresholds is None: state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] From 6698d76f3c8e2e2aecf0bc8d9f445beea1f715f6 Mon Sep 17 00:00:00 2001 From: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> Date: Thu, 3 Nov 2022 01:02:40 +0100 Subject: [PATCH 5/8] update consts --- src/torchmetrics/classification/precision_recall_curve.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index 2aacff9f772..0093f44cd70 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -46,10 +46,10 @@ from torchmetrics.utilities.data import dim_zero_cat DYNAMIC_THRESHOLDS_MODE_STR = "dynamic" -_DYNAMIC_THRESHOLDS_NBINS = 3 * 10**3 -_DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN = 4 * 10**3 +_DYNAMIC_THRESHOLDS_NBINS = 10**4 +_DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN = 10**4 _DYNAMIC_THRESHOLDS_MIN_NSAMPLES = _DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN * _DYNAMIC_THRESHOLDS_NBINS -_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION = 12 * 10**6 +_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION = 10**7 def _validate_budget(budget: Optional[int]) -> int: From d5a8309ae54c879935b61a375016ac80a0fbaba6 Mon Sep 17 00:00:00 2001 From: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> Date: Thu, 3 Nov 2022 14:29:32 +0100 Subject: [PATCH 6/8] simplify --- requirements/classification.txt | 1 + .../classification/precision_recall_curve.py | 110 +++++++----------- .../classification/precision_recall_curve.py | 14 ++- 3 files changed, 52 insertions(+), 73 deletions(-) create mode 100644 requirements/classification.txt diff --git a/requirements/classification.txt b/requirements/classification.txt new file mode 100644 index 00000000000..f5368c4974d --- /dev/null +++ b/requirements/classification.txt @@ -0,0 +1 @@ +humanfriendly diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index 0093f44cd70..5a0d78d6244 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -15,14 +15,14 @@ from enum import Enum from typing import Any, List, Optional, Tuple, Union -import numpy as np -import scipy +import humanfriendly import torch from torch import Tensor from typing_extensions import Literal from torchmetrics.functional.classification.precision_recall_curve import ( _adjust_threshold_arg, + _binary_clf_curve, _binary_precision_recall_curve_arg_validation, _binary_precision_recall_curve_compute, _binary_precision_recall_curve_format, @@ -45,60 +45,27 @@ from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat -DYNAMIC_THRESHOLDS_MODE_STR = "dynamic" -_DYNAMIC_THRESHOLDS_NBINS = 10**4 -_DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN = 10**4 -_DYNAMIC_THRESHOLDS_MIN_NSAMPLES = _DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN * _DYNAMIC_THRESHOLDS_NBINS -_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION = 10**7 +_DYNAMIC_THRESHOLDS_MIN_NSAMPLES = 1024**2 # 1MiB,TODO: find a better way to estimate a reasonable minimum -def _validate_budget(budget: Optional[int]) -> int: - if budget is None: - raise ValueError("Budget must be specified when using dynamic thresholds mode.") +def _budget_bytes_to_nsamples(budget_bytes: int): + # assume that both preds and target ("* 2") will be of size (N, 1) and of type float32 (4 bytes) + return budget_bytes / (2 * 4) + +def _validate_memory_budget(budget: int): if budget <= 0: raise ValueError("Budget must be larger than 0.") - if budget <= _DYNAMIC_THRESHOLDS_MIN_NSAMPLES: - warnings.warn( - f"Budget is small ({budget/10**6:.3f} million samples) to use dynamic thresholds mode. " - "This mode is recommended for a number of samples larger than " - f"{_DYNAMIC_THRESHOLDS_MIN_NSAMPLES/10**6:.3f} million samples." - ) - - if budget > _DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION: + if _budget_bytes_to_nsamples(budget) <= _DYNAMIC_THRESHOLDS_MIN_NSAMPLES: warnings.warn( - f"Budget is {budget/10**6:.3f} million samples but the dynamic thresholds mode samples " - f"{_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION/10**6:.3f} million samples " - "to estimate the thresholds." + f"Budget is relatively small ({humanfriendly.format_size(budget, binary=True)}). " + "The dynamic mode is recommended for bigger samples." ) return budget -def _estimate_threhsholds(preds: Tensor) -> Tensor: - global _DYNAMIC_THRESHOLDS_NBINS, _DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN - npreds = preds.numel() - - # sample from the predictions if there are too many (computation of mquantiles can be very slow) - if npreds > _DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION: - indices = torch.randperm(npreds, device=preds.device) - indices = indices[:_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION] - preds_q_estimation = preds[indices] - else: - preds_q_estimation = preds - - preds_q_estimation = preds_q_estimation.cpu() - - thresholds = scipy.stats.mstats.mquantiles( - preds_q_estimation, # it has to be on the CPU - prob=np.linspace(0, 1, _DYNAMIC_THRESHOLDS_NBINS), - ) - - # remove the min/max so lower/higher values will go to the first/last "bin" - return thresholds[1:-1] - - class BinaryPrecisionRecallCurve(Metric): r"""Computes the precision-recall curve for binary tasks. The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen. @@ -130,6 +97,7 @@ class BinaryPrecisionRecallCurve(Metric): - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as bins for the calculation. + - If set to a `str`, the value is interpreted as a memory budget and the dynamic mode approach is used. validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. @@ -161,19 +129,27 @@ class BinaryPrecisionRecallCurve(Metric): higher_is_better: Optional[bool] = None full_state_update: bool = False - class _DynamicModeState(Enum): + class _ComputationMode(Enum): """Internal state of the dynamic mode.""" - NONE = "none" - NON_BINNED = "non_binned" BINNED = "binned" + NON_BINNED = "non-binned" + NON_BINNED_DYNAMIC = "non-binned-dynamic" + + @staticmethod + def _deduce_computation_mode(thresholds: Optional[Union[int, List[float], Tensor, str]]) -> _ComputationMode: + if isinstance(thresholds, str): + return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED_DYNAMIC + elif thresholds is None: + return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED + else: + return BinaryPrecisionRecallCurve._ComputationMode.BINNED def __init__( self, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, List[float], Tensor, str]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, - budget: Optional[int] = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -183,30 +159,21 @@ def __init__( self.ignore_index = ignore_index self.validate_args = validate_args + self._computation_mode = self._deduce_computation_mode(thresholds) thresholds = _adjust_threshold_arg(thresholds) - if isinstance(thresholds, str) and thresholds != DYNAMIC_THRESHOLDS_MODE_STR: - raise ValueError(f"Invalid thresholds mode '{thresholds}'.") - - self._dynamic_mode_state = self._DynamicModeState.NONE - - if thresholds == DYNAMIC_THRESHOLDS_MODE_STR: - - self.budget = _validate_budget(budget) - self._dynamic_mode_state = self._DynamicModeState.NON_BINNED - - # they are deleted after the switch to binned mode - self.preds = [] - self.target = [] - + if self._computation_mode == self._ComputationMode.NON_BINNED_DYNAMIC: + self._memory_budget_bytes = _validate_memory_budget(thresholds) # used after the switch to binned mode self.register_buffer("thresholds", None) self.add_state( - # "-1" here compenstes the lack min/max removed (see comments in update() to understand) "confmat", - default=torch.zeros(_DYNAMIC_THRESHOLDS_NBINS - 1, 2, 2, dtype=torch.long), + default=torch.zeros(_budget_bytes_to_nsamples(self._memory_budget_bytes), 2, 2, dtype=torch.long), dist_reduce_fx="sum", ) + # they are deleted after the switch to binned mode + self.preds = [] + self.target = [] elif thresholds is None: self.thresholds = thresholds @@ -229,19 +196,22 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.preds.append(state[0]) self.target.append(state[1]) - if self._dynamic_mode_state != self._DynamicModeState.NON_BINNED: + if self._computation_mode != self._ComputationMode.NON_BINNED_DYNAMIC: return all_preds = dim_zero_cat(self.preds) + mem_used = all_preds.element_size() * all_preds.nelement() * 2 # 2 accounts for the target - if all_preds.numel() < self.budget: + if mem_used < self._memory_budget_bytes: return # switch to binned mode - self.thresholds = _estimate_threhsholds(all_preds) - self.confmat = _binary_precision_recall_curve_update(all_preds, dim_zero_cat(self.target), self.thresholds) + self.preds, self.target = all_preds, dim_zero_cat(self.target) + _, _, self.thresholds = _binary_clf_curve(self.preds, self.target) + # if the number of thr + self.confmat = _binary_precision_recall_curve_update(self.preds, self.target, self.thresholds) del self.preds, self.target - self._dynamic_mode_state = self._DynamicModeState.BINNED + self._computation_mode = self._ComputationMode.BINNED def compute(self) -> Tuple[Tensor, Tensor, Tensor]: if self.thresholds is None: diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index a505898f040..792ab5c306d 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -14,6 +14,7 @@ from typing import List, Optional, Sequence, Tuple, Union +import humandfriendly import torch from torch import Tensor, tensor from torch.nn import functional as F @@ -81,13 +82,20 @@ def _binary_clf_curve( def _adjust_threshold_arg( - thresholds: Optional[Union[int, List[float], Tensor]] = None, device: Optional[torch.device] = None -) -> Optional[Tensor]: - """Utility function for converting the threshold arg for list and int to tensor format.""" + thresholds: Optional[Union[int, List[float], Tensor, str]] = None, device: Optional[torch.device] = None +) -> Optional[Union[Tensor, int]]: + """Utility function for converting the threshold arg. + + - list and int -> tensor + - None -> None + - str -> int (memory budget) in Mb + """ if isinstance(thresholds, int): thresholds = torch.linspace(0, 1, thresholds, device=device) if isinstance(thresholds, list): thresholds = torch.tensor(thresholds, device=device) + if isinstance(thresholds, str): + thresholds = humandfriendly.parse_size(thresholds, binary=True) return thresholds From a1b687d631b0b79ae4d4224f520c514386b187e4 Mon Sep 17 00:00:00 2001 From: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> Date: Thu, 3 Nov 2022 14:44:41 +0100 Subject: [PATCH 7/8] correct import --- .../functional/classification/precision_recall_curve.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index 792ab5c306d..1bfd26769df 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -14,7 +14,7 @@ from typing import List, Optional, Sequence, Tuple, Union -import humandfriendly +import humanfriendly import torch from torch import Tensor, tensor from torch.nn import functional as F @@ -95,7 +95,7 @@ def _adjust_threshold_arg( if isinstance(thresholds, list): thresholds = torch.tensor(thresholds, device=device) if isinstance(thresholds, str): - thresholds = humandfriendly.parse_size(thresholds, binary=True) + thresholds = humanfriendly.parse_size(thresholds, binary=True) return thresholds From 86b11780cbee63841c777afe9bba0b538c282dda Mon Sep 17 00:00:00 2001 From: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> Date: Thu, 3 Nov 2022 14:44:56 +0100 Subject: [PATCH 8/8] add humanfriendly to requirements --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index a386aed41f6..b77912c703a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ numpy>=1.17.2 torch>=1.8.1 packaging typing-extensions; python_version < '3.9' +humanfriendly