From 3ad08a12a72bcb6f40b8c7443da93ec62ccc2ecc Mon Sep 17 00:00:00 2001 From: Varun Gupta Date: Mon, 3 Mar 2025 18:16:31 -0800 Subject: [PATCH] Make stream include usage as optional (#788) * Make stream include usage as optional --------- Signed-off-by: Varun Gupta --- development/vllm/config/deployment.yaml | 25 ++++- pkg/plugins/gateway/gateway.go | 107 +----------------- pkg/plugins/gateway/gateway_test.go | 2 +- pkg/plugins/gateway/util.go | 138 ++++++++++++++++++++++++ 4 files changed, 167 insertions(+), 105 deletions(-) create mode 100644 pkg/plugins/gateway/util.go diff --git a/development/vllm/config/deployment.yaml b/development/vllm/config/deployment.yaml index f1ed0dbb..511db812 100644 --- a/development/vllm/config/deployment.yaml +++ b/development/vllm/config/deployment.yaml @@ -27,8 +27,29 @@ spec: image: aibrix/vllm-cpu-env:macos ports: - containerPort: 8000 - command: ["/bin/sh", "-c"] - args: ["vllm serve facebook/opt-125m --served-model-name facebook-opt-125m --chat-template /etc/chat-template-config/chat-template.j2 --trust-remote-code --device cpu --disable_async_output_proc --enforce-eager --dtype float16"] + command: + - python3 + - -m + - vllm.entrypoints.openai.api_server + - --host + - "0.0.0.0" + - --port + - "8000" + - --uvicorn-log-level + - warning + - --model + - facebook/opt-125m + - --served-model-name + - facebook-opt-125m + - --chat-template + - /etc/chat-template-config/chat-template.j2 + - --trust-remote-code + - --device + - cpu + - --disable_async_output_proc + - --enforce-eager + - --dtype + - float16 env: - name: DEPLOYMENT_NAME valueFrom: diff --git a/pkg/plugins/gateway/gateway.go b/pkg/plugins/gateway/gateway.go index e61f1e88..f18c6a29 100644 --- a/pkg/plugins/gateway/gateway.go +++ b/pkg/plugins/gateway/gateway.go @@ -24,7 +24,6 @@ import ( "fmt" "io" "net/http" - "slices" "strconv" "strings" "sync" @@ -243,7 +242,7 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, requestID string, req } } - routingStrategy, routingStrategyEnabled := GetRoutingStrategy(h.RequestHeaders.Headers.Headers) + routingStrategy, routingStrategyEnabled := getRoutingStrategy(h.RequestHeaders.Headers.Headers) if routingStrategyEnabled && !validateRoutingStrategy(routingStrategy) { klog.ErrorS(nil, "incorrect routing strategy", "routing-strategy", routingStrategy) return generateErrorResponse( @@ -338,22 +337,9 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestID string, req *e } stream, ok = jsonMap["stream"].(bool) - if stream && ok { - streamOptions, ok := jsonMap["stream_options"].(map[string]interface{}) - if !ok { - klog.ErrorS(nil, "no stream option available", "requestID", requestID, "jsonMap", jsonMap) - return generateErrorResponse(envoyTypePb.StatusCode_InternalServerError, - []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ - Key: HeaderErrorNoStreamOptions, RawValue: []byte("stream options not set")}}}, - "no stream option available"), model, targetPodIP, stream, term - } - includeUsage, ok := streamOptions["include_usage"].(bool) - if !includeUsage || !ok { - klog.ErrorS(nil, "no stream with usage option available", "requestID", requestID, "jsonMap", jsonMap) - return generateErrorResponse(envoyTypePb.StatusCode_InternalServerError, - []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ - Key: HeaderErrorStreamOptionsIncludeUsage, RawValue: []byte("include usage for stream options not set")}}}, - "no stream with usage option available"), model, targetPodIP, stream, term + if ok && stream { + if errRes := validateStreamOptions(requestID, user, jsonMap); errRes != nil { + return errRes, model, targetPodIP, stream, term } } @@ -466,7 +452,7 @@ func (s *Server) HandleResponseHeaders(ctx context.Context, requestID string, re func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest, user utils.User, rpm int64, model string, targetPodIP string, stream bool, traceTerm int64, hasCompleted bool) (*extProcPb.ProcessingResponse, bool) { b := req.Request.(*extProcPb.ProcessingRequest_ResponseBody) - klog.InfoS("-- In ResponseBody processing ...", "requestID", requestID, "endOfSteam", b.ResponseBody.EndOfStream) + klog.InfoS("-- In ResponseBody processing ...", "requestID", requestID, "endOfStream", b.ResponseBody.EndOfStream) var res openai.ChatCompletion var usage openai.CompletionUsage @@ -717,86 +703,3 @@ func (s *Server) selectTargetPod(ctx context.Context, routingStrategy string, po return route.Route(ctx, pods, model, message) } - -func validateRoutingStrategy(routingStrategy string) bool { - routingStrategy = strings.TrimSpace(routingStrategy) - return slices.Contains(routingStrategies, routingStrategy) -} - -func generateErrorResponse(statusCode envoyTypePb.StatusCode, headers []*configPb.HeaderValueOption, body string) *extProcPb.ProcessingResponse { - // Set the Content-Type header to application/json - headers = append(headers, &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: "Content-Type", - Value: "application/json", - }, - }) - - return &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_ImmediateResponse{ - ImmediateResponse: &extProcPb.ImmediateResponse{ - Status: &envoyTypePb.HttpStatus{ - Code: statusCode, - }, - Headers: &extProcPb.HeaderMutation{ - SetHeaders: headers, - }, - Body: generateErrorMessage(body, int(statusCode)), - }, - }, - } -} - -func getRequestMessage(jsonMap map[string]interface{}) (string, *extProcPb.ProcessingResponse) { - messages, ok := jsonMap["messages"] - if !ok { - return "", generateErrorResponse(envoyTypePb.StatusCode_InternalServerError, - []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{Key: HeaderErrorRequestBodyProcessing, RawValue: []byte("true")}}}, - "no messages in the request body") - } - messagesJSON, err := json.Marshal(messages) - if err != nil { - return "", generateErrorResponse(envoyTypePb.StatusCode_InternalServerError, - []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{Key: HeaderErrorRequestBodyProcessing, RawValue: []byte("true")}}}, - "unable to marshal messages from request body") - } - return string(messagesJSON), nil -} - -// GetRoutingStrategy retrieves the routing strategy from the headers or environment variable -// It returns the routing strategy value and whether custom routing strategy is enabled. -func GetRoutingStrategy(headers []*configPb.HeaderValue) (string, bool) { - var routingStrategy string - routingStrategyEnabled := false - - // Check headers for routing strategy - for _, header := range headers { - if strings.ToLower(header.Key) == HeaderRoutingStrategy { - routingStrategy = string(header.RawValue) - routingStrategyEnabled = true - break // Prioritize header value over environment variable - } - } - - // If header not set, check environment variable - if !routingStrategyEnabled { - if value, exists := utils.CheckEnvExists(EnvRoutingAlgorithm); exists { - routingStrategy = value - routingStrategyEnabled = true - } - } - - return routingStrategy, routingStrategyEnabled -} - -// generateErrorMessage constructs a JSON error message using fmt.Sprintf -func generateErrorMessage(message string, code int) string { - errorStruct := map[string]interface{}{ - "error": map[string]interface{}{ - "message": message, - "code": code, - }, - } - jsonData, _ := json.Marshal(errorStruct) - return string(jsonData) -} diff --git a/pkg/plugins/gateway/gateway_test.go b/pkg/plugins/gateway/gateway_test.go index 50675a91..a556100e 100644 --- a/pkg/plugins/gateway/gateway_test.go +++ b/pkg/plugins/gateway/gateway_test.go @@ -115,7 +115,7 @@ func TestGetRoutingStrategy(t *testing.T) { _ = os.Unsetenv("ROUTING_ALGORITHM") } - routingStrategy, enabled := GetRoutingStrategy(tt.headers) + routingStrategy, enabled := getRoutingStrategy(tt.headers) assert.Equal(t, tt.expectedStrategy, routingStrategy, tt.message) assert.Equal(t, tt.expectedEnabled, enabled, tt.message) diff --git a/pkg/plugins/gateway/util.go b/pkg/plugins/gateway/util.go new file mode 100644 index 00000000..ec768ffc --- /dev/null +++ b/pkg/plugins/gateway/util.go @@ -0,0 +1,138 @@ +/* +Copyright 2024 The Aibrix Team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package gateway + +import ( + "encoding/json" + "slices" + "strings" + + configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3" + "github.com/vllm-project/aibrix/pkg/utils" + "k8s.io/klog/v2" +) + +// validateStreamOptions validates whether stream options to include usage is set for user request +func validateStreamOptions(requestID string, user utils.User, jsonMap map[string]interface{}) *extProcPb.ProcessingResponse { + if user.Tpm != 0 { + streamOptions, ok := jsonMap["stream_options"].(map[string]interface{}) + if !ok { + klog.ErrorS(nil, "no stream option available", "requestID", requestID, "jsonMap", jsonMap) + return generateErrorResponse(envoyTypePb.StatusCode_InternalServerError, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: HeaderErrorNoStreamOptions, RawValue: []byte("stream options not set")}}}, + "no stream option available") + } + includeUsage, ok := streamOptions["include_usage"].(bool) + if !includeUsage || !ok { + klog.ErrorS(nil, "no stream with usage option available", "requestID", requestID, "jsonMap", jsonMap) + return generateErrorResponse(envoyTypePb.StatusCode_InternalServerError, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: HeaderErrorStreamOptionsIncludeUsage, RawValue: []byte("include usage for stream options not set")}}}, + "no stream with usage option available") + } + } + return nil +} + +// getRoutingStrategy retrieves the routing strategy from the headers or environment variable +// It returns the routing strategy value and whether custom routing strategy is enabled. +func getRoutingStrategy(headers []*configPb.HeaderValue) (string, bool) { + var routingStrategy string + routingStrategyEnabled := false + + // Check headers for routing strategy + for _, header := range headers { + if strings.ToLower(header.Key) == HeaderRoutingStrategy { + routingStrategy = string(header.RawValue) + routingStrategyEnabled = true + break // Prioritize header value over environment variable + } + } + + // If header not set, check environment variable + if !routingStrategyEnabled { + if value, exists := utils.CheckEnvExists(EnvRoutingAlgorithm); exists { + routingStrategy = value + routingStrategyEnabled = true + } + } + + return routingStrategy, routingStrategyEnabled +} + +// validateRoutingStrategy validates if user provided routing strategy is supported by gateway +func validateRoutingStrategy(routingStrategy string) bool { + routingStrategy = strings.TrimSpace(routingStrategy) + return slices.Contains(routingStrategies, routingStrategy) +} + +// getRequestMessage returns input request message field which has user prompt +func getRequestMessage(jsonMap map[string]interface{}) (string, *extProcPb.ProcessingResponse) { + messages, ok := jsonMap["messages"] + if !ok { + return "", generateErrorResponse(envoyTypePb.StatusCode_InternalServerError, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{Key: HeaderErrorRequestBodyProcessing, RawValue: []byte("true")}}}, + "no messages in the request body") + } + messagesJSON, err := json.Marshal(messages) + if err != nil { + return "", generateErrorResponse(envoyTypePb.StatusCode_InternalServerError, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{Key: HeaderErrorRequestBodyProcessing, RawValue: []byte("true")}}}, + "unable to marshal messages from request body") + } + return string(messagesJSON), nil +} + +// generateErrorResponse construct envoy proxy error response +func generateErrorResponse(statusCode envoyTypePb.StatusCode, headers []*configPb.HeaderValueOption, body string) *extProcPb.ProcessingResponse { + // Set the Content-Type header to application/json + headers = append(headers, &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: "Content-Type", + Value: "application/json", + }, + }) + + return &extProcPb.ProcessingResponse{ + Response: &extProcPb.ProcessingResponse_ImmediateResponse{ + ImmediateResponse: &extProcPb.ImmediateResponse{ + Status: &envoyTypePb.HttpStatus{ + Code: statusCode, + }, + Headers: &extProcPb.HeaderMutation{ + SetHeaders: headers, + }, + Body: generateErrorMessage(body, int(statusCode)), + }, + }, + } +} + +// generateErrorMessage constructs a JSON error message +func generateErrorMessage(message string, code int) string { + errorStruct := map[string]interface{}{ + "error": map[string]interface{}{ + "message": message, + "code": code, + }, + } + jsonData, _ := json.Marshal(errorStruct) + return string(jsonData) +}