From 5cb99b7c1d375a527377f555b30a3f1102eee4cb Mon Sep 17 00:00:00 2001 From: Aryan Ramani Date: Mon, 19 May 2025 14:20:23 +0100 Subject: [PATCH 1/2] enh: adds a check for consistent output for predict and predict_proba --- .../estimator_checking/_yield_classification_checks.py | 7 +++++++ .../estimator_checking/_yield_clustering_checks.py | 9 +++++++++ 2 files changed, 16 insertions(+) diff --git a/aeon/testing/estimator_checking/_yield_classification_checks.py b/aeon/testing/estimator_checking/_yield_classification_checks.py index 0591119f7a..40a7d71244 100644 --- a/aeon/testing/estimator_checking/_yield_classification_checks.py +++ b/aeon/testing/estimator_checking/_yield_classification_checks.py @@ -374,3 +374,10 @@ def check_classifier_output(estimator, datatype): # check predict proba (all classifiers have predict_proba by default) y_proba = estimator.predict_proba(FULL_TEST_DATA_DICT[datatype]["test"][0]) _assert_predict_probabilities(y_proba, datatype, n_classes=len(unique_labels)) + + y_pred_proba = np.argmax(y_proba, axis=1) + _assert_predict_labels(y_pred_proba, datatype, unique_labels=unique_labels) + + np.testing.assert_array_equal( + y_pred, y_pred_proba, err_msg="predict and predict_proba are not consistent" + ) diff --git a/aeon/testing/estimator_checking/_yield_clustering_checks.py b/aeon/testing/estimator_checking/_yield_clustering_checks.py index 4e3940c489..6ed44212a8 100644 --- a/aeon/testing/estimator_checking/_yield_clustering_checks.py +++ b/aeon/testing/estimator_checking/_yield_clustering_checks.py @@ -141,6 +141,15 @@ def check_clusterer_output(estimator, datatype): assert isinstance(y_proba, np.ndarray) np.testing.assert_almost_equal(y_proba.sum(axis=1), 1, decimal=4) + # check predict and predict_proba have consistent outputs + y_pred_proba = np.argmax(y_proba, axis=1) + + np.testing.assert_array_equal( + y_pred, + y_pred_proba, + err_msg="predict and predict_proba outputs are inconsistent", + ) + def check_clusterer_saving_loading_deep_learning(estimator_class, datatype): """Test Deep Clusterer saving.""" From ca80c1863589d4b21e45c1c22b1b7a89f863bc79 Mon Sep 17 00:00:00 2001 From: Aryan Ramani Date: Thu, 22 May 2025 09:58:22 +0100 Subject: [PATCH 2/2] using classes_ to obtain predictions labels --- .../estimator_checking/_yield_classification_checks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aeon/testing/estimator_checking/_yield_classification_checks.py b/aeon/testing/estimator_checking/_yield_classification_checks.py index 40a7d71244..399747c2db 100644 --- a/aeon/testing/estimator_checking/_yield_classification_checks.py +++ b/aeon/testing/estimator_checking/_yield_classification_checks.py @@ -375,8 +375,8 @@ def check_classifier_output(estimator, datatype): y_proba = estimator.predict_proba(FULL_TEST_DATA_DICT[datatype]["test"][0]) _assert_predict_probabilities(y_proba, datatype, n_classes=len(unique_labels)) - y_pred_proba = np.argmax(y_proba, axis=1) - _assert_predict_labels(y_pred_proba, datatype, unique_labels=unique_labels) + y_pred_proba_indices = np.argmax(y_proba, axis=1) + y_pred_proba = estimator.classes_[y_pred_proba_indices] np.testing.assert_array_equal( y_pred, y_pred_proba, err_msg="predict and predict_proba are not consistent"