Skip to content

FEAT: type hint #417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 52 additions & 36 deletions channels_redis/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
decode_hosts,
)

import typing

if typing.TYPE_CHECKING:
from redis.asyncio.connection import ConnectionPool
from redis.asyncio.client import Redis
Comment on lines +26 to +27
Copy link
Author

@iwakitakuma33 iwakitakuma33 May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amirreza8002
Shouldn't I also use Redis and ConnectionPool?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those are just classes

what i mean is, if you see something like Redis().some_method(), there is a change the type hint on some_method is wrong, and we shouldn't just put the same thing in this codebase

from typing_extensions import Buffer

logger = logging.getLogger(__name__)


Expand All @@ -32,23 +39,27 @@ class ChannelLock:
"""

def __init__(self):
self.locks = collections.defaultdict(asyncio.Lock)
self.wait_counts = collections.defaultdict(int)
self.locks: collections.defaultdict[str, asyncio.Lock] = (
collections.defaultdict(asyncio.Lock)
)
self.wait_counts: collections.defaultdict[str, int] = collections.defaultdict(
int
)

async def acquire(self, channel):
async def acquire(self, channel: str) -> bool:
"""
Acquire the lock for the given channel.
"""
self.wait_counts[channel] += 1
return await self.locks[channel].acquire()

def locked(self, channel):
def locked(self, channel: str) -> bool:
"""
Return ``True`` if the lock for the given channel is acquired.
"""
return self.locks[channel].locked()

def release(self, channel):
def release(self, channel: str):
"""
Release the lock for the given channel.
"""
Expand All @@ -73,12 +84,12 @@ def put_nowait(self, item):


class RedisLoopLayer:
def __init__(self, channel_layer):
def __init__(self, channel_layer: "RedisChannelLayer"):
self._lock = asyncio.Lock()
self.channel_layer = channel_layer
self._connections = {}
self._connections: typing.Dict[int, "Redis"] = {}

def get_connection(self, index):
def get_connection(self, index: int) -> "Redis":
if index not in self._connections:
pool = self.channel_layer.create_pool(index)
self._connections[index] = aioredis.Redis(connection_pool=pool)
Expand Down Expand Up @@ -134,7 +145,7 @@ def __init__(
symmetric_encryption_keys=symmetric_encryption_keys,
)
# Cached redis connection pools and the event loop they are from
self._layers = {}
self._layers: typing.Dict[asyncio.AbstractEventLoop, "RedisLoopLayer"] = {}
# Normal channels choose a host index by cycling through the available hosts
self._receive_index_generator = itertools.cycle(range(len(self.hosts)))
self._send_index_generator = itertools.cycle(range(len(self.hosts)))
Expand All @@ -143,27 +154,27 @@ def __init__(
# Number of coroutines trying to receive right now
self.receive_count = 0
# The receive lock
self.receive_lock = None
self.receive_lock: typing.Optional[asyncio.Lock] = None
# Event loop they are trying to receive on
self.receive_event_loop = None
self.receive_event_loop: typing.Optional[asyncio.AbstractEventLoop] = None
# Buffered messages by process-local channel name
self.receive_buffer = collections.defaultdict(
functools.partial(BoundedQueue, self.capacity)
self.receive_buffer: collections.defaultdict[str, BoundedQueue] = (
collections.defaultdict(functools.partial(BoundedQueue, self.capacity))
)
# Detached channel cleanup tasks
self.receive_cleaners = []
self.receive_cleaners: typing.List[asyncio.Task] = []
# Per-channel cleanup locks to prevent a receive starting and moving
# a message back into the main queue before its cleanup has completed
self.receive_clean_locks = ChannelLock()

def create_pool(self, index):
def create_pool(self, index: int) -> "ConnectionPool":
return create_pool(self.hosts[index])

### Channel layer API ###

extensions = ["groups", "flush"]

async def send(self, channel, message):
async def send(self, channel: str, message):
"""
Send a message onto a (general or specific) channel.
"""
Expand Down Expand Up @@ -203,13 +214,15 @@ async def send(self, channel, message):
await connection.zadd(channel_key, {self.serialize(message): time.time()})
await connection.expire(channel_key, int(self.expiry))

def _backup_channel_name(self, channel):
def _backup_channel_name(self, channel: str) -> str:
"""
Construct the key used as a backup queue for the given channel.
"""
return channel + "$inflight"

async def _brpop_with_clean(self, index, channel, timeout):
async def _brpop_with_clean(
self, index: int, channel: str, timeout: typing.Union[int, float, bytes, str]
):
"""
Perform a Redis BRPOP and manage the backup processing queue.
In case of cancellation, make sure the message is not lost.
Expand Down Expand Up @@ -240,15 +253,15 @@ async def _brpop_with_clean(self, index, channel, timeout):

return member

async def _clean_receive_backup(self, index, channel):
async def _clean_receive_backup(self, index: int, channel: str):
"""
Pop the oldest message off the channel backup queue.
The result isn't interesting as it was already processed.
"""
connection = self.connection(index)
await connection.zpopmin(self._backup_channel_name(channel))

async def receive(self, channel):
async def receive(self, channel: str):
"""
Receive the first message that arrives on the channel.
If more than one coroutine waits on the same channel, the first waiter
Expand Down Expand Up @@ -372,7 +385,7 @@ async def receive(self, channel):
# Do a plain direct receive
return (await self.receive_single(channel))[1]

async def receive_single(self, channel):
async def receive_single(self, channel: str) -> typing.Tuple:
"""
Receives a single message off of the channel and returns it.
"""
Expand Down Expand Up @@ -408,7 +421,7 @@ async def receive_single(self, channel):
)
self.receive_cleaners.append(cleaner)

def _cleanup_done(cleaner):
def _cleanup_done(cleaner: asyncio.Task):
self.receive_cleaners.remove(cleaner)
self.receive_clean_locks.release(channel_key)

Expand All @@ -427,7 +440,7 @@ def _cleanup_done(cleaner):
del message["__asgi_channel__"]
return channel, message

async def new_channel(self, prefix="specific"):
async def new_channel(self, prefix: str = "specific") -> str:
"""
Returns a new channel name that can be used by something in our
process as a specific channel.
Expand Down Expand Up @@ -477,13 +490,13 @@ async def wait_received(self):

### Groups extension ###

async def group_add(self, group, channel):
async def group_add(self, group: str, channel: str):
"""
Adds the channel name to a group.
"""
# Check the inputs
assert self.valid_group_name(group), "Group name not valid"
assert self.valid_channel_name(channel), "Channel name not valid"
assert self.valid_group_name(group), True
assert self.valid_channel_name(channel), True
# Get a connection to the right shard
group_key = self._group_key(group)
connection = self.connection(self.consistent_hash(group))
Expand All @@ -493,7 +506,7 @@ async def group_add(self, group, channel):
# it at this point is guaranteed to expire before that
await connection.expire(group_key, self.group_expiry)

async def group_discard(self, group, channel):
async def group_discard(self, group: str, channel: str):
"""
Removes the channel from the named group if it is in the group;
does nothing otherwise (does not error)
Expand All @@ -504,7 +517,7 @@ async def group_discard(self, group, channel):
connection = self.connection(self.consistent_hash(group))
await connection.zrem(key, channel)

async def group_send(self, group, message):
async def group_send(self, group: str, message):
"""
Sends a message to the entire group.
"""
Expand Down Expand Up @@ -573,7 +586,12 @@ async def group_send(self, group, message):
channels_over_capacity = await connection.eval(
group_send_lua, len(channel_redis_keys), *channel_redis_keys, *args
)
if channels_over_capacity > 0:
_channels_over_capacity = -1
try:
_channels_over_capacity = float(channels_over_capacity)
except Exception:
pass
if _channels_over_capacity > 0:
Comment on lines +589 to +594

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

am not sure what's happening here
but i think it should be in another commit

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

connection.eval is inferred to return await[str] | str.
By converting it to float, it can be compared with 0.

logger.info(
"%s of %s channels over capacity in group %s",
channels_over_capacity,
Expand Down Expand Up @@ -631,37 +649,35 @@ def _map_channel_keys_to_connection(self, channel_names, message):
channel_key_to_capacity,
)

def _group_key(self, group):
def _group_key(self, group: str) -> bytes:
"""
Common function to make the storage key for the group.
"""
return f"{self.prefix}:group:{group}".encode("utf8")

### Serialization ###

def serialize(self, message):
def serialize(self, message) -> bytes:
"""
Serializes message to a byte string.
"""
return self._serializer.serialize(message)

def deserialize(self, message):
def deserialize(self, message: bytes):
"""
Deserializes from a byte string.
"""
return self._serializer.deserialize(message)

### Internal functions ###

def consistent_hash(self, value):
def consistent_hash(self, value: typing.Union[str, "Buffer"]) -> int:
return _consistent_hash(value, self.ring_size)

def __str__(self):
return f"{self.__class__.__name__}(hosts={self.hosts})"

### Connection handling ###

def connection(self, index):
def connection(self, index: int) -> "Redis":
"""
Returns the correct connection for the index given.
Lazily instantiates pools.
Expand Down
Loading