From 2d991b5fbdd16849acf9532471de176c64eb473a Mon Sep 17 00:00:00 2001 From: Benny Huang Date: Mon, 9 Dec 2024 02:03:19 -0500 Subject: [PATCH 1/3] Fix for warning test arguments with default values --- src/_pytest/python.py | 16 +++++++++++++++- src/_pytest/warning_types.py | 7 +++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/_pytest/python.py b/src/_pytest/python.py index 85e3cb0ae71..0c23297d2a2 100644 --- a/src/_pytest/python.py +++ b/src/_pytest/python.py @@ -69,7 +69,7 @@ from _pytest.scope import _ScopeName from _pytest.scope import Scope from _pytest.stash import StashKey -from _pytest.warning_types import PytestCollectionWarning +from _pytest.warning_types import PytestCollectionWarning, PytestDefaultArgumentWarning, warn_explicit_for if TYPE_CHECKING: @@ -143,9 +143,23 @@ def async_fail(nodeid: str) -> None: fail(msg, pytrace=False) +def async_default_arg_warn(nodeid: str, function_name, param) -> None: + msg = ( + "Test function '" + function_name + "' has a default argument '" + param.name + "=" + str(param.default) + "'.\n" + ) + warnings.simplefilter("always", PytestDefaultArgumentWarning) + warnings.warn(PytestDefaultArgumentWarning(msg)) + + @hookimpl(trylast=True) def pytest_pyfunc_call(pyfuncitem: Function) -> object | None: testfunction = pyfuncitem.obj + sig = inspect.signature(testfunction) + for param in sig.parameters.values(): + if param.default is not param.empty: + function_name = testfunction.__name__ + async_default_arg_warn(pyfuncitem.nodeid, function_name, param) + if is_async_function(testfunction): async_fail(pyfuncitem.nodeid) funcargs = pyfuncitem.funcargs diff --git a/src/_pytest/warning_types.py b/src/_pytest/warning_types.py index b8e9998cd2e..76f33b19fe8 100644 --- a/src/_pytest/warning_types.py +++ b/src/_pytest/warning_types.py @@ -44,6 +44,13 @@ class PytestCollectionWarning(PytestWarning): __module__ = "pytest" +@final +class PytestDefaultArgumentWarning(PytestWarning): + """Warning emitted when a test function has default arguments.""" + + __module__ = "pytest" + + class PytestDeprecationWarning(PytestWarning, DeprecationWarning): """Warning class for features that will be removed in a future version.""" From 8b4bd3abc1c7c9f2eed218410c39bc2954dc8d8d Mon Sep 17 00:00:00 2001 From: Benny Huang Date: Mon, 9 Dec 2024 14:35:33 -0500 Subject: [PATCH 2/3] Fixed placements of the warnings to be in fixture, changed warning message from using string concat to f-strings for more readability, and added a function in compat to grab default parameters and its values --- src/_pytest/compat.py | 15 ++++++++++ src/_pytest/fixtures.py | 22 +++++++++++++- src/_pytest/python.py | 13 --------- testing/test_default_params.py | 52 ++++++++++++++++++++++++++++++++++ 4 files changed, 88 insertions(+), 14 deletions(-) create mode 100644 testing/test_default_params.py diff --git a/src/_pytest/compat.py b/src/_pytest/compat.py index 1aa7495bddb..6d3cc2c5310 100644 --- a/src/_pytest/compat.py +++ b/src/_pytest/compat.py @@ -173,6 +173,21 @@ def get_default_arg_names(function: Callable[..., Any]) -> tuple[str, ...]: and p.default is not Parameter.empty ) +# def check_default_arguments(func, nodeid): +# """Check for default arguments in the function and issue warnings.""" +# sig = inspect.signature(func) +# for param in sig.parameters.values(): +# if param.default is not param.empty: +# function_name = func.__name__ +# async_default_arg_warn(nodeid, function_name, param) + +def get_default_name_val(function: Callable[..., Any]) -> dict[str, Any]: + sig = signature(function) + return { + param.name: param.default + for param in sig.parameters.values() + if param.default is not Parameter.empty + } _non_printable_ascii_translate_table = { i: f"\\x{i:02x}" for i in range(128) if i not in range(32, 127) diff --git a/src/_pytest/fixtures.py b/src/_pytest/fixtures.py index fc6541c1404..ce8974f4780 100644 --- a/src/_pytest/fixtures.py +++ b/src/_pytest/fixtures.py @@ -41,6 +41,7 @@ from _pytest._code.code import TerminalRepr from _pytest._io import TerminalWriter from _pytest.compat import _PytestWrapper +from _pytest.compat import get_default_name_val from _pytest.compat import assert_never from _pytest.compat import get_real_func from _pytest.compat import get_real_method @@ -70,7 +71,7 @@ from _pytest.scope import _ScopeName from _pytest.scope import HIGH_SCOPES from _pytest.scope import Scope -from _pytest.warning_types import PytestRemovedIn9Warning +from _pytest.warning_types import PytestDefaultArgumentWarning, PytestRemovedIn9Warning from _pytest.warning_types import PytestWarning @@ -1448,6 +1449,20 @@ def deduplicate_names(*seqs: Iterable[str]) -> tuple[str, ...]: return tuple(dict.fromkeys(name for seq in seqs for name in seq)) +def default_arg_warn(nodeid: str, function_name, param_name, param_val) -> None: + msg = ( + f"Test function '{function_name}' has a default argument '{param_name}={param_val}'.\n" + ) + warnings.simplefilter("always", PytestDefaultArgumentWarning) + warnings.warn(PytestDefaultArgumentWarning(msg)) + + +def check_default_arguments(func_name, default_args, nodeid): + """Check for default arguments in the function and issue warnings.""" + for arg_name, default_val in default_args.items(): + default_arg_warn(nodeid, func_name, arg_name, default_val) + + class FixtureManager: """pytest fixture definitions and information is stored and managed from this class. @@ -1528,6 +1543,11 @@ def getfixtureinfo( ignore_args=direct_parametrize_args, ) + if func is not None: + function_name = func.__name__ + default_args = get_default_name_val(func) + check_default_arguments(function_name, default_args, node.nodeid) + return FuncFixtureInfo(argnames, initialnames, names_closure, arg2fixturedefs) def pytest_plugin_registered(self, plugin: _PluggyPlugin, plugin_name: str) -> None: diff --git a/src/_pytest/python.py b/src/_pytest/python.py index 0c23297d2a2..0f4eb3ebad1 100644 --- a/src/_pytest/python.py +++ b/src/_pytest/python.py @@ -143,22 +143,9 @@ def async_fail(nodeid: str) -> None: fail(msg, pytrace=False) -def async_default_arg_warn(nodeid: str, function_name, param) -> None: - msg = ( - "Test function '" + function_name + "' has a default argument '" + param.name + "=" + str(param.default) + "'.\n" - ) - warnings.simplefilter("always", PytestDefaultArgumentWarning) - warnings.warn(PytestDefaultArgumentWarning(msg)) - - @hookimpl(trylast=True) def pytest_pyfunc_call(pyfuncitem: Function) -> object | None: testfunction = pyfuncitem.obj - sig = inspect.signature(testfunction) - for param in sig.parameters.values(): - if param.default is not param.empty: - function_name = testfunction.__name__ - async_default_arg_warn(pyfuncitem.nodeid, function_name, param) if is_async_function(testfunction): async_fail(pyfuncitem.nodeid) diff --git a/testing/test_default_params.py b/testing/test_default_params.py new file mode 100644 index 00000000000..1df15f1c7d3 --- /dev/null +++ b/testing/test_default_params.py @@ -0,0 +1,52 @@ +from _pytest.pytester import Pytester + +def test_no_default_argument(pytester: Pytester) -> None: + pytester.makepyfile( + """ + def test_with_default_param(param): + assert param == 42 + """ + ) + result = pytester.runpytest() + result.stdout.fnmatch_lines([ + "*fixture 'param' not found*" + ]) + + +def test_default_argument_warning(pytester: Pytester) -> None: + pytester.makepyfile( + """ + def test_with_default_param(param=42): + assert param == 42 + """ + ) + result = pytester.runpytest() + result.stdout.fnmatch_lines([ + "*PytestDefaultArgumentWarning: Test function 'test_with_default_param' has a default argument 'param=42'.*" + ]) + + +def test_no_warning_for_no_default_param(pytester: Pytester) -> None: + pytester.makepyfile( + """ + def test_without_default_param(param): + assert param is None + """ + ) + result = pytester.runpytest() + assert "PytestDefaultArgumentWarning" not in result.stdout.str() + + +def test_warning_for_multiple_default_params(pytester: Pytester) -> None: + pytester.makepyfile( + """ + def test_with_multiple_defaults(param1=42, param2="default"): + assert param1 == 42 + assert param2 == "default" + """ + ) + result = pytester.runpytest() + result.stdout.fnmatch_lines([ + "*PytestDefaultArgumentWarning: Test function 'test_with_multiple_defaults' has a default argument 'param1=42'.*", + "*PytestDefaultArgumentWarning: Test function 'test_with_multiple_defaults' has a default argument 'param2=default'.*" + ]) \ No newline at end of file From 97fd1feec76be1419514667394326367e3fd6f28 Mon Sep 17 00:00:00 2001 From: Benny Huang Date: Mon, 9 Dec 2024 14:41:14 -0500 Subject: [PATCH 3/3] Removed comment from compat.py --- src/_pytest/compat.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/_pytest/compat.py b/src/_pytest/compat.py index 6d3cc2c5310..c6491ee6ec9 100644 --- a/src/_pytest/compat.py +++ b/src/_pytest/compat.py @@ -173,13 +173,6 @@ def get_default_arg_names(function: Callable[..., Any]) -> tuple[str, ...]: and p.default is not Parameter.empty ) -# def check_default_arguments(func, nodeid): -# """Check for default arguments in the function and issue warnings.""" -# sig = inspect.signature(func) -# for param in sig.parameters.values(): -# if param.default is not param.empty: -# function_name = func.__name__ -# async_default_arg_warn(nodeid, function_name, param) def get_default_name_val(function: Callable[..., Any]) -> dict[str, Any]: sig = signature(function) @@ -189,6 +182,7 @@ def get_default_name_val(function: Callable[..., Any]) -> dict[str, Any]: if param.default is not Parameter.empty } + _non_printable_ascii_translate_table = { i: f"\\x{i:02x}" for i in range(128) if i not in range(32, 127) }