@@ -388,6 +388,120 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) {
388388 }
389389}
390390
391+ func TestCreateChatCompletionStreamStreamOptions (t * testing.T ) {
392+ client , server , teardown := setupOpenAITestServer ()
393+ defer teardown ()
394+
395+ server .RegisterHandler ("/v1/chat/completions" , func (w http.ResponseWriter , _ * http.Request ) {
396+ w .Header ().Set ("Content-Type" , "text/event-stream" )
397+
398+ // Send test responses
399+ var dataBytes []byte
400+ //nolint:lll
401+ data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}],"usage":null}`
402+ dataBytes = append (dataBytes , []byte ("data: " + data + "\n \n " )... )
403+
404+ //nolint:lll
405+ data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}],"usage":null}`
406+ dataBytes = append (dataBytes , []byte ("data: " + data + "\n \n " )... )
407+
408+ //nolint:lll
409+ data = `{"id":"3","object":"completion","created":1598069256,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`
410+ dataBytes = append (dataBytes , []byte ("data: " + data + "\n \n " )... )
411+
412+ dataBytes = append (dataBytes , []byte ("data: [DONE]\n \n " )... )
413+
414+ _ , err := w .Write (dataBytes )
415+ checks .NoError (t , err , "Write error" )
416+ })
417+
418+ stream , err := client .CreateChatCompletionStream (context .Background (), openai.ChatCompletionRequest {
419+ MaxTokens : 5 ,
420+ Model : openai .GPT3Dot5Turbo ,
421+ Messages : []openai.ChatCompletionMessage {
422+ {
423+ Role : openai .ChatMessageRoleUser ,
424+ Content : "Hello!" ,
425+ },
426+ },
427+ Stream : true ,
428+ StreamOptions : & openai.StreamOptions {
429+ IncludeUsage : true ,
430+ },
431+ })
432+ checks .NoError (t , err , "CreateCompletionStream returned error" )
433+ defer stream .Close ()
434+
435+ expectedResponses := []openai.ChatCompletionStreamResponse {
436+ {
437+ ID : "1" ,
438+ Object : "completion" ,
439+ Created : 1598069254 ,
440+ Model : openai .GPT3Dot5Turbo ,
441+ SystemFingerprint : "fp_d9767fc5b9" ,
442+ Choices : []openai.ChatCompletionStreamChoice {
443+ {
444+ Delta : openai.ChatCompletionStreamChoiceDelta {
445+ Content : "response1" ,
446+ },
447+ FinishReason : "max_tokens" ,
448+ },
449+ },
450+ },
451+ {
452+ ID : "2" ,
453+ Object : "completion" ,
454+ Created : 1598069255 ,
455+ Model : openai .GPT3Dot5Turbo ,
456+ SystemFingerprint : "fp_d9767fc5b9" ,
457+ Choices : []openai.ChatCompletionStreamChoice {
458+ {
459+ Delta : openai.ChatCompletionStreamChoiceDelta {
460+ Content : "response2" ,
461+ },
462+ FinishReason : "max_tokens" ,
463+ },
464+ },
465+ },
466+ {
467+ ID : "3" ,
468+ Object : "completion" ,
469+ Created : 1598069256 ,
470+ Model : openai .GPT3Dot5Turbo ,
471+ SystemFingerprint : "fp_d9767fc5b9" ,
472+ Choices : []openai.ChatCompletionStreamChoice {},
473+ Usage : & openai.Usage {
474+ PromptTokens : 1 ,
475+ CompletionTokens : 1 ,
476+ TotalTokens : 2 ,
477+ },
478+ },
479+ }
480+
481+ for ix , expectedResponse := range expectedResponses {
482+ b , _ := json .Marshal (expectedResponse )
483+ t .Logf ("%d: %s" , ix , string (b ))
484+
485+ receivedResponse , streamErr := stream .Recv ()
486+ checks .NoError (t , streamErr , "stream.Recv() failed" )
487+ if ! compareChatResponses (expectedResponse , receivedResponse ) {
488+ t .Errorf ("Stream response %v is %v, expected %v" , ix , receivedResponse , expectedResponse )
489+ }
490+ }
491+
492+ _ , streamErr := stream .Recv ()
493+ if ! errors .Is (streamErr , io .EOF ) {
494+ t .Errorf ("stream.Recv() did not return EOF in the end: %v" , streamErr )
495+ }
496+
497+ _ , streamErr = stream .Recv ()
498+
499+ checks .ErrorIs (t , streamErr , io .EOF , "stream.Recv() did not return EOF when the stream is finished" )
500+ if ! errors .Is (streamErr , io .EOF ) {
501+ t .Errorf ("stream.Recv() did not return EOF when the stream is finished: %v" , streamErr )
502+ }
503+ }
504+
391505// Helper funcs.
392506func compareChatResponses (r1 , r2 openai.ChatCompletionStreamResponse ) bool {
393507 if r1 .ID != r2 .ID || r1 .Object != r2 .Object || r1 .Created != r2 .Created || r1 .Model != r2 .Model {
@@ -401,6 +515,15 @@ func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool {
401515 return false
402516 }
403517 }
518+ if r1 .Usage != nil || r2 .Usage != nil {
519+ if r1 .Usage == nil || r2 .Usage == nil {
520+ return false
521+ }
522+ if r1 .Usage .PromptTokens != r2 .Usage .PromptTokens || r1 .Usage .CompletionTokens != r2 .Usage .CompletionTokens ||
523+ r1 .Usage .TotalTokens != r2 .Usage .TotalTokens {
524+ return false
525+ }
526+ }
404527 return true
405528}
406529
0 commit comments