Skip to content

Commit f0300ec

Browse files
authored
Merge pull request #187 from dongri/fix-stream-for-openai
Fix stream for OpenAI
2 parents ae1970d + a6821db commit f0300ec

File tree

1 file changed

+95
-51
lines changed

1 file changed

+95
-51
lines changed

src/v1/chat_completion/chat_completion_stream.rs

Lines changed: 95 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -113,70 +113,114 @@ pub struct ChatCompletionStream<S: Stream<Item = Result<bytes::Bytes, reqwest::E
113113
pub first_chunk: bool,
114114
}
115115

116+
impl<S> ChatCompletionStream<S>
117+
where
118+
S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin,
119+
{
120+
fn find_event_delimiter(buffer: &str) -> Option<(usize, usize)> {
121+
let carriage_idx = buffer.find("\r\n\r\n");
122+
let newline_idx = buffer.find("\n\n");
123+
124+
match (carriage_idx, newline_idx) {
125+
(Some(r_idx), Some(n_idx)) => {
126+
if r_idx <= n_idx {
127+
Some((r_idx, 4))
128+
} else {
129+
Some((n_idx, 2))
130+
}
131+
}
132+
(Some(r_idx), None) => Some((r_idx, 4)),
133+
(None, Some(n_idx)) => Some((n_idx, 2)),
134+
(None, None) => None,
135+
}
136+
}
137+
138+
fn next_response_from_buffer(&mut self) -> Option<ChatCompletionStreamResponse> {
139+
while let Some((idx, delimiter_len)) = Self::find_event_delimiter(&self.buffer) {
140+
let event = self.buffer[..idx].to_owned();
141+
self.buffer = self.buffer[idx + delimiter_len..].to_owned();
142+
143+
let mut data_payload = String::new();
144+
for line in event.lines() {
145+
let trimmed_line = line.trim_end_matches('\r');
146+
if let Some(content) = trimmed_line
147+
.strip_prefix("data: ")
148+
.or_else(|| trimmed_line.strip_prefix("data:"))
149+
{
150+
if !content.is_empty() {
151+
if !data_payload.is_empty() {
152+
data_payload.push('\n');
153+
}
154+
data_payload.push_str(content);
155+
}
156+
}
157+
}
158+
159+
if data_payload.is_empty() {
160+
continue;
161+
}
162+
163+
if data_payload == "[DONE]" {
164+
return Some(ChatCompletionStreamResponse::Done);
165+
}
166+
167+
match serde_json::from_str::<Value>(&data_payload) {
168+
Ok(json) => {
169+
if let Some(delta) = json
170+
.get("choices")
171+
.and_then(|choices| choices.get(0))
172+
.and_then(|choice| choice.get("delta"))
173+
{
174+
if let Some(tool_call_response) = delta
175+
.get("tool_calls")
176+
.and_then(|tool_calls| tool_calls.as_array())
177+
.map(|tool_calls_array| {
178+
tool_calls_array
179+
.iter()
180+
.filter_map(|v| serde_json::from_value(v.clone()).ok())
181+
.collect::<Vec<ToolCall>>()
182+
})
183+
.filter(|tool_calls_vec| !tool_calls_vec.is_empty())
184+
.map(ChatCompletionStreamResponse::ToolCall)
185+
{
186+
return Some(tool_call_response);
187+
}
188+
189+
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
190+
let output = content.replace("\\n", "\n");
191+
return Some(ChatCompletionStreamResponse::Content(output));
192+
}
193+
}
194+
}
195+
Err(error) => {
196+
eprintln!("Failed to parse SSE chunk as JSON: {}", error);
197+
}
198+
}
199+
}
200+
201+
None
202+
}
203+
}
204+
116205
impl<S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin> Stream
117206
for ChatCompletionStream<S>
118207
{
119208
type Item = ChatCompletionStreamResponse;
120209

121210
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
122211
loop {
212+
if let Some(response) = self.next_response_from_buffer() {
213+
return Poll::Ready(Some(response));
214+
}
215+
123216
match Pin::new(&mut self.as_mut().response).poll_next(cx) {
124217
Poll::Ready(Some(Ok(chunk))) => {
125-
let mut utf8_str = String::from_utf8_lossy(&chunk).to_string();
218+
let chunk_str = String::from_utf8_lossy(&chunk).to_string();
126219

127220
if self.first_chunk {
128-
let lines: Vec<&str> = utf8_str.lines().collect();
129-
utf8_str = if lines.len() >= 2 {
130-
lines[lines.len() - 2].to_string()
131-
} else {
132-
utf8_str.clone()
133-
};
134221
self.first_chunk = false;
135222
}
136-
137-
let trimmed_str = utf8_str.trim_start_matches("data: ");
138-
if trimmed_str.contains("[DONE]") {
139-
return Poll::Ready(Some(ChatCompletionStreamResponse::Done));
140-
}
141-
142-
self.buffer.push_str(trimmed_str);
143-
let json_result: Result<Value, _> = serde_json::from_str(&self.buffer);
144-
145-
if let Ok(json) = json_result {
146-
self.buffer.clear();
147-
148-
if let Some(choices) = json.get("choices") {
149-
if let Some(choice) = choices.get(0) {
150-
if let Some(delta) = choice.get("delta") {
151-
if let Some(tool_calls) = delta.get("tool_calls") {
152-
if let Some(tool_calls_array) = tool_calls.as_array() {
153-
let tool_calls_vec: Vec<ToolCall> = tool_calls_array
154-
.iter()
155-
.filter_map(|v| {
156-
serde_json::from_value(v.clone()).ok()
157-
})
158-
.collect();
159-
160-
return Poll::Ready(Some(
161-
ChatCompletionStreamResponse::ToolCall(
162-
tool_calls_vec,
163-
),
164-
));
165-
}
166-
}
167-
168-
if let Some(content) =
169-
delta.get("content").and_then(|c| c.as_str())
170-
{
171-
let output = content.replace("\\n", "\n");
172-
return Poll::Ready(Some(
173-
ChatCompletionStreamResponse::Content(output),
174-
));
175-
}
176-
}
177-
}
178-
}
179-
}
223+
self.buffer.push_str(&chunk_str);
180224
}
181225
Poll::Ready(Some(Err(error))) => {
182226
eprintln!("Error in stream: {:?}", error);

0 commit comments

Comments
 (0)