Skip to content

Commit

Permalink
feat: split into 2 errors invalid/missing default
Browse files Browse the repository at this point in the history
  • Loading branch information
Sacrimento committed Jan 31, 2023
1 parent 6a64bf3 commit 12ec4e8
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 28 deletions.
41 changes: 27 additions & 14 deletions get_chaining.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
import ast
from enum import Enum, unique
import sys
from typing import Any, Generator, List, Optional, Tuple, Type

__version__ = "0.1.0"
__version__ = "0.2.0"

ERR_MSG = "dict.get chaining might crash"

@unique
class ErrorType(Enum):
DGC1001 = "DGC1001"
DGC1002 = "DGC1002"


ERR_MSGS = {
ErrorType.DGC1001: "missing default argument when chaining dict.get()",
ErrorType.DGC1002: "invalid default argument when chaining dict.get()",
}


def call_position(call: ast.Call) -> Tuple[int, Optional[int]]:
Expand All @@ -24,13 +35,13 @@ def run(self) -> Generator[Tuple[int, Optional[int], str, Type[Any]], None, None
visitor = GetChainingVisitor()
visitor.visit(self._tree)

for lineno, col in visitor.issues:
yield lineno, col, f"DGC1001 {ERR_MSG}", type(self)
for err, lineno, col in visitor.issues:
yield lineno, col, f"{err.value} {ERR_MSGS[err]}", type(self)


class GetChainingVisitor(ast.NodeVisitor):
def __init__(self) -> None:
self.issues: List[Tuple[int, Optional[int]]] = []
self.issues: List[Tuple[ErrorType, int, Optional[int]]] = []

def visit_Call(self, node: ast.Call) -> Any:

Expand All @@ -47,19 +58,21 @@ def visit_Call(self, node: ast.Call) -> Any:

if len(get_call.args) > 1:
arg = get_call.args[1]
if isinstance(arg, ast.Dict) or (
if not isinstance(arg, ast.Dict) and not (
isinstance(arg, ast.Name) and arg.id.isidentifier()
):
return self.generic_visit(node)
self.issues.append(call_position(get_call))
self.issues.append((ErrorType.DGC1002, *call_position(get_call)))
elif get_call.keywords:
for kw in get_call.keywords:
if kw.arg == "default":
if isinstance(kw.value, ast.Dict):
return self.generic_visit(node)
if isinstance(kw.value, ast.Name) and kw.value.id.isidentifier():
return self.generic_visit(node)
self.issues.append(call_position(get_call))
if not isinstance(kw.value, ast.Dict) and not (
isinstance(kw.value, ast.Name) and kw.value.id.isidentifier()
):
self.issues.append(
(ErrorType.DGC1002, *call_position(get_call))
)
return self.generic_visit(node)
self.issues.append((ErrorType.DGC1001, *call_position(get_call)))
else:
self.issues.append(call_position(get_call))
self.issues.append((ErrorType.DGC1001, *call_position(get_call)))
return self.generic_visit(node)
74 changes: 60 additions & 14 deletions tests/test_get_chaining.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@

import pytest

from get_chaining import ERR_MSG, GetChainingChecker

EXPECTED_ERR_MSG = f"DGC1001 {ERR_MSG}"
from get_chaining import ERR_MSGS, ErrorType, GetChainingChecker


def _gl(line_col):
Expand All @@ -15,6 +13,10 @@ def _gl(line_col):
return line_col # pragma: no coverage


def _err_msg(error_type: ErrorType):
return f"{error_type.value} {ERR_MSGS[error_type]}"


def _results(code: str) -> List[str]:
tree = ast.parse(code)
checker = GetChainingChecker(tree)
Expand All @@ -28,6 +30,8 @@ def _results(code: str) -> List[str]:
'test.do("nothing")',
'test.get("test")',
"attr.get",
"attr.get().get",
"attr.get.get()",
"notadict.get()",
],
)
Expand All @@ -48,28 +52,70 @@ def test_valid_chaining(inp):
assert not _results(inp)


@pytest.mark.parametrize(
"inp, err_pos",
[
('test.get("test").get("test")', "1:16"),
('test.get("test", {}).get("test").get("test")', "1:32"),
('test.get("test").get("test", {}).get("test")', "1:16"),
('test.get("test", default={}).get("test").get("test")', "1:40"),
('test.get("test").get("test", default={}).get("test")', "1:16"),
('test.get("test", test=None).get("test")', "1:27"),
],
)
def test_DGC1001(inp, err_pos):
assert _results(inp) == [f"{_gl(err_pos)} {_err_msg(ErrorType.DGC1001)}"]


@pytest.mark.parametrize(
"inp, err_pos",
[
('test.get("test", default=None).get("test")', "1:30"),
('test.get("test", None).get("test")', "1:22"),
('test.get("test", None, default=None).get("test")', "1:36"),
('test.get("test", test=foo, default=None).get("test")', "1:40"),
('test.get("test", {}).get("test", default=None).get("test")', "1:46"),
('test.get("test", default=None).get("test", {}).get("test")', "1:30"),
('test.get("test", default={}).get("test", None).get("test")', "1:46"),
('test.get("test", None).get("test", default={}).get("test")', "1:22"),
],
)
def test_DGC1002(inp, err_pos):
assert _results(inp) == [f"{_gl(err_pos)} {_err_msg(ErrorType.DGC1002)}"]


@pytest.mark.parametrize(
"inp, errs",
[
('test.get("test").get("test")', [f"{_gl('1:16')} {EXPECTED_ERR_MSG}"]),
('test.get("test", None).get("test")', [f"{_gl('1:22')} {EXPECTED_ERR_MSG}"]),
(
'test.get("test", default=None).get("test")',
[f"{_gl('1:30')} {EXPECTED_ERR_MSG}"],
'test.get("test").get("test").get("test")',
[
f"{_gl('1:28')} {_err_msg(ErrorType.DGC1001)}",
f"{_gl('1:16')} {_err_msg(ErrorType.DGC1001)}",
],
),
(
'test.get("test", test=None).get("test")',
[f"{_gl('1:27')} {EXPECTED_ERR_MSG}"],
'test.get("test", default=None).get("test", None).get("test")',
[
f"{_gl('1:48')} {_err_msg(ErrorType.DGC1002)}",
f"{_gl('1:30')} {_err_msg(ErrorType.DGC1002)}",
],
),
(
'test.get("test", {}).get("test").get("test")',
[f"{_gl('1:32')} {EXPECTED_ERR_MSG}"],
'test.get("test").get("test", None).get("test")',
[
f"{_gl('1:34')} {_err_msg(ErrorType.DGC1002)}",
f"{_gl('1:16')} {_err_msg(ErrorType.DGC1001)}",
],
),
(
'test.get("test").get("test").get("test")',
[f"{_gl('1:28')} {EXPECTED_ERR_MSG}", f"{_gl('1:16')} {EXPECTED_ERR_MSG}"],
'test.get("test", default=None).get("test").get("test")',
[
f"{_gl('1:42')} {_err_msg(ErrorType.DGC1001)}",
f"{_gl('1:30')} {_err_msg(ErrorType.DGC1002)}",
],
),
],
)
def test_invalid_chaining(inp, errs):
def test_multiple_invalid_chaining(inp, errs):
assert _results(inp) == errs

0 comments on commit 12ec4e8

Please sign in to comment.