diff --git a/aeon/testing/estimator_checking/_yield_classification_checks.py b/aeon/testing/estimator_checking/_yield_classification_checks.py index 0591119f7a..399747c2db 100644 --- a/aeon/testing/estimator_checking/_yield_classification_checks.py +++ b/aeon/testing/estimator_checking/_yield_classification_checks.py @@ -374,3 +374,10 @@ def check_classifier_output(estimator, datatype): # check predict proba (all classifiers have predict_proba by default) y_proba = estimator.predict_proba(FULL_TEST_DATA_DICT[datatype]["test"][0]) _assert_predict_probabilities(y_proba, datatype, n_classes=len(unique_labels)) + + y_pred_proba_indices = np.argmax(y_proba, axis=1) + y_pred_proba = estimator.classes_[y_pred_proba_indices] + + np.testing.assert_array_equal( + y_pred, y_pred_proba, err_msg="predict and predict_proba are not consistent" + ) diff --git a/aeon/testing/estimator_checking/_yield_clustering_checks.py b/aeon/testing/estimator_checking/_yield_clustering_checks.py index 4e3940c489..6ed44212a8 100644 --- a/aeon/testing/estimator_checking/_yield_clustering_checks.py +++ b/aeon/testing/estimator_checking/_yield_clustering_checks.py @@ -141,6 +141,15 @@ def check_clusterer_output(estimator, datatype): assert isinstance(y_proba, np.ndarray) np.testing.assert_almost_equal(y_proba.sum(axis=1), 1, decimal=4) + # check predict and predict_proba have consistent outputs + y_pred_proba = np.argmax(y_proba, axis=1) + + np.testing.assert_array_equal( + y_pred, + y_pred_proba, + err_msg="predict and predict_proba outputs are inconsistent", + ) + def check_clusterer_saving_loading_deep_learning(estimator_class, datatype): """Test Deep Clusterer saving."""