|
9 | 9 | import sys |
10 | 10 | import threading |
11 | 11 | import warnings |
| 12 | +from contextvars import ContextVar |
12 | 13 | from pathlib import Path |
13 | 14 | from types import FrameType |
14 | 15 | from typing import Any, Awaitable, Callable, TypeVar, cast |
@@ -126,6 +127,7 @@ def run(self, coro: Any) -> Any: |
126 | 127 |
|
127 | 128 |
|
128 | 129 | _runner_map: dict[str, _TaskRunner] = {} |
| 130 | +_loop: ContextVar[asyncio.AbstractEventLoop | None] = ContextVar("_loop", default=None) |
129 | 131 |
|
130 | 132 |
|
131 | 133 | def run_sync(coro: Callable[..., Awaitable[T]]) -> Callable[..., T]: |
@@ -159,22 +161,30 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: |
159 | 161 | pass |
160 | 162 |
|
161 | 163 | # Run the loop for this thread. |
162 | | - # In Python 3.12, a deprecation warning is raised, which |
163 | | - # may later turn into a RuntimeError. We handle both |
164 | | - # cases. |
165 | | - with warnings.catch_warnings(): |
166 | | - warnings.simplefilter("ignore", DeprecationWarning) |
167 | | - try: |
168 | | - loop = asyncio.get_event_loop() |
169 | | - except RuntimeError: |
170 | | - loop = asyncio.new_event_loop() |
171 | | - asyncio.set_event_loop(loop) |
172 | | - return loop.run_until_complete(inner) |
| 164 | + loop = ensure_event_loop() |
| 165 | + return loop.run_until_complete(inner) |
173 | 166 |
|
174 | 167 | wrapped.__doc__ = coro.__doc__ |
175 | 168 | return wrapped |
176 | 169 |
|
177 | 170 |
|
| 171 | +def ensure_event_loop(prefer_selector_loop: bool = False) -> asyncio.AbstractEventLoop: |
| 172 | + # Get the loop for this thread, or create a new one. |
| 173 | + loop = _loop.get() |
| 174 | + if loop is not None and not loop.is_closed(): |
| 175 | + return loop |
| 176 | + try: |
| 177 | + loop = asyncio.get_running_loop() |
| 178 | + except RuntimeError: |
| 179 | + if sys.platform == "win32" and prefer_selector_loop: |
| 180 | + loop = asyncio.WindowsSelectorEventLoopPolicy().new_event_loop() |
| 181 | + else: |
| 182 | + loop = asyncio.new_event_loop() |
| 183 | + asyncio.set_event_loop(loop) |
| 184 | + _loop.set(loop) |
| 185 | + return loop |
| 186 | + |
| 187 | + |
178 | 188 | async def ensure_async(obj: Awaitable[T] | T) -> T: |
179 | 189 | """Convert a non-awaitable object to a coroutine if needed, |
180 | 190 | and await it if it was not already awaited. |
|
0 commit comments