From 8108512196b0831962eb923b7c3f5c5935b1fde1 Mon Sep 17 00:00:00 2001 From: ttys0dev <126845556+ttys0dev@users.noreply.github.com> Date: Thu, 1 Feb 2024 18:26:35 -0700 Subject: [PATCH] Move variable initialization in AsyncToSync from __init__ to __call__ (#440) Co-authored-by: germaniuss --- asgiref/sync.py | 40 +++++++++++++++++----------------- tests/test_sync.py | 53 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 20 deletions(-) diff --git a/asgiref/sync.py b/asgiref/sync.py index 692cbc75..f263a432 100644 --- a/asgiref/sync.py +++ b/asgiref/sync.py @@ -153,30 +153,30 @@ def __init__( self.__self__ = self.awaitable.__self__ # type: ignore[union-attr] except AttributeError: pass - if force_new_loop: - # They have asked that we always run in a new sub-loop. - self.main_event_loop = None - else: - try: - self.main_event_loop = asyncio.get_running_loop() - except RuntimeError: - # There's no event loop in this thread. Look for the threadlocal if - # we're inside SyncToAsync - main_event_loop_pid = getattr( - SyncToAsync.threadlocal, "main_event_loop_pid", None - ) - # We make sure the parent loop is from the same process - if - # they've forked, this is not going to be valid any more (#194) - if main_event_loop_pid and main_event_loop_pid == os.getpid(): - self.main_event_loop = getattr( - SyncToAsync.threadlocal, "main_event_loop", None - ) - else: - self.main_event_loop = None + self.force_new_loop = force_new_loop + self.main_event_loop = None + try: + self.main_event_loop = asyncio.get_running_loop() + except RuntimeError: + # There's no event loop in this thread. + pass def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: __traceback_hide__ = True # noqa: F841 + if not self.force_new_loop and not self.main_event_loop: + # There's no event loop in this thread. Look for the threadlocal if + # we're inside SyncToAsync + main_event_loop_pid = getattr( + SyncToAsync.threadlocal, "main_event_loop_pid", None + ) + # We make sure the parent loop is from the same process - if + # they've forked, this is not going to be valid any more (#194) + if main_event_loop_pid and main_event_loop_pid == os.getpid(): + self.main_event_loop = getattr( + SyncToAsync.threadlocal, "main_event_loop", None + ) + # You can't call AsyncToSync from a thread with a running event loop try: event_loop = asyncio.get_running_loop() diff --git a/tests/test_sync.py b/tests/test_sync.py index 3e83c91b..a4d2413b 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -1,6 +1,7 @@ import asyncio import functools import multiprocessing +import sys import threading import time import warnings @@ -223,6 +224,58 @@ def sync_function(): assert result["thread"] == threading.current_thread() +@pytest.mark.asyncio +async def test_async_to_sync_to_async_decorator(): + """ + Test async_to_sync as a function decorator uses the outer thread + when used inside sync_to_async. + """ + result = {} + + # Define async function + @async_to_sync + async def inner_async_function(): + result["worked"] = True + result["thread"] = threading.current_thread() + return 42 + + # Define sync function + @sync_to_async + def sync_function(): + return inner_async_function() + + # Check it works right + number = await sync_function() + assert number == 42 + assert result["worked"] + # Make sure that it didn't needlessly make a new async loop + assert result["thread"] == threading.current_thread() + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9") +async def test_async_to_sync_to_thread_decorator(): + """ + Test async_to_sync as a function decorator uses the outer thread + when used inside another sync thread. + """ + result = {} + + # Define async function + @async_to_sync + async def inner_async_function(): + result["worked"] = True + result["thread"] = threading.current_thread() + return 42 + + # Check it works right + number = await asyncio.to_thread(inner_async_function) + assert number == 42 + assert result["worked"] + # Make sure that it didn't needlessly make a new async loop + assert result["thread"] == threading.current_thread() + + def test_async_to_sync_fail_non_function(): """ async_to_sync raises a TypeError when applied to a non-function.