Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make stream include usage as optional #788

Merged
merged 2 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions development/vllm/config/deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,29 @@ spec:
image: aibrix/vllm-cpu-env:macos
ports:
- containerPort: 8000
command: ["/bin/sh", "-c"]
args: ["vllm serve facebook/opt-125m --served-model-name facebook-opt-125m --chat-template /etc/chat-template-config/chat-template.j2 --trust-remote-code --device cpu --disable_async_output_proc --enforce-eager --dtype float16"]
command:
- python3
- -m
- vllm.entrypoints.openai.api_server
- --host
- "0.0.0.0"
- --port
- "8000"
- --uvicorn-log-level
- warning
- --model
- facebook/opt-125m
- --served-model-name
- facebook-opt-125m
- --chat-template
- /etc/chat-template-config/chat-template.j2
- --trust-remote-code
- --device
- cpu
- --disable_async_output_proc
- --enforce-eager
- --dtype
- float16
env:
- name: DEPLOYMENT_NAME
valueFrom:
Expand Down
107 changes: 5 additions & 102 deletions pkg/plugins/gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"fmt"
"io"
"net/http"
"slices"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -243,7 +242,7 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, requestID string, req
}
}

routingStrategy, routingStrategyEnabled := GetRoutingStrategy(h.RequestHeaders.Headers.Headers)
routingStrategy, routingStrategyEnabled := getRoutingStrategy(h.RequestHeaders.Headers.Headers)
if routingStrategyEnabled && !validateRoutingStrategy(routingStrategy) {
klog.ErrorS(nil, "incorrect routing strategy", "routing-strategy", routingStrategy)
return generateErrorResponse(
Expand Down Expand Up @@ -338,22 +337,9 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestID string, req *e
}

stream, ok = jsonMap["stream"].(bool)
if stream && ok {
streamOptions, ok := jsonMap["stream_options"].(map[string]interface{})
if !ok {
klog.ErrorS(nil, "no stream option available", "requestID", requestID, "jsonMap", jsonMap)
return generateErrorResponse(envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: HeaderErrorNoStreamOptions, RawValue: []byte("stream options not set")}}},
"no stream option available"), model, targetPodIP, stream, term
}
includeUsage, ok := streamOptions["include_usage"].(bool)
if !includeUsage || !ok {
klog.ErrorS(nil, "no stream with usage option available", "requestID", requestID, "jsonMap", jsonMap)
return generateErrorResponse(envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: HeaderErrorStreamOptionsIncludeUsage, RawValue: []byte("include usage for stream options not set")}}},
"no stream with usage option available"), model, targetPodIP, stream, term
if ok && stream {
if errRes := validateStreamOptions(requestID, user, jsonMap); errRes != nil {
return errRes, model, targetPodIP, stream, term
}
}

Expand Down Expand Up @@ -466,7 +452,7 @@ func (s *Server) HandleResponseHeaders(ctx context.Context, requestID string, re

func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest, user utils.User, rpm int64, model string, targetPodIP string, stream bool, traceTerm int64, hasCompleted bool) (*extProcPb.ProcessingResponse, bool) {
b := req.Request.(*extProcPb.ProcessingRequest_ResponseBody)
klog.InfoS("-- In ResponseBody processing ...", "requestID", requestID, "endOfSteam", b.ResponseBody.EndOfStream)
klog.InfoS("-- In ResponseBody processing ...", "requestID", requestID, "endOfStream", b.ResponseBody.EndOfStream)

var res openai.ChatCompletion
var usage openai.CompletionUsage
Expand Down Expand Up @@ -717,86 +703,3 @@ func (s *Server) selectTargetPod(ctx context.Context, routingStrategy string, po

return route.Route(ctx, pods, model, message)
}

func validateRoutingStrategy(routingStrategy string) bool {
routingStrategy = strings.TrimSpace(routingStrategy)
return slices.Contains(routingStrategies, routingStrategy)
}

func generateErrorResponse(statusCode envoyTypePb.StatusCode, headers []*configPb.HeaderValueOption, body string) *extProcPb.ProcessingResponse {
// Set the Content-Type header to application/json
headers = append(headers, &configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: "Content-Type",
Value: "application/json",
},
})

return &extProcPb.ProcessingResponse{
Response: &extProcPb.ProcessingResponse_ImmediateResponse{
ImmediateResponse: &extProcPb.ImmediateResponse{
Status: &envoyTypePb.HttpStatus{
Code: statusCode,
},
Headers: &extProcPb.HeaderMutation{
SetHeaders: headers,
},
Body: generateErrorMessage(body, int(statusCode)),
},
},
}
}

func getRequestMessage(jsonMap map[string]interface{}) (string, *extProcPb.ProcessingResponse) {
messages, ok := jsonMap["messages"]
if !ok {
return "", generateErrorResponse(envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{Key: HeaderErrorRequestBodyProcessing, RawValue: []byte("true")}}},
"no messages in the request body")
}
messagesJSON, err := json.Marshal(messages)
if err != nil {
return "", generateErrorResponse(envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{Key: HeaderErrorRequestBodyProcessing, RawValue: []byte("true")}}},
"unable to marshal messages from request body")
}
return string(messagesJSON), nil
}

// GetRoutingStrategy retrieves the routing strategy from the headers or environment variable
// It returns the routing strategy value and whether custom routing strategy is enabled.
func GetRoutingStrategy(headers []*configPb.HeaderValue) (string, bool) {
var routingStrategy string
routingStrategyEnabled := false

// Check headers for routing strategy
for _, header := range headers {
if strings.ToLower(header.Key) == HeaderRoutingStrategy {
routingStrategy = string(header.RawValue)
routingStrategyEnabled = true
break // Prioritize header value over environment variable
}
}

// If header not set, check environment variable
if !routingStrategyEnabled {
if value, exists := utils.CheckEnvExists(EnvRoutingAlgorithm); exists {
routingStrategy = value
routingStrategyEnabled = true
}
}

return routingStrategy, routingStrategyEnabled
}

// generateErrorMessage constructs a JSON error message using fmt.Sprintf
func generateErrorMessage(message string, code int) string {
errorStruct := map[string]interface{}{
"error": map[string]interface{}{
"message": message,
"code": code,
},
}
jsonData, _ := json.Marshal(errorStruct)
return string(jsonData)
}
2 changes: 1 addition & 1 deletion pkg/plugins/gateway/gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func TestGetRoutingStrategy(t *testing.T) {
_ = os.Unsetenv("ROUTING_ALGORITHM")
}

routingStrategy, enabled := GetRoutingStrategy(tt.headers)
routingStrategy, enabled := getRoutingStrategy(tt.headers)
assert.Equal(t, tt.expectedStrategy, routingStrategy, tt.message)
assert.Equal(t, tt.expectedEnabled, enabled, tt.message)

Expand Down
138 changes: 138 additions & 0 deletions pkg/plugins/gateway/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/*
Copyright 2024 The Aibrix Team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package gateway

import (
"encoding/json"
"slices"
"strings"

configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3"
"github.com/vllm-project/aibrix/pkg/utils"
"k8s.io/klog/v2"
)

// validateStreamOptions validates whether stream options to include usage is set for user request
func validateStreamOptions(requestID string, user utils.User, jsonMap map[string]interface{}) *extProcPb.ProcessingResponse {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we know user are using heterogenous features? If we have the feature flag, we can do some validation here as well. not something we need to fix in this PR, just surface this requirement

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking we should have a feature flag for heterogenous feature. Right now it is enabled by default and incurs some performance penalty.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, please help document this issue.

if user.Tpm != 0 {
streamOptions, ok := jsonMap["stream_options"].(map[string]interface{})
if !ok {
klog.ErrorS(nil, "no stream option available", "requestID", requestID, "jsonMap", jsonMap)
return generateErrorResponse(envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: HeaderErrorNoStreamOptions, RawValue: []byte("stream options not set")}}},
"no stream option available")
}
includeUsage, ok := streamOptions["include_usage"].(bool)
if !includeUsage || !ok {
klog.ErrorS(nil, "no stream with usage option available", "requestID", requestID, "jsonMap", jsonMap)
return generateErrorResponse(envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: HeaderErrorStreamOptionsIncludeUsage, RawValue: []byte("include usage for stream options not set")}}},
"no stream with usage option available")
}
}
return nil
}

// getRoutingStrategy retrieves the routing strategy from the headers or environment variable
// It returns the routing strategy value and whether custom routing strategy is enabled.
func getRoutingStrategy(headers []*configPb.HeaderValue) (string, bool) {
var routingStrategy string
routingStrategyEnabled := false

// Check headers for routing strategy
for _, header := range headers {
if strings.ToLower(header.Key) == HeaderRoutingStrategy {
routingStrategy = string(header.RawValue)
routingStrategyEnabled = true
break // Prioritize header value over environment variable
}
}

// If header not set, check environment variable
if !routingStrategyEnabled {
if value, exists := utils.CheckEnvExists(EnvRoutingAlgorithm); exists {
routingStrategy = value
routingStrategyEnabled = true
}
}

return routingStrategy, routingStrategyEnabled
}

// validateRoutingStrategy validates if user provided routing strategy is supported by gateway
func validateRoutingStrategy(routingStrategy string) bool {
routingStrategy = strings.TrimSpace(routingStrategy)
return slices.Contains(routingStrategies, routingStrategy)
}

// getRequestMessage returns input request message field which has user prompt
func getRequestMessage(jsonMap map[string]interface{}) (string, *extProcPb.ProcessingResponse) {
messages, ok := jsonMap["messages"]
if !ok {
return "", generateErrorResponse(envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{Key: HeaderErrorRequestBodyProcessing, RawValue: []byte("true")}}},
"no messages in the request body")
}
messagesJSON, err := json.Marshal(messages)
if err != nil {
return "", generateErrorResponse(envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{Key: HeaderErrorRequestBodyProcessing, RawValue: []byte("true")}}},
"unable to marshal messages from request body")
}
return string(messagesJSON), nil
}

// generateErrorResponse construct envoy proxy error response
func generateErrorResponse(statusCode envoyTypePb.StatusCode, headers []*configPb.HeaderValueOption, body string) *extProcPb.ProcessingResponse {
// Set the Content-Type header to application/json
headers = append(headers, &configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: "Content-Type",
Value: "application/json",
},
})

return &extProcPb.ProcessingResponse{
Response: &extProcPb.ProcessingResponse_ImmediateResponse{
ImmediateResponse: &extProcPb.ImmediateResponse{
Status: &envoyTypePb.HttpStatus{
Code: statusCode,
},
Headers: &extProcPb.HeaderMutation{
SetHeaders: headers,
},
Body: generateErrorMessage(body, int(statusCode)),
},
},
}
}

// generateErrorMessage constructs a JSON error message
func generateErrorMessage(message string, code int) string {
errorStruct := map[string]interface{}{
"error": map[string]interface{}{
"message": message,
"code": code,
},
}
jsonData, _ := json.Marshal(errorStruct)
return string(jsonData)
}
Loading