diff --git a/asgiref/local.py b/asgiref/local.py index a8b9459b..7d228aeb 100644 --- a/asgiref/local.py +++ b/asgiref/local.py @@ -2,37 +2,38 @@ import contextlib import contextvars import threading -from typing import Any, Dict, Union +from typing import Any, Union class _CVar: """Storage utility for Local.""" def __init__(self) -> None: - self._data: "contextvars.ContextVar[Dict[str, Any]]" = contextvars.ContextVar( - "asgiref.local" - ) + self._data: dict[str, contextvars.ContextVar[Any]] = {} - def __getattr__(self, key): - storage_object = self._data.get({}) + def __getattr__(self, key: str) -> Any: try: - return storage_object[key] + var = self._data[key] except KeyError: raise AttributeError(f"{self!r} object has no attribute {key!r}") + try: + return var.get() + except LookupError: + raise AttributeError(f"{self!r} object has no attribute {key!r}") + def __setattr__(self, key: str, value: Any) -> None: if key == "_data": return super().__setattr__(key, value) - storage_object = self._data.get({}) - storage_object[key] = value - self._data.set(storage_object) + var = self._data.get(key) + if var is None: + self._data[key] = var = contextvars.ContextVar(key) + var.set(value) def __delattr__(self, key: str) -> None: - storage_object = self._data.get({}) - if key in storage_object: - del storage_object[key] - self._data.set(storage_object) + if key in self._data: + del self._data[key] else: raise AttributeError(f"{self!r} object has no attribute {key!r}") diff --git a/tests/test_local.py b/tests/test_local.py index d50cba21..cdcbd280 100644 --- a/tests/test_local.py +++ b/tests/test_local.py @@ -1,6 +1,7 @@ import asyncio import gc import threading +from threading import Thread import pytest @@ -338,3 +339,39 @@ async def async_function(): # inner value was set inside a new async context, meaning that # we do not see it, as context vars don't propagate up the stack assert not hasattr(test_local_not_tc, "test_value") + + +def test_visibility_thread_asgiref() -> None: + """Check visibility with subthreads.""" + test_local = Local() + test_local.value = 0 + + def _test() -> None: + # Local() is cleared when changing thread + assert not hasattr(test_local, "value") + setattr(test_local, "value", 1) + assert test_local.value == 1 + + thread = Thread(target=_test) + thread.start() + thread.join() + + assert test_local.value == 0 + + +@pytest.mark.asyncio +async def test_visibility_task() -> None: + """Check visibility with asyncio tasks.""" + test_local = Local() + test_local.value = 0 + + async def _test() -> None: + # Local is inherited when changing task + assert test_local.value == 0 + test_local.value = 1 + assert test_local.value == 1 + + await asyncio.create_task(_test()) + + # Changes should not leak to the caller + assert test_local.value == 0