Skip to content

Commit d87aff7

Browse files
authored
ci: fix missed testing with oldest (#2803)
1 parent 9cc354c commit d87aff7

8 files changed

Lines changed: 39 additions & 11 deletions

File tree

.github/workflows/ci-integrate.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ jobs:
5353

5454
- name: source cashing
5555
uses: ./.github/actions/pull-caches
56+
with:
57+
requires: ${{ matrix.requires }}
5658
- name: set oldest if/only for integrations
5759
if: matrix.requires == 'oldest'
5860
run: python .github/assistant.py set-oldest-versions --req_files='["requirements/_integrate.txt"]'

.github/workflows/ci-tests.yml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ jobs:
3131
strategy:
3232
fail-fast: false
3333
matrix:
34-
os: ["ubuntu-20.04"]
35-
python-version: ["3.9"]
34+
os: ["ubuntu-22.04"]
35+
python-version: ["3.10"]
3636
pytorch-version:
3737
- "2.0.1"
3838
- "2.1.2"
@@ -42,9 +42,8 @@ jobs:
4242
- "2.5.0"
4343
include:
4444
# cover additional python and PT combinations
45-
- { os: "ubuntu-22.04", python-version: "3.10", pytorch-version: "2.0.1" }
46-
- { os: "ubuntu-22.04", python-version: "3.10", pytorch-version: "2.2.2" }
47-
- { os: "ubuntu-22.04", python-version: "3.11", pytorch-version: "2.3.1" }
45+
- { os: "ubuntu-20.04", python-version: "3.8", pytorch-version: "2.0.1", requires: "oldest" }
46+
- { os: "ubuntu-22.04", python-version: "3.11", pytorch-version: "2.4.1" }
4847
- { os: "ubuntu-22.04", python-version: "3.12", pytorch-version: "2.5.0" }
4948
# standard mac machine, not the M1
5049
- { os: "macOS-13", python-version: "3.10", pytorch-version: "2.0.1" }

.github/workflows/docs-build.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ jobs:
4444
- name: source cashing
4545
uses: ./.github/actions/pull-caches
4646
with:
47-
requires: ${{ matrix.requires }}
4847
pytorch-version: ${{ matrix.pytorch-version }}
4948
pypi-dir: ${{ env.PYPI_CACHE }}
5049

requirements/segmentation_test.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
33

44
scipy >1.0.0, <1.15.0
5-
monai ==1.4.0
5+
monai ==1.3.2 ; python_version < "3.9"
6+
monai ==1.4.0 ; python_version > "3.8"

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,5 +245,6 @@ def _prepare_extras(skip_pattern: str = "^_", skip_files: Tuple[str] = ("base.tx
245245
"Programming Language :: Python :: 3.9",
246246
"Programming Language :: Python :: 3.10",
247247
"Programming Language :: Python :: 3.11",
248+
"Programming Language :: Python :: 3.12",
248249
],
249250
)

src/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,5 @@ def collect(self) -> GeneratorExit:
3636
def pytest_collect_file(parent: Path, path: Path) -> Optional[DoctestModule]:
3737
"""Collect doctests and add the reset_random_seed fixture."""
3838
if path.ext == ".py":
39-
return DoctestModule.from_parent(parent, fspath=path)
39+
return DoctestModule.from_parent(parent, path=Path(path))
4040
return None

src/torchmetrics/collections.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,30 @@
3131
__doctest_skip__ = ["MetricCollection.plot", "MetricCollection.plot_all"]
3232

3333

34+
def _remove_prefix(string: str, prefix: str) -> str:
35+
"""Patch for older version with missing method `removeprefix`.
36+
37+
>>> _remove_prefix("prefix_string", "prefix_")
38+
'string'
39+
>>> _remove_prefix("not_prefix_string", "prefix_")
40+
'not_prefix_string'
41+
42+
"""
43+
return string[len(prefix) :] if string.startswith(prefix) else string
44+
45+
46+
def _remove_suffix(string: str, suffix: str) -> str:
47+
"""Patch for older version with missing method `removesuffix`.
48+
49+
>>> _remove_suffix("string_suffix", "_suffix")
50+
'string'
51+
>>> _remove_suffix("string_suffix_missing", "_suffix")
52+
'string_suffix_missing'
53+
54+
"""
55+
return string[: -len(suffix)] if string.endswith(suffix) else string
56+
57+
3458
class MetricCollection(ModuleDict):
3559
"""MetricCollection class can be used to chain metrics that have the same call pattern into one single class.
3660
@@ -558,9 +582,9 @@ def __getitem__(self, key: str, copy_state: bool = True) -> Metric:
558582
"""
559583
self._compute_groups_create_state_ref(copy_state)
560584
if self.prefix:
561-
key = key.removeprefix(self.prefix)
585+
key = _remove_prefix(key, self.prefix)
562586
if self.postfix:
563-
key = key.removesuffix(self.postfix)
587+
key = _remove_suffix(key, self.postfix)
564588
return self._modules[key]
565589

566590
@staticmethod

tests/unittests/segmentation/test_generalized_dice_score.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import pytest
1818
import torch
19+
from lightning_utilities.core.imports import RequirementCache
1920
from monai.metrics.generalized_dice import compute_generalized_dice
2021
from torchmetrics.functional.segmentation.generalized_dice import generalized_dice_score
2122
from torchmetrics.segmentation.generalized_dice import GeneralizedDiceScore
@@ -51,7 +52,8 @@ def _reference_generalized_dice(
5152
if input_format == "index":
5253
preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1)
5354
target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1)
54-
val = compute_generalized_dice(preds, target, include_background=include_background, sum_over_classes=True)
55+
monai_extra_arg = {"sum_over_classes": True} if RequirementCache("monai>=1.4.0") else {}
56+
val = compute_generalized_dice(preds, target, include_background=include_background, **monai_extra_arg)
5557
if reduce:
5658
val = val.mean()
5759
return val.squeeze()

0 commit comments

Comments
 (0)