From b8ebc96941639328aec7b81502e479b558aff898 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Mon, 1 Jul 2024 18:32:33 -0700 Subject: [PATCH] Do not trigger B901 with explicit Generator return type --- bugbear.py | 34 ++++++++++++++++++++++++++++++---- tests/b901.py | 38 +++++++++++++++++++++++++++++++++++--- tests/test_bugbear.py | 5 ++++- 3 files changed, 69 insertions(+), 8 deletions(-) diff --git a/bugbear.py b/bugbear.py index 61aeaca..5e80576 100644 --- a/bugbear.py +++ b/bugbear.py @@ -1192,10 +1192,27 @@ def _loop(parent, node): for child in node.body: yield from _loop(node, child) - def check_for_b901(self, node): + def check_for_b901(self, node: ast.FunctionDef) -> None: if node.name == "__await__": return + # If the user explicitly wrote the 3-argument version of Generator as the + # return annotation, they probably know what they were doing. + if ( + node.returns is not None + and isinstance(node.returns, ast.Subscript) + and ( + is_name(node.returns.value, "Generator") + or is_name(node.returns.value, "typing.Generator") + or is_name(node.returns.value, "collections.abc.Generator") + ) + ): + slice = node.returns.slice + if sys.version_info < (3, 9) and isinstance(slice, ast.Index): + slice = slice.value + if isinstance(slice, ast.Tuple) and len(slice.elts) == 3: + return + has_yield = False return_node = None @@ -1209,9 +1226,8 @@ def check_for_b901(self, node): if isinstance(x, ast.Return) and x.value is not None: return_node = x - if has_yield and return_node is not None: - self.errors.append(B901(return_node.lineno, return_node.col_offset)) - break + if has_yield and return_node is not None: + self.errors.append(B901(return_node.lineno, return_node.col_offset)) # taken from pep8-naming @classmethod @@ -1708,6 +1724,16 @@ def compose_call_path(node): yield node.id +def is_name(node: ast.expr, name: str) -> bool: + if "." not in name: + return isinstance(node, ast.Name) and node.id == name + else: + if not isinstance(node, ast.Attribute): + return False + rest, attr = name.rsplit(".", maxsplit=1) + return node.attr == attr and is_name(node.value, rest) + + def _tansform_slice_to_py39(slice: ast.Slice) -> ast.Slice | ast.Name: """Transform a py38 style slice to a py39 style slice. diff --git a/tests/b901.py b/tests/b901.py index 8caea73..276eaca 100644 --- a/tests/b901.py +++ b/tests/b901.py @@ -1,11 +1,11 @@ """ Should emit: -B901 - on lines 9, 36 +B901 """ def broken(): if True: - return [1, 2, 3] + return [1, 2, 3] # B901 yield 3 yield 2 @@ -32,7 +32,7 @@ def not_broken3(): def broken2(): - return [3, 2, 1] + return [3, 2, 1] # B901 yield from not_broken() @@ -75,3 +75,35 @@ class NotBroken9(object): def __await__(self): yield from function() return 42 + + +def broken3(): + if True: + return [1, 2, 3] # B901 + else: + yield 3 + + +def broken4() -> Iterable[str]: + yield "x" + return ["x"] # B901 + + +def broken5() -> Generator[str]: + yield "x" + return ["x"] # B901 + + +def not_broken10() -> Generator[str, int, float]: + yield "x" + return 1.0 + + +def not_broken11() -> typing.Generator[str, int, float]: + yield "x" + return 1.0 + + +def not_broken12() -> collections.abc.Generator[str, int, float]: + yield "x" + return 1.0 diff --git a/tests/test_bugbear.py b/tests/test_bugbear.py index 2de9f3b..50c04e6 100644 --- a/tests/test_bugbear.py +++ b/tests/test_bugbear.py @@ -792,7 +792,10 @@ def test_b901(self): filename = Path(__file__).absolute().parent / "b901.py" bbc = BugBearChecker(filename=str(filename)) errors = list(bbc.run()) - self.assertEqual(errors, self.errors(B901(8, 8), B901(35, 4))) + self.assertEqual( + errors, + self.errors(B901(8, 8), B901(35, 4), B901(82, 8), B901(89, 4), B901(94, 4)), + ) def test_b902(self): filename = Path(__file__).absolute().parent / "b902.py"