-
Notifications
You must be signed in to change notification settings - Fork 197
FEAT: type hint #417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
FEAT: type hint #417
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,13 @@ | |
decode_hosts, | ||
) | ||
|
||
import typing | ||
|
||
if typing.TYPE_CHECKING: | ||
from redis.asyncio.connection import ConnectionPool | ||
from redis.asyncio.client import Redis | ||
from typing_extensions import Buffer | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
|
@@ -32,23 +39,27 @@ class ChannelLock: | |
""" | ||
|
||
def __init__(self): | ||
self.locks = collections.defaultdict(asyncio.Lock) | ||
self.wait_counts = collections.defaultdict(int) | ||
self.locks: collections.defaultdict[str, asyncio.Lock] = ( | ||
collections.defaultdict(asyncio.Lock) | ||
) | ||
self.wait_counts: collections.defaultdict[str, int] = collections.defaultdict( | ||
int | ||
) | ||
|
||
async def acquire(self, channel): | ||
async def acquire(self, channel: str) -> bool: | ||
""" | ||
Acquire the lock for the given channel. | ||
""" | ||
self.wait_counts[channel] += 1 | ||
return await self.locks[channel].acquire() | ||
|
||
def locked(self, channel): | ||
def locked(self, channel: str) -> bool: | ||
""" | ||
Return ``True`` if the lock for the given channel is acquired. | ||
""" | ||
return self.locks[channel].locked() | ||
|
||
def release(self, channel): | ||
def release(self, channel: str): | ||
""" | ||
Release the lock for the given channel. | ||
""" | ||
|
@@ -73,12 +84,12 @@ def put_nowait(self, item): | |
|
||
|
||
class RedisLoopLayer: | ||
def __init__(self, channel_layer): | ||
def __init__(self, channel_layer: "RedisChannelLayer"): | ||
self._lock = asyncio.Lock() | ||
self.channel_layer = channel_layer | ||
self._connections = {} | ||
self._connections: typing.Dict[int, "Redis"] = {} | ||
|
||
def get_connection(self, index): | ||
def get_connection(self, index: int) -> "Redis": | ||
if index not in self._connections: | ||
pool = self.channel_layer.create_pool(index) | ||
self._connections[index] = aioredis.Redis(connection_pool=pool) | ||
|
@@ -134,7 +145,7 @@ def __init__( | |
symmetric_encryption_keys=symmetric_encryption_keys, | ||
) | ||
# Cached redis connection pools and the event loop they are from | ||
self._layers = {} | ||
self._layers: typing.Dict[asyncio.AbstractEventLoop, "RedisLoopLayer"] = {} | ||
# Normal channels choose a host index by cycling through the available hosts | ||
self._receive_index_generator = itertools.cycle(range(len(self.hosts))) | ||
self._send_index_generator = itertools.cycle(range(len(self.hosts))) | ||
|
@@ -143,27 +154,27 @@ def __init__( | |
# Number of coroutines trying to receive right now | ||
self.receive_count = 0 | ||
# The receive lock | ||
self.receive_lock = None | ||
self.receive_lock: typing.Optional[asyncio.Lock] = None | ||
# Event loop they are trying to receive on | ||
self.receive_event_loop = None | ||
self.receive_event_loop: typing.Optional[asyncio.AbstractEventLoop] = None | ||
# Buffered messages by process-local channel name | ||
self.receive_buffer = collections.defaultdict( | ||
functools.partial(BoundedQueue, self.capacity) | ||
self.receive_buffer: collections.defaultdict[str, BoundedQueue] = ( | ||
collections.defaultdict(functools.partial(BoundedQueue, self.capacity)) | ||
) | ||
# Detached channel cleanup tasks | ||
self.receive_cleaners = [] | ||
self.receive_cleaners: typing.List[asyncio.Task] = [] | ||
# Per-channel cleanup locks to prevent a receive starting and moving | ||
# a message back into the main queue before its cleanup has completed | ||
self.receive_clean_locks = ChannelLock() | ||
|
||
def create_pool(self, index): | ||
def create_pool(self, index: int) -> "ConnectionPool": | ||
return create_pool(self.hosts[index]) | ||
|
||
### Channel layer API ### | ||
|
||
extensions = ["groups", "flush"] | ||
|
||
async def send(self, channel, message): | ||
async def send(self, channel: str, message): | ||
""" | ||
Send a message onto a (general or specific) channel. | ||
""" | ||
|
@@ -203,13 +214,15 @@ async def send(self, channel, message): | |
await connection.zadd(channel_key, {self.serialize(message): time.time()}) | ||
await connection.expire(channel_key, int(self.expiry)) | ||
|
||
def _backup_channel_name(self, channel): | ||
def _backup_channel_name(self, channel: str) -> str: | ||
""" | ||
Construct the key used as a backup queue for the given channel. | ||
""" | ||
return channel + "$inflight" | ||
|
||
async def _brpop_with_clean(self, index, channel, timeout): | ||
async def _brpop_with_clean( | ||
self, index: int, channel: str, timeout: typing.Union[int, float, bytes, str] | ||
): | ||
""" | ||
Perform a Redis BRPOP and manage the backup processing queue. | ||
In case of cancellation, make sure the message is not lost. | ||
|
@@ -240,15 +253,15 @@ async def _brpop_with_clean(self, index, channel, timeout): | |
|
||
return member | ||
|
||
async def _clean_receive_backup(self, index, channel): | ||
async def _clean_receive_backup(self, index: int, channel: str): | ||
""" | ||
Pop the oldest message off the channel backup queue. | ||
The result isn't interesting as it was already processed. | ||
""" | ||
connection = self.connection(index) | ||
await connection.zpopmin(self._backup_channel_name(channel)) | ||
|
||
async def receive(self, channel): | ||
async def receive(self, channel: str): | ||
""" | ||
Receive the first message that arrives on the channel. | ||
If more than one coroutine waits on the same channel, the first waiter | ||
|
@@ -372,7 +385,7 @@ async def receive(self, channel): | |
# Do a plain direct receive | ||
return (await self.receive_single(channel))[1] | ||
|
||
async def receive_single(self, channel): | ||
async def receive_single(self, channel: str) -> typing.Tuple: | ||
""" | ||
Receives a single message off of the channel and returns it. | ||
""" | ||
|
@@ -408,7 +421,7 @@ async def receive_single(self, channel): | |
) | ||
self.receive_cleaners.append(cleaner) | ||
|
||
def _cleanup_done(cleaner): | ||
def _cleanup_done(cleaner: asyncio.Task): | ||
self.receive_cleaners.remove(cleaner) | ||
self.receive_clean_locks.release(channel_key) | ||
|
||
|
@@ -427,7 +440,7 @@ def _cleanup_done(cleaner): | |
del message["__asgi_channel__"] | ||
return channel, message | ||
|
||
async def new_channel(self, prefix="specific"): | ||
async def new_channel(self, prefix: str = "specific") -> str: | ||
""" | ||
Returns a new channel name that can be used by something in our | ||
process as a specific channel. | ||
|
@@ -477,13 +490,13 @@ async def wait_received(self): | |
|
||
### Groups extension ### | ||
|
||
async def group_add(self, group, channel): | ||
async def group_add(self, group: str, channel: str): | ||
""" | ||
Adds the channel name to a group. | ||
""" | ||
# Check the inputs | ||
assert self.valid_group_name(group), "Group name not valid" | ||
assert self.valid_channel_name(channel), "Channel name not valid" | ||
assert self.valid_group_name(group), True | ||
assert self.valid_channel_name(channel), True | ||
# Get a connection to the right shard | ||
group_key = self._group_key(group) | ||
connection = self.connection(self.consistent_hash(group)) | ||
|
@@ -493,7 +506,7 @@ async def group_add(self, group, channel): | |
# it at this point is guaranteed to expire before that | ||
await connection.expire(group_key, self.group_expiry) | ||
|
||
async def group_discard(self, group, channel): | ||
async def group_discard(self, group: str, channel: str): | ||
""" | ||
Removes the channel from the named group if it is in the group; | ||
does nothing otherwise (does not error) | ||
|
@@ -504,7 +517,7 @@ async def group_discard(self, group, channel): | |
connection = self.connection(self.consistent_hash(group)) | ||
await connection.zrem(key, channel) | ||
|
||
async def group_send(self, group, message): | ||
async def group_send(self, group: str, message): | ||
""" | ||
Sends a message to the entire group. | ||
""" | ||
|
@@ -573,7 +586,12 @@ async def group_send(self, group, message): | |
channels_over_capacity = await connection.eval( | ||
group_send_lua, len(channel_redis_keys), *channel_redis_keys, *args | ||
) | ||
if channels_over_capacity > 0: | ||
_channels_over_capacity = -1 | ||
try: | ||
_channels_over_capacity = float(channels_over_capacity) | ||
except Exception: | ||
pass | ||
if _channels_over_capacity > 0: | ||
Comment on lines
+589
to
+594
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. am not sure what's happening here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. connection.eval is inferred to return await[str] | str. |
||
logger.info( | ||
"%s of %s channels over capacity in group %s", | ||
channels_over_capacity, | ||
|
@@ -631,37 +649,35 @@ def _map_channel_keys_to_connection(self, channel_names, message): | |
channel_key_to_capacity, | ||
) | ||
|
||
def _group_key(self, group): | ||
def _group_key(self, group: str) -> bytes: | ||
""" | ||
Common function to make the storage key for the group. | ||
""" | ||
return f"{self.prefix}:group:{group}".encode("utf8") | ||
|
||
### Serialization ### | ||
|
||
def serialize(self, message): | ||
def serialize(self, message) -> bytes: | ||
""" | ||
Serializes message to a byte string. | ||
""" | ||
return self._serializer.serialize(message) | ||
|
||
def deserialize(self, message): | ||
def deserialize(self, message: bytes): | ||
""" | ||
Deserializes from a byte string. | ||
""" | ||
return self._serializer.deserialize(message) | ||
|
||
### Internal functions ### | ||
|
||
def consistent_hash(self, value): | ||
def consistent_hash(self, value: typing.Union[str, "Buffer"]) -> int: | ||
return _consistent_hash(value, self.ring_size) | ||
|
||
def __str__(self): | ||
return f"{self.__class__.__name__}(hosts={self.hosts})" | ||
|
||
### Connection handling ### | ||
|
||
def connection(self, index): | ||
def connection(self, index: int) -> "Redis": | ||
""" | ||
Returns the correct connection for the index given. | ||
Lazily instantiates pools. | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@amirreza8002
Shouldn't I also use Redis and ConnectionPool?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
those are just classes
what i mean is, if you see something like
Redis().some_method()
, there is a change the type hint onsome_method
is wrong, and we shouldn't just put the same thing in this codebase