Skip to content

Commit d500a6f

Browse files
fix: prts version specification until CompML/PRTS#77 is merged and add skiptests conditions if not installed
1 parent e8849fa commit d500a6f

File tree

3 files changed

+61
-29
lines changed

3 files changed

+61
-29
lines changed

aeon/performance_metrics/anomaly_detection/_binary.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,7 @@ def range_precision(
5555
1920–30. 2018.
5656
http://papers.nips.cc/paper/7462-precision-and-recall-for-time-series.pdf.
5757
"""
58-
_check_soft_dependencies(
59-
"prts>=1.0.0.3", obj="range_precision", suppress_import_stdout=True
60-
)
58+
_check_soft_dependencies("prts", obj="range_precision", suppress_import_stdout=True)
6159

6260
from prts import ts_precision
6361

@@ -117,9 +115,7 @@ def range_recall(
117115
1920–30. 2018.
118116
http://papers.nips.cc/paper/7462-precision-and-recall-for-time-series.pdf.
119117
"""
120-
_check_soft_dependencies(
121-
"prts>=1.0.0.3", obj="range_recall", suppress_import_stdout=True
122-
)
118+
_check_soft_dependencies("prts", obj="range_recall", suppress_import_stdout=True)
123119

124120
from prts import ts_recall
125121

@@ -187,9 +183,7 @@ def range_f_score(
187183
1920–30. 2018.
188184
http://papers.nips.cc/paper/7462-precision-and-recall-for-time-series.pdf.
189185
"""
190-
_check_soft_dependencies(
191-
"prts>=1.0.0.3", obj="range_recall", suppress_import_stdout=True
192-
)
186+
_check_soft_dependencies("prts", obj="range_recall", suppress_import_stdout=True)
193187

194188
from prts import ts_fscore
195189

aeon/performance_metrics/anomaly_detection/_continuous.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def f_score_at_k_ranges(
157157
Function used to find the threshold.
158158
"""
159159
_check_soft_dependencies(
160-
"prts>=1.0.0.3", obj="f_score_at_k_ranges", suppress_import_stdout=True
160+
"prts", obj="f_score_at_k_ranges", suppress_import_stdout=True
161161
)
162162

163163
from prts import ts_fscore
@@ -231,7 +231,7 @@ def rp_rr_auc_score(
231231
http://papers.nips.cc/paper/7462-precision-and-recall-for-time-series.pdf.
232232
"""
233233
_check_soft_dependencies(
234-
"prts>=1.0.0.3", obj="f_score_at_k_ranges", suppress_import_stdout=True
234+
"prts", obj="f_score_at_k_ranges", suppress_import_stdout=True
235235
)
236236

237237
from prts import ts_precision, ts_recall

aeon/performance_metrics/anomaly_detection/tests/test_ad_metrics.py

Lines changed: 56 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,36 +18,46 @@
1818
rp_rr_auc_score,
1919
)
2020
from aeon.testing.data_generation import make_example_1d_numpy
21+
from aeon.utils.validation._dependencies import _check_soft_dependencies
2122

22-
pr_metrics = [pr_auc_score, rp_rr_auc_score]
23+
pr_metrics = [pr_auc_score]
2324
range_metrics = [
2425
range_roc_auc_score,
2526
range_pr_auc_score,
2627
range_roc_vus_score,
2728
range_pr_vus_score,
28-
range_recall,
29-
range_precision,
30-
range_f_score,
3129
]
3230
other_metrics = [
3331
roc_auc_score,
3432
f_score_at_k_points,
35-
f_score_at_k_ranges,
36-
rp_rr_auc_score,
3733
]
34+
continuous_metrics = [*pr_metrics, *other_metrics, *range_metrics]
35+
binary_metrics = []
36+
37+
if _check_soft_dependencies("prts", severity="none"):
38+
pr_metrics.append(rp_rr_auc_score)
39+
range_metrics.extend(
40+
[
41+
range_recall,
42+
range_precision,
43+
range_f_score,
44+
]
45+
)
46+
other_metrics.extend(
47+
[
48+
f_score_at_k_ranges,
49+
rp_rr_auc_score,
50+
]
51+
)
52+
continuous_metrics.extend(
53+
[
54+
rp_rr_auc_score,
55+
f_score_at_k_ranges,
56+
]
57+
)
58+
binary_metrics = [range_recall, range_precision, range_f_score]
59+
3860
metrics = [*pr_metrics, *range_metrics, *other_metrics]
39-
continuous_metrics = [
40-
*pr_metrics,
41-
range_roc_auc_score,
42-
range_pr_auc_score,
43-
range_roc_vus_score,
44-
range_pr_vus_score,
45-
roc_auc_score,
46-
rp_rr_auc_score,
47-
f_score_at_k_points,
48-
f_score_at_k_ranges,
49-
]
50-
binary_metrics = [range_recall, range_precision, range_f_score]
5161

5262

5363
@pytest.mark.parametrize("metric", metrics, ids=[m.__name__ for m in metrics])
@@ -140,6 +150,10 @@ def test_edge_cases_pr_metrics(metric):
140150
assert score <= 0.2, f"{metric.__name__}(y_true, y_inverted)={score} is not <= 0.2"
141151

142152

153+
@pytest.mark.skipif(
154+
not _check_soft_dependencies("prts", severity="none"),
155+
reason="required soft dependency prts not available",
156+
)
143157
def test_range_based_f1():
144158
"""Test range-based F1 score."""
145159
y_pred = np.array([0, 1, 1, 0])
@@ -148,6 +162,10 @@ def test_range_based_f1():
148162
np.testing.assert_almost_equal(result, 0.66666, decimal=4)
149163

150164

165+
@pytest.mark.skipif(
166+
not _check_soft_dependencies("prts", severity="none"),
167+
reason="required soft dependency prts not available",
168+
)
151169
def test_range_based_precision():
152170
"""Test range-based precision."""
153171
y_pred = np.array([0, 1, 1, 0])
@@ -156,6 +174,10 @@ def test_range_based_precision():
156174
assert result == 0.5
157175

158176

177+
@pytest.mark.skipif(
178+
not _check_soft_dependencies("prts", severity="none"),
179+
reason="required soft dependency prts not available",
180+
)
159181
def test_range_based_recall():
160182
"""Test range-based recall."""
161183
y_pred = np.array([0, 1, 1, 0])
@@ -164,6 +186,10 @@ def test_range_based_recall():
164186
assert result == 1
165187

166188

189+
@pytest.mark.skipif(
190+
not _check_soft_dependencies("prts", severity="none"),
191+
reason="required soft dependency prts not available",
192+
)
167193
def test_rf1_value_error():
168194
"""Test range-based F1 score raises ValueError on binary predictions."""
169195
y_pred = np.array([0, 0.2, 0.7, 0])
@@ -187,6 +213,10 @@ def test_pr_curve_auc():
187213
# np.testing.assert_almost_equal(result, 0.8333, decimal=4)
188214

189215

216+
@pytest.mark.skipif(
217+
not _check_soft_dependencies("prts", severity="none"),
218+
reason="required soft dependency prts not available",
219+
)
190220
def test_range_based_p_range_based_r_curve_auc():
191221
"""Test range-based precision-recall curve AUC."""
192222
y_pred = np.array([0, 0.1, 1.0, 0.5, 0.1, 0])
@@ -195,6 +225,10 @@ def test_range_based_p_range_based_r_curve_auc():
195225
np.testing.assert_almost_equal(result, 0.9792, decimal=4)
196226

197227

228+
@pytest.mark.skipif(
229+
not _check_soft_dependencies("prts", severity="none"),
230+
reason="required soft dependency prts not available",
231+
)
198232
def test_range_based_p_range_based_r_auc_perfect_hit():
199233
"""Test range-based precision-recall curve AUC with perfect hit."""
200234
y_pred = np.array([0, 0, 0.5, 0.5, 0, 0])
@@ -203,6 +237,10 @@ def test_range_based_p_range_based_r_auc_perfect_hit():
203237
np.testing.assert_almost_equal(result, 1.0000, decimal=4)
204238

205239

240+
@pytest.mark.skipif(
241+
not _check_soft_dependencies("prts", severity="none"),
242+
reason="required soft dependency prts not available",
243+
)
206244
def test_f_score_at_k_ranges():
207245
"""Test range-based F1 score at k ranges."""
208246
y_pred = np.array([0.4, 0.1, 1.0, 0.5, 0.1, 0, 0.4, 0.5])

0 commit comments

Comments
 (0)