Skip to content

Commit 315d2d2

Browse files
authored
Utility function to warn changes in default arguments (#5738)
Fixes #5653. ### Description A new utility decorator that enables us to warn users of a changing default argument. Current implementation does the following: - If version < since no warning will be printed. - If the default argument is explicitly set by the caller no warning will be printed. - If since <= version < replaced a warning like this will be printed: "Default of argument `{name}` has been deprecated since version `{since}` from `{old_default}` to `{new_default}`. It will be changed in version `{replaced}`." - If replaced <= version a warning like this will be printed: "Default of argument `{name}` was changed in version `{changed}` from `{old_default}` to `{new_default}`." - It doesn't validate the `old_default`, so you can even use this in scenarios where the default is actually `None` but set later in the function. This also enables us to set `old_default` to any string if the default is e.g. not printable. - The only validation that will throw an error is, if the `new_default` == the actual default and version < replaced. Which means, that somebody replaced the value already, before the version was incremented. Apart from that also any value for `new_default` can be set, giving the same advantages as for the `old_default`. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Felix Schnabel <[email protected]>
1 parent c5ea694 commit 315d2d2

File tree

3 files changed

+259
-4
lines changed

3 files changed

+259
-4
lines changed

monai/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# have to explicitly bring these in here to resolve circular import issues
1313
from .aliases import alias, resolve_name
1414
from .decorators import MethodReplacer, RestartGenerator
15-
from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg
15+
from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default
1616
from .dist import evenly_divisible_all_gather, get_dist_device, string_list_all_gather
1717
from .enums import (
1818
Average,

monai/utils/deprecate_utils.py

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
import warnings
1515
from functools import wraps
1616
from types import FunctionType
17-
from typing import Optional
17+
from typing import Any, Optional
1818

1919
from monai.utils.module import version_leq
2020

2121
from .. import __version__
2222

23-
__all__ = ["deprecated", "deprecated_arg", "DeprecatedError"]
23+
__all__ = ["deprecated", "deprecated_arg", "DeprecatedError", "deprecated_arg_default"]
2424

2525

2626
class DeprecatedError(Exception):
@@ -223,3 +223,105 @@ def _wrapper(*args, **kwargs):
223223
return _wrapper
224224

225225
return _decorator
226+
227+
228+
def deprecated_arg_default(
229+
name: str,
230+
old_default: Any,
231+
new_default: Any,
232+
since: Optional[str] = None,
233+
replaced: Optional[str] = None,
234+
msg_suffix: str = "",
235+
version_val: str = __version__,
236+
warning_category=FutureWarning,
237+
):
238+
"""
239+
Marks a particular arguments default of a callable as deprecated. It is changed from `old_default` to `new_default`
240+
in version `changed`.
241+
242+
When the decorated definition is called, a `warning_category` is issued if `since` is given,
243+
the default is not explicitly set by the caller and the current version is at or later than that given.
244+
Another warning with the same category is issued if `changed` is given and the current version is at or later.
245+
246+
The relevant docstring of the deprecating function should also be updated accordingly,
247+
using the Sphinx directives such as `.. versionchanged:: version` and `.. deprecated:: version`.
248+
https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-versionadded
249+
250+
In the current implementation type annotations are not preserved.
251+
252+
253+
Args:
254+
name: name of position or keyword argument where the default is deprecated/changed.
255+
old_default: name of the old default. This is only for the warning message, it will not be validated.
256+
new_default: name of the new default.
257+
It is validated that this value is not present as the default before version `replaced`.
258+
This means, that you can also use this if the actual default value is `None` and set later in the function.
259+
You can also set this to any string representation, e.g. `"calculate_default_value()"`
260+
if the default is calculated from another function.
261+
since: version at which the argument default was marked deprecated but not replaced.
262+
replaced: version at which the argument default was/will be replaced.
263+
msg_suffix: message appended to warning/exception detailing reasons for deprecation.
264+
version_val: (used for testing) version to compare since and removed against, default is MONAI version.
265+
warning_category: a warning category class, defaults to `FutureWarning`.
266+
267+
Returns:
268+
Decorated callable which warns when deprecated default argument is not explicitly specified.
269+
"""
270+
271+
if version_val.startswith("0+") or not f"{version_val}".strip()[0].isdigit():
272+
# version unknown, set version_val to a large value (assuming the latest version)
273+
version_val = f"{sys.maxsize}"
274+
if since is not None and replaced is not None and not version_leq(since, replaced):
275+
raise ValueError(f"since must be less or equal to replaced, got since={since}, replaced={replaced}.")
276+
is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since)
277+
if is_not_yet_deprecated:
278+
# smaller than `since`, do nothing
279+
return lambda obj: obj
280+
if since is None and replaced is None:
281+
# raise a DeprecatedError directly
282+
is_replaced = True
283+
is_deprecated = True
284+
else:
285+
# compare the numbers
286+
is_deprecated = since is not None and version_leq(since, version_val)
287+
is_replaced = replaced is not None and version_leq(replaced, version_val)
288+
289+
def _decorator(func):
290+
argname = f"{func.__module__} {func.__qualname__}:{name}"
291+
292+
msg_prefix = f"Default of argument `{name}`"
293+
294+
if is_replaced:
295+
msg_infix = f"was replaced in version {replaced} from `{old_default}` to `{new_default}`."
296+
elif is_deprecated:
297+
msg_infix = f"has been deprecated since version {since} from `{old_default}` to `{new_default}`."
298+
if replaced is not None:
299+
msg_infix += f" It will be replaced in version {replaced}."
300+
else:
301+
msg_infix = f"has been deprecated from `{old_default}` to `{new_default}`."
302+
303+
msg = f"{msg_prefix} {msg_infix} {msg_suffix}".strip()
304+
305+
sig = inspect.signature(func)
306+
if name not in sig.parameters:
307+
raise ValueError(f"Argument `{name}` not found in signature of {func.__qualname__}.")
308+
param = sig.parameters[name]
309+
if param.default is inspect.Parameter.empty:
310+
raise ValueError(f"Argument `{name}` has no default value.")
311+
312+
if param.default == new_default and not is_replaced:
313+
raise ValueError(
314+
f"Argument `{name}` was replaced to the new default value `{new_default}` before the specified version {replaced}."
315+
)
316+
317+
@wraps(func)
318+
def _wrapper(*args, **kwargs):
319+
if name not in sig.bind(*args, **kwargs).arguments and is_deprecated:
320+
# arg was not found so the default value is used
321+
warn_deprecated(argname, msg, warning_category)
322+
323+
return func(*args, **kwargs)
324+
325+
return _wrapper
326+
327+
return _decorator

tests/test_deprecated.py

Lines changed: 154 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import unittest
1313
import warnings
1414

15-
from monai.utils import DeprecatedError, deprecated, deprecated_arg
15+
from monai.utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default
1616

1717

1818
class TestDeprecatedRC(unittest.TestCase):
@@ -287,6 +287,159 @@ def afoo4(a, b=None, **kwargs):
287287
self.assertEqual(afoo4(a=1, b=2, c=3), (1, {"c": 3})) # prefers the new arg
288288
self.assertEqual(afoo4(1, 2, c=3), (1, {"c": 3})) # prefers the new positional arg
289289

290+
def test_deprecated_arg_default_explicit_default(self):
291+
"""
292+
Test deprecated arg default, where the default is explicitly set (no warning).
293+
"""
294+
295+
@deprecated_arg_default(
296+
"b", old_default="a", new_default="b", since=self.prev_version, version_val=self.test_version
297+
)
298+
def foo(a, b="a"):
299+
return a, b
300+
301+
with self.assertWarns(FutureWarning) as aw:
302+
self.assertEqual(foo("a", "a"), ("a", "a"))
303+
self.assertEqual(foo("a", "b"), ("a", "b"))
304+
self.assertEqual(foo("a", "c"), ("a", "c"))
305+
warnings.warn("fake warning", FutureWarning)
306+
307+
self.assertEqual(aw.warning.args[0], "fake warning")
308+
309+
def test_deprecated_arg_default_version_less_than_since(self):
310+
"""
311+
Test deprecated arg default, where the current version is less than `since` (no warning).
312+
"""
313+
314+
@deprecated_arg_default(
315+
"b", old_default="a", new_default="b", since=self.test_version, version_val=self.prev_version
316+
)
317+
def foo(a, b="a"):
318+
return a, b
319+
320+
with self.assertWarns(FutureWarning) as aw:
321+
self.assertEqual(foo("a"), ("a", "a"))
322+
self.assertEqual(foo("a", "a"), ("a", "a"))
323+
warnings.warn("fake warning", FutureWarning)
324+
325+
self.assertEqual(aw.warning.args[0], "fake warning")
326+
327+
def test_deprecated_arg_default_warning_deprecated(self):
328+
"""
329+
Test deprecated arg default, where the default is used.
330+
"""
331+
332+
@deprecated_arg_default(
333+
"b", old_default="a", new_default="b", since=self.prev_version, version_val=self.test_version
334+
)
335+
def foo(a, b="a"):
336+
return a, b
337+
338+
self.assertWarns(FutureWarning, lambda: foo("a"))
339+
340+
def test_deprecated_arg_default_warning_replaced(self):
341+
"""
342+
Test deprecated arg default, where the default is used.
343+
"""
344+
345+
@deprecated_arg_default(
346+
"b",
347+
old_default="a",
348+
new_default="b",
349+
since=self.prev_version,
350+
replaced=self.prev_version,
351+
version_val=self.test_version,
352+
)
353+
def foo(a, b="a"):
354+
return a, b
355+
356+
self.assertWarns(FutureWarning, lambda: foo("a"))
357+
358+
def test_deprecated_arg_default_warning_with_none_as_placeholder(self):
359+
"""
360+
Test deprecated arg default, where the default is used.
361+
"""
362+
363+
@deprecated_arg_default(
364+
"b", old_default="a", new_default="b", since=self.prev_version, version_val=self.test_version
365+
)
366+
def foo(a, b=None):
367+
if b is None:
368+
b = "a"
369+
return a, b
370+
371+
self.assertWarns(FutureWarning, lambda: foo("a"))
372+
373+
@deprecated_arg_default(
374+
"b", old_default="a", new_default="b", since=self.prev_version, version_val=self.test_version
375+
)
376+
def foo2(a, b=None):
377+
if b is None:
378+
b = "b"
379+
return a, b
380+
381+
self.assertWarns(FutureWarning, lambda: foo2("a"))
382+
383+
def test_deprecated_arg_default_errors(self):
384+
"""
385+
Test deprecated arg default, where the decorator is wrongly used.
386+
"""
387+
388+
# since > replaced
389+
def since_grater_than_replaced():
390+
@deprecated_arg_default(
391+
"b",
392+
old_default="a",
393+
new_default="b",
394+
since=self.test_version,
395+
replaced=self.prev_version,
396+
version_val=self.test_version,
397+
)
398+
def foo(a, b=None):
399+
return a, b
400+
401+
self.assertRaises(ValueError, since_grater_than_replaced)
402+
403+
# argname doesnt exist
404+
def argname_doesnt_exist():
405+
@deprecated_arg_default(
406+
"other", old_default="a", new_default="b", since=self.test_version, version_val=self.test_version
407+
)
408+
def foo(a, b=None):
409+
return a, b
410+
411+
self.assertRaises(ValueError, argname_doesnt_exist)
412+
413+
# argname has no default
414+
def argname_has_no_default():
415+
@deprecated_arg_default(
416+
"a",
417+
old_default="a",
418+
new_default="b",
419+
since=self.prev_version,
420+
replaced=self.test_version,
421+
version_val=self.test_version,
422+
)
423+
def foo(a):
424+
return a
425+
426+
self.assertRaises(ValueError, argname_has_no_default)
427+
428+
# new default is used but version < replaced
429+
def argname_was_replaced_before_specified_version():
430+
@deprecated_arg_default(
431+
"a",
432+
old_default="a",
433+
new_default="b",
434+
since=self.prev_version,
435+
replaced=self.next_version,
436+
version_val=self.test_version,
437+
)
438+
def foo(a, b="b"):
439+
return a, b
440+
441+
self.assertRaises(ValueError, argname_was_replaced_before_specified_version)
442+
290443

291444
if __name__ == "__main__":
292445
unittest.main()

0 commit comments

Comments
 (0)