diff --git a/channels_redis/core.py b/channels_redis/core.py index 669be2a..adca5c9 100644 --- a/channels_redis/core.py +++ b/channels_redis/core.py @@ -166,6 +166,80 @@ class UnsupportedRedis(Exception): pass +class ReceiveBuffer: + """ + Receive buffer + + It manages waiters and buffers messages for all specific channels under the same 'real channel' + Also manages the receive loop for the 'real channel' + """ + + def __init__(self, receive_single, real_channel): + self.loop = None + self.real_channel = real_channel + self.receive_single = receive_single + self.getters = collections.defaultdict(collections.deque) + self.buffers = collections.defaultdict(collections.deque) + self.receiver = None + + def __bool__(self): + return bool(self.getters) + + def get(self, channel): + """ + :param channel: name of the channel + :return: Future for the next message on channel + """ + assert channel.startswith( + self.real_channel + ), "channel not managed by this buffer" + getter = self.loop.create_future() + + if channel in self.buffers: + getter.set_result(self.buffers[channel].popleft()) + if not self.buffers[channel]: + del self.buffers[channel] + else: + getter.channel = channel + getter.add_done_callback(self._getter_done_prematurely) + self.getters[channel].append(getter) + + # ensure receiver is running + if not self.receiver: + self.receiver = asyncio.ensure_future(self.receiver_factory()) + return getter + + def _getter_done_prematurely(self, getter): + channel = getter.channel + self.getters[channel].remove(getter) + if not self.getters[channel]: + del self.getters[channel] + if not self and self.receiver: + self.receiver.cancel() + + def put(self, channel, message): + if channel in self.getters: + getter = self.getters[channel].popleft() + getter.remove_done_callback(self._getter_done_prematurely) + if not self.getters[channel]: + del self.getters[channel] + getter.set_result(message) + else: + self.buffers[channel].append(message) + + async def receiver_factory(self): + try: + while self: + message_channel, message = await self.receive_single(self.real_channel) + if type(message_channel) is list: + for chan in message_channel: + self.put(chan, message) + else: + self.put(message_channel, message) + finally: + self.receiver = None + + class RedisChannelLayer(BaseChannelLayer): """ Redis channel layer. @@ -209,14 +283,8 @@ def __init__( ) # Set up any encryption objects self._setup_encryption(symmetric_encryption_keys) - # Number of coroutines trying to receive right now - self.receive_count = 0 - # The receive lock - self.receive_lock = None - # Event loop they are trying to receive on - self.receive_event_loop = None # Buffered messages by process-local channel name - self.receive_buffer = collections.defaultdict(asyncio.Queue) + self.receive_buffers = {} # Detached channel cleanup tasks self.receive_cleaners = [] # Per-channel cleanup locks to prevent a receive starting and moving @@ -352,110 +420,20 @@ async def receive(self, channel): ), "Wrong client prefix" # Enter receiving section loop = asyncio.get_event_loop() - self.receive_count += 1 - try: - if self.receive_count == 1: - # If we're the first coroutine in, create the receive lock! - self.receive_lock = asyncio.Lock() - self.receive_event_loop = loop - else: - # Otherwise, check our event loop matches - if self.receive_event_loop != loop: - raise RuntimeError( - "Two event loops are trying to receive() on one channel layer at once!" - ) - - # Wait for our message to appear - message = None - while self.receive_buffer[channel].empty(): - tasks = [ - self.receive_lock.acquire(), - self.receive_buffer[channel].get(), - ] - tasks = [asyncio.ensure_future(task) for task in tasks] - try: - done, pending = await asyncio.wait( - tasks, return_when=asyncio.FIRST_COMPLETED - ) - for task in pending: - # Cancel all pending tasks. - task.cancel() - except asyncio.CancelledError: - # Ensure all tasks are cancelled if we are cancelled. - # Also see: https://bugs.python.org/issue23859 - del self.receive_buffer[channel] - for task in tasks: - if not task.cancel(): - assert task.done() - if task.result() is True: - self.receive_lock.release() - - raise - - message, token, exception = None, None, None - for task in done: - try: - result = task.result() - except Exception as error: # NOQA - # We should not propagate exceptions immediately as otherwise this may cause - # the lock to be held and never be released. - exception = error - continue - - if result is True: - token = result - else: - assert isinstance(result, dict) - message = result - - if message or exception: - if token: - # We will not be receving as we already have the message. - self.receive_lock.release() - - if exception: - raise exception - else: - break - else: - assert token - - # We hold the receive lock, receive and then release it. - try: - # There is no interruption point from when the message is - # unpacked in receive_single to when we get back here, so - # the following lines are essentially atomic. - message_channel, message = await self.receive_single( - real_channel - ) - if type(message_channel) is list: - for chan in message_channel: - self.receive_buffer[chan].put_nowait(message) - else: - self.receive_buffer[message_channel].put_nowait(message) - message = None - except: - del self.receive_buffer[channel] - raise - finally: - self.receive_lock.release() - - # We know there's a message available, because there - # couldn't have been any interruption between empty() and here - if message is None: - message = self.receive_buffer[channel].get_nowait() - - if self.receive_buffer[channel].empty(): - del self.receive_buffer[channel] - return message - - finally: - self.receive_count -= 1 - # If we were the last out, drop the receive lock - if self.receive_count == 0: - assert not self.receive_lock.locked() - self.receive_lock = None - self.receive_event_loop = None + if real_channel not in self.receive_buffers: + self.receive_buffers[real_channel] = ReceiveBuffer( + self.receive_single, real_channel + ) + receive_buffer = self.receive_buffers[real_channel] + + # Check our event loop matches + if receive_buffer.loop != loop and receive_buffer.receiver: + raise RuntimeError( + "Two event loops are trying to receive() on one channel layer at once!" + ) + else: + receive_buffer.loop = loop + return await receive_buffer.get(channel) else: # Do a plain direct receive return (await self.receive_single(channel))[1] diff --git a/tests/test_core.py b/tests/test_core.py index dbf2401..9d9fbe8 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -5,7 +5,7 @@ from async_generator import async_generator, yield_ from asgiref.sync import async_to_sync -from channels_redis.core import ChannelFull, RedisChannelLayer +from channels_redis.core import ChannelFull, ReceiveBuffer, RedisChannelLayer TEST_HOSTS = [("localhost", 6379)] @@ -343,3 +343,68 @@ async def test_receive_cancel(channel_layer): await asyncio.wait_for(task, None) except asyncio.CancelledError: pass + + +@pytest.mark.asyncio +async def test_receive_multiple_specific_prefixes(channel_layer): + """ + Makes sure we receive on multiple real channels + """ + channel_layer = RedisChannelLayer(capacity=10) + channel1 = await channel_layer.new_channel() + channel2 = await channel_layer.new_channel(prefix="thing") + r1, _, r2 = tasks = [ + asyncio.ensure_future(x) + for x in ( + channel_layer.receive(channel1), + channel_layer.send(channel2, {"type": "message"}), + channel_layer.receive(channel2), + ) + ] + await asyncio.wait(tasks, timeout=0.5) + + assert not r1.done() + assert r2.done() and r2.result()["type"] == "message" + r1.cancel() + + +@pytest.mark.asyncio +async def test_buffer_wrong_channel(channel_layer): + async def dummy_receive(channel): + return channel, {"type": "message"} + + buffer = ReceiveBuffer(dummy_receive, "whatever!") + buffer.loop = asyncio.get_event_loop() + with pytest.raises(AssertionError): + buffer.get("wrong!13685sjmh") + + +@pytest.mark.asyncio +async def test_buffer_receiver_stopped(channel_layer): + async def dummy_receive(channel): + return "whatever!meh", {"type": "message"} + + buffer = ReceiveBuffer(dummy_receive, "whatever!") + buffer.loop = asyncio.get_event_loop() + + await buffer.get("whatever!meh") + assert buffer.receiver is None + + +@pytest.mark.asyncio +async def test_buffer_receiver_canceled(channel_layer): + async def dummy_receive(channel): + await asyncio.sleep(2) + return "whatever!meh", {"type": "message"} + + buffer = ReceiveBuffer(dummy_receive, "whatever!") + buffer.loop = asyncio.get_event_loop() + + get1 = buffer.get("whatever!meh") + assert buffer.receiver is not None + get2 = buffer.get("whatever!meh2") + get1.cancel() + assert buffer.receiver is not None + get2.cancel() + await asyncio.sleep(0.1) + assert buffer.receiver is None