Skip to content

Commit 6523f63

Browse files
added back http_sse.py
1 parent 54deb06 commit 6523f63

File tree

1 file changed

+218
-0
lines changed

1 file changed

+218
-0
lines changed

src/cohere/core/http_sse.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
"""
2+
Forked and fixed version of httpx_sse to handle multi-line SSE data fields properly.
3+
4+
This module consolidates all httpx_sse functionality into a single file.
5+
"""
6+
import json
7+
from collections.abc import AsyncGenerator
8+
from contextlib import asynccontextmanager, contextmanager
9+
from typing import Any, AsyncIterator, Iterator, List, Optional, cast
10+
11+
import httpx
12+
13+
14+
class SSEError(httpx.TransportError):
15+
"""Exception raised when SSE processing encounters an error."""
16+
pass
17+
18+
19+
class ServerSentEvent:
20+
"""Represents a Server-Sent Event."""
21+
22+
def __init__(
23+
self,
24+
event: Optional[str] = None,
25+
data: Optional[str] = None,
26+
id: Optional[str] = None,
27+
retry: Optional[int] = None,
28+
) -> None:
29+
if not event:
30+
event = "message"
31+
32+
if data is None:
33+
data = ""
34+
35+
if id is None:
36+
id = ""
37+
38+
self._event = event
39+
self._data = data
40+
self._id = id
41+
self._retry = retry
42+
43+
@property
44+
def event(self) -> str:
45+
return self._event
46+
47+
@property
48+
def data(self) -> str:
49+
return self._data
50+
51+
@property
52+
def id(self) -> str:
53+
return self._id
54+
55+
@property
56+
def retry(self) -> Optional[int]:
57+
return self._retry
58+
59+
def json(self) -> Any:
60+
return json.loads(self.data)
61+
62+
def __repr__(self) -> str:
63+
pieces = [f"event={self.event!r}"]
64+
if self.data != "":
65+
pieces.append(f"data={self.data!r}")
66+
if self.id != "":
67+
pieces.append(f"id={self.id!r}")
68+
if self.retry is not None:
69+
pieces.append(f"retry={self.retry!r}")
70+
return f"ServerSentEvent({', '.join(pieces)})"
71+
72+
73+
class SSEDecoder:
74+
"""Decoder for Server-Sent Events according to the HTML5 specification."""
75+
76+
def __init__(self) -> None:
77+
self._event = ""
78+
self._data: List[str] = []
79+
self._last_event_id = ""
80+
self._retry: Optional[int] = None
81+
82+
def decode(self, line: str) -> Optional[ServerSentEvent]:
83+
# See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501
84+
85+
if not line:
86+
if (
87+
not self._event
88+
and not self._data
89+
and not self._last_event_id
90+
and self._retry is None
91+
):
92+
return None
93+
94+
sse = ServerSentEvent(
95+
event=self._event,
96+
data="\n".join(self._data),
97+
id=self._last_event_id,
98+
retry=self._retry,
99+
)
100+
101+
# NOTE: as per the SSE spec, do not reset last_event_id.
102+
self._event = ""
103+
self._data = []
104+
self._retry = None
105+
106+
return sse
107+
108+
if line.startswith(":"):
109+
return None
110+
111+
fieldname, _, value = line.partition(":")
112+
113+
if value.startswith(" "):
114+
value = value[1:]
115+
116+
if fieldname == "event":
117+
self._event = value
118+
elif fieldname == "data":
119+
self._data.append(value)
120+
elif fieldname == "id":
121+
if "\0" in value:
122+
pass
123+
else:
124+
self._last_event_id = value
125+
elif fieldname == "retry":
126+
try:
127+
self._retry = int(value)
128+
except (TypeError, ValueError):
129+
pass
130+
else:
131+
pass # Field is ignored.
132+
133+
return None
134+
135+
136+
class EventSource:
137+
"""EventSource for handling Server-Sent Events from HTTP responses."""
138+
139+
def __init__(self, response: httpx.Response) -> None:
140+
self._response = response
141+
142+
def _check_content_type(self) -> None:
143+
content_type = self._response.headers.get("content-type", "").partition(";")[0]
144+
if "text/event-stream" not in content_type:
145+
raise SSEError(
146+
"Expected response header Content-Type to contain 'text/event-stream', "
147+
f"got {content_type!r}"
148+
)
149+
150+
@property
151+
def response(self) -> httpx.Response:
152+
return self._response
153+
154+
def iter_sse(self) -> Iterator[ServerSentEvent]:
155+
self._check_content_type()
156+
decoder = SSEDecoder()
157+
158+
# Process the raw stream instead of using iter_lines() which may truncate
159+
# Read the entire response as bytes and process it manually
160+
raw_data = b""
161+
for chunk in self._response.iter_bytes():
162+
raw_data += chunk
163+
164+
# Convert to string and split on newlines manually
165+
text_data = raw_data.decode('utf-8', errors='replace')
166+
lines = text_data.split('\n')
167+
168+
for line in lines:
169+
line = line.rstrip("\r")
170+
sse = decoder.decode(line)
171+
if sse is not None:
172+
yield sse
173+
174+
async def aiter_sse(self) -> AsyncGenerator[ServerSentEvent, None]:
175+
self._check_content_type()
176+
decoder = SSEDecoder()
177+
lines = cast(AsyncGenerator[str, None], self._response.aiter_lines())
178+
try:
179+
async for line in lines:
180+
line = line.rstrip("\n")
181+
sse = decoder.decode(line)
182+
if sse is not None:
183+
yield sse
184+
finally:
185+
await lines.aclose()
186+
187+
188+
@contextmanager
189+
def connect_sse(
190+
client: httpx.Client, method: str, url: str, **kwargs: Any
191+
) -> Iterator[EventSource]:
192+
"""Context manager for connecting to Server-Sent Events with a synchronous client."""
193+
headers = kwargs.pop("headers", {})
194+
headers["Accept"] = "text/event-stream"
195+
headers["Cache-Control"] = "no-store"
196+
197+
with client.stream(method, url, headers=headers, **kwargs) as response:
198+
yield EventSource(response)
199+
200+
201+
@asynccontextmanager
202+
async def aconnect_sse(
203+
client: httpx.AsyncClient,
204+
method: str,
205+
url: str,
206+
**kwargs: Any,
207+
) -> AsyncIterator[EventSource]:
208+
"""Async context manager for connecting to Server-Sent Events with an async client."""
209+
headers = kwargs.pop("headers", {})
210+
headers["Accept"] = "text/event-stream"
211+
headers["Cache-Control"] = "no-store"
212+
213+
async with client.stream(method, url, headers=headers, **kwargs) as response:
214+
yield EventSource(response)
215+
216+
217+
# Public API
218+
__all__ = ["EventSource", "connect_sse", "aconnect_sse", "ServerSentEvent", "SSEError"]

0 commit comments

Comments
 (0)