@@ -113,70 +113,114 @@ pub struct ChatCompletionStream<S: Stream<Item = Result<bytes::Bytes, reqwest::E
113113 pub first_chunk : bool ,
114114}
115115
116+ impl < S > ChatCompletionStream < S >
117+ where
118+ S : Stream < Item = Result < bytes:: Bytes , reqwest:: Error > > + Unpin ,
119+ {
120+ fn find_event_delimiter ( buffer : & str ) -> Option < ( usize , usize ) > {
121+ let carriage_idx = buffer. find ( "\r \n \r \n " ) ;
122+ let newline_idx = buffer. find ( "\n \n " ) ;
123+
124+ match ( carriage_idx, newline_idx) {
125+ ( Some ( r_idx) , Some ( n_idx) ) => {
126+ if r_idx <= n_idx {
127+ Some ( ( r_idx, 4 ) )
128+ } else {
129+ Some ( ( n_idx, 2 ) )
130+ }
131+ }
132+ ( Some ( r_idx) , None ) => Some ( ( r_idx, 4 ) ) ,
133+ ( None , Some ( n_idx) ) => Some ( ( n_idx, 2 ) ) ,
134+ ( None , None ) => None ,
135+ }
136+ }
137+
138+ fn next_response_from_buffer ( & mut self ) -> Option < ChatCompletionStreamResponse > {
139+ while let Some ( ( idx, delimiter_len) ) = Self :: find_event_delimiter ( & self . buffer ) {
140+ let event = self . buffer [ ..idx] . to_owned ( ) ;
141+ self . buffer = self . buffer [ idx + delimiter_len..] . to_owned ( ) ;
142+
143+ let mut data_payload = String :: new ( ) ;
144+ for line in event. lines ( ) {
145+ let trimmed_line = line. trim_end_matches ( '\r' ) ;
146+ if let Some ( content) = trimmed_line
147+ . strip_prefix ( "data: " )
148+ . or_else ( || trimmed_line. strip_prefix ( "data:" ) )
149+ {
150+ if !content. is_empty ( ) {
151+ if !data_payload. is_empty ( ) {
152+ data_payload. push ( '\n' ) ;
153+ }
154+ data_payload. push_str ( content) ;
155+ }
156+ }
157+ }
158+
159+ if data_payload. is_empty ( ) {
160+ continue ;
161+ }
162+
163+ if data_payload == "[DONE]" {
164+ return Some ( ChatCompletionStreamResponse :: Done ) ;
165+ }
166+
167+ match serde_json:: from_str :: < Value > ( & data_payload) {
168+ Ok ( json) => {
169+ if let Some ( delta) = json
170+ . get ( "choices" )
171+ . and_then ( |choices| choices. get ( 0 ) )
172+ . and_then ( |choice| choice. get ( "delta" ) )
173+ {
174+ if let Some ( tool_call_response) = delta
175+ . get ( "tool_calls" )
176+ . and_then ( |tool_calls| tool_calls. as_array ( ) )
177+ . map ( |tool_calls_array| {
178+ tool_calls_array
179+ . iter ( )
180+ . filter_map ( |v| serde_json:: from_value ( v. clone ( ) ) . ok ( ) )
181+ . collect :: < Vec < ToolCall > > ( )
182+ } )
183+ . filter ( |tool_calls_vec| !tool_calls_vec. is_empty ( ) )
184+ . map ( ChatCompletionStreamResponse :: ToolCall )
185+ {
186+ return Some ( tool_call_response) ;
187+ }
188+
189+ if let Some ( content) = delta. get ( "content" ) . and_then ( |c| c. as_str ( ) ) {
190+ let output = content. replace ( "\\ n" , "\n " ) ;
191+ return Some ( ChatCompletionStreamResponse :: Content ( output) ) ;
192+ }
193+ }
194+ }
195+ Err ( error) => {
196+ eprintln ! ( "Failed to parse SSE chunk as JSON: {}" , error) ;
197+ }
198+ }
199+ }
200+
201+ None
202+ }
203+ }
204+
116205impl < S : Stream < Item = Result < bytes:: Bytes , reqwest:: Error > > + Unpin > Stream
117206 for ChatCompletionStream < S >
118207{
119208 type Item = ChatCompletionStreamResponse ;
120209
121210 fn poll_next ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
122211 loop {
212+ if let Some ( response) = self . next_response_from_buffer ( ) {
213+ return Poll :: Ready ( Some ( response) ) ;
214+ }
215+
123216 match Pin :: new ( & mut self . as_mut ( ) . response ) . poll_next ( cx) {
124217 Poll :: Ready ( Some ( Ok ( chunk) ) ) => {
125- let mut utf8_str = String :: from_utf8_lossy ( & chunk) . to_string ( ) ;
218+ let chunk_str = String :: from_utf8_lossy ( & chunk) . to_string ( ) ;
126219
127220 if self . first_chunk {
128- let lines: Vec < & str > = utf8_str. lines ( ) . collect ( ) ;
129- utf8_str = if lines. len ( ) >= 2 {
130- lines[ lines. len ( ) - 2 ] . to_string ( )
131- } else {
132- utf8_str. clone ( )
133- } ;
134221 self . first_chunk = false ;
135222 }
136-
137- let trimmed_str = utf8_str. trim_start_matches ( "data: " ) ;
138- if trimmed_str. contains ( "[DONE]" ) {
139- return Poll :: Ready ( Some ( ChatCompletionStreamResponse :: Done ) ) ;
140- }
141-
142- self . buffer . push_str ( trimmed_str) ;
143- let json_result: Result < Value , _ > = serde_json:: from_str ( & self . buffer ) ;
144-
145- if let Ok ( json) = json_result {
146- self . buffer . clear ( ) ;
147-
148- if let Some ( choices) = json. get ( "choices" ) {
149- if let Some ( choice) = choices. get ( 0 ) {
150- if let Some ( delta) = choice. get ( "delta" ) {
151- if let Some ( tool_calls) = delta. get ( "tool_calls" ) {
152- if let Some ( tool_calls_array) = tool_calls. as_array ( ) {
153- let tool_calls_vec: Vec < ToolCall > = tool_calls_array
154- . iter ( )
155- . filter_map ( |v| {
156- serde_json:: from_value ( v. clone ( ) ) . ok ( )
157- } )
158- . collect ( ) ;
159-
160- return Poll :: Ready ( Some (
161- ChatCompletionStreamResponse :: ToolCall (
162- tool_calls_vec,
163- ) ,
164- ) ) ;
165- }
166- }
167-
168- if let Some ( content) =
169- delta. get ( "content" ) . and_then ( |c| c. as_str ( ) )
170- {
171- let output = content. replace ( "\\ n" , "\n " ) ;
172- return Poll :: Ready ( Some (
173- ChatCompletionStreamResponse :: Content ( output) ,
174- ) ) ;
175- }
176- }
177- }
178- }
179- }
223+ self . buffer . push_str ( & chunk_str) ;
180224 }
181225 Poll :: Ready ( Some ( Err ( error) ) ) => {
182226 eprintln ! ( "Error in stream: {:?}" , error) ;
0 commit comments