diff --git a/src/_pytask/capture.py b/src/_pytask/capture.py index 951cc506..9e0f1244 100644 --- a/src/_pytask/capture.py +++ b/src/_pytask/capture.py @@ -25,8 +25,8 @@ from __future__ import annotations +import abc import contextlib -import functools import io import os import sys @@ -35,13 +35,16 @@ from typing import TYPE_CHECKING from typing import Any from typing import AnyStr +from typing import BinaryIO from typing import Generator from typing import Generic from typing import Iterator +from typing import NamedTuple from typing import TextIO from typing import final import click +from typing_extensions import Self from _pytask.capture_utils import CaptureMethod from _pytask.capture_utils import ShowCapture @@ -50,6 +53,8 @@ from _pytask.shared import convert_to_enum if TYPE_CHECKING: + from types import TracebackType + from _pytask.node_protocols import PTask @@ -143,22 +148,29 @@ def write(self, s: str) -> int: return self._other.write(s) -class DontReadFromInput: +class DontReadFromInput(TextIO): """Class to disable reading from stdin while capturing is activated.""" - encoding = None + @property + def encoding(self) -> str: + return sys.__stdin__.encoding - def read(self, *_args: Any) -> None: + def read(self, size: int = -1) -> str: # noqa: ARG002 msg = ( "pytask: reading from stdin while output is captured! Consider using `-s`." ) raise OSError(msg) readline = read - readlines = read - __next__ = read - def __iter__(self) -> DontReadFromInput: + def __next__(self) -> str: + return self.readline() + + def readlines(self, hint: int | None = -1) -> list[str]: # noqa: ARG002 + msg = "reading from stdin while output is captured! Consider using `-s`." + raise OSError(msg) + + def __iter__(self) -> Iterator[str]: return self def fileno(self) -> int: @@ -178,7 +190,7 @@ def close(self) -> None: def readable(self) -> bool: return False - def seek(self, offset: int) -> int: # noqa: ARG002 + def seek(self, offset: int, whence: int = 0) -> int: # noqa: ARG002 msg = "Redirected stdin is pseudofile, has no seek(int)." raise UnsupportedOperation(msg) @@ -189,11 +201,11 @@ def tell(self) -> int: msg = "Redirected stdin is pseudofile, has no tell()." raise UnsupportedOperation(msg) - def truncate(self, size: int) -> None: # noqa: ARG002 + def truncate(self, size: int | None = None) -> int: # noqa: ARG002 msg = "Cannot truncate stdin." raise UnsupportedOperation(msg) - def write(self, *args: Any) -> None: # noqa: ARG002 + def write(self, data: str) -> int: # noqa: ARG002 msg = "Cannot write to stdin." raise UnsupportedOperation(msg) @@ -204,43 +216,105 @@ def writelines(self, *args: Any) -> None: # noqa: ARG002 def writable(self) -> bool: return False - @property - def buffer(self) -> DontReadFromInput: + def __enter__(self) -> Self: return self + def __exit__( + self, + type: type[BaseException] | None, # noqa: A002 + value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + pass + + @property + def buffer(self) -> BinaryIO: + # The str/bytes doesn't actually matter in this type, so OK to fake. + return self # type: ignore[return-value] + # Capture classes. +class CaptureBase(abc.ABC, Generic[AnyStr]): + EMPTY_BUFFER: AnyStr + + @abc.abstractmethod + def __init__(self, fd: int) -> None: + raise NotImplementedError + + @abc.abstractmethod + def start(self) -> None: + raise NotImplementedError + + @abc.abstractmethod + def done(self) -> None: + raise NotImplementedError + + @abc.abstractmethod + def suspend(self) -> None: + raise NotImplementedError + + @abc.abstractmethod + def resume(self) -> None: + raise NotImplementedError + + @abc.abstractmethod + def writeorg(self, data: AnyStr) -> None: + raise NotImplementedError + + @abc.abstractmethod + def snap(self) -> AnyStr: + raise NotImplementedError + + patchsysdict = {0: "stdin", 1: "stdout", 2: "stderr"} """dict[int, str]: Map file descriptors to their names.""" -class NoCapture: +class NoCapture(CaptureBase[str]): """Dummy class when capturing is disabled.""" - EMPTY_BUFFER = None - __init__ = start = done = suspend = resume = lambda *_args: None + EMPTY_BUFFER = "" + def __init__(self, fd: int) -> None: + pass -class SysCaptureBinary: + def start(self) -> None: + pass + + def done(self) -> None: + pass + + def suspend(self) -> None: + pass + + def resume(self) -> None: + pass + + def snap(self) -> str: + return "" + + def writeorg(self, data: str) -> None: + pass + + +class SysCaptureBase(CaptureBase[AnyStr]): """Capture IO to/from Python's buffer for stdin, stdout, and stderr. Instead of :class:`SysCapture`, this class produces bytes instead of text. """ - EMPTY_BUFFER = b"" - - def __init__( # type: ignore + def __init__( self, fd: int, - tmpfile=None, # noqa: ANN001 + tmpfile: TextIO | None = None, *, tee: bool = False, ) -> None: name = patchsysdict[fd] - self._old = getattr(sys, name) + self._old: TextIO = getattr(sys, name) self.name = name if tmpfile is None: if name == "stdin": @@ -280,14 +354,6 @@ def start(self) -> None: setattr(sys, self.name, self.tmpfile) self._state = "started" - def snap(self) -> str: - self._assert_state("snap", ("started", "suspended")) - self.tmpfile.seek(0) - res = self.tmpfile.buffer.read() - self.tmpfile.seek(0) - self.tmpfile.truncate() - return res - def done(self) -> None: self._assert_state("done", ("initialized", "started", "suspended", "done")) if self._state == "done": @@ -309,23 +375,37 @@ def resume(self) -> None: setattr(sys, self.name, self.tmpfile) self._state = "started" - def writeorg(self, data: str) -> None: + +class SysCaptureBinary(SysCaptureBase[bytes]): + EMPTY_BUFFER = b"" + + def snap(self) -> bytes: + self._assert_state("snap", ("started", "suspended")) + self.tmpfile.seek(0) + res = self.tmpfile.buffer.read() + self.tmpfile.seek(0) + self.tmpfile.truncate() + return res + + def writeorg(self, data: bytes) -> None: self._assert_state("writeorg", ("started", "suspended")) self._old.flush() self._old.buffer.write(data) self._old.buffer.flush() -class SysCapture(SysCaptureBinary): +class SysCapture(SysCaptureBase[str]): """Capture IO to/from Python's buffer for stdin, stdout, and stderr. Instead of :class:`SysCaptureBinary`, this class produces text instead of bytes. """ - EMPTY_BUFFER = "" # type: ignore[assignment] + EMPTY_BUFFER = "" def snap(self) -> str: + self._assert_state("snap", ("started", "suspended")) + assert isinstance(self.tmpfile, CaptureIO) res = self.tmpfile.getvalue() self.tmpfile.seek(0) self.tmpfile.truncate() @@ -337,15 +417,13 @@ def writeorg(self, data: str) -> None: self._old.flush() -class FDCaptureBinary: +class FDCaptureBase(CaptureBase[AnyStr]): """Capture IO to/from a given OS-level file descriptor. snap() produces `bytes`. """ - EMPTY_BUFFER = b"" - def __init__(self, targetfd: int) -> None: self.targetfd = targetfd @@ -371,7 +449,7 @@ def __init__(self, targetfd: int) -> None: if targetfd == 0: self.tmpfile = open(os.devnull, encoding="utf-8") # noqa: SIM115, PTH123 - self.syscapture = SysCapture(targetfd) + self.syscapture: CaptureBase[str] = SysCapture(targetfd) else: self.tmpfile = EncodedFile( TemporaryFile(buffering=0), @@ -383,7 +461,7 @@ def __init__(self, targetfd: int) -> None: if targetfd in patchsysdict: self.syscapture = SysCapture(targetfd, self.tmpfile) else: - self.syscapture = NoCapture() + self.syscapture = NoCapture(targetfd) self._state = "initialized" @@ -410,14 +488,6 @@ def start(self) -> None: self.syscapture.start() self._state = "started" - def snap(self) -> bytes: - self._assert_state("snap", ("started", "suspended")) - self.tmpfile.seek(0) - res = self.tmpfile.buffer.read() - self.tmpfile.seek(0) - self.tmpfile.truncate() - return res - def done(self) -> None: """Stop capturing. @@ -454,24 +524,39 @@ def resume(self) -> None: os.dup2(self.tmpfile.fileno(), self.targetfd) self._state = "started" + +class FDCaptureBinary(FDCaptureBase[bytes]): + """Capture IO to/from a given OS-level file descriptor. + + snap() produces `bytes`. + """ + + EMPTY_BUFFER = b"" + + def snap(self) -> bytes: + self._assert_state("snap", ("started", "suspended")) + self.tmpfile.seek(0) + res = self.tmpfile.buffer.read() + self.tmpfile.seek(0) + self.tmpfile.truncate() + return res + def writeorg(self, data: bytes) -> None: """Write to original file descriptor.""" self._assert_state("writeorg", ("started", "suspended")) os.write(self.targetfd_save, data) -class FDCapture(FDCaptureBinary): +class FDCapture(FDCaptureBase[str]): """Capture IO to/from a given OS-level file descriptor. snap() produces text. """ - # Ignore type because it doesn't match the type in the superclass (bytes). - EMPTY_BUFFER = "" # type: ignore + EMPTY_BUFFER = "" - # Ignore type because it doesn't match the type in the superclass (bytes). - def snap(self) -> str: # type: ignore + def snap(self) -> str: self._assert_state("snap", ("started", "suspended")) self.tmpfile.seek(0) res = self.tmpfile.read() @@ -479,70 +564,36 @@ def snap(self) -> str: # type: ignore self.tmpfile.truncate() return res - # Ignore type because it doesn't match the type in the superclass (bytes). - def writeorg(self, data: str) -> None: # type: ignore + def writeorg(self, data: str) -> None: """Write to original file descriptor.""" - super().writeorg(data.encode("utf-8")) + self._assert_state("writeorg", ("started", "suspended")) + # Will be fixed by pytest. Use encoding of original stream + os.write(self.targetfd_save, data.encode("utf-8")) # MultiCapture -@final -@functools.total_ordering -class CaptureResult(Generic[AnyStr]): - """The result of :meth:`MultiCapture.readouterr` which wraps stdout and stderr. - - This class was a namedtuple, but due to mypy limitation [0]_ it could not be made - generic, so was replaced by a regular class which tries to emulate the pertinent - parts of a namedtuple. If the mypy limitation is ever lifted, can make it a - namedtuple again (https://github.com/python/mypy/issues/685). - - """ +# Generic NamedTuple only supported since Python 3.11. +if sys.version_info >= (3, 11) or TYPE_CHECKING: - __slots__ = ("out", "err") + @final + class CaptureResult(NamedTuple, Generic[AnyStr]): + """A class for capture results.""" - def __init__(self, out: AnyStr, err: AnyStr) -> None: - self.out: AnyStr = out - self.err: AnyStr = err + out: AnyStr + err: AnyStr - def __len__(self) -> int: - return 2 +else: + import collections - def __iter__(self) -> Iterator[AnyStr]: - return iter((self.out, self.err)) - - def __getitem__(self, item: int) -> AnyStr: - return tuple(self)[item] - - def _replace( - self, *, out: AnyStr | None = None, err: AnyStr | None = None - ) -> CaptureResult[AnyStr]: - return CaptureResult( - out=self.out if out is None else out, err=self.err if err is None else err - ) + class CaptureResult( + collections.namedtuple("CaptureResult", ["out", "err"]), # noqa: PYI024 + Generic[AnyStr], + ): + """A class for capture results.""" - def count(self, value: AnyStr) -> int: - return tuple(self).count(value) - - def index(self, value: int) -> int: - return tuple(self).index(value) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, (CaptureResult, tuple)): - return NotImplemented - return tuple(self) == tuple(other) - - def __hash__(self) -> int: - return hash(tuple(self)) - - def __lt__(self, other: object) -> bool: - if not isinstance(other, (CaptureResult, tuple)): - return NotImplemented - return tuple(self) < tuple(other) - - def __repr__(self) -> str: - return f"CaptureResult(out={self.out!r}, err={self.err!r})" + __slots__ = () class MultiCapture(Generic[AnyStr]): @@ -559,13 +610,13 @@ class MultiCapture(Generic[AnyStr]): def __init__( self, - in_: FDCapture | SysCapture | None, - out: FDCapture | SysCapture | None, - err: FDCapture | SysCapture | None, + in_: CaptureBase[AnyStr] | None, + out: CaptureBase[AnyStr] | None, + err: CaptureBase[AnyStr] | None, ) -> None: - self.in_ = in_ - self.out = out - self.err = err + self.in_: CaptureBase[AnyStr] | None = in_ + self.out: CaptureBase[AnyStr] | None = out + self.err: CaptureBase[AnyStr] | None = err def __repr__(self) -> str: return ( # noqa: UP032 @@ -592,9 +643,9 @@ def pop_outerr_to_orig(self) -> tuple[AnyStr, AnyStr]: """Pop current snapshot out/err capture and flush to orig streams.""" out, err = self.readouterr() if out: - self.out.writeorg(out) # type: ignore + self.out.writeorg(out) # type: ignore[union-attr] if err: - self.err.writeorg(err) # type: ignore + self.err.writeorg(err) # type: ignore[union-attr] return out, err def suspend_capturing(self, in_: bool = False) -> None: @@ -637,7 +688,8 @@ def is_started(self) -> bool: def readouterr(self) -> CaptureResult[AnyStr]: out = self.out.snap() if self.out else "" err = self.err.snap() if self.err else "" - return CaptureResult(out, err) # type: ignore + # Will be fixed by pytest. This type error is real, need to fix. + return CaptureResult(out, err) # type: ignore[arg-type] def _get_multicapture(method: CaptureMethod) -> MultiCapture[str]: