From 09fcccae39e69b6a4de9ca8842e0e80dc10f5c4f Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 22 Aug 2023 16:04:40 +0200 Subject: [PATCH] Use lightning utils `apply_to_collection` (#2013) * refactor to use lightning utils * increase requirement --- requirements.txt | 2 +- src/torchmetrics/detection/mean_ap.py | 2 +- src/torchmetrics/metric.py | 2 +- src/torchmetrics/utilities/__init__.py | 2 - src/torchmetrics/utilities/data.py | 54 +------------------ src/torchmetrics/wrappers/bootstrapping.py | 2 +- src/torchmetrics/wrappers/multioutput.py | 2 +- tests/unittests/helpers/testers.py | 3 +- .../unittests/wrappers/test_bootstrapping.py | 2 +- 9 files changed, 10 insertions(+), 61 deletions(-) diff --git a/requirements.txt b/requirements.txt index b68dbb40274..27b5d0d3feb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,4 @@ numpy >1.20.0 torch >=1.8.1, <=2.0.1 typing-extensions; python_version < '3.9' packaging # hotfix for utils, can be dropped with lit-utils >=0.5 -lightning-utilities >=0.7.0, <0.10.0 +lightning-utilities >=0.8.0, <0.10.0 diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 075adeeefd4..0d47049a608 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -18,6 +18,7 @@ import numpy as np import torch +from lightning_utilities import apply_to_collection from torch import Tensor from torch import distributed as dist from typing_extensions import Literal @@ -25,7 +26,6 @@ from torchmetrics.detection.helpers import _fix_empty_tensors, _input_validator, _validate_iou_type_arg from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn -from torchmetrics.utilities.data import apply_to_collection from torchmetrics.utilities.imports import ( _MATPLOTLIB_AVAILABLE, _PYCOCOTOOLS_AVAILABLE, diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index ddf386ebe6c..ef3f61a9d6f 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -23,13 +23,13 @@ from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, Sequence, Tuple, Union import torch +from lightning_utilities import apply_to_collection from torch import Tensor from torch.nn import Module from torchmetrics.utilities.data import ( _flatten, _squeeze_if_scalar, - apply_to_collection, dim_zero_cat, dim_zero_max, dim_zero_mean, diff --git a/src/torchmetrics/utilities/__init__.py b/src/torchmetrics/utilities/__init__.py index 234e3474873..9250d7b2f46 100644 --- a/src/torchmetrics/utilities/__init__.py +++ b/src/torchmetrics/utilities/__init__.py @@ -12,13 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.utilities.checks import check_forward_full_state_property -from torchmetrics.utilities.data import apply_to_collection from torchmetrics.utilities.distributed import class_reduce, reduce from torchmetrics.utilities.prints import rank_zero_debug, rank_zero_info, rank_zero_warn __all__ = [ "check_forward_full_state_property", - "apply_to_collection", "class_reduce", "reduce", "rank_zero_debug", diff --git a/src/torchmetrics/utilities/data.py b/src/torchmetrics/utilities/data.py index 16e59f78df9..ebb81679a02 100644 --- a/src/torchmetrics/utilities/data.py +++ b/src/torchmetrics/utilities/data.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Union import torch +from lightning_utilities import apply_to_collection from torch import Tensor from torchmetrics.utilities.exceptions import TorchMetricsUserWarning @@ -152,57 +153,6 @@ def to_categorical(x: Tensor, argmax_dim: int = 1) -> Tensor: return torch.argmax(x, dim=argmax_dim) -def apply_to_collection( - data: Any, - dtype: Union[type, tuple], - function: Callable, - *args: Any, - wrong_dtype: Optional[Union[type, tuple]] = None, - **kwargs: Any, -) -> Any: - """Recursively applies a function to all elements of a certain dtype. - - Args: - data: the collection to apply the function to - dtype: the given function will be applied to all elements of this dtype - function: the function to apply - *args: positional arguments (will be forwarded to call of ``function``) - wrong_dtype: the given function won't be applied if this type is specified and the given collections is of - the :attr:`wrong_type` even if it is of type :attr`dtype` - **kwargs: keyword arguments (will be forwarded to call of ``function``) - - Returns: - the resulting collection - - Example: - >>> apply_to_collection(torch.tensor([8, 0, 2, 6, 7]), dtype=Tensor, function=lambda x: x ** 2) - tensor([64, 0, 4, 36, 49]) - >>> apply_to_collection([8, 0, 2, 6, 7], dtype=int, function=lambda x: x ** 2) - [64, 0, 4, 36, 49] - >>> apply_to_collection(dict(abc=123), dtype=int, function=lambda x: x ** 2) - {'abc': 15129} - - """ - elem_type = type(data) - - # Breaking condition - if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)): - return function(data, *args, **kwargs) - - # Recursively apply to collection items - if isinstance(data, Mapping): - return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()}) - - if isinstance(data, tuple) and hasattr(data, "_fields"): # named tuple - return elem_type(*(apply_to_collection(d, dtype, function, *args, **kwargs) for d in data)) - - if isinstance(data, Sequence) and not isinstance(data, str): - return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data]) - - # data is neither of dtype, nor a collection - return data - - def _squeeze_scalar_element_tensor(x: Tensor) -> Tensor: return x.squeeze() if x.numel() == 1 else x diff --git a/src/torchmetrics/wrappers/bootstrapping.py b/src/torchmetrics/wrappers/bootstrapping.py index 304d254d5b5..0157ce99f71 100644 --- a/src/torchmetrics/wrappers/bootstrapping.py +++ b/src/torchmetrics/wrappers/bootstrapping.py @@ -15,11 +15,11 @@ from typing import Any, Dict, Optional, Sequence, Union import torch +from lightning_utilities import apply_to_collection from torch import Tensor from torch.nn import ModuleList from torchmetrics.metric import Metric -from torchmetrics.utilities import apply_to_collection from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE from torchmetrics.wrappers.abstract import WrapperMetric diff --git a/src/torchmetrics/wrappers/multioutput.py b/src/torchmetrics/wrappers/multioutput.py index 0fb86f12f38..7853e6257e6 100644 --- a/src/torchmetrics/wrappers/multioutput.py +++ b/src/torchmetrics/wrappers/multioutput.py @@ -15,11 +15,11 @@ from typing import Any, List, Optional, Sequence, Tuple, Union import torch +from lightning_utilities import apply_to_collection from torch import Tensor from torch.nn import ModuleList from torchmetrics.metric import Metric -from torchmetrics.utilities import apply_to_collection from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE from torchmetrics.wrappers.abstract import WrapperMetric diff --git a/tests/unittests/helpers/testers.py b/tests/unittests/helpers/testers.py index c9a37e48972..3740e7bf335 100644 --- a/tests/unittests/helpers/testers.py +++ b/tests/unittests/helpers/testers.py @@ -20,9 +20,10 @@ import numpy as np import pytest import torch +from lightning_utilities import apply_to_collection from torch import Tensor, tensor from torchmetrics import Metric -from torchmetrics.utilities.data import _flatten, apply_to_collection +from torchmetrics.utilities.data import _flatten from unittests import NUM_PROCESSES diff --git a/tests/unittests/wrappers/test_bootstrapping.py b/tests/unittests/wrappers/test_bootstrapping.py index e2339e3b719..7d38d5728ac 100644 --- a/tests/unittests/wrappers/test_bootstrapping.py +++ b/tests/unittests/wrappers/test_bootstrapping.py @@ -18,11 +18,11 @@ import numpy as np import pytest import torch +from lightning_utilities import apply_to_collection from sklearn.metrics import mean_squared_error, precision_score, recall_score from torch import Tensor from torchmetrics.classification import MulticlassPrecision, MulticlassRecall from torchmetrics.regression import MeanSquaredError -from torchmetrics.utilities import apply_to_collection from torchmetrics.wrappers.bootstrapping import BootStrapper, _bootstrap_sampler from unittests.helpers import seed_all