Skip to content

Commit 08e8dad

Browse files
authored
refactor: type annotate content properly (#4427)
* refactor: type annotate `content` properly * refactor: reuse `DEFAULT_SEPARATOR`
1 parent f39e7ba commit 08e8dad

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

litestar/response/sse.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from __future__ import annotations
22

3-
import io
43
import re
5-
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Iterable, Iterator
4+
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Iterable, Iterator
65
from dataclasses import dataclass
7-
from typing import TYPE_CHECKING, Any
6+
from io import StringIO
7+
from typing import TYPE_CHECKING
88

99
from litestar.concurrency import sync_to_thread
1010
from litestar.exceptions import ImproperlyConfiguredException
@@ -26,7 +26,7 @@ class _ServerSentEventIterator(AsyncIteratorWrapper[bytes]):
2626

2727
def __init__(
2828
self,
29-
content: str | bytes | StreamType[SSEData],
29+
content: str | bytes | StreamType[SSEData] | Callable[[], str | bytes | StreamType[SSEData]],
3030
event_type: str | None = None,
3131
event_id: int | str | None = None,
3232
retry_duration: int | None = None,
@@ -37,33 +37,34 @@ def __init__(
3737
self.event_type = event_type
3838
self.retry_duration = retry_duration
3939
chunks: list[bytes] = []
40+
4041
if comment_message is not None:
41-
chunks.extend([f": {chunk}\r\n".encode() for chunk in _LINE_BREAK_RE.split(comment_message)])
42+
chunks.extend(f": {chunk}{DEFAULT_SEPARATOR}".encode() for chunk in _LINE_BREAK_RE.split(comment_message))
4243

4344
if event_id is not None:
44-
chunks.append(f"id: {event_id}\r\n".encode())
45+
chunks.append(f"id: {event_id}{DEFAULT_SEPARATOR}".encode())
4546

4647
if event_type is not None:
47-
chunks.append(f"event: {event_type}\r\n".encode())
48+
chunks.append(f"event: {event_type}{DEFAULT_SEPARATOR}".encode())
4849

4950
if retry_duration is not None:
50-
chunks.append(f"retry: {retry_duration}\r\n".encode())
51+
chunks.append(f"retry: {retry_duration}{DEFAULT_SEPARATOR}".encode())
5152

5253
super().__init__(iterator=chunks)
5354

5455
if not isinstance(content, (Iterator, AsyncIterator, AsyncIteratorWrapper)) and callable(content):
55-
content = content() # type: ignore[unreachable]
56+
content = content()
5657

5758
if isinstance(content, (str, bytes)):
5859
self.content_async_iterator = AsyncIteratorWrapper([content])
59-
elif isinstance(content, (Iterable, Iterator)):
60+
elif isinstance(content, Iterable):
6061
self.content_async_iterator = AsyncIteratorWrapper(content)
61-
elif isinstance(content, (AsyncIterable, AsyncIterator, AsyncIteratorWrapper)):
62+
elif isinstance(content, (AsyncIterable, AsyncIteratorWrapper)):
6263
self.content_async_iterator = content
6364
else:
6465
raise ImproperlyConfiguredException(f"Invalid type {type(content)} for ServerSentEvent")
6566

66-
def ensure_bytes(self, data: str | int | bytes | dict | ServerSentEventMessage | Any, sep: str) -> bytes:
67+
def ensure_bytes(self, data: str | int | bytes | dict | ServerSentEventMessage, sep: str) -> bytes:
6768
if isinstance(data, ServerSentEventMessage):
6869
return data.encode()
6970
if isinstance(data, dict):
@@ -100,7 +101,7 @@ class ServerSentEventMessage:
100101
sep: str = DEFAULT_SEPARATOR
101102

102103
def encode(self) -> bytes:
103-
buffer = io.StringIO()
104+
buffer = StringIO()
104105
if self.comment is not None:
105106
for chunk in _LINE_BREAK_RE.split(str(self.comment)):
106107
buffer.write(f": {chunk}")

0 commit comments

Comments
 (0)