diff --git a/changelog.d/696.misc b/changelog.d/696.misc new file mode 100644 index 00000000..066426d6 --- /dev/null +++ b/changelog.d/696.misc @@ -0,0 +1 @@ +Typed DefaultClient \ No newline at end of file diff --git a/django_redis/client/default.py b/django_redis/client/default.py index 15a5067a..9485627f 100644 --- a/django_redis/client/default.py +++ b/django_redis/client/default.py @@ -3,8 +3,7 @@ import socket from collections import OrderedDict from contextlib import suppress -from datetime import datetime -from typing import Any, Dict, Iterator, List, Optional, Union +from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union from django.conf import settings from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache, get_key_func @@ -12,6 +11,7 @@ from django.utils.module_loading import import_string from redis import Redis from redis.exceptions import ConnectionError, ResponseError, TimeoutError +from redis.typing import AbsExpiryT, EncodableT, ExpiryT, KeyT from django_redis import pool from django_redis.exceptions import CompressorError, ConnectionInterrupted @@ -63,7 +63,7 @@ def __init__(self, server, params: Dict[str, Any], backend: BaseCache) -> None: self.connection_factory = pool.get_connection_factory(options=self._options) - def __contains__(self, key: Any) -> bool: + def __contains__(self, key: KeyT) -> bool: return self.has_key(key) def get_next_client_index( @@ -92,8 +92,7 @@ def get_client( self, write: bool = True, tried: Optional[List[int]] = None, - show_index: bool = False, - ): + ) -> Redis: """ Method used for obtain a raw redis client. @@ -106,10 +105,26 @@ def get_client( if self._clients[index] is None: self._clients[index] = self.connect(index) - if show_index: - return self._clients[index], index + return self._clients[index] # type:ignore + + def get_client_with_index( + self, + write: bool = True, + tried: Optional[List[int]] = None, + ) -> Tuple[Redis, int]: + """ + Method used for obtain a raw redis client. + + This function is used by almost all cache backend + operations for obtain a native redis client/connection + instance. + """ + index = self.get_next_client_index(write=write, tried=tried) + + if self._clients[index] is None: + self._clients[index] = self.connect(index) - return self._clients[index] + return self._clients[index], index # type:ignore def connect(self, index: int = 0) -> Redis: """ @@ -119,16 +134,20 @@ def connect(self, index: int = 0) -> Redis: """ return self.connection_factory.connect(self._server[index]) - def disconnect(self, index=0, client=None): - """delegates the connection factory to disconnect the client""" - if not client: + def disconnect(self, index: int = 0, client: Optional[Redis] = None) -> None: + """ + delegates the connection factory to disconnect the client + """ + if client is None: client = self._clients[index] - return self.connection_factory.disconnect(client) if client else None + + if client is not None: + self.connection_factory.disconnect(client) def set( # noqa: A003 self, - key: Any, - value: Any, + key: KeyT, + value: EncodableT, timeout: Optional[float] = DEFAULT_TIMEOUT, version: Optional[int] = None, client: Optional[Redis] = None, @@ -152,9 +171,7 @@ def set( # noqa: A003 while True: try: if client is None: - client, index = self.get_client( - write=True, tried=tried, show_index=True - ) + client, index = self.get_client_with_index(write=True, tried=tried) if timeout is not None: # Convert to milliseconds @@ -186,7 +203,7 @@ def set( # noqa: A003 def incr_version( self, - key: Any, + key: KeyT, delta: int = 1, version: Optional[int] = None, client: Optional[Redis] = None, @@ -211,7 +228,8 @@ def incr_version( raise ConnectionInterrupted(connection=client) from e if value is None: - raise ValueError("Key '%s' not found" % key) + error_message = f"Key '{key!r}' not found" + raise ValueError(error_message) if isinstance(key, CacheKey): new_key = self.make_key(key.original_key(), version=version + delta) @@ -224,10 +242,10 @@ def incr_version( def add( self, - key: Any, - value: Any, - timeout: Any = DEFAULT_TIMEOUT, - version: Optional[Any] = None, + key: KeyT, + value: EncodableT, + timeout: Optional[float] = DEFAULT_TIMEOUT, + version: Optional[int] = None, client: Optional[Redis] = None, ) -> bool: """ @@ -239,8 +257,8 @@ def add( def get( self, - key: Any, - default=None, + key: KeyT, + default: Optional[Any] = None, version: Optional[int] = None, client: Optional[Redis] = None, ) -> Any: @@ -265,7 +283,7 @@ def get( return self.decode(value) def persist( - self, key: Any, version: Optional[int] = None, client: Optional[Redis] = None + self, key: KeyT, version: Optional[int] = None, client: Optional[Redis] = None ) -> bool: if client is None: client = self.get_client(write=True) @@ -276,8 +294,8 @@ def persist( def expire( self, - key: Any, - timeout, + key: KeyT, + timeout: ExpiryT, version: Optional[int] = None, client: Optional[Redis] = None, ) -> bool: @@ -286,9 +304,17 @@ def expire( key = self.make_key(key, version=version) - return client.expire(key, timeout) + # for some strange reason mypy complains, + # saying that timeout type is float | timedelta + return client.expire(key, timeout) # type: ignore - def pexpire(self, key, timeout, version=None, client=None) -> bool: + def pexpire( + self, + key: KeyT, + timeout: ExpiryT, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> bool: if client is None: client = self.get_client(write=True) @@ -296,12 +322,14 @@ def pexpire(self, key, timeout, version=None, client=None) -> bool: # Temporary casting until https://github.com/redis/redis-py/issues/1664 # is fixed. - return bool(client.pexpire(key, timeout)) + # for some strange reason mypy complains, + # saying that timeout type is float | timedelta + return bool(client.pexpire(key, timeout)) # type: ignore def pexpire_at( self, - key: Any, - when: Union[datetime, int], + key: KeyT, + when: AbsExpiryT, version: Optional[int] = None, client: Optional[Redis] = None, ) -> bool: @@ -318,8 +346,8 @@ def pexpire_at( def expire_at( self, - key: Any, - when: Union[datetime, int], + key: KeyT, + when: AbsExpiryT, version: Optional[int] = None, client: Optional[Redis] = None, ) -> bool: @@ -336,13 +364,13 @@ def expire_at( def lock( self, - key, + key: KeyT, version: Optional[int] = None, - timeout=None, - sleep=0.1, - blocking_timeout=None, + timeout: Optional[float] = None, + sleep: float = 0.1, + blocking_timeout: Optional[float] = None, client: Optional[Redis] = None, - thread_local=True, + thread_local: bool = True, ): if client is None: client = self.get_client(write=True) @@ -358,7 +386,7 @@ def lock( def delete( self, - key: Any, + key: KeyT, version: Optional[int] = None, prefix: Optional[str] = None, client: Optional[Redis] = None, @@ -405,8 +433,11 @@ def delete_pattern( raise ConnectionInterrupted(connection=client) from e def delete_many( - self, keys, version: Optional[int] = None, client: Optional[Redis] = None - ): + self, + keys: Iterable[KeyT], + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: """ Remove multiple keys at once. """ @@ -417,7 +448,7 @@ def delete_many( keys = [self.make_key(k, version=version) for k in keys] if not keys: - return None + return 0 try: return client.delete(*keys) @@ -437,7 +468,7 @@ def clear(self, client: Optional[Redis] = None) -> None: except _main_exceptions as e: raise ConnectionInterrupted(connection=client) from e - def decode(self, value: Union[bytes, int]) -> Any: + def decode(self, value: EncodableT) -> Any: """ Decode the given value. """ @@ -450,7 +481,7 @@ def decode(self, value: Union[bytes, int]) -> Any: value = self._serializer.loads(value) return value - def encode(self, value: Any) -> Union[bytes, Any]: + def encode(self, value: EncodableT) -> Union[bytes, int]: """ Encode the given value. """ @@ -462,7 +493,10 @@ def encode(self, value: Any) -> Union[bytes, Any]: return value def get_many( - self, keys, version: Optional[int] = None, client: Optional[Redis] = None + self, + keys: Iterable[KeyT], + version: Optional[int] = None, + client: Optional[Redis] = None, ) -> OrderedDict: """ Retrieve many keys. @@ -491,7 +525,7 @@ def get_many( def set_many( self, - data: Dict[Any, Any], + data: Dict[KeyT, EncodableT], timeout: Optional[float] = DEFAULT_TIMEOUT, version: Optional[int] = None, client: Optional[Redis] = None, @@ -516,7 +550,7 @@ def set_many( def _incr( self, - key: Any, + key: KeyT, delta: int = 1, version: Optional[int] = None, client: Optional[Redis] = None, @@ -545,7 +579,7 @@ def _incr( """ value = client.eval(lua, 1, key, delta) if value is None: - error_message = f"Key '{key}' not found" + error_message = f"Key '{key!r}' not found" raise ValueError(error_message) except ResponseError as e: # if cached value or total value is greater than 64 bit signed @@ -559,7 +593,7 @@ def _incr( # returns -2 if the key does not exist # means, that key have expired if timeout == -2: - error_message = f"Key '{key}' not found" + error_message = f"Key '{key!r}' not found" raise ValueError(error_message) from e value = self.get(key, version=version, client=client) + delta self.set(key, value, version=version, timeout=timeout, client=client) @@ -570,7 +604,7 @@ def _incr( def incr( self, - key: Any, + key: KeyT, delta: int = 1, version: Optional[int] = None, client: Optional[Redis] = None, @@ -591,7 +625,7 @@ def incr( def decr( self, - key: Any, + key: KeyT, delta: int = 1, version: Optional[int] = None, client: Optional[Redis] = None, @@ -603,7 +637,7 @@ def decr( return self._incr(key=key, delta=-delta, version=version, client=client) def ttl( - self, key: Any, version: Optional[int] = None, client: Optional[Redis] = None + self, key: KeyT, version: Optional[int] = None, client: Optional[Redis] = None ) -> Optional[int]: """ Executes TTL redis command and return the "time-to-live" of specified key. @@ -628,7 +662,9 @@ def ttl( # Should never reach here return None - def pttl(self, key, version=None, client=None): + def pttl( + self, key: KeyT, version: Optional[int] = None, client: Optional[Redis] = None + ) -> Optional[int]: """ Executes PTTL redis command and return the "time-to-live" of specified key. If key is a non volatile key, it returns None. @@ -653,7 +689,7 @@ def pttl(self, key, version=None, client=None): return None def has_key( - self, key: Any, version: Optional[int] = None, client: Optional[Redis] = None + self, key: KeyT, version: Optional[int] = None, client: Optional[Redis] = None ) -> bool: """ Test if key exists. @@ -707,8 +743,8 @@ def keys( raise ConnectionInterrupted(connection=client) from e def make_key( - self, key: Any, version: Optional[Any] = None, prefix: Optional[str] = None - ) -> CacheKey: + self, key: KeyT, version: Optional[int] = None, prefix: Optional[str] = None + ) -> KeyT: if isinstance(key, CacheKey): return key @@ -722,7 +758,7 @@ def make_key( def make_pattern( self, pattern: str, version: Optional[int] = None, prefix: Optional[str] = None - ) -> CacheKey: + ) -> str: if isinstance(pattern, CacheKey): return pattern @@ -736,7 +772,7 @@ def make_pattern( return CacheKey(self._backend.key_func(pattern, prefix, version_str)) - def close(self): + def close(self) -> None: close_flag = self._options.get( "CLOSE_CONNECTION", getattr(settings, "DJANGO_REDIS_CLOSE_CONNECTION", False), @@ -744,8 +780,10 @@ def close(self): if close_flag: self.do_close_clients() - def do_close_clients(self): - """default implementation: Override in custom client""" + def do_close_clients(self) -> None: + """ + default implementation: Override in custom client + """ num_clients = len(self._clients) for idx in range(num_clients): self.disconnect(index=idx) @@ -753,7 +791,7 @@ def do_close_clients(self): def touch( self, - key: Any, + key: KeyT, timeout: Optional[float] = DEFAULT_TIMEOUT, version: Optional[int] = None, client: Optional[Redis] = None, diff --git a/django_redis/client/sharded.py b/django_redis/client/sharded.py index 7c34e6e7..7a0fb0f4 100644 --- a/django_redis/client/sharded.py +++ b/django_redis/client/sharded.py @@ -79,7 +79,14 @@ def get_many(self, keys, version=None): return recovered_data def set( # noqa: A003 - self, key, value, timeout=DEFAULT_TIMEOUT, version=None, client=None, nx=False + self, + key, + value, + timeout=DEFAULT_TIMEOUT, + version=None, + client=None, + nx=False, + xx=False, ): """ Persist a value to the cache, and set an optional expiration time. @@ -89,10 +96,16 @@ def set( # noqa: A003 client = self.get_server(key) return super().set( - key=key, value=value, timeout=timeout, version=version, client=client, nx=nx + key=key, + value=value, + timeout=timeout, + version=version, + client=client, + nx=nx, + xx=xx, ) - def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None): + def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None, client=None): """ Set a bunch of values in the cache at once from a dict of key/value pairs. This is much more efficient than calling set() multiple times. @@ -101,7 +114,7 @@ def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None): the default cache timeout will be used. """ for key, value in data.items(): - self.set(key, value, timeout, version=version) + self.set(key, value, timeout, version=version, client=client) def has_key(self, key, version=None, client=None): """ diff --git a/django_redis/pool.py b/django_redis/pool.py index 979c463c..b0e5f2a3 100644 --- a/django_redis/pool.py +++ b/django_redis/pool.py @@ -5,7 +5,7 @@ from django.core.exceptions import ImproperlyConfigured from django.utils.module_loading import import_string from redis import Redis -from redis.connection import DefaultParser, to_bool +from redis.connection import ConnectionPool, DefaultParser, to_bool from redis.sentinel import Sentinel @@ -16,7 +16,7 @@ class ConnectionFactory: # ConnectionFactory is instantiated, as Django creates new cache client # (DefaultClient) instance for every request. - _pools: Dict[str, Redis] = {} + _pools: Dict[str, ConnectionPool] = {} def __init__(self, options): pool_cls_path = options.get( @@ -70,7 +70,7 @@ def connect(self, url: str) -> Redis: params = self.make_connection_params(url) return self.get_connection(params) - def disconnect(self, connection): + def disconnect(self, connection: Redis) -> None: """ Given a not null client connection it disconnect from the Redis server. diff --git a/tests/test_backend.py b/tests/test_backend.py index 7f4beb70..58d6b92c 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -734,6 +734,18 @@ def test_primary_replica_switching(self, cache: RedisCache): assert client.get_client(write=True) == "Foo" assert client.get_client(write=False) == "Bar" + def test_primary_replica_switching_with_index(self, cache: RedisCache): + if isinstance(cache.client, ShardClient): + pytest.skip("ShardClient doesn't support get_client") + + cache = cast(RedisCache, caches["sample"]) + client = cache.client + client._server = ["foo", "bar"] + client._clients = ["Foo", "Bar"] + + assert client.get_client_with_index(write=True) == ("Foo", 0) + assert client.get_client_with_index(write=False) == ("Bar", 1) + def test_touch_zero_timeout(self, cache: RedisCache): cache.set("test_key", 222, timeout=10)