diff --git a/src/_pytest/_code/code.py b/src/_pytest/_code/code.py index 3343650425..5f354a6982 100644 --- a/src/_pytest/_code/code.py +++ b/src/_pytest/_code/code.py @@ -7,13 +7,17 @@ from io import StringIO from traceback import format_exception_only from types import CodeType +from types import FrameType from types import TracebackType from typing import Any +from typing import Callable from typing import Dict from typing import Generic +from typing import Iterable from typing import List from typing import Optional from typing import Pattern +from typing import Sequence from typing import Set from typing import Tuple from typing import TypeVar @@ -27,9 +31,13 @@ import _pytest from _pytest._io.saferepr import safeformat from _pytest._io.saferepr import saferepr +from _pytest.compat import overload if False: # TYPE_CHECKING from typing import Type + from typing_extensions import Literal + + from _pytest._code import Source class Code: @@ -38,13 +46,12 @@ class Code: def __init__(self, rawcode) -> None: if not hasattr(rawcode, "co_filename"): rawcode = getrawcode(rawcode) - try: - self.filename = rawcode.co_filename - self.firstlineno = rawcode.co_firstlineno - 1 - self.name = rawcode.co_name - except AttributeError: + if not isinstance(rawcode, CodeType): raise TypeError("not a code object: {!r}".format(rawcode)) - self.raw = rawcode # type: CodeType + self.filename = rawcode.co_filename + self.firstlineno = rawcode.co_firstlineno - 1 + self.name = rawcode.co_name + self.raw = rawcode def __eq__(self, other): return self.raw == other.raw @@ -72,7 +79,7 @@ def path(self): return p @property - def fullsource(self): + def fullsource(self) -> Optional["Source"]: """ return a _pytest._code.Source object for the full source file of the code """ from _pytest._code import source @@ -80,7 +87,7 @@ def fullsource(self): full, _ = source.findsource(self.raw) return full - def source(self): + def source(self) -> "Source": """ return a _pytest._code.Source object for the code object's source only """ # return source only for that part of code @@ -88,7 +95,7 @@ def source(self): return _pytest._code.Source(self.raw) - def getargs(self, var=False): + def getargs(self, var: bool = False) -> Tuple[str, ...]: """ return a tuple with the argument names for the code object if 'var' is set True also return the names of the variable and @@ -107,7 +114,7 @@ class Frame: """Wrapper around a Python frame holding f_locals and f_globals in which expressions can be evaluated.""" - def __init__(self, frame): + def __init__(self, frame: FrameType) -> None: self.lineno = frame.f_lineno - 1 self.f_globals = frame.f_globals self.f_locals = frame.f_locals @@ -115,7 +122,7 @@ def __init__(self, frame): self.code = Code(frame.f_code) @property - def statement(self): + def statement(self) -> "Source": """ statement this frame is at """ import _pytest._code @@ -134,7 +141,7 @@ def eval(self, code, **vars): f_locals.update(vars) return eval(code, self.f_globals, f_locals) - def exec_(self, code, **vars): + def exec_(self, code, **vars) -> None: """ exec 'code' in the frame 'vars' are optional; additional local variables @@ -143,7 +150,7 @@ def exec_(self, code, **vars): f_locals.update(vars) exec(code, self.f_globals, f_locals) - def repr(self, object): + def repr(self, object: object) -> str: """ return a 'safe' (non-recursive, one-line) string repr for 'object' """ return saferepr(object) @@ -151,7 +158,7 @@ def repr(self, object): def is_true(self, object): return object - def getargs(self, var=False): + def getargs(self, var: bool = False): """ return a list of tuples (name, value) for all arguments if 'var' is set True also include the variable and keyword @@ -169,35 +176,34 @@ def getargs(self, var=False): class TracebackEntry: """ a single entry in a traceback """ - _repr_style = None + _repr_style = None # type: Optional[Literal['short', 'long']] exprinfo = None - def __init__(self, rawentry, excinfo=None): + def __init__(self, rawentry: TracebackType, excinfo=None) -> None: self._excinfo = excinfo self._rawentry = rawentry self.lineno = rawentry.tb_lineno - 1 - def set_repr_style(self, mode): + def set_repr_style(self, mode: "Literal['short', 'long']") -> None: assert mode in ("short", "long") self._repr_style = mode @property - def frame(self): - import _pytest._code - - return _pytest._code.Frame(self._rawentry.tb_frame) + def frame(self) -> Frame: + return Frame(self._rawentry.tb_frame) @property - def relline(self): + def relline(self) -> int: return self.lineno - self.frame.code.firstlineno - def __repr__(self): + def __repr__(self) -> str: return "" % (self.frame.code.path, self.lineno + 1) @property - def statement(self): + def statement(self) -> "Source": """ _pytest._code.Source object for the current statement """ source = self.frame.code.fullsource + assert source is not None return source.getstatement(self.lineno) @property @@ -206,14 +212,14 @@ def path(self): return self.frame.code.path @property - def locals(self): + def locals(self) -> Dict[str, Any]: """ locals of underlying frame """ return self.frame.f_locals - def getfirstlinesource(self): + def getfirstlinesource(self) -> int: return self.frame.code.firstlineno - def getsource(self, astcache=None): + def getsource(self, astcache=None) -> Optional["Source"]: """ return failing source code. """ # we use the passed in astcache to not reparse asttrees # within exception info printing @@ -258,7 +264,7 @@ def ishidden(self): return tbh(None if self._excinfo is None else self._excinfo()) return tbh - def __str__(self): + def __str__(self) -> str: try: fn = str(self.path) except py.error.Error: @@ -273,31 +279,40 @@ def __str__(self): return " File %r:%d in %s\n %s\n" % (fn, self.lineno + 1, name, line) @property - def name(self): + def name(self) -> str: """ co_name of underlying code """ return self.frame.code.raw.co_name -class Traceback(list): +class Traceback(List[TracebackEntry]): """ Traceback objects encapsulate and offer higher level access to Traceback entries. """ - def __init__(self, tb, excinfo=None): + def __init__( + self, tb: Union[TracebackType, Iterable[TracebackEntry]], excinfo=None + ) -> None: """ initialize from given python traceback object and ExceptionInfo """ self._excinfo = excinfo - if hasattr(tb, "tb_next"): + if isinstance(tb, TracebackType): - def f(cur): - while cur is not None: - yield TracebackEntry(cur, excinfo=excinfo) - cur = cur.tb_next + def f(cur: TracebackType) -> Iterable[TracebackEntry]: + cur_ = cur # type: Optional[TracebackType] + while cur_ is not None: + yield TracebackEntry(cur_, excinfo=excinfo) + cur_ = cur_.tb_next - list.__init__(self, f(tb)) + super().__init__(f(tb)) else: - list.__init__(self, tb) + super().__init__(tb) - def cut(self, path=None, lineno=None, firstlineno=None, excludepath=None): + def cut( + self, + path=None, + lineno: Optional[int] = None, + firstlineno: Optional[int] = None, + excludepath=None, + ) -> "Traceback": """ return a Traceback instance wrapping part of this Traceback by providing any combination of path, lineno and firstlineno, the @@ -323,13 +338,25 @@ def cut(self, path=None, lineno=None, firstlineno=None, excludepath=None): return Traceback(x._rawentry, self._excinfo) return self - def __getitem__(self, key): - val = super().__getitem__(key) - if isinstance(key, type(slice(0))): - val = self.__class__(val) - return val + @overload + def __getitem__(self, key: int) -> TracebackEntry: + raise NotImplementedError() + + @overload # noqa: F811 + def __getitem__(self, key: slice) -> "Traceback": # noqa: F811 + raise NotImplementedError() + + def __getitem__( # noqa: F811 + self, key: Union[int, slice] + ) -> Union[TracebackEntry, "Traceback"]: + if isinstance(key, slice): + return self.__class__(super().__getitem__(key)) + else: + return super().__getitem__(key) - def filter(self, fn=lambda x: not x.ishidden()): + def filter( + self, fn: Callable[[TracebackEntry], bool] = lambda x: not x.ishidden() + ) -> "Traceback": """ return a Traceback instance with certain items removed fn is a function that gets a single argument, a TracebackEntry @@ -341,7 +368,7 @@ def filter(self, fn=lambda x: not x.ishidden()): """ return Traceback(filter(fn, self), self._excinfo) - def getcrashentry(self): + def getcrashentry(self) -> TracebackEntry: """ return last non-hidden traceback entry that lead to the exception of a traceback. """ @@ -351,7 +378,7 @@ def getcrashentry(self): return entry return self[-1] - def recursionindex(self): + def recursionindex(self) -> Optional[int]: """ return the index of the frame/TracebackEntry where recursion originates if appropriate, None if no recursion occurred """ @@ -619,16 +646,16 @@ class FormattedExcinfo: flow_marker = ">" fail_marker = "E" - showlocals = attr.ib(default=False) - style = attr.ib(default="long") - abspath = attr.ib(default=True) - tbfilter = attr.ib(default=True) - funcargs = attr.ib(default=False) - truncate_locals = attr.ib(default=True) - chain = attr.ib(default=True) + showlocals = attr.ib(type=bool, default=False) + style = attr.ib(type=str, default="long") + abspath = attr.ib(type=bool, default=True) + tbfilter = attr.ib(type=bool, default=True) + funcargs = attr.ib(type=bool, default=False) + truncate_locals = attr.ib(type=bool, default=True) + chain = attr.ib(type=bool, default=True) astcache = attr.ib(default=attr.Factory(dict), init=False, repr=False) - def _getindent(self, source): + def _getindent(self, source: "Source") -> int: # figure out indent for given source try: s = str(source.getstatement(len(source) - 1)) @@ -643,20 +670,27 @@ def _getindent(self, source): return 0 return 4 + (len(s) - len(s.lstrip())) - def _getentrysource(self, entry): + def _getentrysource(self, entry: TracebackEntry) -> Optional["Source"]: source = entry.getsource(self.astcache) if source is not None: source = source.deindent() return source - def repr_args(self, entry): + def repr_args(self, entry: TracebackEntry) -> Optional["ReprFuncArgs"]: if self.funcargs: args = [] for argname, argvalue in entry.frame.getargs(var=True): args.append((argname, saferepr(argvalue))) return ReprFuncArgs(args) + return None - def get_source(self, source, line_index=-1, excinfo=None, short=False) -> List[str]: + def get_source( + self, + source: "Source", + line_index: int = -1, + excinfo: Optional[ExceptionInfo] = None, + short: bool = False, + ) -> List[str]: """ return formatted and marked up source lines. """ import _pytest._code @@ -680,19 +714,21 @@ def get_source(self, source, line_index=-1, excinfo=None, short=False) -> List[s lines.extend(self.get_exconly(excinfo, indent=indent, markall=True)) return lines - def get_exconly(self, excinfo, indent=4, markall=False): + def get_exconly( + self, excinfo: ExceptionInfo, indent: int = 4, markall: bool = False + ) -> List[str]: lines = [] - indent = " " * indent + indentstr = " " * indent # get the real exception information out exlines = excinfo.exconly(tryshort=True).split("\n") - failindent = self.fail_marker + indent[1:] + failindent = self.fail_marker + indentstr[1:] for line in exlines: lines.append(failindent + line) if not markall: - failindent = indent + failindent = indentstr return lines - def repr_locals(self, locals): + def repr_locals(self, locals: Dict[str, object]) -> Optional["ReprLocals"]: if self.showlocals: lines = [] keys = [loc for loc in locals if loc[0] != "@"] @@ -717,8 +753,11 @@ def repr_locals(self, locals): # # XXX # pprint.pprint(value, stream=self.excinfowriter) return ReprLocals(lines) + return None - def repr_traceback_entry(self, entry, excinfo=None): + def repr_traceback_entry( + self, entry: TracebackEntry, excinfo: Optional[ExceptionInfo] = None + ) -> "ReprEntry": import _pytest._code source = self._getentrysource(entry) @@ -729,9 +768,7 @@ def repr_traceback_entry(self, entry, excinfo=None): line_index = entry.lineno - entry.getfirstlinesource() lines = [] # type: List[str] - style = entry._repr_style - if style is None: - style = self.style + style = entry._repr_style if entry._repr_style is not None else self.style if style in ("short", "long"): short = style == "short" reprargs = self.repr_args(entry) if not short else None @@ -761,7 +798,7 @@ def _makepath(self, path): path = np return path - def repr_traceback(self, excinfo): + def repr_traceback(self, excinfo: ExceptionInfo) -> "ReprTraceback": traceback = excinfo.traceback if self.tbfilter: traceback = traceback.filter() @@ -779,7 +816,9 @@ def repr_traceback(self, excinfo): entries.append(reprentry) return ReprTraceback(entries, extraline, style=self.style) - def _truncate_recursive_traceback(self, traceback): + def _truncate_recursive_traceback( + self, traceback: Traceback + ) -> Tuple[Traceback, Optional[str]]: """ Truncate the given recursive traceback trying to find the starting point of the recursion. @@ -806,7 +845,10 @@ def _truncate_recursive_traceback(self, traceback): max_frames=max_frames, total=len(traceback), ) # type: Optional[str] - traceback = traceback[:max_frames] + traceback[-max_frames:] + # Type ignored because mypy thinks the slice is a List[TracebackEntry] + # instead of the Traceback subclass, even though we override __getitem__(). + # Should figure out why it does that. + traceback = traceback[:max_frames] + traceback[-max_frames:] # type: ignore else: if recursionindex is not None: extraline = "!!! Recursion detected (same locals & position)" @@ -863,7 +905,7 @@ def repr_excinfo(self, excinfo: ExceptionInfo) -> "ExceptionChainRepr": class TerminalRepr: - def __str__(self): + def __str__(self) -> str: # FYI this is called from pytest-xdist's serialization of exception # information. io = StringIO() @@ -871,7 +913,7 @@ def __str__(self): self.toterminal(tw) return io.getvalue().strip() - def __repr__(self): + def __repr__(self) -> str: return "<{} instance at {:0x}>".format(self.__class__, id(self)) def toterminal(self, tw) -> None: @@ -882,7 +924,7 @@ class ExceptionRepr(TerminalRepr): def __init__(self) -> None: self.sections = [] # type: List[Tuple[str, str, str]] - def addsection(self, name, content, sep="-"): + def addsection(self, name: str, content: str, sep: str = "-") -> None: self.sections.append((name, content, sep)) def toterminal(self, tw) -> None: @@ -892,7 +934,12 @@ def toterminal(self, tw) -> None: class ExceptionChainRepr(ExceptionRepr): - def __init__(self, chain): + def __init__( + self, + chain: Sequence[ + Tuple["ReprTraceback", Optional["ReprFileLocation"], Optional[str]] + ], + ) -> None: super().__init__() self.chain = chain # reprcrash and reprtraceback of the outermost (the newest) exception @@ -910,7 +957,9 @@ def toterminal(self, tw) -> None: class ReprExceptionInfo(ExceptionRepr): - def __init__(self, reprtraceback, reprcrash): + def __init__( + self, reprtraceback: "ReprTraceback", reprcrash: "ReprFileLocation" + ) -> None: super().__init__() self.reprtraceback = reprtraceback self.reprcrash = reprcrash @@ -923,7 +972,12 @@ def toterminal(self, tw) -> None: class ReprTraceback(TerminalRepr): entrysep = "_ " - def __init__(self, reprentries, extraline, style): + def __init__( + self, + reprentries: Sequence[Union["ReprEntry", "ReprEntryNative"]], + extraline: Optional[str], + style: str, + ) -> None: self.reprentries = reprentries self.extraline = extraline self.style = style @@ -948,7 +1002,7 @@ def toterminal(self, tw) -> None: class ReprTracebackNative(ReprTraceback): - def __init__(self, tblines): + def __init__(self, tblines: Sequence[str]) -> None: self.style = "native" self.reprentries = [ReprEntryNative(tblines)] self.extraline = None @@ -957,7 +1011,7 @@ def __init__(self, tblines): class ReprEntryNative(TerminalRepr): style = "native" - def __init__(self, tblines): + def __init__(self, tblines: Sequence[str]) -> None: self.lines = tblines def toterminal(self, tw) -> None: @@ -965,7 +1019,14 @@ def toterminal(self, tw) -> None: class ReprEntry(TerminalRepr): - def __init__(self, lines, reprfuncargs, reprlocals, filelocrepr, style): + def __init__( + self, + lines: Sequence[str], + reprfuncargs: Optional["ReprFuncArgs"], + reprlocals: Optional["ReprLocals"], + filelocrepr: Optional["ReprFileLocation"], + style: str, + ) -> None: self.lines = lines self.reprfuncargs = reprfuncargs self.reprlocals = reprlocals @@ -974,6 +1035,7 @@ def __init__(self, lines, reprfuncargs, reprlocals, filelocrepr, style): def toterminal(self, tw) -> None: if self.style == "short": + assert self.reprfileloc is not None self.reprfileloc.toterminal(tw) for line in self.lines: red = line.startswith("E ") @@ -992,14 +1054,14 @@ def toterminal(self, tw) -> None: tw.line("") self.reprfileloc.toterminal(tw) - def __str__(self): + def __str__(self) -> str: return "{}\n{}\n{}".format( "\n".join(self.lines), self.reprlocals, self.reprfileloc ) class ReprFileLocation(TerminalRepr): - def __init__(self, path, lineno, message): + def __init__(self, path, lineno: int, message: str) -> None: self.path = str(path) self.lineno = lineno self.message = message @@ -1016,7 +1078,7 @@ def toterminal(self, tw) -> None: class ReprLocals(TerminalRepr): - def __init__(self, lines): + def __init__(self, lines: Sequence[str]) -> None: self.lines = lines def toterminal(self, tw) -> None: @@ -1025,7 +1087,7 @@ def toterminal(self, tw) -> None: class ReprFuncArgs(TerminalRepr): - def __init__(self, args): + def __init__(self, args: Sequence[Tuple[str, object]]) -> None: self.args = args def toterminal(self, tw) -> None: @@ -1047,7 +1109,7 @@ def toterminal(self, tw) -> None: tw.line("") -def getrawcode(obj, trycall=True): +def getrawcode(obj, trycall: bool = True): """ return code object for given function. """ try: return obj.__code__ @@ -1075,7 +1137,7 @@ def getrawcode(obj, trycall=True): _PY_DIR = py.path.local(py.__file__).dirpath() -def filter_traceback(entry): +def filter_traceback(entry: TracebackEntry) -> bool: """Return True if a TracebackEntry instance should be removed from tracebacks: * dynamically generated code (no code to show up for it); * internal traceback from pytest or its internal libraries, py and pluggy. diff --git a/testing/code/test_code.py b/testing/code/test_code.py index 2f55720b42..f8e1ce17f2 100644 --- a/testing/code/test_code.py +++ b/testing/code/test_code.py @@ -1,18 +1,19 @@ import sys +from types import FrameType from unittest import mock import _pytest._code import pytest -def test_ne(): +def test_ne() -> None: code1 = _pytest._code.Code(compile('foo = "bar"', "", "exec")) assert code1 == code1 code2 = _pytest._code.Code(compile('foo = "baz"', "", "exec")) assert code2 != code1 -def test_code_gives_back_name_for_not_existing_file(): +def test_code_gives_back_name_for_not_existing_file() -> None: name = "abc-123" co_code = compile("pass\n", name, "exec") assert co_code.co_filename == name @@ -21,68 +22,67 @@ def test_code_gives_back_name_for_not_existing_file(): assert code.fullsource is None -def test_code_with_class(): +def test_code_with_class() -> None: class A: pass pytest.raises(TypeError, _pytest._code.Code, A) -def x(): +def x() -> None: raise NotImplementedError() -def test_code_fullsource(): +def test_code_fullsource() -> None: code = _pytest._code.Code(x) full = code.fullsource assert "test_code_fullsource()" in str(full) -def test_code_source(): +def test_code_source() -> None: code = _pytest._code.Code(x) src = code.source() - expected = """def x(): + expected = """def x() -> None: raise NotImplementedError()""" assert str(src) == expected -def test_frame_getsourcelineno_myself(): - def func(): +def test_frame_getsourcelineno_myself() -> None: + def func() -> FrameType: return sys._getframe(0) - f = func() - f = _pytest._code.Frame(f) + f = _pytest._code.Frame(func()) source, lineno = f.code.fullsource, f.lineno + assert source is not None assert source[lineno].startswith(" return sys._getframe(0)") -def test_getstatement_empty_fullsource(): - def func(): +def test_getstatement_empty_fullsource() -> None: + def func() -> FrameType: return sys._getframe(0) - f = func() - f = _pytest._code.Frame(f) + f = _pytest._code.Frame(func()) with mock.patch.object(f.code.__class__, "fullsource", None): assert f.statement == "" -def test_code_from_func(): +def test_code_from_func() -> None: co = _pytest._code.Code(test_frame_getsourcelineno_myself) assert co.firstlineno assert co.path -def test_unicode_handling(): +def test_unicode_handling() -> None: value = "ąć".encode() - def f(): + def f() -> None: raise Exception(value) excinfo = pytest.raises(Exception, f) str(excinfo) -def test_code_getargs(): +def test_code_getargs() -> None: def f1(x): raise NotImplementedError() @@ -108,26 +108,26 @@ def f4(x, *y, **z): assert c4.getargs(var=True) == ("x", "y", "z") -def test_frame_getargs(): - def f1(x): +def test_frame_getargs() -> None: + def f1(x) -> FrameType: return sys._getframe(0) fr1 = _pytest._code.Frame(f1("a")) assert fr1.getargs(var=True) == [("x", "a")] - def f2(x, *y): + def f2(x, *y) -> FrameType: return sys._getframe(0) fr2 = _pytest._code.Frame(f2("a", "b", "c")) assert fr2.getargs(var=True) == [("x", "a"), ("y", ("b", "c"))] - def f3(x, **z): + def f3(x, **z) -> FrameType: return sys._getframe(0) fr3 = _pytest._code.Frame(f3("a", b="c")) assert fr3.getargs(var=True) == [("x", "a"), ("z", {"b": "c"})] - def f4(x, *y, **z): + def f4(x, *y, **z) -> FrameType: return sys._getframe(0) fr4 = _pytest._code.Frame(f4("a", "b", c="d")) @@ -135,7 +135,7 @@ def f4(x, *y, **z): class TestExceptionInfo: - def test_bad_getsource(self): + def test_bad_getsource(self) -> None: try: if False: pass @@ -145,13 +145,13 @@ def test_bad_getsource(self): exci = _pytest._code.ExceptionInfo.from_current() assert exci.getrepr() - def test_from_current_with_missing(self): + def test_from_current_with_missing(self) -> None: with pytest.raises(AssertionError, match="no current exception"): _pytest._code.ExceptionInfo.from_current() class TestTracebackEntry: - def test_getsource(self): + def test_getsource(self) -> None: try: if False: pass @@ -161,12 +161,13 @@ def test_getsource(self): exci = _pytest._code.ExceptionInfo.from_current() entry = exci.traceback[0] source = entry.getsource() + assert source is not None assert len(source) == 6 assert "assert False" in source[5] class TestReprFuncArgs: - def test_not_raise_exception_with_mixed_encoding(self, tw_mock): + def test_not_raise_exception_with_mixed_encoding(self, tw_mock) -> None: from _pytest._code.code import ReprFuncArgs args = [("unicode_string", "São Paulo"), ("utf8_string", b"S\xc3\xa3o Paulo")] diff --git a/testing/code/test_excinfo.py b/testing/code/test_excinfo.py index f088086489..997b14e2f6 100644 --- a/testing/code/test_excinfo.py +++ b/testing/code/test_excinfo.py @@ -3,6 +3,7 @@ import queue import sys import textwrap +from typing import Union import py @@ -224,23 +225,25 @@ def f(n): repr = excinfo.getrepr() assert "RuntimeError: hello" in str(repr.reprcrash) - def test_traceback_no_recursion_index(self): - def do_stuff(): + def test_traceback_no_recursion_index(self) -> None: + def do_stuff() -> None: raise RuntimeError - def reraise_me(): + def reraise_me() -> None: import sys exc, val, tb = sys.exc_info() + assert val is not None raise val.with_traceback(tb) - def f(n): + def f(n: int) -> None: try: do_stuff() except: # noqa reraise_me() excinfo = pytest.raises(RuntimeError, f, 8) + assert excinfo is not None traceback = excinfo.traceback recindex = traceback.recursionindex() assert recindex is None @@ -596,7 +599,6 @@ def func1(): assert lines[3] == "E world" assert not lines[4:] - loc = repr_entry.reprlocals is not None loc = repr_entry.reprfileloc assert loc.path == mod.__file__ assert loc.lineno == 3 @@ -1286,9 +1288,10 @@ def unreraise(): @pytest.mark.parametrize("style", ["short", "long"]) @pytest.mark.parametrize("encoding", [None, "utf8", "utf16"]) def test_repr_traceback_with_unicode(style, encoding): - msg = "☹" - if encoding is not None: - msg = msg.encode(encoding) + if encoding is None: + msg = "☹" # type: Union[str, bytes] + else: + msg = "☹".encode(encoding) try: raise RuntimeError(msg) except RuntimeError: diff --git a/testing/code/test_source.py b/testing/code/test_source.py index 519344dd43..c6e68fc1ef 100644 --- a/testing/code/test_source.py +++ b/testing/code/test_source.py @@ -4,13 +4,16 @@ import ast import inspect import sys +from typing import Any +from typing import Dict +from typing import Optional import _pytest._code import pytest from _pytest._code import Source -def test_source_str_function(): +def test_source_str_function() -> None: x = Source("3") assert str(x) == "3" @@ -25,7 +28,7 @@ def test_source_str_function(): assert str(x) == "\n3" -def test_unicode(): +def test_unicode() -> None: x = Source("4") assert str(x) == "4" co = _pytest._code.compile('"å"', mode="eval") @@ -33,12 +36,12 @@ def test_unicode(): assert isinstance(val, str) -def test_source_from_function(): +def test_source_from_function() -> None: source = _pytest._code.Source(test_source_str_function) - assert str(source).startswith("def test_source_str_function():") + assert str(source).startswith("def test_source_str_function() -> None:") -def test_source_from_method(): +def test_source_from_method() -> None: class TestClass: def test_method(self): pass @@ -47,13 +50,13 @@ def test_method(self): assert source.lines == ["def test_method(self):", " pass"] -def test_source_from_lines(): +def test_source_from_lines() -> None: lines = ["a \n", "b\n", "c"] source = _pytest._code.Source(lines) assert source.lines == ["a ", "b", "c"] -def test_source_from_inner_function(): +def test_source_from_inner_function() -> None: def f(): pass @@ -63,7 +66,7 @@ def f(): assert str(source).startswith("def f():") -def test_source_putaround_simple(): +def test_source_putaround_simple() -> None: source = Source("raise ValueError") source = source.putaround( "try:", @@ -85,7 +88,7 @@ def test_source_putaround_simple(): ) -def test_source_putaround(): +def test_source_putaround() -> None: source = Source() source = source.putaround( """ @@ -96,28 +99,30 @@ def test_source_putaround(): assert str(source).strip() == "if 1:\n x=1" -def test_source_strips(): +def test_source_strips() -> None: source = Source("") assert source == Source() assert str(source) == "" assert source.strip() == source -def test_source_strip_multiline(): +def test_source_strip_multiline() -> None: source = Source() source.lines = ["", " hello", " "] source2 = source.strip() assert source2.lines == [" hello"] -def test_syntaxerror_rerepresentation(): +def test_syntaxerror_rerepresentation() -> None: ex = pytest.raises(SyntaxError, _pytest._code.compile, "xyz xyz") + assert ex is not None assert ex.value.lineno == 1 assert ex.value.offset in {5, 7} # cpython: 7, pypy3.6 7.1.1: 5 + assert ex.value.text is not None assert ex.value.text.strip(), "x x" -def test_isparseable(): +def test_isparseable() -> None: assert Source("hello").isparseable() assert Source("if 1:\n pass").isparseable() assert Source(" \nif 1:\n pass").isparseable() @@ -127,7 +132,7 @@ def test_isparseable(): class TestAccesses: - def setup_class(self): + def setup_class(self) -> None: self.source = Source( """\ def f(x): @@ -137,26 +142,26 @@ def g(x): """ ) - def test_getrange(self): + def test_getrange(self) -> None: x = self.source[0:2] assert x.isparseable() assert len(x.lines) == 2 assert str(x) == "def f(x):\n pass" - def test_getline(self): + def test_getline(self) -> None: x = self.source[0] assert x == "def f(x):" - def test_len(self): + def test_len(self) -> None: assert len(self.source) == 4 - def test_iter(self): + def test_iter(self) -> None: values = [x for x in self.source] assert len(values) == 4 class TestSourceParsingAndCompiling: - def setup_class(self): + def setup_class(self) -> None: self.source = Source( """\ def f(x): @@ -166,19 +171,19 @@ def f(x): """ ).strip() - def test_compile(self): + def test_compile(self) -> None: co = _pytest._code.compile("x=3") - d = {} + d = {} # type: Dict[str, Any] exec(co, d) assert d["x"] == 3 - def test_compile_and_getsource_simple(self): + def test_compile_and_getsource_simple(self) -> None: co = _pytest._code.compile("x=3") exec(co) source = _pytest._code.Source(co) assert str(source) == "x=3" - def test_compile_and_getsource_through_same_function(self): + def test_compile_and_getsource_through_same_function(self) -> None: def gensource(source): return _pytest._code.compile(source) @@ -199,7 +204,7 @@ def f(): source2 = inspect.getsource(co2) assert "ValueError" in source2 - def test_getstatement(self): + def test_getstatement(self) -> None: # print str(self.source) ass = str(self.source[1:]) for i in range(1, 4): @@ -208,7 +213,7 @@ def test_getstatement(self): # x = s.deindent() assert str(s) == ass - def test_getstatementrange_triple_quoted(self): + def test_getstatementrange_triple_quoted(self) -> None: # print str(self.source) source = Source( """hello(''' @@ -219,7 +224,7 @@ def test_getstatementrange_triple_quoted(self): s = source.getstatement(1) assert s == str(source) - def test_getstatementrange_within_constructs(self): + def test_getstatementrange_within_constructs(self) -> None: source = Source( """\ try: @@ -241,7 +246,7 @@ def test_getstatementrange_within_constructs(self): # assert source.getstatementrange(5) == (0, 7) assert source.getstatementrange(6) == (6, 7) - def test_getstatementrange_bug(self): + def test_getstatementrange_bug(self) -> None: source = Source( """\ try: @@ -255,7 +260,7 @@ def test_getstatementrange_bug(self): assert len(source) == 6 assert source.getstatementrange(2) == (1, 4) - def test_getstatementrange_bug2(self): + def test_getstatementrange_bug2(self) -> None: source = Source( """\ assert ( @@ -272,7 +277,7 @@ def test_getstatementrange_bug2(self): assert len(source) == 9 assert source.getstatementrange(5) == (0, 9) - def test_getstatementrange_ast_issue58(self): + def test_getstatementrange_ast_issue58(self) -> None: source = Source( """\ @@ -286,38 +291,44 @@ def test_some(): assert getstatement(2, source).lines == source.lines[2:3] assert getstatement(3, source).lines == source.lines[3:4] - def test_getstatementrange_out_of_bounds_py3(self): + def test_getstatementrange_out_of_bounds_py3(self) -> None: source = Source("if xxx:\n from .collections import something") r = source.getstatementrange(1) assert r == (1, 2) - def test_getstatementrange_with_syntaxerror_issue7(self): + def test_getstatementrange_with_syntaxerror_issue7(self) -> None: source = Source(":") pytest.raises(SyntaxError, lambda: source.getstatementrange(0)) - def test_compile_to_ast(self): + def test_compile_to_ast(self) -> None: source = Source("x = 4") mod = source.compile(flag=ast.PyCF_ONLY_AST) assert isinstance(mod, ast.Module) compile(mod, "", "exec") - def test_compile_and_getsource(self): + def test_compile_and_getsource(self) -> None: co = self.source.compile() exec(co, globals()) - f(7) - excinfo = pytest.raises(AssertionError, f, 6) + f(7) # type: ignore + excinfo = pytest.raises(AssertionError, f, 6) # type: ignore + assert excinfo is not None frame = excinfo.traceback[-1].frame + assert isinstance(frame.code.fullsource, Source) stmt = frame.code.fullsource.getstatement(frame.lineno) assert str(stmt).strip().startswith("assert") @pytest.mark.parametrize("name", ["", None, "my"]) - def test_compilefuncs_and_path_sanity(self, name): + def test_compilefuncs_and_path_sanity(self, name: Optional[str]) -> None: def check(comp, name): co = comp(self.source, name) if not name: - expected = "codegen %s:%d>" % (mypath, mylineno + 2 + 2) + expected = "codegen %s:%d>" % (mypath, mylineno + 2 + 2) # type: ignore else: - expected = "codegen %r %s:%d>" % (name, mypath, mylineno + 2 + 2) + expected = "codegen %r %s:%d>" % ( + name, + mypath, # type: ignore + mylineno + 2 + 2, # type: ignore + ) # type: ignore fn = co.co_filename assert fn.endswith(expected) @@ -332,9 +343,9 @@ def test_offsetless_synerr(self): pytest.raises(SyntaxError, _pytest._code.compile, "lambda a,a: 0", mode="eval") -def test_getstartingblock_singleline(): +def test_getstartingblock_singleline() -> None: class A: - def __init__(self, *args): + def __init__(self, *args) -> None: frame = sys._getframe(1) self.source = _pytest._code.Frame(frame).statement @@ -344,22 +355,22 @@ def __init__(self, *args): assert len(values) == 1 -def test_getline_finally(): - def c(): +def test_getline_finally() -> None: + def c() -> None: pass with pytest.raises(TypeError) as excinfo: teardown = None try: - c(1) + c(1) # type: ignore finally: if teardown: teardown() source = excinfo.traceback[-1].statement - assert str(source).strip() == "c(1)" + assert str(source).strip() == "c(1) # type: ignore" -def test_getfuncsource_dynamic(): +def test_getfuncsource_dynamic() -> None: source = """ def f(): raise ValueError @@ -368,11 +379,13 @@ def g(): pass """ co = _pytest._code.compile(source) exec(co, globals()) - assert str(_pytest._code.Source(f)).strip() == "def f():\n raise ValueError" - assert str(_pytest._code.Source(g)).strip() == "def g(): pass" + f_source = _pytest._code.Source(f) # type: ignore + g_source = _pytest._code.Source(g) # type: ignore + assert str(f_source).strip() == "def f():\n raise ValueError" + assert str(g_source).strip() == "def g(): pass" -def test_getfuncsource_with_multine_string(): +def test_getfuncsource_with_multine_string() -> None: def f(): c = """while True: pass @@ -387,7 +400,7 @@ def f(): assert str(_pytest._code.Source(f)) == expected.rstrip() -def test_deindent(): +def test_deindent() -> None: from _pytest._code.source import deindent as deindent assert deindent(["\tfoo", "\tbar"]) == ["foo", "bar"] @@ -401,7 +414,7 @@ def g(): assert lines == ["def f():", " def g():", " pass"] -def test_source_of_class_at_eof_without_newline(tmpdir, _sys_snapshot): +def test_source_of_class_at_eof_without_newline(tmpdir, _sys_snapshot) -> None: # this test fails because the implicit inspect.getsource(A) below # does not return the "x = 1" last line. source = _pytest._code.Source( @@ -423,7 +436,7 @@ def x(): pass -def test_getsource_fallback(): +def test_getsource_fallback() -> None: from _pytest._code.source import getsource expected = """def x(): @@ -432,7 +445,7 @@ def test_getsource_fallback(): assert src == expected -def test_idem_compile_and_getsource(): +def test_idem_compile_and_getsource() -> None: from _pytest._code.source import getsource expected = "def x(): pass" @@ -441,15 +454,16 @@ def test_idem_compile_and_getsource(): assert src == expected -def test_findsource_fallback(): +def test_findsource_fallback() -> None: from _pytest._code.source import findsource src, lineno = findsource(x) + assert src is not None assert "test_findsource_simple" in str(src) assert src[lineno] == " def x():" -def test_findsource(): +def test_findsource() -> None: from _pytest._code.source import findsource co = _pytest._code.compile( @@ -460,19 +474,21 @@ def x(): ) src, lineno = findsource(co) + assert src is not None assert "if 1:" in str(src) - d = {} + d = {} # type: Dict[str, Any] eval(co, d) src, lineno = findsource(d["x"]) + assert src is not None assert "if 1:" in str(src) assert src[lineno] == " def x():" -def test_getfslineno(): +def test_getfslineno() -> None: from _pytest._code import getfslineno - def f(x): + def f(x) -> None: pass fspath, lineno = getfslineno(f) @@ -498,40 +514,40 @@ class B: assert getfslineno(B)[1] == -1 -def test_code_of_object_instance_with_call(): +def test_code_of_object_instance_with_call() -> None: class A: pass pytest.raises(TypeError, lambda: _pytest._code.Source(A())) class WithCall: - def __call__(self): + def __call__(self) -> None: pass code = _pytest._code.Code(WithCall()) assert "pass" in str(code.source()) class Hello: - def __call__(self): + def __call__(self) -> None: pass pytest.raises(TypeError, lambda: _pytest._code.Code(Hello)) -def getstatement(lineno, source): +def getstatement(lineno: int, source) -> Source: from _pytest._code.source import getstatementrange_ast - source = _pytest._code.Source(source, deindent=False) - ast, start, end = getstatementrange_ast(lineno, source) - return source[start:end] + src = _pytest._code.Source(source, deindent=False) + ast, start, end = getstatementrange_ast(lineno, src) + return src[start:end] -def test_oneline(): +def test_oneline() -> None: source = getstatement(0, "raise ValueError") assert str(source) == "raise ValueError" -def test_comment_and_no_newline_at_end(): +def test_comment_and_no_newline_at_end() -> None: from _pytest._code.source import getstatementrange_ast source = Source( @@ -545,12 +561,12 @@ def test_comment_and_no_newline_at_end(): assert end == 2 -def test_oneline_and_comment(): +def test_oneline_and_comment() -> None: source = getstatement(0, "raise ValueError\n#hello") assert str(source) == "raise ValueError" -def test_comments(): +def test_comments() -> None: source = '''def test(): "comment 1" x = 1 @@ -576,7 +592,7 @@ def test_comments(): assert str(getstatement(line, source)) == '"""\ncomment 4\n"""' -def test_comment_in_statement(): +def test_comment_in_statement() -> None: source = """test(foo=1, # comment 1 bar=2) @@ -588,17 +604,17 @@ def test_comment_in_statement(): ) -def test_single_line_else(): +def test_single_line_else() -> None: source = getstatement(1, "if False: 2\nelse: 3") assert str(source) == "else: 3" -def test_single_line_finally(): +def test_single_line_finally() -> None: source = getstatement(1, "try: 1\nfinally: 3") assert str(source) == "finally: 3" -def test_issue55(): +def test_issue55() -> None: source = ( "def round_trip(dinp):\n assert 1 == dinp\n" 'def test_rt():\n round_trip("""\n""")\n' @@ -607,7 +623,7 @@ def test_issue55(): assert str(s) == ' round_trip("""\n""")' -def test_multiline(): +def test_multiline() -> None: source = getstatement( 0, """\ @@ -621,7 +637,7 @@ def test_multiline(): class TestTry: - def setup_class(self): + def setup_class(self) -> None: self.source = """\ try: raise ValueError @@ -631,25 +647,25 @@ def setup_class(self): raise KeyError() """ - def test_body(self): + def test_body(self) -> None: source = getstatement(1, self.source) assert str(source) == " raise ValueError" - def test_except_line(self): + def test_except_line(self) -> None: source = getstatement(2, self.source) assert str(source) == "except Something:" - def test_except_body(self): + def test_except_body(self) -> None: source = getstatement(3, self.source) assert str(source) == " raise IndexError(1)" - def test_else(self): + def test_else(self) -> None: source = getstatement(5, self.source) assert str(source) == " raise KeyError()" class TestTryFinally: - def setup_class(self): + def setup_class(self) -> None: self.source = """\ try: raise ValueError @@ -657,17 +673,17 @@ def setup_class(self): raise IndexError(1) """ - def test_body(self): + def test_body(self) -> None: source = getstatement(1, self.source) assert str(source) == " raise ValueError" - def test_finally(self): + def test_finally(self) -> None: source = getstatement(3, self.source) assert str(source) == " raise IndexError(1)" class TestIf: - def setup_class(self): + def setup_class(self) -> None: self.source = """\ if 1: y = 3 @@ -677,24 +693,24 @@ def setup_class(self): y = 7 """ - def test_body(self): + def test_body(self) -> None: source = getstatement(1, self.source) assert str(source) == " y = 3" - def test_elif_clause(self): + def test_elif_clause(self) -> None: source = getstatement(2, self.source) assert str(source) == "elif False:" - def test_elif(self): + def test_elif(self) -> None: source = getstatement(3, self.source) assert str(source) == " y = 5" - def test_else(self): + def test_else(self) -> None: source = getstatement(5, self.source) assert str(source) == " y = 7" -def test_semicolon(): +def test_semicolon() -> None: s = """\ hello ; pytest.skip() """ @@ -702,7 +718,7 @@ def test_semicolon(): assert str(source) == s.strip() -def test_def_online(): +def test_def_online() -> None: s = """\ def func(): raise ValueError(42) @@ -713,7 +729,7 @@ def something(): assert str(source) == "def func(): raise ValueError(42)" -def XXX_test_expression_multiline(): +def XXX_test_expression_multiline() -> None: source = """\ something ''' @@ -722,7 +738,7 @@ def XXX_test_expression_multiline(): assert str(result) == "'''\n'''" -def test_getstartingblock_multiline(): +def test_getstartingblock_multiline() -> None: class A: def __init__(self, *args): frame = sys._getframe(1)