Skip to content

Commit

Permalink
Make stream include usage as optional
Browse files Browse the repository at this point in the history
Signed-off-by: Varun Gupta <[email protected]>
  • Loading branch information
varungup90 committed Mar 4, 2025
1 parent 31e3cf9 commit 9223769
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 105 deletions.
25 changes: 23 additions & 2 deletions development/vllm/config/deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,29 @@ spec:
image: aibrix/vllm-cpu-env:macos
ports:
- containerPort: 8000
command: ["/bin/sh", "-c"]
args: ["vllm serve facebook/opt-125m --served-model-name facebook-opt-125m --chat-template /etc/chat-template-config/chat-template.j2 --trust-remote-code --device cpu --disable_async_output_proc --enforce-eager --dtype float16"]
command:
- python3
- -m
- vllm.entrypoints.openai.api_server
- --host
- "0.0.0.0"
- --port
- "8000"
- --uvicorn-log-level
- warning
- --model
- facebook/opt-125m
- --served-model-name
- facebook-opt-125m
- --chat-template
- /etc/chat-template-config/chat-template.j2
- --trust-remote-code
- --device
- cpu
- --disable_async_output_proc
- --enforce-eager
- --dtype
- float16
env:
- name: DEPLOYMENT_NAME
valueFrom:
Expand Down
107 changes: 5 additions & 102 deletions pkg/plugins/gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"fmt"
"io"
"net/http"
"slices"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -243,7 +242,7 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, requestID string, req
}
}

routingStrategy, routingStrategyEnabled := GetRoutingStrategy(h.RequestHeaders.Headers.Headers)
routingStrategy, routingStrategyEnabled := getRoutingStrategy(h.RequestHeaders.Headers.Headers)
if routingStrategyEnabled && !validateRoutingStrategy(routingStrategy) {
klog.ErrorS(nil, "incorrect routing strategy", "routing-strategy", routingStrategy)
return generateErrorResponse(
Expand Down Expand Up @@ -338,22 +337,9 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestID string, req *e
}

stream, ok = jsonMap["stream"].(bool)
if stream && ok {
streamOptions, ok := jsonMap["stream_options"].(map[string]interface{})
if !ok {
klog.ErrorS(nil, "no stream option available", "requestID", requestID, "jsonMap", jsonMap)
return generateErrorResponse(envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: HeaderErrorNoStreamOptions, RawValue: []byte("stream options not set")}}},
"no stream option available"), model, targetPodIP, stream, term
}
includeUsage, ok := streamOptions["include_usage"].(bool)
if !includeUsage || !ok {
klog.ErrorS(nil, "no stream with usage option available", "requestID", requestID, "jsonMap", jsonMap)
return generateErrorResponse(envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: HeaderErrorStreamOptionsIncludeUsage, RawValue: []byte("include usage for stream options not set")}}},
"no stream with usage option available"), model, targetPodIP, stream, term
if ok && stream {
if errRes := validateStreamOptions(requestID, user, jsonMap); errRes != nil {
return errRes, model, targetPodIP, stream, term
}
}

Expand Down Expand Up @@ -466,7 +452,7 @@ func (s *Server) HandleResponseHeaders(ctx context.Context, requestID string, re

func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest, user utils.User, rpm int64, model string, targetPodIP string, stream bool, traceTerm int64, hasCompleted bool) (*extProcPb.ProcessingResponse, bool) {
b := req.Request.(*extProcPb.ProcessingRequest_ResponseBody)
klog.InfoS("-- In ResponseBody processing ...", "requestID", requestID, "endOfSteam", b.ResponseBody.EndOfStream)
klog.InfoS("-- In ResponseBody processing ...", "requestID", requestID, "endOfStream", b.ResponseBody.EndOfStream)

var res openai.ChatCompletion
var usage openai.CompletionUsage
Expand Down Expand Up @@ -717,86 +703,3 @@ func (s *Server) selectTargetPod(ctx context.Context, routingStrategy string, po

return route.Route(ctx, pods, model, message)
}

func validateRoutingStrategy(routingStrategy string) bool {
routingStrategy = strings.TrimSpace(routingStrategy)
return slices.Contains(routingStrategies, routingStrategy)
}

func generateErrorResponse(statusCode envoyTypePb.StatusCode, headers []*configPb.HeaderValueOption, body string) *extProcPb.ProcessingResponse {
// Set the Content-Type header to application/json
headers = append(headers, &configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: "Content-Type",
Value: "application/json",
},
})

return &extProcPb.ProcessingResponse{
Response: &extProcPb.ProcessingResponse_ImmediateResponse{
ImmediateResponse: &extProcPb.ImmediateResponse{
Status: &envoyTypePb.HttpStatus{
Code: statusCode,
},
Headers: &extProcPb.HeaderMutation{
SetHeaders: headers,
},
Body: generateErrorMessage(body, int(statusCode)),
},
},
}
}

func getRequestMessage(jsonMap map[string]interface{}) (string, *extProcPb.ProcessingResponse) {
messages, ok := jsonMap["messages"]
if !ok {
return "", generateErrorResponse(envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{Key: HeaderErrorRequestBodyProcessing, RawValue: []byte("true")}}},
"no messages in the request body")
}
messagesJSON, err := json.Marshal(messages)
if err != nil {
return "", generateErrorResponse(envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{Key: HeaderErrorRequestBodyProcessing, RawValue: []byte("true")}}},
"unable to marshal messages from request body")
}
return string(messagesJSON), nil
}

// GetRoutingStrategy retrieves the routing strategy from the headers or environment variable
// It returns the routing strategy value and whether custom routing strategy is enabled.
func GetRoutingStrategy(headers []*configPb.HeaderValue) (string, bool) {
var routingStrategy string
routingStrategyEnabled := false

// Check headers for routing strategy
for _, header := range headers {
if strings.ToLower(header.Key) == HeaderRoutingStrategy {
routingStrategy = string(header.RawValue)
routingStrategyEnabled = true
break // Prioritize header value over environment variable
}
}

// If header not set, check environment variable
if !routingStrategyEnabled {
if value, exists := utils.CheckEnvExists(EnvRoutingAlgorithm); exists {
routingStrategy = value
routingStrategyEnabled = true
}
}

return routingStrategy, routingStrategyEnabled
}

// generateErrorMessage constructs a JSON error message using fmt.Sprintf
func generateErrorMessage(message string, code int) string {
errorStruct := map[string]interface{}{
"error": map[string]interface{}{
"message": message,
"code": code,
},
}
jsonData, _ := json.Marshal(errorStruct)
return string(jsonData)
}
2 changes: 1 addition & 1 deletion pkg/plugins/gateway/gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func TestGetRoutingStrategy(t *testing.T) {
_ = os.Unsetenv("ROUTING_ALGORITHM")
}

routingStrategy, enabled := GetRoutingStrategy(tt.headers)
routingStrategy, enabled := getRoutingStrategy(tt.headers)
assert.Equal(t, tt.expectedStrategy, routingStrategy, tt.message)
assert.Equal(t, tt.expectedEnabled, enabled, tt.message)

Expand Down
122 changes: 122 additions & 0 deletions pkg/plugins/gateway/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package gateway

import (
"encoding/json"
"slices"
"strings"

configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3"
"github.com/vllm-project/aibrix/pkg/utils"
"k8s.io/klog/v2"
)

// validateStreamOptions validates whether stream options to include usage is set for user request
func validateStreamOptions(requestID string, user utils.User, jsonMap map[string]interface{}) *extProcPb.ProcessingResponse {
if user.Tpm != 0 {
streamOptions, ok := jsonMap["stream_options"].(map[string]interface{})
if !ok {
klog.ErrorS(nil, "no stream option available", "requestID", requestID, "jsonMap", jsonMap)
return generateErrorResponse(envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: HeaderErrorNoStreamOptions, RawValue: []byte("stream options not set")}}},
"no stream option available")
}
includeUsage, ok := streamOptions["include_usage"].(bool)
if !includeUsage || !ok {
klog.ErrorS(nil, "no stream with usage option available", "requestID", requestID, "jsonMap", jsonMap)
return generateErrorResponse(envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: HeaderErrorStreamOptionsIncludeUsage, RawValue: []byte("include usage for stream options not set")}}},
"no stream with usage option available")
}
}
return nil
}

// getRoutingStrategy retrieves the routing strategy from the headers or environment variable
// It returns the routing strategy value and whether custom routing strategy is enabled.
func getRoutingStrategy(headers []*configPb.HeaderValue) (string, bool) {
var routingStrategy string
routingStrategyEnabled := false

// Check headers for routing strategy
for _, header := range headers {
if strings.ToLower(header.Key) == HeaderRoutingStrategy {
routingStrategy = string(header.RawValue)
routingStrategyEnabled = true
break // Prioritize header value over environment variable
}
}

// If header not set, check environment variable
if !routingStrategyEnabled {
if value, exists := utils.CheckEnvExists(EnvRoutingAlgorithm); exists {
routingStrategy = value
routingStrategyEnabled = true
}
}

return routingStrategy, routingStrategyEnabled
}

// validateRoutingStrategy validates if user provided routing strategy is supported by gateway
func validateRoutingStrategy(routingStrategy string) bool {
routingStrategy = strings.TrimSpace(routingStrategy)
return slices.Contains(routingStrategies, routingStrategy)
}

// getRequestMessage returns input request message field which has user prompt
func getRequestMessage(jsonMap map[string]interface{}) (string, *extProcPb.ProcessingResponse) {
messages, ok := jsonMap["messages"]
if !ok {
return "", generateErrorResponse(envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{Key: HeaderErrorRequestBodyProcessing, RawValue: []byte("true")}}},
"no messages in the request body")
}
messagesJSON, err := json.Marshal(messages)
if err != nil {
return "", generateErrorResponse(envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{Key: HeaderErrorRequestBodyProcessing, RawValue: []byte("true")}}},
"unable to marshal messages from request body")
}
return string(messagesJSON), nil
}

// generateErrorResponse construct envoy proxy error response
func generateErrorResponse(statusCode envoyTypePb.StatusCode, headers []*configPb.HeaderValueOption, body string) *extProcPb.ProcessingResponse {
// Set the Content-Type header to application/json
headers = append(headers, &configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: "Content-Type",
Value: "application/json",
},
})

return &extProcPb.ProcessingResponse{
Response: &extProcPb.ProcessingResponse_ImmediateResponse{
ImmediateResponse: &extProcPb.ImmediateResponse{
Status: &envoyTypePb.HttpStatus{
Code: statusCode,
},
Headers: &extProcPb.HeaderMutation{
SetHeaders: headers,
},
Body: generateErrorMessage(body, int(statusCode)),
},
},
}
}

// generateErrorMessage constructs a JSON error message
func generateErrorMessage(message string, code int) string {
errorStruct := map[string]interface{}{
"error": map[string]interface{}{
"message": message,
"code": code,
},
}
jsonData, _ := json.Marshal(errorStruct)
return string(jsonData)
}

0 comments on commit 9223769

Please sign in to comment.