Skip to content

fix: Guarante at most once delivery #414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions channels_redis/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import uuid

import aioredlock
from redis import asyncio as aioredis

from .serializers import registry
Expand Down Expand Up @@ -117,6 +118,11 @@ def __init__(
RedisSingleShardConnection(host, self) for host in decode_hosts(hosts)
]

# Create lock manager for all redis connections
redis_connections = [shard.host for shard in self._shards]
self._lock_manager = aioredlock.Aioredlock()
self._lock_manager.redis_connections = redis_connections

def _get_shard(self, channel_or_group_name):
"""
Return the shard that is used exclusively for this channel or group.
Expand All @@ -135,10 +141,21 @@ def _get_group_channel_name(self, group):
"""
return f"{self.prefix}__group__{group}"

async def _acquire_lock(self, channel):
try:
await self._lock_manager.lock(channel, lock_timeout=60)
except aioredlock.LockError:
logger.debug("Failed to acquire lock on channel %s", channel)
return False

return True

async def _subscribe_to_channel(self, channel):
self.channels[channel] = asyncio.Queue()
shard = self._get_shard(channel)
await shard.subscribe(channel)

if await self._acquire_lock(channel):
shard = self._get_shard(channel)
await shard.subscribe(channel)

extensions = ["groups", "flush"]

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
include_package_data=True,
python_requires=">=3.8",
install_requires=[
"aioredlock>=0.7.3,<1",
"redis>=4.6",
"msgpack~=1.0",
"asgiref>=3.2.10,<4",
Expand Down
44 changes: 43 additions & 1 deletion tests/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import async_timeout
import pytest

from asgiref.sync import async_to_sync

from channels_redis.pubsub import RedisPubSubChannelLayer
from channels_redis.utils import _close_redis

Expand Down Expand Up @@ -261,3 +261,45 @@ async def test_discard_before_add(channel_layer):
channel_name = await channel_layer.new_channel(prefix="test-channel")
# Make sure that we can remove a group before it was ever added without crashing.
await channel_layer.group_discard("test-group", channel_name)


@pytest.mark.asyncio
async def test_guarantee_at_most_once_delivery() -> None:
"""
Tests that at most once delivery is guaranteed.

If two consumers are listening on the same channel,
the message should be delivered to only one of them.
"""

channel_name = "same-channel"
loop = asyncio.get_running_loop()

channel_layer = RedisPubSubChannelLayer(hosts=TEST_HOSTS)
channel_layer_2 = RedisPubSubChannelLayer(hosts=TEST_HOSTS)
future_channel_layer = loop.create_future()
future_channel_layer_2 = loop.create_future()

async def receive_task(
channel_layer: RedisPubSubChannelLayer, future: asyncio.Future
) -> None:
message = await channel_layer.receive(channel_name)
future.set_result(message)

# Ensure that receive_task_2 is scheduled first and accquires the lock
asyncio.create_task(receive_task(channel_layer_2, future_channel_layer_2))
await asyncio.sleep(1)
asyncio.create_task(receive_task(channel_layer, future_channel_layer))
await asyncio.sleep(1)

await channel_layer.send(channel_name, {"type": "test.message", "text": "Hello!"})

result = await future_channel_layer_2
assert result["type"] == "test.message"
assert result["text"] == "Hello!"

# Channel layer 1 should not receive the message
# as it is already consumed by channel layer 2
with pytest.raises(asyncio.TimeoutError):
async with async_timeout.timeout(1):
await future_channel_layer