Skip to content

Commit

Permalink
Use response buffer to address stream request issue (#679)
Browse files Browse the repository at this point in the history
* Increase the connection bufferLimit to avoid end_of_stream=false case

ResponseBody.EndOfStream sometimes is false even we use buffered mode. Current implementation does not take this into the consideration.

Signed-off-by: Jiaxin Shan <[email protected]>

* Use a response buffer to resolve the end_of_stream issue

Signed-off-by: Jiaxin Shan <[email protected]>

---------

Signed-off-by: Jiaxin Shan <[email protected]>
Signed-off-by: Varun Gupta <[email protected]>
  • Loading branch information
Jeffwan authored and varungup90 committed Feb 20, 2025
1 parent a06f9b1 commit 9aacdd6
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 5 deletions.
2 changes: 1 addition & 1 deletion config/gateway/gateway.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ spec:
kind: Gateway
name: aibrix-eg
connection:
bufferLimit: 262144
bufferLimit: 1048576
---
apiVersion: gateway.envoyproxy.io/v1alpha1
kind: EnvoyExtensionPolicy
Expand Down
32 changes: 29 additions & 3 deletions pkg/plugins/gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"net/http"
"slices"
"strings"
"sync"
"time"

"github.com/google/uuid"
Expand Down Expand Up @@ -101,6 +102,8 @@ var (
routingStrategies = []string{"random", "least-request", "throughput", "prefix-cache", "least-kv-cache", "least-busy-time", "least-latency"}

ErrorUnknownResponse = errors.New("unknown response")

requestBuffers sync.Map // Thread-safe map to track buffers per request
)

// routerConstructors maps router names to their initialization functions.
Expand Down Expand Up @@ -433,8 +436,8 @@ func (s *Server) HandleResponseHeaders(ctx context.Context, requestID string, re
}

func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest, user utils.User, rpm int64, model string, targetPodIP string, stream bool, traceTerm int64, hasCompleted bool) (*extProcPb.ProcessingResponse, bool) {
klog.InfoS("-- In ResponseBody processing ...", "requestID", requestID)
b := req.Request.(*extProcPb.ProcessingRequest_ResponseBody)
klog.InfoS("-- In ResponseBody processing ...", "requestID", requestID, "endOfSteam", b.ResponseBody.EndOfStream)

var res openai.ChatCompletion
var usage openai.CompletionUsage
Expand All @@ -444,7 +447,7 @@ func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req *

defer func() {
// Wrapped in a function to delay the evaluation of parameters. Using complete to make sure DoneRequestTrace only call once for a request.
if !hasCompleted && complete {
if !hasCompleted && complete && b.ResponseBody.EndOfStream {
s.cache.DoneRequestTrace(requestID, model, promptTokens, completionTokens, traceTerm)
}
}()
Expand Down Expand Up @@ -472,7 +475,30 @@ func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req *
err.Error()), complete
}
} else {
if err := json.Unmarshal(b.ResponseBody.Body, &res); err != nil {
// Use request ID as a key to store per-request buffer
// Retrieve or create buffer
buf, _ := requestBuffers.LoadOrStore(requestID, &bytes.Buffer{})
buffer := buf.(*bytes.Buffer)
// Append data to per-request buffer
buffer.Write(b.ResponseBody.Body)

if !b.ResponseBody.EndOfStream {
// Partial data received, wait for more chunks, we just return a common response here.
return &extProcPb.ProcessingResponse{
Response: &extProcPb.ProcessingResponse_ResponseBody{
ResponseBody: &extProcPb.BodyResponse{
Response: &extProcPb.CommonResponse{},
},
},
}, complete
}

// Last part received, process the full response
finalBody := buffer.Bytes()
// Clean up the buffer after final processing
requestBuffers.Delete(requestID)

if err := json.Unmarshal(finalBody, &res); err != nil {
klog.ErrorS(err, "error to unmarshal response", "requestID", requestID, "responseBody", string(b.ResponseBody.GetBody()))
complete = true
return generateErrorResponse(
Expand Down
3 changes: 2 additions & 1 deletion test/e2e/routing_strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestPrefixCacheModelInference(t *testing.T) {
Expand Down Expand Up @@ -63,7 +64,7 @@ func getTargetPodFromChatCompletion(t *testing.T, message string) string {
}),
Model: openai.F(modelName),
})
assert.NoError(t, err, "chat completitions failed")
require.NoError(t, err, "chat completitions failed %v", err)
assert.Equal(t, modelName, chatCompletion.Model)

return dst.Header.Get("target-pod")
Expand Down

0 comments on commit 9aacdd6

Please sign in to comment.