@@ -113,71 +113,116 @@ pub struct ChatCompletionStream<S: Stream<Item = Result<bytes::Bytes, reqwest::E
113113 pub first_chunk : bool ,
114114}
115115
116- impl < S : Stream < Item = Result < bytes:: Bytes , reqwest:: Error > > + Unpin > Stream
117- for ChatCompletionStream < S >
116+ impl < S > ChatCompletionStream < S >
117+ where
118+ S : Stream < Item = Result < bytes:: Bytes , reqwest:: Error > > + Unpin ,
118119{
119- type Item = ChatCompletionStreamResponse ;
120-
121- fn poll_next ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
122- loop {
123- match Pin :: new ( & mut self . as_mut ( ) . response ) . poll_next ( cx) {
124- Poll :: Ready ( Some ( Ok ( chunk) ) ) => {
125- let mut utf8_str = String :: from_utf8_lossy ( & chunk) . to_string ( ) ;
120+ fn find_event_delimiter ( buffer : & str ) -> Option < ( usize , usize ) > {
121+ let carriage_idx = buffer. find ( "\r \n \r \n " ) ;
122+ let newline_idx = buffer. find ( "\n \n " ) ;
123+
124+ match ( carriage_idx, newline_idx) {
125+ ( Some ( r_idx) , Some ( n_idx) ) => {
126+ if r_idx <= n_idx {
127+ Some ( ( r_idx, 4 ) )
128+ } else {
129+ Some ( ( n_idx, 2 ) )
130+ }
131+ }
132+ ( Some ( r_idx) , None ) => Some ( ( r_idx, 4 ) ) ,
133+ ( None , Some ( n_idx) ) => Some ( ( n_idx, 2 ) ) ,
134+ ( None , None ) => None ,
135+ }
136+ }
126137
127- if self . first_chunk {
128- let lines: Vec < & str > = utf8_str. lines ( ) . collect ( ) ;
129- utf8_str = if lines. len ( ) >= 2 {
130- lines[ lines. len ( ) - 2 ] . to_string ( )
131- } else {
132- utf8_str. clone ( )
133- } ;
134- self . first_chunk = false ;
138+ fn next_response_from_buffer ( & mut self ) -> Option < ChatCompletionStreamResponse > {
139+ while let Some ( ( idx, delimiter_len) ) = Self :: find_event_delimiter ( & self . buffer ) {
140+ let event = self . buffer [ ..idx] . to_owned ( ) ;
141+ self . buffer = self . buffer [ idx + delimiter_len..] . to_owned ( ) ;
142+
143+ let mut data_payload = String :: new ( ) ;
144+ for line in event. lines ( ) {
145+ let trimmed_line = line. trim_end_matches ( '\r' ) ;
146+ if let Some ( content) = trimmed_line
147+ . strip_prefix ( "data: " )
148+ . or_else ( || trimmed_line. strip_prefix ( "data:" ) )
149+ {
150+ if !content. is_empty ( ) {
151+ if !data_payload. is_empty ( ) {
152+ data_payload. push ( '\n' ) ;
153+ }
154+ data_payload. push_str ( content) ;
135155 }
156+ }
157+ }
136158
137- let trimmed_str = utf8_str. trim_start_matches ( "data: " ) ;
138- if trimmed_str. contains ( "[DONE]" ) {
139- return Poll :: Ready ( Some ( ChatCompletionStreamResponse :: Done ) ) ;
140- }
159+ if data_payload. is_empty ( ) {
160+ continue ;
161+ }
162+
163+ if data_payload == "[DONE]" {
164+ return Some ( ChatCompletionStreamResponse :: Done ) ;
165+ }
141166
142- self . buffer . push_str ( trimmed_str) ;
143- let json_result: Result < Value , _ > = serde_json:: from_str ( & self . buffer ) ;
144-
145- if let Ok ( json) = json_result {
146- self . buffer . clear ( ) ;
147-
148- if let Some ( choices) = json. get ( "choices" ) {
149- if let Some ( choice) = choices. get ( 0 ) {
150- if let Some ( delta) = choice. get ( "delta" ) {
151- if let Some ( tool_calls) = delta. get ( "tool_calls" ) {
152- if let Some ( tool_calls_array) = tool_calls. as_array ( ) {
153- let tool_calls_vec: Vec < ToolCall > = tool_calls_array
154- . iter ( )
155- . filter_map ( |v| {
156- serde_json:: from_value ( v. clone ( ) ) . ok ( )
157- } )
158- . collect ( ) ;
159-
160- return Poll :: Ready ( Some (
161- ChatCompletionStreamResponse :: ToolCall (
162- tool_calls_vec,
163- ) ,
167+ match serde_json:: from_str :: < Value > ( & data_payload) {
168+ Ok ( json) => {
169+ if let Some ( choices) = json. get ( "choices" ) {
170+ if let Some ( choice) = choices. get ( 0 ) {
171+ if let Some ( delta) = choice. get ( "delta" ) {
172+ if let Some ( tool_calls) = delta. get ( "tool_calls" ) {
173+ if let Some ( tool_calls_array) = tool_calls. as_array ( ) {
174+ let tool_calls_vec: Vec < ToolCall > = tool_calls_array
175+ . iter ( )
176+ . filter_map ( |v| serde_json:: from_value ( v. clone ( ) ) . ok ( ) )
177+ . collect ( ) ;
178+
179+ if !tool_calls_vec. is_empty ( ) {
180+ return Some ( ChatCompletionStreamResponse :: ToolCall (
181+ tool_calls_vec,
164182 ) ) ;
165183 }
166184 }
185+ }
167186
168- if let Some ( content) =
169- delta. get ( "content" ) . and_then ( |c| c. as_str ( ) )
170- {
171- let output = content. replace ( "\\ n" , "\n " ) ;
172- return Poll :: Ready ( Some (
173- ChatCompletionStreamResponse :: Content ( output) ,
174- ) ) ;
175- }
187+ if let Some ( content) = delta. get ( "content" ) . and_then ( |c| c. as_str ( ) )
188+ {
189+ let output = content. replace ( "\\ n" , "\n " ) ;
190+ return Some ( ChatCompletionStreamResponse :: Content ( output) ) ;
176191 }
177192 }
178193 }
179194 }
180195 }
196+ Err ( error) => {
197+ eprintln ! ( "Failed to parse SSE chunk as JSON: {}" , error) ;
198+ }
199+ }
200+ }
201+
202+ None
203+ }
204+ }
205+
206+ impl < S : Stream < Item = Result < bytes:: Bytes , reqwest:: Error > > + Unpin > Stream
207+ for ChatCompletionStream < S >
208+ {
209+ type Item = ChatCompletionStreamResponse ;
210+
211+ fn poll_next ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
212+ loop {
213+ if let Some ( response) = self . next_response_from_buffer ( ) {
214+ return Poll :: Ready ( Some ( response) ) ;
215+ }
216+
217+ match Pin :: new ( & mut self . as_mut ( ) . response ) . poll_next ( cx) {
218+ Poll :: Ready ( Some ( Ok ( chunk) ) ) => {
219+ let chunk_str = String :: from_utf8_lossy ( & chunk) . to_string ( ) ;
220+
221+ if self . first_chunk {
222+ self . first_chunk = false ;
223+ }
224+ self . buffer . push_str ( & chunk_str) ;
225+ }
181226 Poll :: Ready ( Some ( Err ( error) ) ) => {
182227 eprintln ! ( "Error in stream: {:?}" , error) ;
183228 return Poll :: Ready ( None ) ;
0 commit comments