Skip to content

Commit

Permalink
Use a response buffer to resolve the end_of_stream issue
Browse files Browse the repository at this point in the history
Signed-off-by: Jiaxin Shan <[email protected]>
  • Loading branch information
Jeffwan committed Feb 14, 2025
1 parent 97dfb62 commit 6ddf774
Showing 1 changed file with 29 additions and 3 deletions.
32 changes: 29 additions & 3 deletions pkg/plugins/gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"net/http"
"slices"
"strings"
"sync"
"time"

"github.com/google/uuid"
Expand Down Expand Up @@ -101,6 +102,8 @@ var (
routingStrategies = []string{"random", "least-request", "throughput", "prefix-cache", "least-kv-cache", "least-busy-time", "least-latency"}

ErrorUnknownResponse = errors.New("unknown response")

requestBuffers sync.Map // Thread-safe map to track buffers per request
)

// routerConstructors maps router names to their initialization functions.
Expand Down Expand Up @@ -433,8 +436,8 @@ func (s *Server) HandleResponseHeaders(ctx context.Context, requestID string, re
}

func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest, user utils.User, rpm int64, model string, targetPodIP string, stream bool, traceTerm int64, hasCompleted bool) (*extProcPb.ProcessingResponse, bool) {
klog.InfoS("-- In ResponseBody processing ...", "requestID", requestID)
b := req.Request.(*extProcPb.ProcessingRequest_ResponseBody)
klog.InfoS("-- In ResponseBody processing ...", "requestID", requestID, "endOfSteam", b.ResponseBody.EndOfStream)

var res openai.ChatCompletion
var usage openai.CompletionUsage
Expand All @@ -444,7 +447,7 @@ func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req *

defer func() {
// Wrapped in a function to delay the evaluation of parameters. Using complete to make sure DoneRequestTrace only call once for a request.
if !hasCompleted && complete {
if !hasCompleted && complete && b.ResponseBody.EndOfStream {
s.cache.DoneRequestTrace(requestID, model, promptTokens, completionTokens, traceTerm)
}
}()
Expand Down Expand Up @@ -472,7 +475,30 @@ func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req *
err.Error()), complete
}
} else {
if err := json.Unmarshal(b.ResponseBody.Body, &res); err != nil {
// Use request ID as a key to store per-request buffer
// Retrieve or create buffer
buf, _ := requestBuffers.LoadOrStore(requestID, &bytes.Buffer{})
buffer := buf.(*bytes.Buffer)
// Append data to per-request buffer
buffer.Write(b.ResponseBody.Body)

if !b.ResponseBody.EndOfStream {
// Partial data received, wait for more chunks, we just return a common response here.
return &extProcPb.ProcessingResponse{
Response: &extProcPb.ProcessingResponse_ResponseBody{
ResponseBody: &extProcPb.BodyResponse{
Response: &extProcPb.CommonResponse{},
},
},
}, complete
}

// Last part received, process the full response
finalBody := buffer.Bytes()
// Clean up the buffer after final processing
requestBuffers.Delete(requestID)

if err := json.Unmarshal(finalBody, &res); err != nil {
klog.ErrorS(err, "error to unmarshal response", "requestID", requestID, "responseBody", string(b.ResponseBody.GetBody()))
complete = true
return generateErrorResponse(
Expand Down

0 comments on commit 6ddf774

Please sign in to comment.