From a2e8346d9a7c60304a51a52abeee4535cd86af67 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Fri, 1 Jul 2022 19:20:39 +0100 Subject: [PATCH] Improve `multiprocessing` stubs (#8202) Co-authored-by: Shantanu <12621235+hauntsaninja@users.noreply.github.com> --- stdlib/multiprocessing/__init__.pyi | 90 +++++++---------------- stdlib/multiprocessing/connection.pyi | 2 +- stdlib/multiprocessing/context.pyi | 17 +++-- stdlib/multiprocessing/dummy/__init__.pyi | 24 +++--- stdlib/multiprocessing/managers.pyi | 4 +- stdlib/multiprocessing/pool.pyi | 4 +- stdlib/multiprocessing/process.pyi | 2 +- stdlib/multiprocessing/synchronize.pyi | 3 +- 8 files changed, 56 insertions(+), 90 deletions(-) diff --git a/stdlib/multiprocessing/__init__.pyi b/stdlib/multiprocessing/__init__.pyi index 41af971bc619..4359b6c080aa 100644 --- a/stdlib/multiprocessing/__init__.pyi +++ b/stdlib/multiprocessing/__init__.pyi @@ -1,18 +1,12 @@ import sys -from collections.abc import Callable, Iterable -from logging import Logger -from multiprocessing import connection, context, pool, reduction as reducer, synchronize +from multiprocessing import context, reduction as reducer, synchronize from multiprocessing.context import ( AuthenticationError as AuthenticationError, - BaseContext, BufferTooShort as BufferTooShort, - DefaultContext, Process as Process, ProcessError as ProcessError, - SpawnContext, TimeoutError as TimeoutError, ) -from multiprocessing.managers import SyncManager from multiprocessing.process import active_children as active_children, current_process as current_process # These are technically functions that return instances of these Queue classes. @@ -20,15 +14,12 @@ from multiprocessing.process import active_children as active_children, current_ # multiprocessing.queues or the aliases defined below. See #4266 for discussion. from multiprocessing.queues import JoinableQueue as JoinableQueue, Queue as Queue, SimpleQueue as SimpleQueue from multiprocessing.spawn import freeze_support as freeze_support -from typing import Any, TypeVar, overload -from typing_extensions import Literal, TypeAlias +from typing import TypeVar +from typing_extensions import TypeAlias if sys.version_info >= (3, 8): from multiprocessing.process import parent_process as parent_process -if sys.platform != "win32": - from multiprocessing.context import ForkContext, ForkServerContext - __all__ = [ "Array", "AuthenticationError", @@ -92,60 +83,29 @@ _LockType: TypeAlias = synchronize.Lock _RLockType: TypeAlias = synchronize.RLock _SemaphoreType: TypeAlias = synchronize.Semaphore -# N.B. The functions below are generated at runtime by partially applying -# multiprocessing.context.BaseContext's methods, so the two signatures should -# be identical (modulo self). - -# Synchronization primitives -_LockLike: TypeAlias = synchronize.Lock | synchronize.RLock +# These functions (really bound methods) +# are all autogenerated at runtime here: https://github.com/python/cpython/blob/600c65c094b0b48704d8ec2416930648052ba715/Lib/multiprocessing/__init__.py#L23 RawValue = context._default_context.RawValue RawArray = context._default_context.RawArray Value = context._default_context.Value Array = context._default_context.Array - -def Barrier(parties: int, action: Callable[..., Any] | None = ..., timeout: float | None = ...) -> _BarrierType: ... -def BoundedSemaphore(value: int = ...) -> _BoundedSemaphoreType: ... -def Condition(lock: _LockLike | None = ...) -> _ConditionType: ... -def Event() -> _EventType: ... -def Lock() -> _LockType: ... -def RLock() -> _RLockType: ... -def Semaphore(value: int = ...) -> _SemaphoreType: ... -def Pipe(duplex: bool = ...) -> tuple[connection.Connection, connection.Connection]: ... -def Pool( - processes: int | None = ..., - initializer: Callable[..., Any] | None = ..., - initargs: Iterable[Any] = ..., - maxtasksperchild: int | None = ..., -) -> pool.Pool: ... - -# ----- multiprocessing function stubs ----- -def allow_connection_pickling() -> None: ... -def cpu_count() -> int: ... -def get_logger() -> Logger: ... -def log_to_stderr(level: str | int | None = ...) -> Logger: ... -def Manager() -> SyncManager: ... -def set_executable(executable: str) -> None: ... -def set_forkserver_preload(module_names: list[str]) -> None: ... -def get_all_start_methods() -> list[str]: ... -def get_start_method(allow_none: bool = ...) -> str | None: ... -def set_start_method(method: str, force: bool | None = ...) -> None: ... - -if sys.platform != "win32": - @overload - def get_context(method: None = ...) -> DefaultContext: ... - @overload - def get_context(method: Literal["spawn"]) -> SpawnContext: ... - @overload - def get_context(method: Literal["fork"]) -> ForkContext: ... - @overload - def get_context(method: Literal["forkserver"]) -> ForkServerContext: ... - @overload - def get_context(method: str) -> BaseContext: ... - -else: - @overload - def get_context(method: None = ...) -> DefaultContext: ... - @overload - def get_context(method: Literal["spawn"]) -> SpawnContext: ... - @overload - def get_context(method: str) -> BaseContext: ... +Barrier = context._default_context.Barrier +BoundedSemaphore = context._default_context.BoundedSemaphore +Condition = context._default_context.Condition +Event = context._default_context.Event +Lock = context._default_context.Lock +RLock = context._default_context.RLock +Semaphore = context._default_context.Semaphore +Pipe = context._default_context.Pipe +Pool = context._default_context.Pool +allow_connection_pickling = context._default_context.allow_connection_pickling +cpu_count = context._default_context.cpu_count +get_logger = context._default_context.get_logger +log_to_stderr = context._default_context.log_to_stderr +Manager = context._default_context.Manager +set_executable = context._default_context.set_executable +set_forkserver_preload = context._default_context.set_forkserver_preload +get_all_start_methods = context._default_context.get_all_start_methods +get_start_method = context._default_context.get_start_method +set_start_method = context._default_context.set_start_method +get_context = context._default_context.get_context diff --git a/stdlib/multiprocessing/connection.pyi b/stdlib/multiprocessing/connection.pyi index 7b227a697abe..489e8bd9a9f1 100644 --- a/stdlib/multiprocessing/connection.pyi +++ b/stdlib/multiprocessing/connection.pyi @@ -58,4 +58,4 @@ def wait( object_list: Iterable[Connection | socket.socket | int], timeout: float | None = ... ) -> list[Connection | socket.socket | int]: ... def Client(address: _Address, family: str | None = ..., authkey: bytes | None = ...) -> Connection: ... -def Pipe(duplex: bool = ...) -> tuple[Connection, Connection]: ... +def Pipe(duplex: bool = ...) -> tuple[_ConnectionBase, _ConnectionBase]: ... diff --git a/stdlib/multiprocessing/context.pyi b/stdlib/multiprocessing/context.pyi index d618d1028112..ed52325915c4 100644 --- a/stdlib/multiprocessing/context.pyi +++ b/stdlib/multiprocessing/context.pyi @@ -5,6 +5,8 @@ from collections.abc import Callable, Iterable, Sequence from ctypes import _CData from logging import Logger from multiprocessing import queues, synchronize +from multiprocessing.connection import _ConnectionBase +from multiprocessing.managers import SyncManager from multiprocessing.pool import Pool as _Pool from multiprocessing.process import BaseProcess from multiprocessing.sharedctypes import SynchronizedArray, SynchronizedBase @@ -42,12 +44,10 @@ class BaseContext: @staticmethod def active_children() -> list[BaseProcess]: ... def cpu_count(self) -> int: ... - # TODO: change return to SyncManager once a stub exists in multiprocessing.managers - def Manager(self) -> Any: ... - # TODO: change return to Pipe once a stub exists in multiprocessing.connection - def Pipe(self, duplex: bool = ...) -> Any: ... + def Manager(self) -> SyncManager: ... + def Pipe(self, duplex: bool = ...) -> tuple[_ConnectionBase, _ConnectionBase]: ... def Barrier( - self, parties: int, action: Callable[..., Any] | None = ..., timeout: float | None = ... + self, parties: int, action: Callable[..., object] | None = ..., timeout: float | None = ... ) -> synchronize.Barrier: ... def BoundedSemaphore(self, value: int = ...) -> synchronize.BoundedSemaphore: ... def Condition(self, lock: _LockLike | None = ...) -> synchronize.Condition: ... @@ -61,7 +61,7 @@ class BaseContext: def Pool( self, processes: int | None = ..., - initializer: Callable[..., Any] | None = ..., + initializer: Callable[..., object] | None = ..., initargs: Iterable[Any] = ..., maxtasksperchild: int | None = ..., ) -> _Pool: ... @@ -120,7 +120,10 @@ class BaseContext: @overload def get_context(self, method: str) -> BaseContext: ... - def get_start_method(self, allow_none: bool = ...) -> str: ... + @overload + def get_start_method(self, allow_none: Literal[False] = ...) -> str: ... + @overload + def get_start_method(self, allow_none: bool) -> str | None: ... def set_start_method(self, method: str | None, force: bool = ...) -> None: ... @property def reducer(self) -> str: ... diff --git a/stdlib/multiprocessing/dummy/__init__.pyi b/stdlib/multiprocessing/dummy/__init__.pyi index bbddfd16ded7..5d289c058e03 100644 --- a/stdlib/multiprocessing/dummy/__init__.pyi +++ b/stdlib/multiprocessing/dummy/__init__.pyi @@ -3,6 +3,15 @@ import threading import weakref from collections.abc import Callable, Iterable, Mapping, Sequence from queue import Queue as Queue +from threading import ( + Barrier as Barrier, + BoundedSemaphore as BoundedSemaphore, + Condition as Condition, + Event as Event, + Lock as Lock, + RLock as RLock, + Semaphore as Semaphore, +) from typing import Any from typing_extensions import Literal @@ -28,13 +37,6 @@ __all__ = [ ] JoinableQueue = Queue -Barrier = threading.Barrier -BoundedSemaphore = threading.BoundedSemaphore -Condition = threading.Condition -Event = threading.Event -Lock = threading.Lock -RLock = threading.RLock -Semaphore = threading.Semaphore class DummyProcess(threading.Thread): _children: weakref.WeakKeyDictionary[Any, Any] @@ -46,7 +48,7 @@ class DummyProcess(threading.Thread): def __init__( self, group: Any = ..., - target: Callable[..., Any] | None = ..., + target: Callable[..., object] | None = ..., name: str | None = ..., args: Iterable[Any] = ..., kwargs: Mapping[str, Any] = ..., @@ -67,8 +69,10 @@ class Value: def Array(typecode: Any, sequence: Sequence[Any], lock: Any = ...) -> array.array[Any]: ... def Manager() -> Any: ... -def Pool(processes: int | None = ..., initializer: Callable[..., Any] | None = ..., initargs: Iterable[Any] = ...) -> Any: ... +def Pool(processes: int | None = ..., initializer: Callable[..., object] | None = ..., initargs: Iterable[Any] = ...) -> Any: ... def active_children() -> list[Any]: ... -def current_process() -> threading.Thread: ... + +current_process = threading.current_thread + def freeze_support() -> None: ... def shutdown() -> None: ... diff --git a/stdlib/multiprocessing/managers.pyi b/stdlib/multiprocessing/managers.pyi index 212ffcbf5a3a..5537ea937bae 100644 --- a/stdlib/multiprocessing/managers.pyi +++ b/stdlib/multiprocessing/managers.pyi @@ -148,7 +148,7 @@ class BaseManager: def get_server(self) -> Server: ... def connect(self) -> None: ... - def start(self, initializer: Callable[..., Any] | None = ..., initargs: Iterable[Any] = ...) -> None: ... + def start(self, initializer: Callable[..., object] | None = ..., initargs: Iterable[Any] = ...) -> None: ... def shutdown(self) -> None: ... # only available after start() was called def join(self, timeout: float | None = ...) -> None: ... # undocumented @property @@ -157,7 +157,7 @@ class BaseManager: def register( cls, typeid: str, - callable: Callable[..., Any] | None = ..., + callable: Callable[..., object] | None = ..., proxytype: Any = ..., exposed: Sequence[str] | None = ..., method_to_typeid: Mapping[str, str] | None = ..., diff --git a/stdlib/multiprocessing/pool.pyi b/stdlib/multiprocessing/pool.pyi index 4e1b6c159a2a..2b97e16f0525 100644 --- a/stdlib/multiprocessing/pool.pyi +++ b/stdlib/multiprocessing/pool.pyi @@ -72,7 +72,7 @@ class Pool: def __init__( self, processes: int | None = ..., - initializer: Callable[..., None] | None = ..., + initializer: Callable[..., object] | None = ..., initargs: Iterable[Any] = ..., maxtasksperchild: int | None = ..., context: Any | None = ..., @@ -118,7 +118,7 @@ class Pool: class ThreadPool(Pool): def __init__( - self, processes: int | None = ..., initializer: Callable[..., Any] | None = ..., initargs: Iterable[Any] = ... + self, processes: int | None = ..., initializer: Callable[..., object] | None = ..., initargs: Iterable[Any] = ... ) -> None: ... # undocumented diff --git a/stdlib/multiprocessing/process.pyi b/stdlib/multiprocessing/process.pyi index 1601decbbebc..f903cef6fa72 100644 --- a/stdlib/multiprocessing/process.pyi +++ b/stdlib/multiprocessing/process.pyi @@ -15,7 +15,7 @@ class BaseProcess: def __init__( self, group: None = ..., - target: Callable[..., Any] | None = ..., + target: Callable[..., object] | None = ..., name: str | None = ..., args: Iterable[Any] = ..., kwargs: Mapping[str, Any] = ..., diff --git a/stdlib/multiprocessing/synchronize.pyi b/stdlib/multiprocessing/synchronize.pyi index e93d6c58b5cf..7a86935f7d18 100644 --- a/stdlib/multiprocessing/synchronize.pyi +++ b/stdlib/multiprocessing/synchronize.pyi @@ -4,7 +4,6 @@ from collections.abc import Callable from contextlib import AbstractContextManager from multiprocessing.context import BaseContext from types import TracebackType -from typing import Any from typing_extensions import TypeAlias __all__ = ["Lock", "RLock", "Semaphore", "BoundedSemaphore", "Condition", "Event"] @@ -13,7 +12,7 @@ _LockLike: TypeAlias = Lock | RLock class Barrier(threading.Barrier): def __init__( - self, parties: int, action: Callable[..., Any] | None = ..., timeout: float | None = ..., *ctx: BaseContext + self, parties: int, action: Callable[[], object] | None = ..., timeout: float | None = ..., *ctx: BaseContext ) -> None: ... class BoundedSemaphore(Semaphore):