diff --git a/config/gateway/gateway.yaml b/config/gateway/gateway.yaml index aedd8dd0..dae12ceb 100644 --- a/config/gateway/gateway.yaml +++ b/config/gateway/gateway.yaml @@ -19,6 +19,19 @@ spec: port: 80 --- apiVersion: gateway.envoyproxy.io/v1alpha1 +kind: ClientTrafficPolicy +metadata: + name: client-connection-buffersize + namespace: aibrix-system +spec: + targetRefs: + - group: gateway.networking.k8s.io + kind: Gateway + name: aibrix-eg + connection: + bufferLimit: 262144 +--- +apiVersion: gateway.envoyproxy.io/v1alpha1 kind: EnvoyExtensionPolicy metadata: name: gateway-plugins-extension-policy @@ -36,7 +49,8 @@ spec: request: body: Buffered response: - body: Buffered + body: Streamed + messageTimeout: 5s --- apiVersion: gateway.envoyproxy.io/v1alpha1 kind: EnvoyPatchPolicy @@ -66,7 +80,7 @@ spec: regex: .* route: cluster: original_destination_cluster - timeout: 1000s # Increase route timeout + timeout: 120s # Increase route timeout typed_per_filter_config: "envoy.filters.http.ext_proc/envoyextensionpolicy/aibrix-system/aibrix-gateway-plugins-extension-policy/extproc/0": "@type": "type.googleapis.com/envoy.config.route.v3.FilterConfig" diff --git a/go.mod b/go.mod index 3618d36d..dd6e96bb 100644 --- a/go.mod +++ b/go.mod @@ -11,9 +11,9 @@ require ( github.com/gorilla/mux v1.8.1 github.com/onsi/ginkgo/v2 v2.17.2 github.com/onsi/gomega v1.33.1 + github.com/openai/openai-go v0.1.0-alpha.37 github.com/ray-project/kuberay/ray-operator v1.2.1 github.com/redis/go-redis/v9 v9.6.1 - github.com/sashabaranov/go-openai v1.29.0 github.com/stretchr/testify v1.9.0 golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 google.golang.org/grpc v1.65.0 @@ -72,15 +72,19 @@ require ( github.com/prometheus/common v0.55.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/tidwall/gjson v1.14.4 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.27.0 // indirect - golang.org/x/crypto v0.24.0 // indirect + golang.org/x/crypto v0.25.0 // indirect golang.org/x/mod v0.18.0 // indirect - golang.org/x/net v0.26.0 // indirect + golang.org/x/net v0.27.0 // indirect golang.org/x/oauth2 v0.21.0 // indirect golang.org/x/sync v0.7.0 // indirect - golang.org/x/sys v0.21.0 // indirect - golang.org/x/term v0.21.0 // indirect + golang.org/x/sys v0.22.0 // indirect + golang.org/x/term v0.22.0 // indirect golang.org/x/text v0.16.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.22.0 // indirect diff --git a/go.sum b/go.sum index b890acb5..a003f4fa 100644 --- a/go.sum +++ b/go.sum @@ -101,6 +101,10 @@ github.com/onsi/ginkgo/v2 v2.17.2 h1:7eMhcy3GimbsA3hEnVKdw/PQM9XN9krpKVXsZdph0/g github.com/onsi/ginkgo/v2 v2.17.2/go.mod h1:nP2DPOQoNsQmsVyv5rDA8JkXQoCs6goXIvr/PRJ1eCc= github.com/onsi/gomega v1.33.1 h1:dsYjIxxSR755MDmKVsaFQTE22ChNBcuuTWgkUDSubOk= github.com/onsi/gomega v1.33.1/go.mod h1:U4R44UsT+9eLIaYRB2a5qajjtQYn0hauxvRm16AVYg0= +github.com/openai/openai-go v0.1.0-alpha.34 h1:mKz2UYTlGOQvsN3piK1wdYzpJP769aLyrWuEJ5Qi7xc= +github.com/openai/openai-go v0.1.0-alpha.34/go.mod h1:3SdE6BffOX9HPEQv8IL/fi3LYZ5TUpRYaqGQZbyk11A= +github.com/openai/openai-go v0.1.0-alpha.37 h1:dstNWRmODNmcvVrNhJ1tzmD8J9hy+aaycwKAqLZVx2Q= +github.com/openai/openai-go v0.1.0-alpha.37/go.mod h1:3SdE6BffOX9HPEQv8IL/fi3LYZ5TUpRYaqGQZbyk11A= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -119,14 +123,24 @@ github.com/redis/go-redis/v9 v9.6.1 h1:HHDteefn6ZkTtY5fGUE8tj8uy85AHk6zP7CpzIAM0 github.com/redis/go-redis/v9 v9.6.1/go.mod h1:0C0c6ycQsdpVNQpxb1njEQIqkx5UcsM8FJCQLgE9+RA= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= -github.com/sashabaranov/go-openai v1.29.0 h1:eBH6LSjtX4md5ImDCX8hNhHQvaRf22zujiERoQpsvLo= -github.com/sashabaranov/go-openai v1.29.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= +github.com/sashabaranov/go-openai v1.35.6 h1:oi0rwCvyxMxgFALDGnyqFTyCJm6n72OnEG3sybIFR0g= +github.com/sashabaranov/go-openai v1.35.6/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= +github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= @@ -138,8 +152,8 @@ go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= -golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= +golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= +golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 h1:yixxcjnhBmY0nkL253HFVIm0JsFHwrHdT3Yh6szTnfY= golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8/go.mod h1:jj3sYF3dwk5D+ghuXyeI3r5MFf+NT2An6/9dOA95KSI= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= @@ -150,8 +164,8 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= -golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= +golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= +golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs= golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -162,10 +176,10 @@ golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= -golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= -golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= +golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= +golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk= +golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index 959bf44a..04a60698 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -510,7 +510,7 @@ func parseMetricFromBody(body []byte, metricName string) (float64, error) { return 0, fmt.Errorf("metrics %s not found", metricName) } -func (c *Cache) AddRequestTrace(modelName string, inputTokens, outputTokens int) { +func (c *Cache) AddRequestTrace(modelName string, inputTokens, outputTokens int64) { c.mu.Lock() defer c.mu.Unlock() diff --git a/pkg/controller/modelrouter/modelrouter_controller.go b/pkg/controller/modelrouter/modelrouter_controller.go index 45e7e761..c2229691 100644 --- a/pkg/controller/modelrouter/modelrouter_controller.go +++ b/pkg/controller/modelrouter/modelrouter_controller.go @@ -163,6 +163,9 @@ func (m *ModelRouter) createHTTPRoute(namespace string, labels map[string]string }, }, }, + Timeouts: &gatewayv1.HTTPRouteTimeouts{ + Request: ptr.To(gatewayv1.Duration("120s")), + }, }, }, }, diff --git a/pkg/plugins/gateway/gateway.go b/pkg/plugins/gateway/gateway.go index 76827c85..fd06bd48 100644 --- a/pkg/plugins/gateway/gateway.go +++ b/pkg/plugins/gateway/gateway.go @@ -17,17 +17,20 @@ limitations under the License. package gateway import ( + "bytes" "context" "encoding/json" "fmt" "io" + "net/http" "slices" "strings" "time" "github.com/google/uuid" + "github.com/openai/openai-go" + "github.com/openai/openai-go/packages/ssestream" "github.com/redis/go-redis/v9" - openai "github.com/sashabaranov/go-openai" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" v1 "k8s.io/api/core/v1" @@ -95,6 +98,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { var user utils.User var rpm int64 var routingStrategy, targetPodIP string + var stream bool ctx := srv.Context() requestID := uuid.New().String() @@ -120,13 +124,13 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { resp, user, rpm, routingStrategy = s.HandleRequestHeaders(ctx, requestID, req) case *extProcPb.ProcessingRequest_RequestBody: - resp, targetPodIP = s.HandleRequestBody(ctx, requestID, req, user, routingStrategy) + resp, targetPodIP, stream = s.HandleRequestBody(ctx, requestID, req, user, routingStrategy) case *extProcPb.ProcessingRequest_ResponseHeaders: resp = s.HandleResponseHeaders(ctx, requestID, req, targetPodIP) case *extProcPb.ProcessingRequest_ResponseBody: - resp = s.HandleResponseBody(ctx, requestID, req, user, rpm, targetPodIP) + resp = s.HandleResponseBody(ctx, requestID, req, user, rpm, targetPodIP, stream) default: klog.Infof("Unknown Request type %+v\n", v) @@ -140,7 +144,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { func (s *Server) HandleRequestHeaders(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest) (*extProcPb.ProcessingResponse, utils.User, int64, string) { klog.Info("\n\n") - klog.Info("-- In RequestHeaders processing ...") + klog.InfoS("-- In RequestHeaders processing ...", "requestID", requestID) var username string var user utils.User var rpm int64 @@ -203,25 +207,46 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, requestID string, req }, user, rpm, routingStrategy } -func (s *Server) HandleRequestBody(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest, user utils.User, routingStrategy string) (*extProcPb.ProcessingResponse, string) { - klog.Info("--- In RequestBody processing") +func (s *Server) HandleRequestBody(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest, user utils.User, routingStrategy string) (*extProcPb.ProcessingResponse, string, bool) { + klog.InfoS("-- In RequestBody processing ...", "requestID", requestID) var model, targetPodIP string - var ok bool + var ok, stream bool + var jsonMap map[string]interface{} body := req.Request.(*extProcPb.ProcessingRequest_RequestBody) if err := json.Unmarshal(body.RequestBody.GetBody(), &jsonMap); err != nil { + klog.ErrorS(err, "error to unmarshal response", "requestID", requestID, "requestBody", string(body.RequestBody.GetBody())) return generateErrorResponse(envoyTypePb.StatusCode_InternalServerError, []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ Key: "x-request-body-processing-error", RawValue: []byte("true")}}}, - "error processing request body"), targetPodIP + "error processing request body"), targetPodIP, stream } - if model, ok = jsonMap["model"].(string); !ok || model == "" || !s.cache.CheckModelExists(model) { + if model, ok = jsonMap["model"].(string); !ok || model == "" { // || !s.cache.CheckModelExists(model) # enable when dynamic lora is enabled + klog.ErrorS(nil, "model error in request", "requestID", requestID, "jsonMap", jsonMap) return generateErrorResponse(envoyTypePb.StatusCode_InternalServerError, []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ Key: "x-no-model", RawValue: []byte(model)}}}, - fmt.Sprintf("no model in request body or model %s does not exist", model)), targetPodIP + fmt.Sprintf("no model in request body or model %s does not exist", model)), targetPodIP, stream + } + + stream, ok = jsonMap["stream"].(bool) + if stream && ok { + streamOptions, ok := jsonMap["stream_options"].(map[string]interface{}) + if !ok { + return generateErrorResponse(envoyTypePb.StatusCode_InternalServerError, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: "x-stream-options", RawValue: []byte("stream options not set")}}}, + "error processing request body"), targetPodIP, stream + } + includeUsage, ok := streamOptions["include_usage"].(bool) + if !includeUsage || !ok { + return generateErrorResponse(envoyTypePb.StatusCode_InternalServerError, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: "x-stream-options-include-usage", RawValue: []byte("include usage for stream options not set")}}}, + "error processing request body"), targetPodIP, stream + } } headers := []*configPb.HeaderValueOption{} @@ -240,7 +265,7 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestID string, req *e return generateErrorResponse(envoyTypePb.StatusCode_InternalServerError, []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ Key: "x-no-model-deployment", RawValue: []byte("true")}}}, - fmt.Sprintf("error on getting pods for model %s", model)), targetPodIP + fmt.Sprintf("error on getting pods for model %s", model)), targetPodIP, stream } targetPodIP, err = s.selectTargetPod(ctx, routingStrategy, pods) @@ -249,7 +274,7 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestID string, req *e envoyTypePb.StatusCode_InternalServerError, []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ Key: "x-error-routing", RawValue: []byte("true")}}}, - "error on selecting target pod"), targetPodIP + "error on selecting target pod"), targetPodIP, stream } headers = append(headers, &configPb.HeaderValueOption{ @@ -277,11 +302,12 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestID string, req *e }, }, }, - }, targetPodIP + }, targetPodIP, stream } func (s *Server) HandleResponseHeaders(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest, targetPodIP string) *extProcPb.ProcessingResponse { - klog.Info("--- In ResponseHeaders processing") + klog.InfoS("-- In ResponseHeaders processing ...", "requestID", requestID) + headers := []*configPb.HeaderValueOption{{ Header: &configPb.HeaderValue{ Key: "x-went-into-resp-headers", @@ -311,64 +337,97 @@ func (s *Server) HandleResponseHeaders(ctx context.Context, requestID string, re } } -func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest, user utils.User, rpm int64, targetPodIP string) *extProcPb.ProcessingResponse { - klog.Infof("--- In ResponseBody processing") +func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest, user utils.User, rpm int64, targetPodIP string, stream bool) *extProcPb.ProcessingResponse { + klog.InfoS("-- In ResponseBody processing ...", "requestID", requestID) b := req.Request.(*extProcPb.ProcessingRequest_ResponseBody) - var res openai.CompletionResponse - if err := json.Unmarshal(b.ResponseBody.Body, &res); err != nil { - klog.ErrorS(err, "error to unmarshal response", "requestID", requestID) - return generateErrorResponse( - envoyTypePb.StatusCode_InternalServerError, - []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ - Key: "x-error-response-unmarshal", RawValue: []byte("true"), - }}}, - err.Error()) - } - - defer func() { - go func() { - s.cache.AddRequestTrace(res.Model, res.Usage.PromptTokens, res.Usage.CompletionTokens) - }() - }() - + var res openai.ChatCompletion + var model string + var usage openai.CompletionUsage headers := []*configPb.HeaderValueOption{} - if user.Name != "" { - tpm, err := s.ratelimiter.Incr(ctx, fmt.Sprintf("%v_TPM_CURRENT", user), int64(res.Usage.TotalTokens)) - if err != nil { + + switch stream { + case true: + t := &http.Response{ + Body: io.NopCloser(bytes.NewReader(b.ResponseBody.GetBody())), + } + streaming := ssestream.NewStream[openai.ChatCompletionChunk](ssestream.NewDecoder(t), nil) + for streaming.Next() { + evt := streaming.Current() + if len(evt.Choices) == 0 { + model = evt.Model + usage = evt.Usage + } + } + if err := streaming.Err(); err != nil { + klog.ErrorS(err, "error to unmarshal response", "requestID", requestID, "responseBody", string(b.ResponseBody.GetBody())) return generateErrorResponse( envoyTypePb.StatusCode_InternalServerError, []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ - Key: "x-error-update-tpm", RawValue: []byte("true"), + Key: "x-streaming-error", RawValue: []byte("true"), }}}, err.Error()) } + case false: + if err := json.Unmarshal(b.ResponseBody.Body, &res); err != nil { + klog.ErrorS(err, "error to unmarshal response", "requestID", requestID, "responseBody", string(b.ResponseBody.GetBody())) + return generateErrorResponse( + envoyTypePb.StatusCode_InternalServerError, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: "x-error-response-unmarshal", RawValue: []byte("true"), + }}}, + err.Error()) + } + model = res.Model + usage = res.Usage + } + var requestEnd string + if usage.TotalTokens != 0 { + defer func() { + go func() { + s.cache.AddRequestTrace(model, usage.PromptTokens, usage.CompletionTokens) + }() + }() - headers = append(headers, - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: "x-update-rpm", - RawValue: []byte(fmt.Sprintf("%d", rpm)), + if user.Name != "" { + tpm, err := s.ratelimiter.Incr(ctx, fmt.Sprintf("%v_TPM_CURRENT", user), res.Usage.TotalTokens) + if err != nil { + return generateErrorResponse( + envoyTypePb.StatusCode_InternalServerError, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: "x-error-update-tpm", RawValue: []byte("true"), + }}}, + err.Error()) + } + + headers = append(headers, + &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: "x-update-rpm", + RawValue: []byte(fmt.Sprintf("%d", rpm)), + }, }, - }, - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: "x-update-tpm", - RawValue: []byte(fmt.Sprintf("%d", tpm)), + &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: "x-update-tpm", + RawValue: []byte(fmt.Sprintf("%d", tpm)), + }, }, - }, - ) - klog.InfoS("request end", "requestID", requestID, "rpm", rpm, "tpm", tpm) - } - if targetPodIP != "" { - headers = append(headers, - &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: "target-pod", - RawValue: []byte(targetPodIP), + ) + requestEnd = fmt.Sprintf(requestEnd+"rpm: %s, tpm: %s", rpm, tpm) + } + if targetPodIP != "" { + headers = append(headers, + &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: "target-pod", + RawValue: []byte(targetPodIP), + }, }, - }, - ) + ) + requestEnd = fmt.Sprintf(requestEnd+", targetPod: %s", targetPodIP) + } + klog.Infof("request end, requestID: %s - %s", requestID, requestEnd) } return &extProcPb.ProcessingResponse{