Skip to content

Commit 54a84ca

Browse files
committed
Fix response buffer
1 parent ae1970d commit 54a84ca

File tree

1 file changed

+96
-51
lines changed

1 file changed

+96
-51
lines changed

src/v1/chat_completion/chat_completion_stream.rs

Lines changed: 96 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -113,71 +113,116 @@ pub struct ChatCompletionStream<S: Stream<Item = Result<bytes::Bytes, reqwest::E
113113
pub first_chunk: bool,
114114
}
115115

116-
impl<S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin> Stream
117-
for ChatCompletionStream<S>
116+
impl<S> ChatCompletionStream<S>
117+
where
118+
S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin,
118119
{
119-
type Item = ChatCompletionStreamResponse;
120-
121-
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
122-
loop {
123-
match Pin::new(&mut self.as_mut().response).poll_next(cx) {
124-
Poll::Ready(Some(Ok(chunk))) => {
125-
let mut utf8_str = String::from_utf8_lossy(&chunk).to_string();
120+
fn find_event_delimiter(buffer: &str) -> Option<(usize, usize)> {
121+
let carriage_idx = buffer.find("\r\n\r\n");
122+
let newline_idx = buffer.find("\n\n");
123+
124+
match (carriage_idx, newline_idx) {
125+
(Some(r_idx), Some(n_idx)) => {
126+
if r_idx <= n_idx {
127+
Some((r_idx, 4))
128+
} else {
129+
Some((n_idx, 2))
130+
}
131+
}
132+
(Some(r_idx), None) => Some((r_idx, 4)),
133+
(None, Some(n_idx)) => Some((n_idx, 2)),
134+
(None, None) => None,
135+
}
136+
}
126137

127-
if self.first_chunk {
128-
let lines: Vec<&str> = utf8_str.lines().collect();
129-
utf8_str = if lines.len() >= 2 {
130-
lines[lines.len() - 2].to_string()
131-
} else {
132-
utf8_str.clone()
133-
};
134-
self.first_chunk = false;
138+
fn next_response_from_buffer(&mut self) -> Option<ChatCompletionStreamResponse> {
139+
while let Some((idx, delimiter_len)) = Self::find_event_delimiter(&self.buffer) {
140+
let event = self.buffer[..idx].to_owned();
141+
self.buffer = self.buffer[idx + delimiter_len..].to_owned();
142+
143+
let mut data_payload = String::new();
144+
for line in event.lines() {
145+
let trimmed_line = line.trim_end_matches('\r');
146+
if let Some(content) = trimmed_line
147+
.strip_prefix("data: ")
148+
.or_else(|| trimmed_line.strip_prefix("data:"))
149+
{
150+
if !content.is_empty() {
151+
if !data_payload.is_empty() {
152+
data_payload.push('\n');
153+
}
154+
data_payload.push_str(content);
135155
}
156+
}
157+
}
136158

137-
let trimmed_str = utf8_str.trim_start_matches("data: ");
138-
if trimmed_str.contains("[DONE]") {
139-
return Poll::Ready(Some(ChatCompletionStreamResponse::Done));
140-
}
159+
if data_payload.is_empty() {
160+
continue;
161+
}
162+
163+
if data_payload == "[DONE]" {
164+
return Some(ChatCompletionStreamResponse::Done);
165+
}
141166

142-
self.buffer.push_str(trimmed_str);
143-
let json_result: Result<Value, _> = serde_json::from_str(&self.buffer);
144-
145-
if let Ok(json) = json_result {
146-
self.buffer.clear();
147-
148-
if let Some(choices) = json.get("choices") {
149-
if let Some(choice) = choices.get(0) {
150-
if let Some(delta) = choice.get("delta") {
151-
if let Some(tool_calls) = delta.get("tool_calls") {
152-
if let Some(tool_calls_array) = tool_calls.as_array() {
153-
let tool_calls_vec: Vec<ToolCall> = tool_calls_array
154-
.iter()
155-
.filter_map(|v| {
156-
serde_json::from_value(v.clone()).ok()
157-
})
158-
.collect();
159-
160-
return Poll::Ready(Some(
161-
ChatCompletionStreamResponse::ToolCall(
162-
tool_calls_vec,
163-
),
167+
match serde_json::from_str::<Value>(&data_payload) {
168+
Ok(json) => {
169+
if let Some(choices) = json.get("choices") {
170+
if let Some(choice) = choices.get(0) {
171+
if let Some(delta) = choice.get("delta") {
172+
if let Some(tool_calls) = delta.get("tool_calls") {
173+
if let Some(tool_calls_array) = tool_calls.as_array() {
174+
let tool_calls_vec: Vec<ToolCall> = tool_calls_array
175+
.iter()
176+
.filter_map(|v| serde_json::from_value(v.clone()).ok())
177+
.collect();
178+
179+
if !tool_calls_vec.is_empty() {
180+
return Some(ChatCompletionStreamResponse::ToolCall(
181+
tool_calls_vec,
164182
));
165183
}
166184
}
185+
}
167186

168-
if let Some(content) =
169-
delta.get("content").and_then(|c| c.as_str())
170-
{
171-
let output = content.replace("\\n", "\n");
172-
return Poll::Ready(Some(
173-
ChatCompletionStreamResponse::Content(output),
174-
));
175-
}
187+
if let Some(content) = delta.get("content").and_then(|c| c.as_str())
188+
{
189+
let output = content.replace("\\n", "\n");
190+
return Some(ChatCompletionStreamResponse::Content(output));
176191
}
177192
}
178193
}
179194
}
180195
}
196+
Err(error) => {
197+
eprintln!("Failed to parse SSE chunk as JSON: {}", error);
198+
}
199+
}
200+
}
201+
202+
None
203+
}
204+
}
205+
206+
impl<S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin> Stream
207+
for ChatCompletionStream<S>
208+
{
209+
type Item = ChatCompletionStreamResponse;
210+
211+
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
212+
loop {
213+
if let Some(response) = self.next_response_from_buffer() {
214+
return Poll::Ready(Some(response));
215+
}
216+
217+
match Pin::new(&mut self.as_mut().response).poll_next(cx) {
218+
Poll::Ready(Some(Ok(chunk))) => {
219+
let chunk_str = String::from_utf8_lossy(&chunk).to_string();
220+
221+
if self.first_chunk {
222+
self.first_chunk = false;
223+
}
224+
self.buffer.push_str(&chunk_str);
225+
}
181226
Poll::Ready(Some(Err(error))) => {
182227
eprintln!("Error in stream: {:?}", error);
183228
return Poll::Ready(None);

0 commit comments

Comments
 (0)