18
18
rp_rr_auc_score ,
19
19
)
20
20
from aeon .testing .data_generation import make_example_1d_numpy
21
+ from aeon .utils .validation ._dependencies import _check_soft_dependencies
21
22
22
- pr_metrics = [pr_auc_score , rp_rr_auc_score ]
23
+ pr_metrics = [pr_auc_score ]
23
24
range_metrics = [
24
25
range_roc_auc_score ,
25
26
range_pr_auc_score ,
26
27
range_roc_vus_score ,
27
28
range_pr_vus_score ,
28
- range_recall ,
29
- range_precision ,
30
- range_f_score ,
31
29
]
32
30
other_metrics = [
33
31
roc_auc_score ,
34
32
f_score_at_k_points ,
35
- f_score_at_k_ranges ,
36
- rp_rr_auc_score ,
37
33
]
34
+ continuous_metrics = [* pr_metrics , * other_metrics , * range_metrics ]
35
+ binary_metrics = []
36
+
37
+ if _check_soft_dependencies ("prts" , severity = "none" ):
38
+ pr_metrics .append (rp_rr_auc_score )
39
+ range_metrics .extend (
40
+ [
41
+ range_recall ,
42
+ range_precision ,
43
+ range_f_score ,
44
+ ]
45
+ )
46
+ other_metrics .extend (
47
+ [
48
+ f_score_at_k_ranges ,
49
+ rp_rr_auc_score ,
50
+ ]
51
+ )
52
+ continuous_metrics .extend (
53
+ [
54
+ rp_rr_auc_score ,
55
+ f_score_at_k_ranges ,
56
+ ]
57
+ )
58
+ binary_metrics = [range_recall , range_precision , range_f_score ]
59
+
38
60
metrics = [* pr_metrics , * range_metrics , * other_metrics ]
39
- continuous_metrics = [
40
- * pr_metrics ,
41
- range_roc_auc_score ,
42
- range_pr_auc_score ,
43
- range_roc_vus_score ,
44
- range_pr_vus_score ,
45
- roc_auc_score ,
46
- rp_rr_auc_score ,
47
- f_score_at_k_points ,
48
- f_score_at_k_ranges ,
49
- ]
50
- binary_metrics = [range_recall , range_precision , range_f_score ]
51
61
52
62
53
63
@pytest .mark .parametrize ("metric" , metrics , ids = [m .__name__ for m in metrics ])
@@ -140,6 +150,10 @@ def test_edge_cases_pr_metrics(metric):
140
150
assert score <= 0.2 , f"{ metric .__name__ } (y_true, y_inverted)={ score } is not <= 0.2"
141
151
142
152
153
+ @pytest .mark .skipif (
154
+ not _check_soft_dependencies ("prts" , severity = "none" ),
155
+ reason = "required soft dependency prts not available" ,
156
+ )
143
157
def test_range_based_f1 ():
144
158
"""Test range-based F1 score."""
145
159
y_pred = np .array ([0 , 1 , 1 , 0 ])
@@ -148,6 +162,10 @@ def test_range_based_f1():
148
162
np .testing .assert_almost_equal (result , 0.66666 , decimal = 4 )
149
163
150
164
165
+ @pytest .mark .skipif (
166
+ not _check_soft_dependencies ("prts" , severity = "none" ),
167
+ reason = "required soft dependency prts not available" ,
168
+ )
151
169
def test_range_based_precision ():
152
170
"""Test range-based precision."""
153
171
y_pred = np .array ([0 , 1 , 1 , 0 ])
@@ -156,6 +174,10 @@ def test_range_based_precision():
156
174
assert result == 0.5
157
175
158
176
177
+ @pytest .mark .skipif (
178
+ not _check_soft_dependencies ("prts" , severity = "none" ),
179
+ reason = "required soft dependency prts not available" ,
180
+ )
159
181
def test_range_based_recall ():
160
182
"""Test range-based recall."""
161
183
y_pred = np .array ([0 , 1 , 1 , 0 ])
@@ -164,6 +186,10 @@ def test_range_based_recall():
164
186
assert result == 1
165
187
166
188
189
+ @pytest .mark .skipif (
190
+ not _check_soft_dependencies ("prts" , severity = "none" ),
191
+ reason = "required soft dependency prts not available" ,
192
+ )
167
193
def test_rf1_value_error ():
168
194
"""Test range-based F1 score raises ValueError on binary predictions."""
169
195
y_pred = np .array ([0 , 0.2 , 0.7 , 0 ])
@@ -187,6 +213,10 @@ def test_pr_curve_auc():
187
213
# np.testing.assert_almost_equal(result, 0.8333, decimal=4)
188
214
189
215
216
+ @pytest .mark .skipif (
217
+ not _check_soft_dependencies ("prts" , severity = "none" ),
218
+ reason = "required soft dependency prts not available" ,
219
+ )
190
220
def test_range_based_p_range_based_r_curve_auc ():
191
221
"""Test range-based precision-recall curve AUC."""
192
222
y_pred = np .array ([0 , 0.1 , 1.0 , 0.5 , 0.1 , 0 ])
@@ -195,6 +225,10 @@ def test_range_based_p_range_based_r_curve_auc():
195
225
np .testing .assert_almost_equal (result , 0.9792 , decimal = 4 )
196
226
197
227
228
+ @pytest .mark .skipif (
229
+ not _check_soft_dependencies ("prts" , severity = "none" ),
230
+ reason = "required soft dependency prts not available" ,
231
+ )
198
232
def test_range_based_p_range_based_r_auc_perfect_hit ():
199
233
"""Test range-based precision-recall curve AUC with perfect hit."""
200
234
y_pred = np .array ([0 , 0 , 0.5 , 0.5 , 0 , 0 ])
@@ -203,6 +237,10 @@ def test_range_based_p_range_based_r_auc_perfect_hit():
203
237
np .testing .assert_almost_equal (result , 1.0000 , decimal = 4 )
204
238
205
239
240
+ @pytest .mark .skipif (
241
+ not _check_soft_dependencies ("prts" , severity = "none" ),
242
+ reason = "required soft dependency prts not available" ,
243
+ )
206
244
def test_f_score_at_k_ranges ():
207
245
"""Test range-based F1 score at k ranges."""
208
246
y_pred = np .array ([0.4 , 0.1 , 1.0 , 0.5 , 0.1 , 0 , 0.4 , 0.5 ])
0 commit comments