From 91e00d84267f059df078a2e5132764f533472f3a Mon Sep 17 00:00:00 2001 From: Rapptz Date: Tue, 30 Apr 2019 01:44:50 -0400 Subject: [PATCH] [tasks] Add way to query cancellation state for Loop.after_loop Fixes #2121 --- discord/ext/tasks/__init__.py | 28 ++++++++++++++++++++-------- docs/ext/tasks/index.rst | 31 +++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 8 deletions(-) diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py index a6425a8b7bda..44194917f761 100644 --- a/discord/ext/tasks/__init__.py +++ b/discord/ext/tasks/__init__.py @@ -37,6 +37,7 @@ def __init__(self, coro, seconds, hours, minutes, count, reconnect, loop): self._before_loop = None self._after_loop = None + self._is_being_cancelled = False if self.count is not None and self.count <= 0: raise ValueError('count must be greater than 0 or None.') @@ -69,8 +70,6 @@ async def _loop(self, *args, **kwargs): while True: try: await self.coro(*args, **kwargs) - except asyncio.CancelledError: - break except self._valid_exception as exc: if not self.reconnect: raise @@ -81,8 +80,12 @@ async def _loop(self, *args, **kwargs): break await asyncio.sleep(self._sleep) + except asyncio.CancelledError: + self._is_being_cancelled = True + raise finally: await self._call_loop_function('after_loop') + self._is_being_cancelled = False def __get__(self, obj, objtype): if obj is None: @@ -108,7 +111,7 @@ def start(self, *args, **kwargs): Raises -------- RuntimeError - A task has already been launched. + A task has already been launched and is running. Returns --------- @@ -116,8 +119,8 @@ def start(self, *args, **kwargs): The task that has been created. """ - if self._task is not None: - raise RuntimeError('Task is already launched.') + if self._task is not None and not self._task.done(): + raise RuntimeError('Task is already launched and is not completed.') if self._injected is not None: args = (self._injected, *args) @@ -126,10 +129,9 @@ def start(self, *args, **kwargs): return self._task def cancel(self): - """Cancels the internal task, if any are running.""" - if self._task: + """Cancels the internal task, if it is running.""" + if not self._is_being_cancelled and self._task and not self._task.done(): self._task.cancel() - self._task = None def add_exception_type(self, exc): r"""Adds an exception type to be handled during the reconnect logic. @@ -189,6 +191,10 @@ def get_task(self): """Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running.""" return self._task + def is_being_cancelled(self): + """:class:`bool`: Whether the task is being cancelled.""" + return self._is_being_cancelled + def before_loop(self, coro): """A decorator that registers a coroutine to be called before the loop starts running. @@ -219,6 +225,12 @@ def after_loop(self, coro): The coroutine must take no arguments (except ``self`` in a class context). + .. note:: + + This coroutine is called even during cancellation. If it is desirable + to tell apart whether something was cancelled or not, check to see + whether :meth:`is_being_cancelled` is ``True`` or not. + Parameters ------------ coro: :ref:`coroutine ` diff --git a/docs/ext/tasks/index.rst b/docs/ext/tasks/index.rst index 2aeae796e6cb..0e9a65b933f5 100644 --- a/docs/ext/tasks/index.rst +++ b/docs/ext/tasks/index.rst @@ -97,6 +97,37 @@ Waiting until the bot is ready before the loop starts: print('waiting...') await self.bot.wait_until_ready() +Doing something during cancellation: + +.. code-block:: python3 + + from discord.ext import tasks, commands + import asyncio + + class MyCog(commands.Cog): + def __init__(self, bot): + self.bot= bot + self._batch = [] + self.lock = asyncio.Lock(loop=bot.loop) + self.bulker.start() + + async def do_bulk(self): + # bulk insert data here + ... + + @tasks.loop(seconds=10.0) + async def bulker(self): + async with self.lock: + await self.do_bulk() + + @bulker.after_loop + async def on_bulker_cancel(self): + if self.bulker.is_being_cancelled() and len(self._batch) != 0: + # if we're cancelled and we have some data left... + # let's insert it to our database + await self.do_bulk() + + API Reference ---------------