diff --git a/docs/examples/responses/sse_responses_with_ping.py b/docs/examples/responses/sse_responses_with_ping.py new file mode 100644 index 0000000000..58a5a593d3 --- /dev/null +++ b/docs/examples/responses/sse_responses_with_ping.py @@ -0,0 +1,22 @@ +from asyncio import sleep +from typing import AsyncGenerator + +from litestar import Litestar, get +from litestar.response import ServerSentEvent, ServerSentEventMessage +from litestar.types import SSEData + + +async def my_slow_generator() -> AsyncGenerator[SSEData, None]: + count = 0 + while count < 1: + await sleep(1) + count += 1 + yield ServerSentEventMessage(data="content", event="message") + + +@get(path="/with_ping", sync_to_thread=False) +def sse_handler_with_ping_events() -> ServerSentEvent: + return ServerSentEvent(my_slow_generator(), ping_interval=0.1) + + +app = Litestar(route_handlers=[sse_handler_with_ping_events]) diff --git a/docs/usage/responses.rst b/docs/usage/responses.rst index 45b3b780b8..c5769098da 100644 --- a/docs/usage/responses.rst +++ b/docs/usage/responses.rst @@ -696,6 +696,12 @@ If you want to send a different event type, you can use a dictionary with the ke You can further customize all the sse parameters, add comments, and set the retry duration by using the :class:`ServerSentEvent <.response.ServerSentEvent>` class directly or by using the :class:`ServerSentEventMessage <.response.ServerSentEventMessage>` or dictionaries with the appropriate keys. +If the ``ServerSentEvent`` has ``ping_interval`` set to a positive value, a message with ``event_type`` ``ping`` +will be sent every ``ping_interval`` seconds. This is useful for applications that close connections after a timeout +(e.g., TelegramMiniApps; for more details, see `issue 4082 `_) + +.. literalinclude:: /examples/responses/sse_responses_with_ping.py + :language: python Template Responses ------------------ diff --git a/litestar/response/sse.py b/litestar/response/sse.py index ca9bf991cb..3980aed488 100644 --- a/litestar/response/sse.py +++ b/litestar/response/sse.py @@ -1,18 +1,39 @@ from __future__ import annotations import io +import itertools import re from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, AsyncIterator, Iterable, Iterator +from functools import partial +from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Iterable, Iterator + +from anyio import Event, create_task_group, sleep from litestar.concurrency import sync_to_thread +from litestar.enums import MediaType from litestar.exceptions import ImproperlyConfiguredException -from litestar.response.streaming import Stream +from litestar.response.streaming import ASGIStreamingResponse, Stream from litestar.utils import AsyncIteratorWrapper +from litestar.utils.deprecation import warn_deprecation +from litestar.utils.helpers import get_enum_string_value if TYPE_CHECKING: + from litestar.app import Litestar from litestar.background_tasks import BackgroundTask, BackgroundTasks - from litestar.types import ResponseCookies, ResponseHeaders, SSEData, StreamType + from litestar.connection import Request + from litestar.datastructures.cookie import Cookie + from litestar.enums import OpenAPIMediaType + from litestar.response.base import ASGIResponse + from litestar.types import ( + HTTPResponseBodyEvent, + Receive, + ResponseCookies, + ResponseHeaders, + Send, + SSEData, + StreamType, + TypeEncodersMap, + ) _LINE_BREAK_RE = re.compile(r"\r\n|\r|\n") DEFAULT_SEPARATOR = "\r\n" @@ -127,7 +148,237 @@ def encode(self) -> bytes: return buffer.getvalue().encode("utf-8") -class ServerSentEvent(Stream): +class ASGIStreamingSSEResponse(ASGIStreamingResponse): + """A streaming response which support sending ping messages specific for SSE.""" + + __slots__ = ( + "is_content_end", + "ping_interval", + ) + + def __init__( + self, + *, + iterator: StreamType, + background: BackgroundTask | BackgroundTasks | None = None, + body: bytes | str = b"", + content_length: int | None = None, + cookies: Iterable[Cookie] | None = None, + encoded_headers: Iterable[tuple[bytes, bytes]] | None = None, + encoding: str = "utf-8", + headers: dict[str, Any] | None = None, + is_head_response: bool = False, + media_type: MediaType | str | None = None, + status_code: int | None = None, + ping_interval: float = 0, + ) -> None: + """A SSE ASGI streaming response. + + Args: + background: A background task or a list of background tasks to be executed after the response is sent. + body: encoded content to send in the response body. + .. deprecated:: 2.16 + content_length: The response content length. + cookies: The response cookies. + encoded_headers: The response headers. + encoding: The response encoding. + headers: The response headers. + is_head_response: A boolean indicating if the response is a HEAD response. + iterator: An async iterator or iterable. + media_type: The response media type. + status_code: The response status code. + ping_interval: The interval in seconds between "ping" messages. + is_content_end: Indicates the ending of content in the iterator, e.g., use to stop sending ping events. + """ + super().__init__( + iterator=iterator, + background=background, + body=body, + content_length=content_length, + cookies=cookies, + encoding=encoding, + headers=headers, + is_head_response=is_head_response, + media_type=media_type, + status_code=status_code, + encoded_headers=encoded_headers, + ) + self.ping_interval = ping_interval + self.is_content_end = Event() + + async def _send_ping_event(self, send: Send) -> None: + """Send ping events every `ping_interval` second. + + Args: + send: The ASGI Send function. + + Returns: + None + """ + if not self.ping_interval: + return + + ping_event: HTTPResponseBodyEvent = { + "type": "http.response.body", + "body": b"event: ping\r\n\r\n", + "more_body": True, + } + + while not self.is_content_end.is_set(): + await send(ping_event) + await sleep(self.ping_interval) + + async def _stream(self, send: Send) -> None: + """Send the chunks from the iterator as a stream of ASGI 'http.response.body' events. + + Args: + send: The ASGI Send function. + + Returns: + None + """ + async for chunk in self.iterator: + stream_event: HTTPResponseBodyEvent = { + "type": "http.response.body", + "body": chunk if isinstance(chunk, bytes) else chunk.encode(self.encoding), + "more_body": True, + } + await send(stream_event) + terminus_event: HTTPResponseBodyEvent = {"type": "http.response.body", "body": b"", "more_body": False} + self.is_content_end.set() + await send(terminus_event) + + async def send_body(self, send: Send, receive: Receive) -> None: + """Emit a stream of events correlating with the response body. + + Args: + send: The ASGI send function. + receive: The ASGI receive function. + + Returns: + None + """ + + async with create_task_group() as task_group: + task_group.start_soon(partial(self._stream, send)) + task_group.start_soon(partial(self._send_ping_event, send)) + await self._listen_for_disconnect(cancel_scope=task_group.cancel_scope, receive=receive) + + +class SSEStream(Stream): + """An HTTP response that streams the response data as a series of ASGI ``http.response.body`` events.""" + + __slots__ = ("ping_interval",) + + def __init__( + self, + content: StreamType[str | bytes] | Callable[[], StreamType[str | bytes]], + *, + background: BackgroundTask | BackgroundTasks | None = None, + cookies: ResponseCookies | None = None, + encoding: str = "utf-8", + headers: ResponseHeaders | None = None, + media_type: MediaType | OpenAPIMediaType | str | None = None, + status_code: int | None = None, + ping_interval: float = 0, + ) -> None: + """Initialize the response. + + Args: + content: A sync or async iterator or iterable. + background: A :class:`BackgroundTask <.background_tasks.BackgroundTask>` instance or + :class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished. + Defaults to None. + cookies: A list of :class:`Cookie <.datastructures.Cookie>` instances to be set under the response + ``Set-Cookie`` header. + encoding: The encoding to be used for the response headers. + headers: A string keyed dictionary of response headers. Header keys are insensitive. + media_type: A value for the response ``Content-Type`` header. + status_code: An HTTP status code. + ping_interval: The interval in seconds between "ping" messages. + """ + super().__init__( + background=background, + content=b"", # type: ignore[arg-type] + cookies=cookies, + encoding=encoding, + headers=headers, + media_type=media_type, + status_code=status_code, + ) + self.iterator = content + + if ping_interval < 0: + raise ValueError("argument ping_interval must be not negative") + self.ping_interval = ping_interval + + def to_asgi_response( + self, + app: Litestar | None, + request: Request, + *, + background: BackgroundTask | BackgroundTasks | None = None, + cookies: Iterable[Cookie] | None = None, + encoded_headers: Iterable[tuple[bytes, bytes]] | None = None, + headers: dict[str, str] | None = None, + is_head_response: bool = False, + media_type: MediaType | str | None = None, + status_code: int | None = None, + type_encoders: TypeEncodersMap | None = None, + ) -> ASGIResponse: + """Create an ASGIStreamingResponse from a StremaingResponse instance. + + Args: + app: The :class:`Litestar <.app.Litestar>` application instance. + background: Background task(s) to be executed after the response is sent. + cookies: A list of cookies to be set on the response. + encoded_headers: A list of already encoded headers. + headers: Additional headers to be merged with the response headers. Response headers take precedence. + is_head_response: Whether the response is a HEAD response. + media_type: Media type for the response. If ``media_type`` is already set on the response, this is ignored. + request: The :class:`Request <.connection.Request>` instance. + status_code: Status code for the response. If ``status_code`` is already set on the response, this is + type_encoders: A dictionary of type encoders to use for encoding the response content. + + Returns: + An ASGIStreamingResponse instance. + """ + if app is not None: + warn_deprecation( + version="2.1", + deprecated_name="app", + kind="parameter", + removal_in="3.0.0", + alternative="request.app", + ) + + headers = {**headers, **self.headers} if headers is not None else self.headers + cookies = self.cookies if cookies is None else itertools.chain(self.cookies, cookies) + + media_type = get_enum_string_value(media_type or self.media_type or MediaType.JSON) + + iterator = self.iterator + if not isinstance(iterator, (Iterable, Iterator, AsyncIterable, AsyncIterator)) and callable(iterator): + iterator = iterator() + + return ASGIStreamingSSEResponse( + background=self.background or background, + content_length=0, + cookies=cookies, + encoded_headers=encoded_headers, + encoding=self.encoding, + headers=headers, + is_head_response=is_head_response, + iterator=iterator, + media_type=media_type, + status_code=self.status_code or status_code, + ping_interval=self.ping_interval, + ) + + +class ServerSentEvent(SSEStream): + """docs.""" + def __init__( self, content: str | bytes | StreamType[SSEData], @@ -141,6 +392,7 @@ def __init__( retry_duration: int | None = None, comment_message: str | None = None, status_code: int | None = None, + ping_interval: float = 0, ) -> None: """Initialize the response. @@ -159,6 +411,7 @@ def __init__( event_id: The event ID. This sets the event source's 'last event id'. retry_duration: Retry duration in milliseconds. comment_message: A comment message. This value is ignored by clients and is used mostly for pinging. + ping_interval: The interval in seconds between "ping" messages. """ super().__init__( content=_ServerSentEventIterator( @@ -174,6 +427,7 @@ def __init__( encoding=encoding, headers=headers, status_code=status_code, + ping_interval=ping_interval, ) self.headers.setdefault("Cache-Control", "no-cache") self.headers["Connection"] = "keep-alive" diff --git a/tests/examples/test_responses/test_sse_responses.py b/tests/examples/test_responses/test_sse_responses.py index 9f92d668f1..3d5fb7e738 100644 --- a/tests/examples/test_responses/test_sse_responses.py +++ b/tests/examples/test_responses/test_sse_responses.py @@ -1,4 +1,5 @@ from docs.examples.responses.sse_responses import app +from docs.examples.responses.sse_responses_with_ping import app as app2 from httpx_sse import aconnect_sse from litestar.testing import AsyncTestClient @@ -9,3 +10,13 @@ async def test_sse_responses_example() -> None: async with aconnect_sse(client, "GET", f"{client.base_url}/count") as event_source: events = [sse async for sse in event_source.aiter_sse()] assert len(events) == 50 + + +async def test_sse_responses_example_with_ping_events() -> None: + async with AsyncTestClient(app=app2) as client: + async with aconnect_sse(client, "GET", f"{client.base_url}/with_ping") as event_source: + events = [sse async for sse in event_source.aiter_sse()] + for i in range(9): + assert events[i].event == "ping" + assert events[10].event == "message" + assert events[10].data == "content" diff --git a/tests/unit/test_response/test_sse.py b/tests/unit/test_response/test_sse.py index bff7af75c2..a0b5297953 100644 --- a/tests/unit/test_response/test_sse.py +++ b/tests/unit/test_response/test_sse.py @@ -99,3 +99,29 @@ async def numbers() -> AsyncIterator[SSEData]: def test_invalid_content_type_raises() -> None: with pytest.raises(ImproperlyConfiguredException): ServerSentEvent(content=object()) # type: ignore[arg-type] + + +async def test_sse_ping_events() -> None: + @get("/test_ping") + async def handler() -> ServerSentEvent: + async def slow_generator() -> AsyncIterator[SSEData]: + for i in range(1): + await anyio.sleep(0.1) + yield i + + return ServerSentEvent(content=slow_generator(), ping_interval=0.01) + + async with create_async_test_client(handler) as client: + async with aconnect_sse(client, "GET", f"{client.base_url}/test_ping") as event_source: + events = [sse async for sse in event_source.aiter_sse()] + for i in range(9): + assert events[i].event == "ping" + assert events[i].data == "" + + assert events[10].event == "message" + assert events[10].data == "0" + + +async def test_sse_negative_ping_interval() -> None: + with pytest.raises(ValueError): + ServerSentEvent(content="content", ping_interval=-2)