Skip to content

Commit

Permalink
test for _check_arg
Browse files Browse the repository at this point in the history
relative
  • Loading branch information
Borda committed Apr 20, 2021
1 parent 1aab672 commit db5a1a1
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ addopts =
[coverage:run]
parallel = True
concurrency = thread
relative_files = True

[coverage:report]
exclude_lines =
Expand Down
8 changes: 8 additions & 0 deletions tests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,11 @@ def test_metric_collection_same_order():
col2 = MetricCollection({"b": m2, "a": m1})
for k1, k2 in zip(col1.keys(), col2.keys()):
assert k1 == k2


def test_collection_check_arg():
assert MetricCollection._check_arg(None, 'prefix') is None
assert MetricCollection._check_arg('sample', 'prefix') == 'sample'

with pytest.raises(ValueError, match="Expected input `postfix` to be a string, but got"):
MetricCollection._check_arg(1, 'postfix')
10 changes: 4 additions & 6 deletions torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,7 @@ def _set_name(self, base: str) -> str:
return name

@staticmethod
def _check_arg(arg: str, name: str) -> Optional[str]:
if arg is not None:
if isinstance(arg, str):
return arg
raise ValueError(f'Expected input {name} to be a string')
return None
def _check_arg(arg: Optional[str], name: str) -> Optional[str]:
if arg is None or isinstance(arg, str):
return arg
raise ValueError(f'Expected input `{name}` to be a string, but got {type(arg)}')

0 comments on commit db5a1a1

Please sign in to comment.