diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py index 2dca02518a3c..5b78f10e7373 100644 --- a/discord/ext/tasks/__init__.py +++ b/discord/ext/tasks/__init__.py @@ -27,22 +27,20 @@ import asyncio import datetime from typing import ( - Any, - Awaitable, - Callable, + Any, + Awaitable, + Callable, Generic, - List, - Optional, - Type, + List, + Optional, + Type, TypeVar, Union, - cast, ) import aiohttp import discord import inspect -import logging import sys import traceback @@ -50,8 +48,6 @@ from discord.backoff import ExponentialBackoff from discord.utils import MISSING -_log = logging.getLogger(__name__) - __all__ = ( 'loop', ) @@ -61,7 +57,6 @@ LF = TypeVar('LF', bound=_func) FT = TypeVar('FT', bound=_func) ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]]) -LT = TypeVar('LT', bound='Loop') class SleepHandle: @@ -78,7 +73,7 @@ def recalculate(self, dt: datetime.datetime) -> None: relative_delta = discord.utils.compute_timedelta(dt) self.handle = self.loop.call_later(relative_delta, self.future.set_result, True) - def wait(self) -> asyncio.Future: + def wait(self) -> asyncio.Future[Any]: return self.future def done(self) -> bool: @@ -94,7 +89,9 @@ class Loop(Generic[LF]): The main interface to create this is through :func:`loop`. """ - def __init__(self, + + def __init__( + self, coro: LF, seconds: float, hours: float, @@ -102,15 +99,15 @@ def __init__(self, time: Union[datetime.time, Sequence[datetime.time]], count: Optional[int], reconnect: bool, - loop: Optional[asyncio.AbstractEventLoop], + loop: asyncio.AbstractEventLoop, ) -> None: self.coro: LF = coro self.reconnect: bool = reconnect - self.loop: Optional[asyncio.AbstractEventLoop] = loop + self.loop: asyncio.AbstractEventLoop = loop self.count: Optional[int] = count self._current_loop = 0 - self._handle = None - self._task = None + self._handle: SleepHandle = MISSING + self._task: asyncio.Task[None] = MISSING self._injected = None self._valid_exception = ( OSError, @@ -131,7 +128,7 @@ def __init__(self, self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time) self._last_iteration_failed = False - self._last_iteration = None + self._last_iteration: datetime.datetime = MISSING self._next_iteration = None if not inspect.iscoroutinefunction(self.coro): @@ -147,9 +144,8 @@ async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> Non else: await coro(*args, **kwargs) - def _try_sleep_until(self, dt: datetime.datetime): - self._handle = SleepHandle(dt=dt, loop=self.loop) # type: ignore + self._handle = SleepHandle(dt=dt, loop=self.loop) return self._handle.wait() async def _loop(self, *args: Any, **kwargs: Any) -> None: @@ -178,7 +174,7 @@ async def _loop(self, *args: Any, **kwargs: Any) -> None: await asyncio.sleep(backoff.delay()) else: await self._try_sleep_until(self._next_iteration) - + if self._stop_next_iteration: return @@ -211,14 +207,14 @@ def __get__(self, obj: T, objtype: Type[T]) -> Loop[LF]: if obj is None: return self - copy = Loop( - self.coro, - seconds=self._seconds, - hours=self._hours, + copy: Loop[LF] = Loop( + self.coro, + seconds=self._seconds, + hours=self._hours, minutes=self._minutes, - time=self._time, + time=self._time, count=self.count, - reconnect=self.reconnect, + reconnect=self.reconnect, loop=self.loop, ) copy._injected = obj @@ -237,7 +233,7 @@ def seconds(self) -> Optional[float]: """ if self._seconds is not MISSING: return self._seconds - + @property def minutes(self) -> Optional[float]: """Optional[:class:`float`]: Read-only value for the number of minutes @@ -247,7 +243,7 @@ def minutes(self) -> Optional[float]: """ if self._minutes is not MISSING: return self._minutes - + @property def hours(self) -> Optional[float]: """Optional[:class:`float`]: Read-only value for the number of hours @@ -279,7 +275,7 @@ def next_iteration(self) -> Optional[datetime.datetime]: .. versionadded:: 1.3 """ - if self._task is None: + if self._task is MISSING: return None elif self._task and self._task.done() or self._stop_next_iteration: return None @@ -305,7 +301,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any: return await self.coro(*args, **kwargs) - def start(self, *args: Any, **kwargs: Any) -> asyncio.Task: + def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]: r"""Starts the internal task in the event loop. Parameters @@ -326,13 +322,13 @@ def start(self, *args: Any, **kwargs: Any) -> asyncio.Task: The task that has been created. """ - if self._task is not None and not self._task.done(): + if self._task is not MISSING and not self._task.done(): raise RuntimeError('Task is already launched and is not completed.') if self._injected is not None: args = (self._injected, *args) - if self.loop is None: + if self.loop is MISSING: self.loop = asyncio.get_event_loop() self._task = self.loop.create_task(self._loop(*args, **kwargs)) @@ -356,7 +352,7 @@ def stop(self) -> None: .. versionadded:: 1.2 """ - if self._task and not self._task.done(): + if self._task is not MISSING and not self._task.done(): self._stop_next_iteration = True def _can_be_cancelled(self) -> bool: @@ -383,7 +379,7 @@ def restart(self, *args: Any, **kwargs: Any) -> None: The keyword arguments to use. """ - def restart_when_over(fut, *, args=args, kwargs=kwargs): + def restart_when_over(fut: Any, *, args: Any = args, kwargs: Any = kwargs) -> None: self._task.remove_done_callback(restart_when_over) self.start(*args, **kwargs) @@ -446,9 +442,9 @@ def remove_exception_type(self, *exceptions: Type[BaseException]) -> bool: self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions) return len(self._valid_exception) == old_length - len(exceptions) - def get_task(self) -> Optional[asyncio.Task]: + def get_task(self) -> Optional[asyncio.Task[None]]: """Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running.""" - return self._task + return self._task if self._task is not MISSING else None def is_being_cancelled(self) -> bool: """Whether the task is being cancelled.""" @@ -466,7 +462,7 @@ def is_running(self) -> bool: .. versionadded:: 1.4 """ - return not bool(self._task.done()) if self._task else False + return not bool(self._task.done()) if self._task is not MISSING else False async def _error(self, *args: Any) -> None: exception: Exception = args[-1] @@ -560,7 +556,9 @@ def _get_next_sleep_time(self) -> datetime.datetime: self._time_index = 0 if self._current_loop == 0: # if we're at the last index on the first iteration, we need to sleep until tomorrow - return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0]) + return datetime.datetime.combine( + datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0] + ) next_time = self._time[self._time_index] @@ -568,7 +566,7 @@ def _get_next_sleep_time(self) -> datetime.datetime: self._time_index += 1 return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time) - next_date = cast(datetime.datetime, self._last_iteration) + next_date = self._last_iteration if self._time_index == 0: # we can assume that the earliest time should be scheduled for "tomorrow" next_date += datetime.timedelta(days=1) @@ -576,12 +574,14 @@ def _get_next_sleep_time(self) -> datetime.datetime: self._time_index += 1 return datetime.datetime.combine(next_date, next_time) - def _prepare_time_index(self, now: Optional[datetime.datetime] = None) -> None: + def _prepare_time_index(self, now: datetime.datetime = MISSING) -> None: # now kwarg should be a datetime.datetime representing the time "now" # to calculate the next time index from # pre-condition: self._time is set - time_now = (now or datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)).timetz() + time_now = ( + now if now is not MISSING else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0) + ).timetz() for idx, time in enumerate(self._time): if time >= time_now: self._time_index = idx @@ -597,20 +597,24 @@ def _get_time_parameter( utc: datetime.timezone = datetime.timezone.utc, ) -> List[datetime.time]: if isinstance(time, dt): - ret = time if time.tzinfo is not None else time.replace(tzinfo=utc) - return [ret] + inner = time if time.tzinfo is not None else time.replace(tzinfo=utc) + return [inner] if not isinstance(time, Sequence): - raise TypeError(f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.') + raise TypeError( + f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.' + ) if not time: raise ValueError('time parameter must not be an empty sequence.') - ret = [] + ret: List[datetime.time] = [] for index, t in enumerate(time): if not isinstance(t, dt): - raise TypeError(f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.') + raise TypeError( + f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.' + ) ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc)) - ret = sorted(set(ret)) # de-dupe and sort times + ret = sorted(set(ret)) # de-dupe and sort times return ret def change_interval( @@ -691,7 +695,7 @@ def loop( time: Union[datetime.time, Sequence[datetime.time]] = MISSING, count: Optional[int] = None, reconnect: bool = True, - loop: Optional[asyncio.AbstractEventLoop] = None, + loop: asyncio.AbstractEventLoop = MISSING, ) -> Callable[[LF], Loop[LF]]: """A decorator that schedules a task in the background for you with optional reconnect logic. The decorator returns a :class:`Loop`. @@ -707,7 +711,7 @@ def loop( time: Union[:class:`datetime.time`, Sequence[:class:`datetime.time`]] The exact times to run this loop at. Either a non-empty list or a single value of :class:`datetime.time` should be passed. Timezones are supported. - If no timezone is given for the times, it is assumed to represent UTC time. + If no timezone is given for the times, it is assumed to represent UTC time. This cannot be used in conjunction with the relative time parameters. @@ -724,7 +728,7 @@ def loop( Whether to handle errors and restart the task using an exponential back-off algorithm similar to the one used in :meth:`discord.Client.connect`. - loop: Optional[:class:`asyncio.AbstractEventLoop`] + loop: :class:`asyncio.AbstractEventLoop` The loop to use to register the task, if not given defaults to :func:`asyncio.get_event_loop`. @@ -736,15 +740,17 @@ def loop( The function was not a coroutine, an invalid value for the ``time`` parameter was passed, or ``time`` parameter was passed in conjunction with relative time parameters. """ + def decorator(func: LF) -> Loop[LF]: - kwargs = { - 'seconds': seconds, - 'minutes': minutes, - 'hours': hours, - 'count': count, - 'time': time, - 'reconnect': reconnect, - 'loop': loop, - } - return Loop(func, **kwargs) + return Loop[LF]( + func, + seconds=seconds, + minutes=minutes, + hours=hours, + count=count, + time=time, + reconnect=reconnect, + loop=loop, + ) + return decorator