Skip to content

Commit

Permalink
Make stream include usage as optional (#788)
Browse files Browse the repository at this point in the history
* Make stream include usage as optional

---------

Signed-off-by: Varun Gupta <[email protected]>
  • Loading branch information
varungup90 authored Mar 4, 2025
1 parent 31e3cf9 commit 3ad08a1
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 105 deletions.
25 changes: 23 additions & 2 deletions development/vllm/config/deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,29 @@ spec:
image: aibrix/vllm-cpu-env:macos
ports:
- containerPort: 8000
command: ["/bin/sh", "-c"]
args: ["vllm serve facebook/opt-125m --served-model-name facebook-opt-125m --chat-template /etc/chat-template-config/chat-template.j2 --trust-remote-code --device cpu --disable_async_output_proc --enforce-eager --dtype float16"]
command:
- python3
- -m
- vllm.entrypoints.openai.api_server
- --host
- "0.0.0.0"
- --port
- "8000"
- --uvicorn-log-level
- warning
- --model
- facebook/opt-125m
- --served-model-name
- facebook-opt-125m
- --chat-template
- /etc/chat-template-config/chat-template.j2
- --trust-remote-code
- --device
- cpu
- --disable_async_output_proc
- --enforce-eager
- --dtype
- float16
env:
- name: DEPLOYMENT_NAME
valueFrom:
Expand Down
107 changes: 5 additions & 102 deletions pkg/plugins/gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"fmt"
"io"
"net/http"
"slices"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -243,7 +242,7 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, requestID string, req
}
}

routingStrategy, routingStrategyEnabled := GetRoutingStrategy(h.RequestHeaders.Headers.Headers)
routingStrategy, routingStrategyEnabled := getRoutingStrategy(h.RequestHeaders.Headers.Headers)
if routingStrategyEnabled && !validateRoutingStrategy(routingStrategy) {
klog.ErrorS(nil, "incorrect routing strategy", "routing-strategy", routingStrategy)
return generateErrorResponse(
Expand Down Expand Up @@ -338,22 +337,9 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestID string, req *e
}

stream, ok = jsonMap["stream"].(bool)
if stream && ok {
streamOptions, ok := jsonMap["stream_options"].(map[string]interface{})
if !ok {
klog.ErrorS(nil, "no stream option available", "requestID", requestID, "jsonMap", jsonMap)
return generateErrorResponse(envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: HeaderErrorNoStreamOptions, RawValue: []byte("stream options not set")}}},
"no stream option available"), model, targetPodIP, stream, term
}
includeUsage, ok := streamOptions["include_usage"].(bool)
if !includeUsage || !ok {
klog.ErrorS(nil, "no stream with usage option available", "requestID", requestID, "jsonMap", jsonMap)
return generateErrorResponse(envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: HeaderErrorStreamOptionsIncludeUsage, RawValue: []byte("include usage for stream options not set")}}},
"no stream with usage option available"), model, targetPodIP, stream, term
if ok && stream {
if errRes := validateStreamOptions(requestID, user, jsonMap); errRes != nil {
return errRes, model, targetPodIP, stream, term
}
}

Expand Down Expand Up @@ -466,7 +452,7 @@ func (s *Server) HandleResponseHeaders(ctx context.Context, requestID string, re

func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest, user utils.User, rpm int64, model string, targetPodIP string, stream bool, traceTerm int64, hasCompleted bool) (*extProcPb.ProcessingResponse, bool) {
b := req.Request.(*extProcPb.ProcessingRequest_ResponseBody)
klog.InfoS("-- In ResponseBody processing ...", "requestID", requestID, "endOfSteam", b.ResponseBody.EndOfStream)
klog.InfoS("-- In ResponseBody processing ...", "requestID", requestID, "endOfStream", b.ResponseBody.EndOfStream)

var res openai.ChatCompletion
var usage openai.CompletionUsage
Expand Down Expand Up @@ -717,86 +703,3 @@ func (s *Server) selectTargetPod(ctx context.Context, routingStrategy string, po

return route.Route(ctx, pods, model, message)
}

func validateRoutingStrategy(routingStrategy string) bool {
routingStrategy = strings.TrimSpace(routingStrategy)
return slices.Contains(routingStrategies, routingStrategy)
}

func generateErrorResponse(statusCode envoyTypePb.StatusCode, headers []*configPb.HeaderValueOption, body string) *extProcPb.ProcessingResponse {
// Set the Content-Type header to application/json
headers = append(headers, &configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: "Content-Type",
Value: "application/json",
},
})

return &extProcPb.ProcessingResponse{
Response: &extProcPb.ProcessingResponse_ImmediateResponse{
ImmediateResponse: &extProcPb.ImmediateResponse{
Status: &envoyTypePb.HttpStatus{
Code: statusCode,
},
Headers: &extProcPb.HeaderMutation{
SetHeaders: headers,
},
Body: generateErrorMessage(body, int(statusCode)),
},
},
}
}

func getRequestMessage(jsonMap map[string]interface{}) (string, *extProcPb.ProcessingResponse) {
messages, ok := jsonMap["messages"]
if !ok {
return "", generateErrorResponse(envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{Key: HeaderErrorRequestBodyProcessing, RawValue: []byte("true")}}},
"no messages in the request body")
}
messagesJSON, err := json.Marshal(messages)
if err != nil {
return "", generateErrorResponse(envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{Key: HeaderErrorRequestBodyProcessing, RawValue: []byte("true")}}},
"unable to marshal messages from request body")
}
return string(messagesJSON), nil
}

// GetRoutingStrategy retrieves the routing strategy from the headers or environment variable
// It returns the routing strategy value and whether custom routing strategy is enabled.
func GetRoutingStrategy(headers []*configPb.HeaderValue) (string, bool) {
var routingStrategy string
routingStrategyEnabled := false

// Check headers for routing strategy
for _, header := range headers {
if strings.ToLower(header.Key) == HeaderRoutingStrategy {
routingStrategy = string(header.RawValue)
routingStrategyEnabled = true
break // Prioritize header value over environment variable
}
}

// If header not set, check environment variable
if !routingStrategyEnabled {
if value, exists := utils.CheckEnvExists(EnvRoutingAlgorithm); exists {
routingStrategy = value
routingStrategyEnabled = true
}
}

return routingStrategy, routingStrategyEnabled
}

// generateErrorMessage constructs a JSON error message using fmt.Sprintf
func generateErrorMessage(message string, code int) string {
errorStruct := map[string]interface{}{
"error": map[string]interface{}{
"message": message,
"code": code,
},
}
jsonData, _ := json.Marshal(errorStruct)
return string(jsonData)
}
2 changes: 1 addition & 1 deletion pkg/plugins/gateway/gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func TestGetRoutingStrategy(t *testing.T) {
_ = os.Unsetenv("ROUTING_ALGORITHM")
}

routingStrategy, enabled := GetRoutingStrategy(tt.headers)
routingStrategy, enabled := getRoutingStrategy(tt.headers)
assert.Equal(t, tt.expectedStrategy, routingStrategy, tt.message)
assert.Equal(t, tt.expectedEnabled, enabled, tt.message)

Expand Down
138 changes: 138 additions & 0 deletions pkg/plugins/gateway/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/*
Copyright 2024 The Aibrix Team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package gateway

import (
"encoding/json"
"slices"
"strings"

configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3"
"github.com/vllm-project/aibrix/pkg/utils"
"k8s.io/klog/v2"
)

// validateStreamOptions validates whether stream options to include usage is set for user request
func validateStreamOptions(requestID string, user utils.User, jsonMap map[string]interface{}) *extProcPb.ProcessingResponse {
if user.Tpm != 0 {
streamOptions, ok := jsonMap["stream_options"].(map[string]interface{})
if !ok {
klog.ErrorS(nil, "no stream option available", "requestID", requestID, "jsonMap", jsonMap)
return generateErrorResponse(envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: HeaderErrorNoStreamOptions, RawValue: []byte("stream options not set")}}},
"no stream option available")
}
includeUsage, ok := streamOptions["include_usage"].(bool)
if !includeUsage || !ok {
klog.ErrorS(nil, "no stream with usage option available", "requestID", requestID, "jsonMap", jsonMap)
return generateErrorResponse(envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: HeaderErrorStreamOptionsIncludeUsage, RawValue: []byte("include usage for stream options not set")}}},
"no stream with usage option available")
}
}
return nil
}

// getRoutingStrategy retrieves the routing strategy from the headers or environment variable
// It returns the routing strategy value and whether custom routing strategy is enabled.
func getRoutingStrategy(headers []*configPb.HeaderValue) (string, bool) {
var routingStrategy string
routingStrategyEnabled := false

// Check headers for routing strategy
for _, header := range headers {
if strings.ToLower(header.Key) == HeaderRoutingStrategy {
routingStrategy = string(header.RawValue)
routingStrategyEnabled = true
break // Prioritize header value over environment variable
}
}

// If header not set, check environment variable
if !routingStrategyEnabled {
if value, exists := utils.CheckEnvExists(EnvRoutingAlgorithm); exists {
routingStrategy = value
routingStrategyEnabled = true
}
}

return routingStrategy, routingStrategyEnabled
}

// validateRoutingStrategy validates if user provided routing strategy is supported by gateway
func validateRoutingStrategy(routingStrategy string) bool {
routingStrategy = strings.TrimSpace(routingStrategy)
return slices.Contains(routingStrategies, routingStrategy)
}

// getRequestMessage returns input request message field which has user prompt
func getRequestMessage(jsonMap map[string]interface{}) (string, *extProcPb.ProcessingResponse) {
messages, ok := jsonMap["messages"]
if !ok {
return "", generateErrorResponse(envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{Key: HeaderErrorRequestBodyProcessing, RawValue: []byte("true")}}},
"no messages in the request body")
}
messagesJSON, err := json.Marshal(messages)
if err != nil {
return "", generateErrorResponse(envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{Key: HeaderErrorRequestBodyProcessing, RawValue: []byte("true")}}},
"unable to marshal messages from request body")
}
return string(messagesJSON), nil
}

// generateErrorResponse construct envoy proxy error response
func generateErrorResponse(statusCode envoyTypePb.StatusCode, headers []*configPb.HeaderValueOption, body string) *extProcPb.ProcessingResponse {
// Set the Content-Type header to application/json
headers = append(headers, &configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: "Content-Type",
Value: "application/json",
},
})

return &extProcPb.ProcessingResponse{
Response: &extProcPb.ProcessingResponse_ImmediateResponse{
ImmediateResponse: &extProcPb.ImmediateResponse{
Status: &envoyTypePb.HttpStatus{
Code: statusCode,
},
Headers: &extProcPb.HeaderMutation{
SetHeaders: headers,
},
Body: generateErrorMessage(body, int(statusCode)),
},
},
}
}

// generateErrorMessage constructs a JSON error message
func generateErrorMessage(message string, code int) string {
errorStruct := map[string]interface{}{
"error": map[string]interface{}{
"message": message,
"code": code,
},
}
jsonData, _ := json.Marshal(errorStruct)
return string(jsonData)
}

0 comments on commit 3ad08a1

Please sign in to comment.