Skip to content

Commit

Permalink
Make num_classes optional, in case of micro averaging (#2841)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka B <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Jan 13, 2025
1 parent 276b90b commit 5b0e9b8
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 19 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Make `num_classes` optional for classification in case of micro averaging ([#2841](https://github.com/PyTorchLightning/metrics/pull/2841))


- Enabled specifying weights path for FID ([#2867](https://github.com/PyTorchLightning/metrics/pull/2867))


Expand Down
7 changes: 4 additions & 3 deletions src/torchmetrics/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ class MulticlassStatScores(_AbstractStatScores):

def __init__(
self,
num_classes: int,
num_classes: Optional[int] = None,
top_k: int = 1,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
multidim_average: Literal["global", "samplewise"] = "global",
Expand All @@ -330,7 +330,7 @@ def __init__(
self.zero_division = zero_division

self._create_state(
size=1 if (average == "micro" and top_k == 1) else num_classes, multidim_average=multidim_average
size=1 if (average == "micro" and top_k == 1) else (num_classes or 1), multidim_average=multidim_average
)

def update(self, preds: Tensor, target: Tensor) -> None:
Expand All @@ -340,8 +340,9 @@ def update(self, preds: Tensor, target: Tensor) -> None:
preds, target, self.num_classes, self.multidim_average, self.ignore_index
)
preds, target = _multiclass_stat_scores_format(preds, target, self.top_k)
num_classes = self.num_classes if self.num_classes is not None else 1
tp, fp, tn, fn = _multiclass_stat_scores_update(
preds, target, self.num_classes, self.top_k, self.average, self.multidim_average, self.ignore_index
preds, target, num_classes, self.top_k, self.average, self.multidim_average, self.ignore_index
)
self._update_state(tp, fp, tn, fn)

Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/functional/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def binary_accuracy(
def multiclass_accuracy(
preds: Tensor,
target: Tensor,
num_classes: int,
num_classes: Optional[int] = None,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
top_k: int = 1,
multidim_average: Literal["global", "samplewise"] = "global",
Expand Down Expand Up @@ -266,7 +266,7 @@ def multiclass_accuracy(
_multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index)
preds, target = _multiclass_stat_scores_format(preds, target, top_k)
tp, fp, tn, fn = _multiclass_stat_scores_update(
preds, target, num_classes, top_k, average, multidim_average, ignore_index
preds, target, num_classes or 1, top_k, average, multidim_average, ignore_index
)
return _accuracy_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average, top_k=top_k)

Expand Down
32 changes: 18 additions & 14 deletions src/torchmetrics/functional/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def binary_stat_scores(


def _multiclass_stat_scores_arg_validation(
num_classes: int,
num_classes: Optional[int],
top_k: int = 1,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
multidim_average: Literal["global", "samplewise"] = "global",
Expand All @@ -235,11 +235,15 @@ def _multiclass_stat_scores_arg_validation(
- ``zero_division`` has to be 0 or 1
"""
if not isinstance(num_classes, int) or num_classes < 2:
if num_classes is None and average != "micro":
raise ValueError(
f"Argument `num_classes` can only be `None` for `average='micro'`, but got `average={average}`."
)
if num_classes is not None and (not isinstance(num_classes, int) or num_classes < 2):
raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}")
if not isinstance(top_k, int) and top_k < 1:
raise ValueError(f"Expected argument `top_k` to be an integer larger than or equal to 1, but got {top_k}")
if top_k > num_classes:
if top_k > (num_classes if num_classes is not None else 1):
raise ValueError(
f"Expected argument `top_k` to be smaller or equal to `num_classes` but got {top_k} and {num_classes}"
)
Expand All @@ -260,7 +264,7 @@ def _multiclass_stat_scores_arg_validation(
def _multiclass_stat_scores_tensor_validation(
preds: Tensor,
target: Tensor,
num_classes: int,
num_classes: Optional[int],
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
) -> None:
Expand All @@ -278,7 +282,7 @@ def _multiclass_stat_scores_tensor_validation(
if preds.ndim == target.ndim + 1:
if not preds.is_floating_point():
raise ValueError("If `preds` have one dimension more than `target`, `preds` should be a float tensor.")
if preds.shape[1] != num_classes:
if num_classes is not None and preds.shape[1] != num_classes:
raise ValueError(
"If `preds` have one dimension more than `target`, `preds.shape[1]` should be"
" equal to number of classes."
Expand Down Expand Up @@ -310,15 +314,15 @@ def _multiclass_stat_scores_tensor_validation(
"Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)"
" and `preds` should be (N, C, ...)."
)

check_value = num_classes if ignore_index is None else num_classes + 1
for t, name in ((target, "target"),) + ((preds, "preds"),) if not preds.is_floating_point() else (): # noqa: RUF005
num_unique_values = len(torch.unique(t, dim=None))
if num_unique_values > check_value:
raise RuntimeError(
f"Detected more unique values in `{name}` than expected. Expected only {check_value} but found"
f" {num_unique_values} in `{name}`. Found values: {torch.unique(t, dim=None)}."
)
if num_classes is not None:
check_value = num_classes if ignore_index is None else num_classes + 1
for t, name in ((target, "target"),) + ((preds, "preds"),) if not preds.is_floating_point() else (): # noqa: RUF005
num_unique_values = len(torch.unique(t, dim=None))
if num_unique_values > check_value:
raise RuntimeError(
f"Detected more unique values in `{name}` than expected. Expected only {check_value} but found"
f" {num_unique_values} in `{name}`. Found values: {torch.unique(t, dim=None)}."
)


def _multiclass_stat_scores_format(
Expand Down
6 changes: 6 additions & 0 deletions tests/unittests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,12 @@ def test_corner_cases():
res = metric(preds, target)
assert res == 1.0

metric_micro1 = MulticlassAccuracy(num_classes=None, average="micro", ignore_index=0)
metric_micro2 = MulticlassAccuracy(num_classes=3, average="micro", ignore_index=0)
res1 = metric_micro1(preds, target)
res2 = metric_micro2(preds, target)
assert res1 == res2


@pytest.mark.parametrize(
("metric", "kwargs"),
Expand Down

0 comments on commit 5b0e9b8

Please sign in to comment.