11from __future__ import annotations
22
3- import io
43import re
5- from collections .abc import AsyncGenerator , AsyncIterable , AsyncIterator , Iterable , Iterator
4+ from collections .abc import AsyncGenerator , AsyncIterable , AsyncIterator , Callable , Iterable , Iterator
65from dataclasses import dataclass
7- from typing import TYPE_CHECKING , Any
6+ from io import StringIO
7+ from typing import TYPE_CHECKING
88
99from litestar .concurrency import sync_to_thread
1010from litestar .exceptions import ImproperlyConfiguredException
@@ -26,7 +26,7 @@ class _ServerSentEventIterator(AsyncIteratorWrapper[bytes]):
2626
2727 def __init__ (
2828 self ,
29- content : str | bytes | StreamType [SSEData ],
29+ content : str | bytes | StreamType [SSEData ] | Callable [[], str | bytes | StreamType [ SSEData ]] ,
3030 event_type : str | None = None ,
3131 event_id : int | str | None = None ,
3232 retry_duration : int | None = None ,
@@ -37,33 +37,34 @@ def __init__(
3737 self .event_type = event_type
3838 self .retry_duration = retry_duration
3939 chunks : list [bytes ] = []
40+
4041 if comment_message is not None :
41- chunks .extend ([ f": { chunk } \r \n " .encode () for chunk in _LINE_BREAK_RE .split (comment_message )] )
42+ chunks .extend (f": { chunk } { DEFAULT_SEPARATOR } " .encode () for chunk in _LINE_BREAK_RE .split (comment_message ))
4243
4344 if event_id is not None :
44- chunks .append (f"id: { event_id } \r \n " .encode ())
45+ chunks .append (f"id: { event_id } { DEFAULT_SEPARATOR } " .encode ())
4546
4647 if event_type is not None :
47- chunks .append (f"event: { event_type } \r \n " .encode ())
48+ chunks .append (f"event: { event_type } { DEFAULT_SEPARATOR } " .encode ())
4849
4950 if retry_duration is not None :
50- chunks .append (f"retry: { retry_duration } \r \n " .encode ())
51+ chunks .append (f"retry: { retry_duration } { DEFAULT_SEPARATOR } " .encode ())
5152
5253 super ().__init__ (iterator = chunks )
5354
5455 if not isinstance (content , (Iterator , AsyncIterator , AsyncIteratorWrapper )) and callable (content ):
55- content = content () # type: ignore[unreachable]
56+ content = content ()
5657
5758 if isinstance (content , (str , bytes )):
5859 self .content_async_iterator = AsyncIteratorWrapper ([content ])
59- elif isinstance (content , ( Iterable , Iterator ) ):
60+ elif isinstance (content , Iterable ):
6061 self .content_async_iterator = AsyncIteratorWrapper (content )
61- elif isinstance (content , (AsyncIterable , AsyncIterator , AsyncIteratorWrapper )):
62+ elif isinstance (content , (AsyncIterable , AsyncIteratorWrapper )):
6263 self .content_async_iterator = content
6364 else :
6465 raise ImproperlyConfiguredException (f"Invalid type { type (content )} for ServerSentEvent" )
6566
66- def ensure_bytes (self , data : str | int | bytes | dict | ServerSentEventMessage | Any , sep : str ) -> bytes :
67+ def ensure_bytes (self , data : str | int | bytes | dict | ServerSentEventMessage , sep : str ) -> bytes :
6768 if isinstance (data , ServerSentEventMessage ):
6869 return data .encode ()
6970 if isinstance (data , dict ):
@@ -100,7 +101,7 @@ class ServerSentEventMessage:
100101 sep : str = DEFAULT_SEPARATOR
101102
102103 def encode (self ) -> bytes :
103- buffer = io . StringIO ()
104+ buffer = StringIO ()
104105 if self .comment is not None :
105106 for chunk in _LINE_BREAK_RE .split (str (self .comment )):
106107 buffer .write (f": { chunk } " )
0 commit comments