From 820618b8a86b5d5b4260589d8ebd5146fd5adbd9 Mon Sep 17 00:00:00 2001 From: WisdomPill Date: Wed, 1 Nov 2023 21:09:35 +0200 Subject: [PATCH 1/7] Added types to DefaultClient --- django_redis/client/default.py | 157 ++++++++++++++++++++------------- django_redis/client/sharded.py | 21 ++++- django_redis/pool.py | 6 +- tests/start_redis.sh | 2 +- tests/test_backend.py | 12 +++ 5 files changed, 129 insertions(+), 69 deletions(-) diff --git a/django_redis/client/default.py b/django_redis/client/default.py index 15a5067a..05f06dd2 100644 --- a/django_redis/client/default.py +++ b/django_redis/client/default.py @@ -4,7 +4,7 @@ from collections import OrderedDict from contextlib import suppress from datetime import datetime -from typing import Any, Dict, Iterator, List, Optional, Union +from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Union, Any from django.conf import settings from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache, get_key_func @@ -12,6 +12,7 @@ from django.utils.module_loading import import_string from redis import Redis from redis.exceptions import ConnectionError, ResponseError, TimeoutError +from redis.typing import EncodableT, KeyT, AbsExpiryT, ExpiryT from django_redis import pool from django_redis.exceptions import CompressorError, ConnectionInterrupted @@ -63,7 +64,7 @@ def __init__(self, server, params: Dict[str, Any], backend: BaseCache) -> None: self.connection_factory = pool.get_connection_factory(options=self._options) - def __contains__(self, key: Any) -> bool: + def __contains__(self, key: KeyT) -> bool: return self.has_key(key) def get_next_client_index( @@ -92,8 +93,7 @@ def get_client( self, write: bool = True, tried: Optional[List[int]] = None, - show_index: bool = False, - ): + ) -> Redis: """ Method used for obtain a raw redis client. @@ -106,10 +106,26 @@ def get_client( if self._clients[index] is None: self._clients[index] = self.connect(index) - if show_index: - return self._clients[index], index + return self._clients[index] # type:ignore + + def get_client_with_index( + self, + write: bool = True, + tried: Optional[List[int]] = None, + ) -> Tuple[Redis, int]: + """ + Method used for obtain a raw redis client. + + This function is used by almost all cache backend + operations for obtain a native redis client/connection + instance. + """ + index = self.get_next_client_index(write=write, tried=tried) + + if self._clients[index] is None: + self._clients[index] = self.connect(index) - return self._clients[index] + return self._clients[index], index # type:ignore def connect(self, index: int = 0) -> Redis: """ @@ -119,16 +135,20 @@ def connect(self, index: int = 0) -> Redis: """ return self.connection_factory.connect(self._server[index]) - def disconnect(self, index=0, client=None): - """delegates the connection factory to disconnect the client""" - if not client: + def disconnect(self, index: int = 0, client: Optional[Redis] = None) -> None: + """ + delegates the connection factory to disconnect the client + """ + if client is None: client = self._clients[index] - return self.connection_factory.disconnect(client) if client else None + + if client is not None: + self.connection_factory.disconnect(client) def set( # noqa: A003 self, - key: Any, - value: Any, + key: KeyT, + value: EncodableT, timeout: Optional[float] = DEFAULT_TIMEOUT, version: Optional[int] = None, client: Optional[Redis] = None, @@ -152,9 +172,7 @@ def set( # noqa: A003 while True: try: if client is None: - client, index = self.get_client( - write=True, tried=tried, show_index=True - ) + client, index = self.get_client_with_index(write=True, tried=tried) if timeout is not None: # Convert to milliseconds @@ -186,7 +204,7 @@ def set( # noqa: A003 def incr_version( self, - key: Any, + key: KeyT, delta: int = 1, version: Optional[int] = None, client: Optional[Redis] = None, @@ -211,7 +229,8 @@ def incr_version( raise ConnectionInterrupted(connection=client) from e if value is None: - raise ValueError("Key '%s' not found" % key) + error_message = f"Key '{key!r}' not found" + raise ValueError(error_message) if isinstance(key, CacheKey): new_key = self.make_key(key.original_key(), version=version + delta) @@ -224,10 +243,10 @@ def incr_version( def add( self, - key: Any, - value: Any, - timeout: Any = DEFAULT_TIMEOUT, - version: Optional[Any] = None, + key: KeyT, + value: EncodableT, + timeout: Optional[float] = DEFAULT_TIMEOUT, + version: Optional[int] = None, client: Optional[Redis] = None, ) -> bool: """ @@ -239,8 +258,8 @@ def add( def get( self, - key: Any, - default=None, + key: KeyT, + default: Optional[Any] = None, version: Optional[int] = None, client: Optional[Redis] = None, ) -> Any: @@ -265,7 +284,7 @@ def get( return self.decode(value) def persist( - self, key: Any, version: Optional[int] = None, client: Optional[Redis] = None + self, key: KeyT, version: Optional[int] = None, client: Optional[Redis] = None ) -> bool: if client is None: client = self.get_client(write=True) @@ -276,8 +295,8 @@ def persist( def expire( self, - key: Any, - timeout, + key: KeyT, + timeout: ExpiryT, version: Optional[int] = None, client: Optional[Redis] = None, ) -> bool: @@ -288,7 +307,13 @@ def expire( return client.expire(key, timeout) - def pexpire(self, key, timeout, version=None, client=None) -> bool: + def pexpire( + self, + key: KeyT, + timeout: ExpiryT, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> bool: if client is None: client = self.get_client(write=True) @@ -300,8 +325,8 @@ def pexpire(self, key, timeout, version=None, client=None) -> bool: def pexpire_at( self, - key: Any, - when: Union[datetime, int], + key: KeyT, + when: AbsExpiryT, version: Optional[int] = None, client: Optional[Redis] = None, ) -> bool: @@ -318,8 +343,8 @@ def pexpire_at( def expire_at( self, - key: Any, - when: Union[datetime, int], + key: KeyT, + when: AbsExpiryT, version: Optional[int] = None, client: Optional[Redis] = None, ) -> bool: @@ -336,13 +361,13 @@ def expire_at( def lock( self, - key, + key: KeyT, version: Optional[int] = None, - timeout=None, - sleep=0.1, - blocking_timeout=None, + timeout: Optional[float] = None, + sleep: float = 0.1, + blocking_timeout: Optional[float] = None, client: Optional[Redis] = None, - thread_local=True, + thread_local: bool = True, ): if client is None: client = self.get_client(write=True) @@ -358,7 +383,7 @@ def lock( def delete( self, - key: Any, + key: KeyT, version: Optional[int] = None, prefix: Optional[str] = None, client: Optional[Redis] = None, @@ -405,8 +430,11 @@ def delete_pattern( raise ConnectionInterrupted(connection=client) from e def delete_many( - self, keys, version: Optional[int] = None, client: Optional[Redis] = None - ): + self, + keys: Iterable[KeyT], + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: """ Remove multiple keys at once. """ @@ -417,7 +445,7 @@ def delete_many( keys = [self.make_key(k, version=version) for k in keys] if not keys: - return None + return 0 try: return client.delete(*keys) @@ -437,7 +465,7 @@ def clear(self, client: Optional[Redis] = None) -> None: except _main_exceptions as e: raise ConnectionInterrupted(connection=client) from e - def decode(self, value: Union[bytes, int]) -> Any: + def decode(self, value: EncodableT) -> Any: """ Decode the given value. """ @@ -450,19 +478,22 @@ def decode(self, value: Union[bytes, int]) -> Any: value = self._serializer.loads(value) return value - def encode(self, value: Any) -> Union[bytes, Any]: + def encode(self, value: EncodableT) -> Union[bytes, int]: """ Encode the given value. """ - if isinstance(value, bool) or not isinstance(value, int): + if not isinstance(value, int): value = self._serializer.dumps(value) return self._compressor.compress(value) return value def get_many( - self, keys, version: Optional[int] = None, client: Optional[Redis] = None + self, + keys: Iterable[KeyT], + version: Optional[int] = None, + client: Optional[Redis] = None, ) -> OrderedDict: """ Retrieve many keys. @@ -491,7 +522,7 @@ def get_many( def set_many( self, - data: Dict[Any, Any], + data: Dict[KeyT, EncodableT], timeout: Optional[float] = DEFAULT_TIMEOUT, version: Optional[int] = None, client: Optional[Redis] = None, @@ -516,7 +547,7 @@ def set_many( def _incr( self, - key: Any, + key: KeyT, delta: int = 1, version: Optional[int] = None, client: Optional[Redis] = None, @@ -545,7 +576,7 @@ def _incr( """ value = client.eval(lua, 1, key, delta) if value is None: - error_message = f"Key '{key}' not found" + error_message = f"Key '{key!r}' not found" raise ValueError(error_message) except ResponseError as e: # if cached value or total value is greater than 64 bit signed @@ -559,7 +590,7 @@ def _incr( # returns -2 if the key does not exist # means, that key have expired if timeout == -2: - error_message = f"Key '{key}' not found" + error_message = f"Key '{key!r}' not found" raise ValueError(error_message) from e value = self.get(key, version=version, client=client) + delta self.set(key, value, version=version, timeout=timeout, client=client) @@ -570,7 +601,7 @@ def _incr( def incr( self, - key: Any, + key: KeyT, delta: int = 1, version: Optional[int] = None, client: Optional[Redis] = None, @@ -591,7 +622,7 @@ def incr( def decr( self, - key: Any, + key: KeyT, delta: int = 1, version: Optional[int] = None, client: Optional[Redis] = None, @@ -603,7 +634,7 @@ def decr( return self._incr(key=key, delta=-delta, version=version, client=client) def ttl( - self, key: Any, version: Optional[int] = None, client: Optional[Redis] = None + self, key: KeyT, version: Optional[int] = None, client: Optional[Redis] = None ) -> Optional[int]: """ Executes TTL redis command and return the "time-to-live" of specified key. @@ -628,7 +659,9 @@ def ttl( # Should never reach here return None - def pttl(self, key, version=None, client=None): + def pttl( + self, key: KeyT, version: Optional[int] = None, client: Optional[Redis] = None + ) -> Optional[int]: """ Executes PTTL redis command and return the "time-to-live" of specified key. If key is a non volatile key, it returns None. @@ -653,7 +686,7 @@ def pttl(self, key, version=None, client=None): return None def has_key( - self, key: Any, version: Optional[int] = None, client: Optional[Redis] = None + self, key: KeyT, version: Optional[int] = None, client: Optional[Redis] = None ) -> bool: """ Test if key exists. @@ -707,8 +740,8 @@ def keys( raise ConnectionInterrupted(connection=client) from e def make_key( - self, key: Any, version: Optional[Any] = None, prefix: Optional[str] = None - ) -> CacheKey: + self, key: KeyT, version: Optional[int] = None, prefix: Optional[str] = None + ) -> KeyT: if isinstance(key, CacheKey): return key @@ -722,7 +755,7 @@ def make_key( def make_pattern( self, pattern: str, version: Optional[int] = None, prefix: Optional[str] = None - ) -> CacheKey: + ) -> str: if isinstance(pattern, CacheKey): return pattern @@ -736,7 +769,7 @@ def make_pattern( return CacheKey(self._backend.key_func(pattern, prefix, version_str)) - def close(self): + def close(self) -> None: close_flag = self._options.get( "CLOSE_CONNECTION", getattr(settings, "DJANGO_REDIS_CLOSE_CONNECTION", False), @@ -744,8 +777,10 @@ def close(self): if close_flag: self.do_close_clients() - def do_close_clients(self): - """default implementation: Override in custom client""" + def do_close_clients(self) -> None: + """ + default implementation: Override in custom client + """ num_clients = len(self._clients) for idx in range(num_clients): self.disconnect(index=idx) @@ -753,7 +788,7 @@ def do_close_clients(self): def touch( self, - key: Any, + key: KeyT, timeout: Optional[float] = DEFAULT_TIMEOUT, version: Optional[int] = None, client: Optional[Redis] = None, diff --git a/django_redis/client/sharded.py b/django_redis/client/sharded.py index 7c34e6e7..7a0fb0f4 100644 --- a/django_redis/client/sharded.py +++ b/django_redis/client/sharded.py @@ -79,7 +79,14 @@ def get_many(self, keys, version=None): return recovered_data def set( # noqa: A003 - self, key, value, timeout=DEFAULT_TIMEOUT, version=None, client=None, nx=False + self, + key, + value, + timeout=DEFAULT_TIMEOUT, + version=None, + client=None, + nx=False, + xx=False, ): """ Persist a value to the cache, and set an optional expiration time. @@ -89,10 +96,16 @@ def set( # noqa: A003 client = self.get_server(key) return super().set( - key=key, value=value, timeout=timeout, version=version, client=client, nx=nx + key=key, + value=value, + timeout=timeout, + version=version, + client=client, + nx=nx, + xx=xx, ) - def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None): + def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None, client=None): """ Set a bunch of values in the cache at once from a dict of key/value pairs. This is much more efficient than calling set() multiple times. @@ -101,7 +114,7 @@ def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None): the default cache timeout will be used. """ for key, value in data.items(): - self.set(key, value, timeout, version=version) + self.set(key, value, timeout, version=version, client=client) def has_key(self, key, version=None, client=None): """ diff --git a/django_redis/pool.py b/django_redis/pool.py index 979c463c..b0e5f2a3 100644 --- a/django_redis/pool.py +++ b/django_redis/pool.py @@ -5,7 +5,7 @@ from django.core.exceptions import ImproperlyConfigured from django.utils.module_loading import import_string from redis import Redis -from redis.connection import DefaultParser, to_bool +from redis.connection import ConnectionPool, DefaultParser, to_bool from redis.sentinel import Sentinel @@ -16,7 +16,7 @@ class ConnectionFactory: # ConnectionFactory is instantiated, as Django creates new cache client # (DefaultClient) instance for every request. - _pools: Dict[str, Redis] = {} + _pools: Dict[str, ConnectionPool] = {} def __init__(self, options): pool_cls_path = options.get( @@ -70,7 +70,7 @@ def connect(self, url: str) -> Redis: params = self.make_connection_params(url) return self.get_connection(params) - def disconnect(self, connection): + def disconnect(self, connection: Redis) -> None: """ Given a not null client connection it disconnect from the Redis server. diff --git a/tests/start_redis.sh b/tests/start_redis.sh index 00bf2b03..9766fc51 100755 --- a/tests/start_redis.sh +++ b/tests/start_redis.sh @@ -47,7 +47,7 @@ docker run \ --health-interval 10s \ --health-retries 5 \ --health-timeout 5s \ - --network host \ --user $(id -u):$(id -g) \ + --publish $PORT:$PORT \ --volume /tmp:/tmp \ --detach redis:latest redis-server "${ARGS[@]}" diff --git a/tests/test_backend.py b/tests/test_backend.py index 7f4beb70..58d6b92c 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -734,6 +734,18 @@ def test_primary_replica_switching(self, cache: RedisCache): assert client.get_client(write=True) == "Foo" assert client.get_client(write=False) == "Bar" + def test_primary_replica_switching_with_index(self, cache: RedisCache): + if isinstance(cache.client, ShardClient): + pytest.skip("ShardClient doesn't support get_client") + + cache = cast(RedisCache, caches["sample"]) + client = cache.client + client._server = ["foo", "bar"] + client._clients = ["Foo", "Bar"] + + assert client.get_client_with_index(write=True) == ("Foo", 0) + assert client.get_client_with_index(write=False) == ("Bar", 1) + def test_touch_zero_timeout(self, cache: RedisCache): cache.set("test_key", 222, timeout=10) From de08bc3a2907a29bb8965b5a9d286170d8239814 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Nov 2023 19:11:10 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- django_redis/client/default.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/django_redis/client/default.py b/django_redis/client/default.py index 05f06dd2..251668fe 100644 --- a/django_redis/client/default.py +++ b/django_redis/client/default.py @@ -3,8 +3,7 @@ import socket from collections import OrderedDict from contextlib import suppress -from datetime import datetime -from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Union, Any +from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union from django.conf import settings from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache, get_key_func @@ -12,7 +11,7 @@ from django.utils.module_loading import import_string from redis import Redis from redis.exceptions import ConnectionError, ResponseError, TimeoutError -from redis.typing import EncodableT, KeyT, AbsExpiryT, ExpiryT +from redis.typing import AbsExpiryT, EncodableT, ExpiryT, KeyT from django_redis import pool from django_redis.exceptions import CompressorError, ConnectionInterrupted From e29200c6a6eacd5e5d5ff150535f4759db272b4f Mon Sep 17 00:00:00 2001 From: WisdomPill Date: Wed, 1 Nov 2023 21:28:35 +0200 Subject: [PATCH 3/7] Restored start_redis script --- tests/start_redis.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/start_redis.sh b/tests/start_redis.sh index 9766fc51..00bf2b03 100755 --- a/tests/start_redis.sh +++ b/tests/start_redis.sh @@ -47,7 +47,7 @@ docker run \ --health-interval 10s \ --health-retries 5 \ --health-timeout 5s \ + --network host \ --user $(id -u):$(id -g) \ - --publish $PORT:$PORT \ --volume /tmp:/tmp \ --detach redis:latest redis-server "${ARGS[@]}" From e96c4ceec57ebfb26a35a52f46204086231e1662 Mon Sep 17 00:00:00 2001 From: WisdomPill Date: Wed, 1 Nov 2023 21:39:02 +0200 Subject: [PATCH 4/7] Silenced strange mypy errors with pexpire and expire --- django_redis/client/default.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/django_redis/client/default.py b/django_redis/client/default.py index 251668fe..db9aac23 100644 --- a/django_redis/client/default.py +++ b/django_redis/client/default.py @@ -3,6 +3,7 @@ import socket from collections import OrderedDict from contextlib import suppress +from datetime import timedelta from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union from django.conf import settings @@ -304,7 +305,9 @@ def expire( key = self.make_key(key, version=version) - return client.expire(key, timeout) + # for some strange reason mypy complains, + # saying that timeout type is float | timedelta + return client.expire(key, timeout) # type: ignore def pexpire( self, @@ -320,7 +323,9 @@ def pexpire( # Temporary casting until https://github.com/redis/redis-py/issues/1664 # is fixed. - return bool(client.pexpire(key, timeout)) + # for some strange reason mypy complains, + # saying that timeout type is float | timedelta + return bool(client.pexpire(key, timeout)) # type: ignore def pexpire_at( self, From 13aee50506a88fb87ef8a50f22b2cb85de94a556 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Nov 2023 19:39:13 +0000 Subject: [PATCH 5/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- django_redis/client/default.py | 1 - 1 file changed, 1 deletion(-) diff --git a/django_redis/client/default.py b/django_redis/client/default.py index db9aac23..21b93a04 100644 --- a/django_redis/client/default.py +++ b/django_redis/client/default.py @@ -3,7 +3,6 @@ import socket from collections import OrderedDict from contextlib import suppress -from datetime import timedelta from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union from django.conf import settings From 2e2dc55696769d8d02300fc5649c6d8323c5ea96 Mon Sep 17 00:00:00 2001 From: WisdomPill Date: Thu, 2 Nov 2023 09:44:04 +0200 Subject: [PATCH 6/7] restore encoding logic --- django_redis/client/default.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/django_redis/client/default.py b/django_redis/client/default.py index 21b93a04..9485627f 100644 --- a/django_redis/client/default.py +++ b/django_redis/client/default.py @@ -486,7 +486,7 @@ def encode(self, value: EncodableT) -> Union[bytes, int]: Encode the given value. """ - if not isinstance(value, int): + if isinstance(value, bool) or not isinstance(value, int): value = self._serializer.dumps(value) return self._compressor.compress(value) From 28500908cb0f608e2fc36566e3324c3d0d6ecca5 Mon Sep 17 00:00:00 2001 From: WisdomPill Date: Thu, 2 Nov 2023 16:00:16 +0200 Subject: [PATCH 7/7] Added changelog --- changelog.d/696.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/696.misc diff --git a/changelog.d/696.misc b/changelog.d/696.misc new file mode 100644 index 00000000..066426d6 --- /dev/null +++ b/changelog.d/696.misc @@ -0,0 +1 @@ +Typed DefaultClient \ No newline at end of file