diff --git a/get_chaining.py b/get_chaining.py index afaa9ad..39358f7 100644 --- a/get_chaining.py +++ b/get_chaining.py @@ -1,10 +1,21 @@ import ast +from enum import Enum, unique import sys from typing import Any, Generator, List, Optional, Tuple, Type -__version__ = "0.1.0" +__version__ = "0.2.0" -ERR_MSG = "dict.get chaining might crash" + +@unique +class ErrorType(Enum): + DGC1001 = "DGC1001" + DGC1002 = "DGC1002" + + +ERR_MSGS = { + ErrorType.DGC1001: "missing default argument when chaining dict.get()", + ErrorType.DGC1002: "invalid default argument when chaining dict.get()", +} def call_position(call: ast.Call) -> Tuple[int, Optional[int]]: @@ -24,13 +35,13 @@ def run(self) -> Generator[Tuple[int, Optional[int], str, Type[Any]], None, None visitor = GetChainingVisitor() visitor.visit(self._tree) - for lineno, col in visitor.issues: - yield lineno, col, f"DGC1001 {ERR_MSG}", type(self) + for err, lineno, col in visitor.issues: + yield lineno, col, f"{err.value} {ERR_MSGS[err]}", type(self) class GetChainingVisitor(ast.NodeVisitor): def __init__(self) -> None: - self.issues: List[Tuple[int, Optional[int]]] = [] + self.issues: List[Tuple[ErrorType, int, Optional[int]]] = [] def visit_Call(self, node: ast.Call) -> Any: @@ -47,19 +58,21 @@ def visit_Call(self, node: ast.Call) -> Any: if len(get_call.args) > 1: arg = get_call.args[1] - if isinstance(arg, ast.Dict) or ( + if not isinstance(arg, ast.Dict) and not ( isinstance(arg, ast.Name) and arg.id.isidentifier() ): - return self.generic_visit(node) - self.issues.append(call_position(get_call)) + self.issues.append((ErrorType.DGC1002, *call_position(get_call))) elif get_call.keywords: for kw in get_call.keywords: if kw.arg == "default": - if isinstance(kw.value, ast.Dict): - return self.generic_visit(node) - if isinstance(kw.value, ast.Name) and kw.value.id.isidentifier(): - return self.generic_visit(node) - self.issues.append(call_position(get_call)) + if not isinstance(kw.value, ast.Dict) and not ( + isinstance(kw.value, ast.Name) and kw.value.id.isidentifier() + ): + self.issues.append( + (ErrorType.DGC1002, *call_position(get_call)) + ) + return self.generic_visit(node) + self.issues.append((ErrorType.DGC1001, *call_position(get_call))) else: - self.issues.append(call_position(get_call)) + self.issues.append((ErrorType.DGC1001, *call_position(get_call))) return self.generic_visit(node) diff --git a/tests/test_get_chaining.py b/tests/test_get_chaining.py index 2947746..43b486a 100644 --- a/tests/test_get_chaining.py +++ b/tests/test_get_chaining.py @@ -4,9 +4,7 @@ import pytest -from get_chaining import ERR_MSG, GetChainingChecker - -EXPECTED_ERR_MSG = f"DGC1001 {ERR_MSG}" +from get_chaining import ERR_MSGS, ErrorType, GetChainingChecker def _gl(line_col): @@ -15,6 +13,10 @@ def _gl(line_col): return line_col # pragma: no coverage +def _err_msg(error_type: ErrorType): + return f"{error_type.value} {ERR_MSGS[error_type]}" + + def _results(code: str) -> List[str]: tree = ast.parse(code) checker = GetChainingChecker(tree) @@ -28,6 +30,8 @@ def _results(code: str) -> List[str]: 'test.do("nothing")', 'test.get("test")', "attr.get", + "attr.get().get", + "attr.get.get()", "notadict.get()", ], ) @@ -48,28 +52,70 @@ def test_valid_chaining(inp): assert not _results(inp) +@pytest.mark.parametrize( + "inp, err_pos", + [ + ('test.get("test").get("test")', "1:16"), + ('test.get("test", {}).get("test").get("test")', "1:32"), + ('test.get("test").get("test", {}).get("test")', "1:16"), + ('test.get("test", default={}).get("test").get("test")', "1:40"), + ('test.get("test").get("test", default={}).get("test")', "1:16"), + ('test.get("test", test=None).get("test")', "1:27"), + ], +) +def test_DGC1001(inp, err_pos): + assert _results(inp) == [f"{_gl(err_pos)} {_err_msg(ErrorType.DGC1001)}"] + + +@pytest.mark.parametrize( + "inp, err_pos", + [ + ('test.get("test", default=None).get("test")', "1:30"), + ('test.get("test", None).get("test")', "1:22"), + ('test.get("test", None, default=None).get("test")', "1:36"), + ('test.get("test", test=foo, default=None).get("test")', "1:40"), + ('test.get("test", {}).get("test", default=None).get("test")', "1:46"), + ('test.get("test", default=None).get("test", {}).get("test")', "1:30"), + ('test.get("test", default={}).get("test", None).get("test")', "1:46"), + ('test.get("test", None).get("test", default={}).get("test")', "1:22"), + ], +) +def test_DGC1002(inp, err_pos): + assert _results(inp) == [f"{_gl(err_pos)} {_err_msg(ErrorType.DGC1002)}"] + + @pytest.mark.parametrize( "inp, errs", [ - ('test.get("test").get("test")', [f"{_gl('1:16')} {EXPECTED_ERR_MSG}"]), - ('test.get("test", None).get("test")', [f"{_gl('1:22')} {EXPECTED_ERR_MSG}"]), ( - 'test.get("test", default=None).get("test")', - [f"{_gl('1:30')} {EXPECTED_ERR_MSG}"], + 'test.get("test").get("test").get("test")', + [ + f"{_gl('1:28')} {_err_msg(ErrorType.DGC1001)}", + f"{_gl('1:16')} {_err_msg(ErrorType.DGC1001)}", + ], ), ( - 'test.get("test", test=None).get("test")', - [f"{_gl('1:27')} {EXPECTED_ERR_MSG}"], + 'test.get("test", default=None).get("test", None).get("test")', + [ + f"{_gl('1:48')} {_err_msg(ErrorType.DGC1002)}", + f"{_gl('1:30')} {_err_msg(ErrorType.DGC1002)}", + ], ), ( - 'test.get("test", {}).get("test").get("test")', - [f"{_gl('1:32')} {EXPECTED_ERR_MSG}"], + 'test.get("test").get("test", None).get("test")', + [ + f"{_gl('1:34')} {_err_msg(ErrorType.DGC1002)}", + f"{_gl('1:16')} {_err_msg(ErrorType.DGC1001)}", + ], ), ( - 'test.get("test").get("test").get("test")', - [f"{_gl('1:28')} {EXPECTED_ERR_MSG}", f"{_gl('1:16')} {EXPECTED_ERR_MSG}"], + 'test.get("test", default=None).get("test").get("test")', + [ + f"{_gl('1:42')} {_err_msg(ErrorType.DGC1001)}", + f"{_gl('1:30')} {_err_msg(ErrorType.DGC1002)}", + ], ), ], ) -def test_invalid_chaining(inp, errs): +def test_multiple_invalid_chaining(inp, errs): assert _results(inp) == errs