Skip to content
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2c7fdd9
added ping_interval keyword
RenameMe1 Apr 10, 2025
65e97f7
created _send_ping_event method
RenameMe1 Apr 10, 2025
0d051b3
added _send_ping_event in task_group
RenameMe1 Apr 10, 2025
381a1ea
Conveted the ping_interval type to a float
RenameMe1 Apr 11, 2025
01c188c
Changed doc string to ping_interval
RenameMe1 Apr 11, 2025
22ac9aa
Update _send_ping_event doc string
RenameMe1 Apr 11, 2025
9acda57
added new sse example with `ping_interval`
RenameMe1 Apr 11, 2025
cbe9e48
added `event:` in ping body
RenameMe1 Apr 11, 2025
7018a69
added test case sse response with ping events
RenameMe1 Apr 11, 2025
08f84d7
lint updated
RenameMe1 Apr 11, 2025
8e9bd4c
Merge branch 'main' into sse_timeout_events
RenameMe1 Apr 11, 2025
6c6a00c
added ping events test case in test_sse.py
RenameMe1 Apr 12, 2025
76d0aae
added type hint
RenameMe1 Apr 12, 2025
23a8455
Merge branch 'main' into sse_timeout_events
RenameMe1 Apr 15, 2025
385d8fc
Merge branch 'litestar-org:main' into sse_timeout_events
RenameMe1 Apr 20, 2025
642d3f6
removed ping_event from ASGIStreamingResponse, added stop if the SSE …
RenameMe1 Apr 20, 2025
c0798be
added test for negative number for ping_interval
RenameMe1 Apr 20, 2025
0b5f8f0
Merge branch 'main' into sse_timeout_events
RenameMe1 May 8, 2025
174bc0f
Merge branch 'litestar-org:main' into sse_timeout_events
RenameMe1 May 11, 2025
9429e83
Change doc message to ASGIStreamingSSEResponse
RenameMe1 May 11, 2025
b7e7bc1
added a more informative description of `ping_interval`
RenameMe1 May 12, 2025
372e891
pre-commit update
RenameMe1 May 12, 2025
10e805e
Added anyio.Event to stop sending ping events
RenameMe1 May 18, 2025
e6aa91c
update doc to __init__
RenameMe1 May 18, 2025
4f4ff27
extra spaces have been removed
RenameMe1 May 21, 2025
75ee8ba
transferring the code to sse.py
RenameMe1 May 29, 2025
c00ff86
slots update
RenameMe1 May 29, 2025
5d744f2
decreased ping_interval to test case
RenameMe1 May 29, 2025
165dd59
fix docs warning
RenameMe1 May 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions docs/examples/responses/sse_responses_with_ping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from asyncio import sleep
from typing import AsyncGenerator

from litestar import Litestar, get
from litestar.response import ServerSentEvent, ServerSentEventMessage
from litestar.types import SSEData

async def my_slow_generator() -> AsyncGenerator[SSEData, None]:
count = 0
while count < 1:
await sleep(1)
count += 1
yield ServerSentEventMessage(data="content", event="message")

@get(path="/with_ping", sync_to_thread=False)
def sse_handler_with_ping_events() -> ServerSentEvent:
return ServerSentEvent(my_slow_generator(), ping_interval=0.1)


app = Litestar(route_handlers=[sse_handler_with_ping_events])
6 changes: 6 additions & 0 deletions docs/usage/responses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,12 @@ If you want to send a different event type, you can use a dictionary with the ke

You can further customize all the sse parameters, add comments, and set the retry duration by using the :class:`ServerSentEvent <.response.ServerSentEvent>` class directly or by using the :class:`ServerSentEventMessage <.response.ServerSentEventMessage>` or dictionaries with the appropriate keys.

If the ``ServerSentEvent`` has ``ping_interval`` set to a positive value, a message with ``event_type`` ``ping``
will be sent every ``ping_interval`` seconds. This is useful for applications that close connections after a timeout
(e.g., TelegramMiniApps; for more details, see `issue 4082 <https://github.com/litestar-org/litestar/issues/4082>`_)

.. literalinclude:: /examples/responses/sse_responses_with_ping.py
:language: python

Template Responses
------------------
Expand Down
3 changes: 3 additions & 0 deletions litestar/response/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def __init__(
retry_duration: int | None = None,
comment_message: str | None = None,
status_code: int | None = None,
ping_interval: float = 0,
) -> None:
"""Initialize the response.

Expand All @@ -159,6 +160,7 @@ def __init__(
event_id: The event ID. This sets the event source's 'last event id'.
retry_duration: Retry duration in milliseconds.
comment_message: A comment message. This value is ignored by clients and is used mostly for pinging.
ping_interval: The interval in seconds between "ping" messages.
"""
super().__init__(
content=_ServerSentEventIterator(
Expand All @@ -174,6 +176,7 @@ def __init__(
encoding=encoding,
headers=headers,
status_code=status_code,
ping_interval=ping_interval,
)
self.headers.setdefault("Cache-Control", "no-cache")
self.headers["Connection"] = "keep-alive"
Expand Down
130 changes: 127 additions & 3 deletions litestar/response/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from functools import partial
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Iterable, Iterator, Union

from anyio import CancelScope, create_task_group
from anyio import CancelScope, create_task_group, sleep

from litestar.enums import MediaType
from litestar.response.base import ASGIResponse, Response
Expand Down Expand Up @@ -145,10 +145,127 @@ async def send_body(self, send: Send, receive: Receive) -> None:
await self._listen_for_disconnect(cancel_scope=task_group.cancel_scope, receive=receive)


class ASGIStreamingSSEResponse(ASGIStreamingResponse):
"""A streaming response which support sending ping messages specific for SSE."""

__slots__ = (
"content_exist",
"ping_interval",
)

def __init__(
self,
*,
iterator: StreamType,
background: BackgroundTask | BackgroundTasks | None = None,
body: bytes | str = b"",
content_length: int | None = None,
cookies: Iterable[Cookie] | None = None,
encoded_headers: Iterable[tuple[bytes, bytes]] | None = None,
encoding: str = "utf-8",
headers: dict[str, Any] | None = None,
is_head_response: bool = False,
media_type: MediaType | str | None = None,
status_code: int | None = None,
ping_interval: float = 0,
) -> None:
"""A low-level ASGI streaming response.

Args:
background: A background task or a list of background tasks to be executed after the response is sent.
body: encoded content to send in the response body.
.. deprecated:: 2.16
content_length: The response content length.
cookies: The response cookies.
encoded_headers: The response headers.
encoding: The response encoding.
headers: The response headers.
is_head_response: A boolean indicating if the response is a HEAD response.
iterator: An async iterator or iterable.
media_type: The response media type.
status_code: The response status code.
ping_interval: The interval in seconds between "ping" messages.
"""
super().__init__(
iterator=iterator,
background=background,
body=body,
content_length=content_length,
cookies=cookies,
encoding=encoding,
headers=headers,
is_head_response=is_head_response,
media_type=media_type,
status_code=status_code,
encoded_headers=encoded_headers,
)
self.ping_interval = ping_interval
self.content_exist = True

async def _send_ping_event(self, send: Send) -> None:
"""Send ping events every `ping_interval` second.

Args:
send: The ASGI Send function.

Returns:
None
"""
stream_event: HTTPResponseBodyEvent = {
"type": "http.response.body",
"body": b"event: ping\r\n\r\n",
"more_body": True,
}

while self.content_exist:
await send(stream_event)
await sleep(self.ping_interval)

async def _stream(self, send: Send) -> None:
"""Send the chunks from the iterator as a stream of ASGI 'http.response.body' events.

Args:
send: The ASGI Send function.

Returns:
None
"""
async for chunk in self.iterator:
stream_event: HTTPResponseBodyEvent = {
"type": "http.response.body",
"body": chunk if isinstance(chunk, bytes) else chunk.encode(self.encoding),
"more_body": True,
}
await send(stream_event)
terminus_event: HTTPResponseBodyEvent = {"type": "http.response.body", "body": b"", "more_body": False}
self.content_exist = False
await send(terminus_event)

async def send_body(self, send: Send, receive: Receive) -> None:
"""Emit a stream of events correlating with the response body.

Args:
send: The ASGI send function.
receive: The ASGI receive function.

Returns:
None
"""

async with create_task_group() as task_group:
task_group.start_soon(partial(self._stream, send))
if self.ping_interval:
task_group.start_soon(partial(self._send_ping_event, send))
await self._listen_for_disconnect(cancel_scope=task_group.cancel_scope, receive=receive)


class Stream(Response[StreamType[Union[str, bytes]]]):
"""An HTTP response that streams the response data as a series of ASGI ``http.response.body`` events."""

__slots__ = ("iterator",)
__slots__ = (
"iterator",
"ping_interval",
)

def __init__(
self,
Expand All @@ -160,6 +277,7 @@ def __init__(
headers: ResponseHeaders | None = None,
media_type: MediaType | OpenAPIMediaType | str | None = None,
status_code: int | None = None,
ping_interval: float = 0,
) -> None:
"""Initialize the response.

Expand All @@ -174,6 +292,7 @@ def __init__(
headers: A string keyed dictionary of response headers. Header keys are insensitive.
media_type: A value for the response ``Content-Type`` header.
status_code: An HTTP status code.
ping_interval: The interval in seconds between "ping" messages.
"""
super().__init__(
background=background,
Expand All @@ -186,6 +305,10 @@ def __init__(
)
self.iterator = content

if ping_interval < 0:
raise ValueError("argument ping_interval must be not negative")
self.ping_interval = ping_interval

def to_asgi_response(
self,
app: Litestar | None,
Expand Down Expand Up @@ -235,7 +358,7 @@ def to_asgi_response(
if not isinstance(iterator, (Iterable, Iterator, AsyncIterable, AsyncIterator)) and callable(iterator):
iterator = iterator()

return ASGIStreamingResponse(
return ASGIStreamingSSEResponse(
background=self.background or background,
content_length=0,
cookies=cookies,
Expand All @@ -246,4 +369,5 @@ def to_asgi_response(
iterator=iterator,
media_type=media_type,
status_code=self.status_code or status_code,
ping_interval=self.ping_interval,
)
11 changes: 11 additions & 0 deletions tests/examples/test_responses/test_sse_responses.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from docs.examples.responses.sse_responses import app
from docs.examples.responses.sse_responses_with_ping import app as app2
from httpx_sse import aconnect_sse

from litestar.testing import AsyncTestClient
Expand All @@ -9,3 +10,13 @@ async def test_sse_responses_example() -> None:
async with aconnect_sse(client, "GET", f"{client.base_url}/count") as event_source:
events = [sse async for sse in event_source.aiter_sse()]
assert len(events) == 50


async def test_sse_responses_example_with_ping_events() -> None:
async with AsyncTestClient(app=app2) as client:
async with aconnect_sse(client, "GET", f"{client.base_url}/with_ping") as event_source:
events = [sse async for sse in event_source.aiter_sse()]
for i in range(9):
assert events[i].event == " ping"
assert events[10].event == "message"
assert events[10].data == "content"
26 changes: 26 additions & 0 deletions tests/unit/test_response/test_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,29 @@ async def numbers() -> AsyncIterator[SSEData]:
def test_invalid_content_type_raises() -> None:
with pytest.raises(ImproperlyConfiguredException):
ServerSentEvent(content=object()) # type: ignore[arg-type]


async def test_sse_ping_events() -> None:
@get("/test_ping")
async def handler() -> ServerSentEvent:
async def slow_generator() -> AsyncIterator[SSEData]:
for i in range(1):
await anyio.sleep(1)
yield i

return ServerSentEvent(content=slow_generator(), ping_interval=0.1)

async with create_async_test_client(handler) as client:
async with aconnect_sse(client, "GET", f"{client.base_url}/test_ping") as event_source:
events = [sse async for sse in event_source.aiter_sse()]
for i in range(9):
assert events[i].event == " ping"
assert events[i].data == ""

assert events[10].event == "message"
assert events[10].data == "0"


async def test_sse_negatove_ping_interval() -> None:
with pytest.raises(ValueError):
ServerSentEvent(content="content", ping_interval=-2)
Loading