diff --git a/tests/unittests/classification/test_precision_recall.py b/tests/unittests/classification/test_precision_recall.py index 00eee202cc0..a40427d16c4 100644 --- a/tests/unittests/classification/test_precision_recall.py +++ b/tests/unittests/classification/test_precision_recall.py @@ -659,6 +659,37 @@ def test_corner_case(): assert res == 1.0 +def test_multiclass_recall_ignore_index(): + """Issue: https://github.com/Lightning-AI/torchmetrics/issues/2441.""" + y_true = torch.tensor([0, 0, 1, 1]) + y_pred = torch.tensor([ + [0.9, 0.1], + [0.9, 0.1], + [0.9, 0.1], + [0.1, 0.9], + ]) + + # Test with ignore_index=0 and average="macro" + metric_ignore_0 = MulticlassRecall(num_classes=2, ignore_index=0, average="macro") + res_ignore_0 = metric_ignore_0(y_pred, y_true) + assert res_ignore_0 == 0.5, f"Expected 0.5, but got {res_ignore_0}" + + # Test with ignore_index=1 and average="macro" + metric_ignore_1 = MulticlassRecall(num_classes=2, ignore_index=1, average="macro") + res_ignore_1 = metric_ignore_1(y_pred, y_true) + assert res_ignore_1 == 1.0, f"Expected 1.0, but got {res_ignore_1}" + + # Test with no ignore_index and average="macro" + metric_no_ignore = MulticlassRecall(num_classes=2, average="macro") + res_no_ignore = metric_no_ignore(y_pred, y_true) + assert res_no_ignore == 0.75, f"Expected 0.75, but got {res_no_ignore}" + + # Test with ignore_index=0 and average="none" + metric_none = MulticlassRecall(num_classes=2, ignore_index=0, average="none") + res_none = metric_none(y_pred, y_true) + assert torch.allclose(res_none, torch.tensor([0.0, 0.5])), f"Expected [0.0, 0.5], but got {res_none}" + + @pytest.mark.parametrize( ("metric", "kwargs", "base_metric"), [