Skip to content

Commit

Permalink
Improve multiprocessing stubs (#8202)
Browse files Browse the repository at this point in the history
Co-authored-by: Shantanu <[email protected]>
  • Loading branch information
AlexWaygood and hauntsaninja authored Jul 1, 2022
1 parent 7b54854 commit a2e8346
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 90 deletions.
90 changes: 25 additions & 65 deletions stdlib/multiprocessing/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,34 +1,25 @@
import sys
from collections.abc import Callable, Iterable
from logging import Logger
from multiprocessing import connection, context, pool, reduction as reducer, synchronize
from multiprocessing import context, reduction as reducer, synchronize
from multiprocessing.context import (
AuthenticationError as AuthenticationError,
BaseContext,
BufferTooShort as BufferTooShort,
DefaultContext,
Process as Process,
ProcessError as ProcessError,
SpawnContext,
TimeoutError as TimeoutError,
)
from multiprocessing.managers import SyncManager
from multiprocessing.process import active_children as active_children, current_process as current_process

# These are technically functions that return instances of these Queue classes.
# Using them as annotations is deprecated. Either use imports from
# multiprocessing.queues or the aliases defined below. See #4266 for discussion.
from multiprocessing.queues import JoinableQueue as JoinableQueue, Queue as Queue, SimpleQueue as SimpleQueue
from multiprocessing.spawn import freeze_support as freeze_support
from typing import Any, TypeVar, overload
from typing_extensions import Literal, TypeAlias
from typing import TypeVar
from typing_extensions import TypeAlias

if sys.version_info >= (3, 8):
from multiprocessing.process import parent_process as parent_process

if sys.platform != "win32":
from multiprocessing.context import ForkContext, ForkServerContext

__all__ = [
"Array",
"AuthenticationError",
Expand Down Expand Up @@ -92,60 +83,29 @@ _LockType: TypeAlias = synchronize.Lock
_RLockType: TypeAlias = synchronize.RLock
_SemaphoreType: TypeAlias = synchronize.Semaphore

# N.B. The functions below are generated at runtime by partially applying
# multiprocessing.context.BaseContext's methods, so the two signatures should
# be identical (modulo self).

# Synchronization primitives
_LockLike: TypeAlias = synchronize.Lock | synchronize.RLock
# These functions (really bound methods)
# are all autogenerated at runtime here: https://github.com/python/cpython/blob/600c65c094b0b48704d8ec2416930648052ba715/Lib/multiprocessing/__init__.py#L23
RawValue = context._default_context.RawValue
RawArray = context._default_context.RawArray
Value = context._default_context.Value
Array = context._default_context.Array

def Barrier(parties: int, action: Callable[..., Any] | None = ..., timeout: float | None = ...) -> _BarrierType: ...
def BoundedSemaphore(value: int = ...) -> _BoundedSemaphoreType: ...
def Condition(lock: _LockLike | None = ...) -> _ConditionType: ...
def Event() -> _EventType: ...
def Lock() -> _LockType: ...
def RLock() -> _RLockType: ...
def Semaphore(value: int = ...) -> _SemaphoreType: ...
def Pipe(duplex: bool = ...) -> tuple[connection.Connection, connection.Connection]: ...
def Pool(
processes: int | None = ...,
initializer: Callable[..., Any] | None = ...,
initargs: Iterable[Any] = ...,
maxtasksperchild: int | None = ...,
) -> pool.Pool: ...

# ----- multiprocessing function stubs -----
def allow_connection_pickling() -> None: ...
def cpu_count() -> int: ...
def get_logger() -> Logger: ...
def log_to_stderr(level: str | int | None = ...) -> Logger: ...
def Manager() -> SyncManager: ...
def set_executable(executable: str) -> None: ...
def set_forkserver_preload(module_names: list[str]) -> None: ...
def get_all_start_methods() -> list[str]: ...
def get_start_method(allow_none: bool = ...) -> str | None: ...
def set_start_method(method: str, force: bool | None = ...) -> None: ...

if sys.platform != "win32":
@overload
def get_context(method: None = ...) -> DefaultContext: ...
@overload
def get_context(method: Literal["spawn"]) -> SpawnContext: ...
@overload
def get_context(method: Literal["fork"]) -> ForkContext: ...
@overload
def get_context(method: Literal["forkserver"]) -> ForkServerContext: ...
@overload
def get_context(method: str) -> BaseContext: ...

else:
@overload
def get_context(method: None = ...) -> DefaultContext: ...
@overload
def get_context(method: Literal["spawn"]) -> SpawnContext: ...
@overload
def get_context(method: str) -> BaseContext: ...
Barrier = context._default_context.Barrier
BoundedSemaphore = context._default_context.BoundedSemaphore
Condition = context._default_context.Condition
Event = context._default_context.Event
Lock = context._default_context.Lock
RLock = context._default_context.RLock
Semaphore = context._default_context.Semaphore
Pipe = context._default_context.Pipe
Pool = context._default_context.Pool
allow_connection_pickling = context._default_context.allow_connection_pickling
cpu_count = context._default_context.cpu_count
get_logger = context._default_context.get_logger
log_to_stderr = context._default_context.log_to_stderr
Manager = context._default_context.Manager
set_executable = context._default_context.set_executable
set_forkserver_preload = context._default_context.set_forkserver_preload
get_all_start_methods = context._default_context.get_all_start_methods
get_start_method = context._default_context.get_start_method
set_start_method = context._default_context.set_start_method
get_context = context._default_context.get_context
2 changes: 1 addition & 1 deletion stdlib/multiprocessing/connection.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,4 @@ def wait(
object_list: Iterable[Connection | socket.socket | int], timeout: float | None = ...
) -> list[Connection | socket.socket | int]: ...
def Client(address: _Address, family: str | None = ..., authkey: bytes | None = ...) -> Connection: ...
def Pipe(duplex: bool = ...) -> tuple[Connection, Connection]: ...
def Pipe(duplex: bool = ...) -> tuple[_ConnectionBase, _ConnectionBase]: ...
17 changes: 10 additions & 7 deletions stdlib/multiprocessing/context.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ from collections.abc import Callable, Iterable, Sequence
from ctypes import _CData
from logging import Logger
from multiprocessing import queues, synchronize
from multiprocessing.connection import _ConnectionBase
from multiprocessing.managers import SyncManager
from multiprocessing.pool import Pool as _Pool
from multiprocessing.process import BaseProcess
from multiprocessing.sharedctypes import SynchronizedArray, SynchronizedBase
Expand Down Expand Up @@ -42,12 +44,10 @@ class BaseContext:
@staticmethod
def active_children() -> list[BaseProcess]: ...
def cpu_count(self) -> int: ...
# TODO: change return to SyncManager once a stub exists in multiprocessing.managers
def Manager(self) -> Any: ...
# TODO: change return to Pipe once a stub exists in multiprocessing.connection
def Pipe(self, duplex: bool = ...) -> Any: ...
def Manager(self) -> SyncManager: ...
def Pipe(self, duplex: bool = ...) -> tuple[_ConnectionBase, _ConnectionBase]: ...
def Barrier(
self, parties: int, action: Callable[..., Any] | None = ..., timeout: float | None = ...
self, parties: int, action: Callable[..., object] | None = ..., timeout: float | None = ...
) -> synchronize.Barrier: ...
def BoundedSemaphore(self, value: int = ...) -> synchronize.BoundedSemaphore: ...
def Condition(self, lock: _LockLike | None = ...) -> synchronize.Condition: ...
Expand All @@ -61,7 +61,7 @@ class BaseContext:
def Pool(
self,
processes: int | None = ...,
initializer: Callable[..., Any] | None = ...,
initializer: Callable[..., object] | None = ...,
initargs: Iterable[Any] = ...,
maxtasksperchild: int | None = ...,
) -> _Pool: ...
Expand Down Expand Up @@ -120,7 +120,10 @@ class BaseContext:
@overload
def get_context(self, method: str) -> BaseContext: ...

def get_start_method(self, allow_none: bool = ...) -> str: ...
@overload
def get_start_method(self, allow_none: Literal[False] = ...) -> str: ...
@overload
def get_start_method(self, allow_none: bool) -> str | None: ...
def set_start_method(self, method: str | None, force: bool = ...) -> None: ...
@property
def reducer(self) -> str: ...
Expand Down
24 changes: 14 additions & 10 deletions stdlib/multiprocessing/dummy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@ import threading
import weakref
from collections.abc import Callable, Iterable, Mapping, Sequence
from queue import Queue as Queue
from threading import (
Barrier as Barrier,
BoundedSemaphore as BoundedSemaphore,
Condition as Condition,
Event as Event,
Lock as Lock,
RLock as RLock,
Semaphore as Semaphore,
)
from typing import Any
from typing_extensions import Literal

Expand All @@ -28,13 +37,6 @@ __all__ = [
]

JoinableQueue = Queue
Barrier = threading.Barrier
BoundedSemaphore = threading.BoundedSemaphore
Condition = threading.Condition
Event = threading.Event
Lock = threading.Lock
RLock = threading.RLock
Semaphore = threading.Semaphore

class DummyProcess(threading.Thread):
_children: weakref.WeakKeyDictionary[Any, Any]
Expand All @@ -46,7 +48,7 @@ class DummyProcess(threading.Thread):
def __init__(
self,
group: Any = ...,
target: Callable[..., Any] | None = ...,
target: Callable[..., object] | None = ...,
name: str | None = ...,
args: Iterable[Any] = ...,
kwargs: Mapping[str, Any] = ...,
Expand All @@ -67,8 +69,10 @@ class Value:

def Array(typecode: Any, sequence: Sequence[Any], lock: Any = ...) -> array.array[Any]: ...
def Manager() -> Any: ...
def Pool(processes: int | None = ..., initializer: Callable[..., Any] | None = ..., initargs: Iterable[Any] = ...) -> Any: ...
def Pool(processes: int | None = ..., initializer: Callable[..., object] | None = ..., initargs: Iterable[Any] = ...) -> Any: ...
def active_children() -> list[Any]: ...
def current_process() -> threading.Thread: ...

current_process = threading.current_thread

def freeze_support() -> None: ...
def shutdown() -> None: ...
4 changes: 2 additions & 2 deletions stdlib/multiprocessing/managers.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class BaseManager:

def get_server(self) -> Server: ...
def connect(self) -> None: ...
def start(self, initializer: Callable[..., Any] | None = ..., initargs: Iterable[Any] = ...) -> None: ...
def start(self, initializer: Callable[..., object] | None = ..., initargs: Iterable[Any] = ...) -> None: ...
def shutdown(self) -> None: ... # only available after start() was called
def join(self, timeout: float | None = ...) -> None: ... # undocumented
@property
Expand All @@ -157,7 +157,7 @@ class BaseManager:
def register(
cls,
typeid: str,
callable: Callable[..., Any] | None = ...,
callable: Callable[..., object] | None = ...,
proxytype: Any = ...,
exposed: Sequence[str] | None = ...,
method_to_typeid: Mapping[str, str] | None = ...,
Expand Down
4 changes: 2 additions & 2 deletions stdlib/multiprocessing/pool.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class Pool:
def __init__(
self,
processes: int | None = ...,
initializer: Callable[..., None] | None = ...,
initializer: Callable[..., object] | None = ...,
initargs: Iterable[Any] = ...,
maxtasksperchild: int | None = ...,
context: Any | None = ...,
Expand Down Expand Up @@ -118,7 +118,7 @@ class Pool:

class ThreadPool(Pool):
def __init__(
self, processes: int | None = ..., initializer: Callable[..., Any] | None = ..., initargs: Iterable[Any] = ...
self, processes: int | None = ..., initializer: Callable[..., object] | None = ..., initargs: Iterable[Any] = ...
) -> None: ...

# undocumented
Expand Down
2 changes: 1 addition & 1 deletion stdlib/multiprocessing/process.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class BaseProcess:
def __init__(
self,
group: None = ...,
target: Callable[..., Any] | None = ...,
target: Callable[..., object] | None = ...,
name: str | None = ...,
args: Iterable[Any] = ...,
kwargs: Mapping[str, Any] = ...,
Expand Down
3 changes: 1 addition & 2 deletions stdlib/multiprocessing/synchronize.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ from collections.abc import Callable
from contextlib import AbstractContextManager
from multiprocessing.context import BaseContext
from types import TracebackType
from typing import Any
from typing_extensions import TypeAlias

__all__ = ["Lock", "RLock", "Semaphore", "BoundedSemaphore", "Condition", "Event"]
Expand All @@ -13,7 +12,7 @@ _LockLike: TypeAlias = Lock | RLock

class Barrier(threading.Barrier):
def __init__(
self, parties: int, action: Callable[..., Any] | None = ..., timeout: float | None = ..., *ctx: BaseContext
self, parties: int, action: Callable[[], object] | None = ..., timeout: float | None = ..., *ctx: BaseContext
) -> None: ...

class BoundedSemaphore(Semaphore):
Expand Down

0 comments on commit a2e8346

Please sign in to comment.