diff --git a/aiomisc/thread_pool.py b/aiomisc/thread_pool.py index c5cd4f30..dcb1a1a0 100644 --- a/aiomisc/thread_pool.py +++ b/aiomisc/thread_pool.py @@ -15,8 +15,9 @@ from types import MappingProxyType from typing import ( Any, Awaitable, Callable, Coroutine, Dict, FrozenSet, Generator, Generic, - Optional, Set, Tuple, TypeVar, Union, overload, + Optional, Set, Tuple, TypeVar, Union, overload, MutableMapping, ) +from weakref import WeakKeyDictionary from ._context_vars import EVENT_LOOP from .compat import Concatenate, ParamSpec @@ -379,6 +380,8 @@ class Threaded(ThreadedBase[P, T]): func_type: type def __init__(self, func: Callable[P, T]) -> None: + self.__cache: MutableMapping[Any, Any] = WeakKeyDictionary() + if isinstance(func, staticmethod): self.func_type = staticmethod self.func = func.__func__ @@ -415,14 +418,22 @@ def __get__( instance: Any, owner: Optional[type] = None, ) -> "Threaded[P, T] | BoundThreaded[Any, T]": + key = instance + result: Any + if key in self.__cache: + return self.__cache[key] if self.func_type is staticmethod: - return self + result = self elif self.func_type is classmethod: cls = owner if instance is None else type(instance) - return BoundThreaded(self.func, cls) + result = BoundThreaded(self.func, cls) elif instance is not None: - return BoundThreaded(self.func, instance) - return self + result = BoundThreaded(self.func, instance) + else: + result = self + + self.__cache[key] = result + return result class BoundThreaded(ThreadedBase[P, T]): @@ -570,6 +581,7 @@ def __init__( self.func = actual_func self.max_size = max_size + self.__cache: MutableMapping[Any, Any] = WeakKeyDictionary() @overload def __get__( @@ -592,14 +604,23 @@ def __get__( instance: Any, owner: Optional[type] = None, ) -> "ThreadedIterable[P, T] | BoundThreadedIterable[Any, T]": + key = instance + result: Any + if key in self.__cache: + return self.__cache[key] + if self.func_type is staticmethod: - return self + result = self elif self.func_type is classmethod: cls = owner if instance is None else type(instance) - return BoundThreadedIterable(self.func, cls, self.max_size) + result = BoundThreadedIterable(self.func, cls, self.max_size) elif instance is not None: - return BoundThreadedIterable(self.func, instance, self.max_size) - return self + result = BoundThreadedIterable(self.func, instance, self.max_size) + else: + result = self + + self.__cache[key] = result + return result class BoundThreadedIterable(ThreadedIterableBase[P, T]): diff --git a/tests/test_thread_pool.py b/tests/test_thread_pool.py index 31d5b06f..751fe54b 100644 --- a/tests/test_thread_pool.py +++ b/tests/test_thread_pool.py @@ -584,6 +584,7 @@ def foo(self): return 42 instance = TestClass() + assert instance.foo is instance.foo assert instance.foo.sync_call() == 42 assert await instance.foo() == 42 assert await instance.foo.async_call() == 42 @@ -597,6 +598,7 @@ def foo(): return 42 instance = TestClass() + assert instance.foo is instance.foo assert instance.foo.sync_call() == 42 assert await instance.foo() == 42 assert await instance.foo.async_call() == 42 @@ -610,6 +612,7 @@ def foo(cls): return 42 instance = TestClass() + assert instance.foo is instance.foo assert instance.foo.sync_call() == 42 assert await instance.foo() == 42 assert await instance.foo.async_call() == 42 @@ -620,6 +623,7 @@ async def test_threaded_iterator_class_func(): def foo(): yield 42 + assert foo is foo assert list(foo.sync_call()) == [42] assert [x async for x in foo()] == [42] assert [x async for x in foo.async_call()] == [42] @@ -632,6 +636,7 @@ def foo(self): yield 42 instance = TestClass() + assert instance.foo is instance.foo assert list(instance.foo.sync_call()) == [42] assert [x async for x in instance.foo()] == [42] assert [x async for x in instance.foo.async_call()] == [42] @@ -645,6 +650,7 @@ def foo(): yield 42 instance = TestClass() + assert instance.foo is instance.foo assert list(instance.foo.sync_call()) == [42] assert [x async for x in instance.foo()] == [42] assert [x async for x in instance.foo.async_call()] == [42] @@ -658,6 +664,7 @@ def foo(cls): yield 42 instance = TestClass() + assert instance.foo is instance.foo assert list(instance.foo.sync_call()) == [42] assert [x async for x in instance.foo()] == [42] assert [x async for x in instance.foo.async_call()] == [42]