Skip to content

Commit 346bcdc

Browse files
Bordapre-commit-ci[bot]SkafteNicki
authored
bump: support torch>=2.0 (#2671)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
1 parent 7f579eb commit 346bcdc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+125
-232
lines changed

.azure/gpu-integrations.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ jobs:
1717
- job: integrate_GPU
1818
strategy:
1919
matrix:
20-
"torch | 1.x":
21-
docker-image: "pytorchlightning/torchmetrics:ubuntu22.04-cuda11.8.0-py3.9-torch1.13"
22-
torch-ver: "1.13"
20+
"torch | 2.0":
21+
docker-image: "pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime"
22+
torch-ver: "2.0"
2323
requires: "oldest"
2424
"torch | 2.x":
2525
docker-image: "pytorch/pytorch:2.4.0-cuda12.1-cudnn9-runtime"

.azure/gpu-unittests.yml

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,10 @@ jobs:
2424
- job: unitest_GPU
2525
strategy:
2626
matrix:
27-
"PyTorch | 1.10 oldest":
27+
"PyTorch | 2.0 oldest":
2828
# Torch does not have build wheels with old Torch versions for newer CUDA
29-
docker-image: "ubuntu20.04-cuda11.3.1-py3.9-torch1.10"
30-
torch-ver: "1.10"
31-
"PyTorch | 1.X LTS":
32-
docker-image: "ubuntu22.04-cuda11.8.0-py3.9-torch1.13"
33-
torch-ver: "1.13"
29+
docker-image: "ubuntu22.04-cuda11.8.0-py3.10-torch2.0"
30+
torch-ver: "2.0"
3431
"PyTorch | 2.X stable":
3532
docker-image: "ubuntu22.04-cuda12.1.1-py3.11-torch2.4"
3633
torch-ver: "2.4"
@@ -123,7 +120,7 @@ jobs:
123120
124121
- bash: |
125122
python .github/assistant.py set-oldest-versions
126-
condition: eq(variables['torch-ver'], '1.10')
123+
condition: eq(variables['torch-ver'], '2.0')
127124
displayName: "Setting oldest versions"
128125
129126
- bash: |
@@ -191,7 +188,7 @@ jobs:
191188
workingDirectory: "tests/"
192189
# skip for PR if there is nothing to test, note that outside PR there is default 'unittests'
193190
condition: and(succeeded(), ne(variables['TEST_DIRS'], ''))
194-
timeoutInMinutes: "90"
191+
timeoutInMinutes: "95"
195192
displayName: "UnitTesting common"
196193
197194
- bash: |
@@ -203,7 +200,7 @@ jobs:
203200
workingDirectory: "tests/"
204201
# skip for PR if there is nothing to test, note that outside PR there is default 'unittests'
205202
condition: and(succeeded(), ne(variables['TEST_DIRS'], ''))
206-
timeoutInMinutes: "90"
203+
timeoutInMinutes: "95"
207204
displayName: "UnitTesting DDP"
208205
209206
- bash: |

.github/workflows/ci-tests.yml

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,29 +34,22 @@ jobs:
3434
os: ["ubuntu-20.04"]
3535
python-version: ["3.9"]
3636
pytorch-version:
37-
- "1.10.2"
38-
- "1.11.0"
39-
- "1.12.1"
40-
- "1.13.1"
4137
- "2.0.1"
4238
- "2.1.2"
4339
- "2.2.2"
4440
- "2.3.1"
4541
- "2.4.0"
4642
include:
4743
# cover additional python and PT combinations
48-
- { os: "ubuntu-22.04", python-version: "3.8", pytorch-version: "1.13.1" }
4944
- { os: "ubuntu-22.04", python-version: "3.10", pytorch-version: "2.0.1" }
5045
- { os: "ubuntu-22.04", python-version: "3.10", pytorch-version: "2.2.2" }
5146
- { os: "ubuntu-22.04", python-version: "3.11", pytorch-version: "2.3.1" }
5247
# standard mac machine, not the M1
53-
- { os: "macOS-13", python-version: "3.8", pytorch-version: "1.13.1" }
5448
- { os: "macOS-13", python-version: "3.10", pytorch-version: "2.0.1" }
5549
# using the ARM based M1 machine
5650
- { os: "macOS-14", python-version: "3.10", pytorch-version: "2.0.1" }
5751
- { os: "macOS-14", python-version: "3.11", pytorch-version: "2.4.0" }
5852
# some windows
59-
- { os: "windows-2022", python-version: "3.8", pytorch-version: "1.13.1" }
6053
- { os: "windows-2022", python-version: "3.10", pytorch-version: "2.0.1" }
6154
- { os: "windows-2022", python-version: "3.11", pytorch-version: "2.4.0" }
6255
# Future released version

.github/workflows/docker-build.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,6 @@ jobs:
6666
include:
6767
# These are the base images for PL release docker images,
6868
# so include at least all the combinations in release-dockers.yml.
69-
- { python: "3.9", pytorch: "1.10", cuda: "11.3.1", ubuntu: "20.04" }
70-
#- { python: "3.9", pytorch: "1.11", cuda: "11.8.0", ubuntu: "22.04" }
71-
- { python: "3.9", pytorch: "1.13", cuda: "11.8.0", ubuntu: "22.04" }
7269
- { python: "3.10", pytorch: "2.2", cuda: "12.1.1", ubuntu: "22.04" }
7370
- { python: "3.11", pytorch: "2.2", cuda: "12.1.1", ubuntu: "22.04" }
7471
- { python: "3.11", pytorch: "2.3", cuda: "12.1.1", ubuntu: "22.04" }

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2222

2323
### Removed
2424

25-
-
25+
- Changed minimum supported Pytorch version to 2.0 ([#2671](https://github.com/Lightning-AI/torchmetrics/pull/2671))
2626

2727

2828
### Fixed

requirements/_tests.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
22
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
33

4+
codecov ==2.1.13
45
coverage ==7.6.*
56
codecov ==2.1.13
67
pytest ==8.3.*

requirements/audio.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
# this need to be the same as used inside speechmetrics
55
pesq >=0.0.4, <0.0.5
66
pystoi >=0.4.0, <0.5.0
7-
torchaudio >=0.10.0, <2.5.0
7+
torchaudio >=2.0.1, <2.5.0
88
gammatone >=1.0.0, <1.1.0
9-
librosa >=0.9.0, <0.11.0
9+
librosa >=0.10.0, <0.11.0
1010
onnxruntime >=1.12.0, <1.20 # installing onnxruntime_gpu-gpu failed on macos
1111
requests >=2.19.0, <2.33.0

requirements/base.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33

44
numpy >1.20.0, <2.0 # strict, for compatibility reasons
55
packaging >17.1
6-
torch >=1.10.0, <2.5.0
6+
torch >=2.0.0, <2.5.0
77
typing-extensions; python_version < '3.9'
88
lightning-utilities >=0.8.0, <0.12.0

requirements/detection.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
22
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
33

4-
torchvision >=0.8, <0.20.0
4+
torchvision >=0.15.1, <0.20.0
55
pycocotools >2.0.0, <2.1.0

requirements/image.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
33

44
scipy >1.0.0, <1.15.0
5-
torchvision >=0.8, <0.20.0
5+
torchvision >=0.15.1, <0.20.0
66
torch-fidelity <=0.4.0 # bumping to allow install version from master, now used in testing

src/torchmetrics/audio/__init__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,17 @@
2828
_ONNXRUNTIME_AVAILABLE,
2929
_PESQ_AVAILABLE,
3030
_PYSTOI_AVAILABLE,
31+
_SCIPI_AVAILABLE,
3132
_TORCHAUDIO_AVAILABLE,
32-
_TORCHAUDIO_GREATER_EQUAL_0_10,
3333
)
3434

35+
if _SCIPI_AVAILABLE:
36+
import scipy.signal
37+
38+
# back compatibility patch due to SMRMpy using scipy.signal.hamming
39+
if not hasattr(scipy.signal, "hamming"):
40+
scipy.signal.hamming = scipy.signal.windows.hamming
41+
3542
__all__ = [
3643
"PermutationInvariantTraining",
3744
"ScaleInvariantSignalDistortionRatio",
@@ -52,7 +59,7 @@
5259

5360
__all__ += ["ShortTimeObjectiveIntelligibility"]
5461

55-
if _GAMMATONE_AVAILABLE and _TORCHAUDIO_AVAILABLE and _TORCHAUDIO_GREATER_EQUAL_0_10:
62+
if _GAMMATONE_AVAILABLE and _TORCHAUDIO_AVAILABLE:
5663
from torchmetrics.audio.srmr import SpeechReverberationModulationEnergyRatio
5764

5865
__all__ += ["SpeechReverberationModulationEnergyRatio"]

src/torchmetrics/audio/srmr.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,10 @@
2424
_GAMMATONE_AVAILABLE,
2525
_MATPLOTLIB_AVAILABLE,
2626
_TORCHAUDIO_AVAILABLE,
27-
_TORCHAUDIO_GREATER_EQUAL_0_10,
2827
)
2928
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
3029

31-
if not all([_GAMMATONE_AVAILABLE, _TORCHAUDIO_AVAILABLE, _TORCHAUDIO_GREATER_EQUAL_0_10]):
30+
if not all([_GAMMATONE_AVAILABLE, _TORCHAUDIO_AVAILABLE]):
3231
__doctest_skip__ = ["SpeechReverberationModulationEnergyRatio", "SpeechReverberationModulationEnergyRatio.plot"]
3332
elif not _MATPLOTLIB_AVAILABLE:
3433
__doctest_skip__ = ["SpeechReverberationModulationEnergyRatio.plot"]
@@ -105,7 +104,7 @@ def __init__(
105104
**kwargs: Any,
106105
) -> None:
107106
super().__init__(**kwargs)
108-
if not _TORCHAUDIO_AVAILABLE or not _TORCHAUDIO_GREATER_EQUAL_0_10 or not _GAMMATONE_AVAILABLE:
107+
if not _TORCHAUDIO_AVAILABLE or not _GAMMATONE_AVAILABLE:
109108
raise ModuleNotFoundError(
110109
"speech_reverberation_modulation_energy_ratio requires you to have `gammatone` and"
111110
" `torchaudio>=0.10` installed. Either install as ``pip install torchmetrics[audio]`` or "

src/torchmetrics/detection/__init__.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,21 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from torchmetrics.detection.panoptic_qualities import ModifiedPanopticQuality, PanopticQuality
15-
from torchmetrics.utilities.imports import (
16-
_TORCHVISION_GREATER_EQUAL_0_8,
17-
_TORCHVISION_GREATER_EQUAL_0_13,
18-
)
15+
from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE
1916

2017
__all__ = ["ModifiedPanopticQuality", "PanopticQuality"]
2118

22-
if _TORCHVISION_GREATER_EQUAL_0_8:
19+
if _TORCHVISION_AVAILABLE:
20+
from torchmetrics.detection.ciou import CompleteIntersectionOverUnion
21+
from torchmetrics.detection.diou import DistanceIntersectionOverUnion
2322
from torchmetrics.detection.giou import GeneralizedIntersectionOverUnion
2423
from torchmetrics.detection.iou import IntersectionOverUnion
2524
from torchmetrics.detection.mean_ap import MeanAveragePrecision
2625

27-
__all__ += ["MeanAveragePrecision", "GeneralizedIntersectionOverUnion", "IntersectionOverUnion"]
28-
29-
if _TORCHVISION_GREATER_EQUAL_0_13:
30-
from torchmetrics.detection.ciou import CompleteIntersectionOverUnion
31-
from torchmetrics.detection.diou import DistanceIntersectionOverUnion
32-
33-
__all__ += ["CompleteIntersectionOverUnion", "DistanceIntersectionOverUnion"]
26+
__all__ += [
27+
"MeanAveragePrecision",
28+
"GeneralizedIntersectionOverUnion",
29+
"IntersectionOverUnion",
30+
"CompleteIntersectionOverUnion",
31+
"DistanceIntersectionOverUnion",
32+
]

src/torchmetrics/detection/_deprecated.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,8 @@
11
from typing import Any, Collection
22

33
from torchmetrics.detection import ModifiedPanopticQuality, PanopticQuality
4-
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12
54
from torchmetrics.utilities.prints import _deprecated_root_import_class
65

7-
if not _TORCH_GREATER_EQUAL_1_12:
8-
__doctest_skip__ = [
9-
"_PanopticQuality",
10-
"_PanopticQuality.*",
11-
"_ModifiedPanopticQuality",
12-
"_ModifiedPanopticQuality.*",
13-
]
14-
156

167
class _ModifiedPanopticQuality(ModifiedPanopticQuality):
178
"""Wrapper for deprecated import.

src/torchmetrics/detection/_mean_ap.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222
from torchmetrics.detection.helpers import _fix_empty_tensors, _input_validator
2323
from torchmetrics.metric import Metric
2424
from torchmetrics.utilities.data import _cumsum
25-
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _PYCOCOTOOLS_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8
25+
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _PYCOCOTOOLS_AVAILABLE, _TORCHVISION_AVAILABLE
2626
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
2727

2828
if not _MATPLOTLIB_AVAILABLE:
2929
__doctest_skip__ = ["MeanAveragePrecision.plot"]
3030

31-
if not _TORCHVISION_GREATER_EQUAL_0_8 or not _PYCOCOTOOLS_AVAILABLE:
31+
if not _TORCHVISION_AVAILABLE or not _PYCOCOTOOLS_AVAILABLE:
3232
__doctest_skip__ = ["MeanAveragePrecision.plot", "MeanAveragePrecision"]
3333

3434
log = logging.getLogger(__name__)
@@ -327,10 +327,10 @@ def __init__(
327327
"`MAP` metric requires that `pycocotools` installed."
328328
" Please install with `pip install pycocotools` or `pip install torchmetrics[detection]`"
329329
)
330-
if not _TORCHVISION_GREATER_EQUAL_0_8:
330+
if not _TORCHVISION_AVAILABLE:
331331
raise ModuleNotFoundError(
332-
"`MeanAveragePrecision` metric requires that `torchvision` version 0.8.0 or newer is installed."
333-
" Please install with `pip install torchvision>=0.8` or `pip install torchmetrics[detection]`."
332+
"`MeanAveragePrecision` metric requires that `torchvision` is installed."
333+
" Please install with `pip install torchmetrics[detection]`."
334334
)
335335

336336
allowed_box_formats = ("xyxy", "xywh", "cxcywh")

src/torchmetrics/detection/ciou.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
from torchmetrics.detection.iou import IntersectionOverUnion
1919
from torchmetrics.functional.detection.ciou import _ciou_compute, _ciou_update
20-
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_13
20+
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE
2121
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
2222

23-
if not _TORCHVISION_GREATER_EQUAL_0_13:
23+
if not _TORCHVISION_AVAILABLE:
2424
__doctest_skip__ = ["CompleteIntersectionOverUnion", "CompleteIntersectionOverUnion.plot"]
2525
elif not _MATPLOTLIB_AVAILABLE:
2626
__doctest_skip__ = ["CompleteIntersectionOverUnion.plot"]
@@ -110,10 +110,10 @@ def __init__(
110110
respect_labels: bool = True,
111111
**kwargs: Any,
112112
) -> None:
113-
if not _TORCHVISION_GREATER_EQUAL_0_13:
113+
if not _TORCHVISION_AVAILABLE:
114114
raise ModuleNotFoundError(
115-
f"Metric `{self._iou_type.upper()}` requires that `torchvision` version 0.13.0 or newer is installed."
116-
" Please install with `pip install torchvision>=0.13` or `pip install torchmetrics[detection]`."
115+
f"Metric `{self._iou_type.upper()}` requires that `torchvision` is installed."
116+
" Please install with `pip install torchmetrics[detection]`."
117117
)
118118
super().__init__(box_format, iou_threshold, class_metrics, respect_labels, **kwargs)
119119

src/torchmetrics/detection/diou.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
from torchmetrics.detection.iou import IntersectionOverUnion
1919
from torchmetrics.functional.detection.diou import _diou_compute, _diou_update
20-
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_13
20+
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE
2121
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
2222

23-
if not _TORCHVISION_GREATER_EQUAL_0_13:
23+
if not _TORCHVISION_AVAILABLE:
2424
__doctest_skip__ = ["DistanceIntersectionOverUnion", "DistanceIntersectionOverUnion.plot"]
2525
elif not _MATPLOTLIB_AVAILABLE:
2626
__doctest_skip__ = ["DistanceIntersectionOverUnion.plot"]
@@ -110,10 +110,10 @@ def __init__(
110110
respect_labels: bool = True,
111111
**kwargs: Any,
112112
) -> None:
113-
if not _TORCHVISION_GREATER_EQUAL_0_13:
113+
if not _TORCHVISION_AVAILABLE:
114114
raise ModuleNotFoundError(
115-
f"Metric `{self._iou_type.upper()}` requires that `torchvision` version 0.13.0 or newer is installed."
116-
" Please install with `pip install torchvision>=0.13` or `pip install torchmetrics[detection]`."
115+
f"Metric `{self._iou_type.upper()}` requires that `torchvision` is installed."
116+
" Please install with `pip install torchmetrics[detection]`."
117117
)
118118
super().__init__(box_format, iou_threshold, class_metrics, respect_labels, **kwargs)
119119

src/torchmetrics/detection/giou.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
from torchmetrics.detection.iou import IntersectionOverUnion
1919
from torchmetrics.functional.detection.giou import _giou_compute, _giou_update
20-
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8
20+
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE
2121
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
2222

23-
if not _TORCHVISION_GREATER_EQUAL_0_8:
23+
if not _TORCHVISION_AVAILABLE:
2424
__doctest_skip__ = ["GeneralizedIntersectionOverUnion", "GeneralizedIntersectionOverUnion.plot"]
2525
elif not _MATPLOTLIB_AVAILABLE:
2626
__doctest_skip__ = ["GeneralizedIntersectionOverUnion.plot"]

src/torchmetrics/detection/iou.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
from torchmetrics.functional.detection.iou import _iou_compute, _iou_update
2121
from torchmetrics.metric import Metric
2222
from torchmetrics.utilities.data import dim_zero_cat
23-
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8
23+
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE
2424
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
2525

26-
if not _TORCHVISION_GREATER_EQUAL_0_8:
26+
if not _TORCHVISION_AVAILABLE:
2727
__doctest_skip__ = ["IntersectionOverUnion", "IntersectionOverUnion.plot"]
2828
elif not _MATPLOTLIB_AVAILABLE:
2929
__doctest_skip__ = ["IntersectionOverUnion.plot"]
@@ -146,10 +146,10 @@ def __init__(
146146
) -> None:
147147
super().__init__(**kwargs)
148148

149-
if not _TORCHVISION_GREATER_EQUAL_0_8:
149+
if not _TORCHVISION_AVAILABLE:
150150
raise ModuleNotFoundError(
151-
f"Metric `{self._iou_type.upper()}` requires that `torchvision` version 0.8.0 or newer is installed."
152-
" Please install with `pip install torchvision>=0.8` or `pip install torchmetrics[detection]`."
151+
f"Metric `{self._iou_type.upper()}` requires that `torchvision` is installed."
152+
" Please install with `pip install torchmetrics[detection]`."
153153
)
154154

155155
allowed_box_formats = ("xyxy", "xywh", "cxcywh")

0 commit comments

Comments
 (0)