Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix leave_context_asyncio to handle cancelled asyncio task #141

Merged
merged 4 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 3 additions & 49 deletions asynq/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@
from ._debug import options as _debug_options


ASYNCIO_CONTEXT_FIELD = "_asynq_contexts"
ASYNCIO_CONTEXT_ACTIVE_FIELD = "_asynq_contexts_active"


class NonAsyncContext(object):
"""Indicates that context can't contain yield statements.

Expand All @@ -38,15 +34,11 @@ class NonAsyncContext(object):
"""

def __enter__(self):
if is_asyncio_mode():
enter_context_asyncio(self)
else:
if not is_asyncio_mode():
self._active_task = enter_context(self)

def __exit__(self, typ, val, tb):
if is_asyncio_mode():
leave_context_asyncio(self)
else:
if not is_asyncio_mode():
leave_context(self, self._active_task)

def pause(self):
Expand Down Expand Up @@ -75,41 +67,6 @@ def leave_context(context, active_task):
active_task._leave_context(context)


def enter_context_asyncio(context):
if _debug_options.DUMP_CONTEXTS:
debug.write("@async: +context: %s" % debug.str(context))

# since we are in asyncio mode, there is an active task
task = asyncio.current_task()

if hasattr(task, ASYNCIO_CONTEXT_FIELD):
getattr(task, ASYNCIO_CONTEXT_FIELD)[id(context)] = context
else:
setattr(task, ASYNCIO_CONTEXT_FIELD, {id(context): context})


def leave_context_asyncio(context):
if _debug_options.DUMP_CONTEXTS:
debug.write("@async: -context: %s" % debug.str(context))

task = asyncio.current_task()
del getattr(task, ASYNCIO_CONTEXT_FIELD)[id(context)] # type: ignore


def pause_contexts_asyncio(task):
if getattr(task, ASYNCIO_CONTEXT_ACTIVE_FIELD, False):
setattr(task, ASYNCIO_CONTEXT_ACTIVE_FIELD, False)
for ctx in reversed(list(getattr(task, ASYNCIO_CONTEXT_FIELD, {}).values())):
ctx.pause()


def resume_contexts_asyncio(task):
if not getattr(task, ASYNCIO_CONTEXT_ACTIVE_FIELD, False):
setattr(task, ASYNCIO_CONTEXT_ACTIVE_FIELD, True)
for ctx in getattr(task, ASYNCIO_CONTEXT_FIELD, {}).values():
ctx.resume()


class AsyncContext(object):
"""Base class for contexts that should pause and resume during an async's function execution.

Expand All @@ -127,17 +84,14 @@ class AsyncContext(object):
"""

def __enter__(self):
if is_asyncio_mode():
enter_context_asyncio(self)
else:
if not is_asyncio_mode():
self._active_task = enter_context(self)

self.resume()
return self

def __exit__(self, ty, value, tb):
if is_asyncio_mode():
leave_context_asyncio(self)
self.pause()
else:
leave_context(self, self._active_task)
Expand Down
3 changes: 0 additions & 3 deletions asynq/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from . import async_task, futures
from .asynq_to_async import AsyncioMode, is_asyncio_mode, resolve_awaitables
from .contexts import pause_contexts_asyncio, resume_contexts_asyncio

__traceback_hide__ = True

Expand Down Expand Up @@ -114,7 +113,6 @@ async def wrapped(*_args, **_kwargs):

generator = fn(*_args, **_kwargs)
while True:
resume_contexts_asyncio(task)
try:
if exception is None:
result = generator.send(send)
Expand All @@ -125,7 +123,6 @@ async def wrapped(*_args, **_kwargs):
except StopIteration as exc:
return exc.value

pause_contexts_asyncio(task)
try:
send = await resolve_awaitables(result)
exception = None
Expand Down
12 changes: 6 additions & 6 deletions asynq/tests/test_asynq_to_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,33 +105,33 @@ async def blocking_op():
await asyncio.sleep(0.1)

@asynq.asynq()
def f1():
def f1(): # 500ms
with AsyncTimer() as timer:
time.sleep(0.1)
t1, t2 = yield f2.asynq()
time.sleep(0.1)
return timer.total_time, t1, t2

@asynq.asynq()
def f2():
def f2(): # 300ms
with AsyncTimer() as timer:
time.sleep(0.1)
t = yield f3.asynq()
time.sleep(0.1)
return timer.total_time, t

@asynq.asynq()
def f3():
def f3(): # 100ms
with AsyncTimer() as timer:
# since AsyncTimer is paused on blocking operations,
# the time for TestBatch is not measured
yield [blocking_op(), blocking_op()]
return timer.total_time

t1, t2, t3 = asyncio.run(f1.asyncio())
assert_eq(400000, t1, tolerance=10000) # 400ms, 10us tolerance
assert_eq(200000, t2, tolerance=10000) # 200ms, 10us tolerance
assert_eq(000000, t3, tolerance=10000) # 0ms, 10us tolerance
assert_eq(500000, t1, tolerance=10000) # 400ms, 10us tolerance
assert_eq(300000, t2, tolerance=10000) # 200ms, 10us tolerance
assert_eq(100000, t3, tolerance=10000) # 0ms, 10us tolerance


def test_method():
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ qcore
pygments
black==24.3.0
mypy==1.4.1
typing_extensions==4.11.0
Loading