diff --git a/pkg/plugins/gateway/gateway.go b/pkg/plugins/gateway/gateway.go index 479a3b4d..1e74013c 100644 --- a/pkg/plugins/gateway/gateway.go +++ b/pkg/plugins/gateway/gateway.go @@ -215,7 +215,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { } func (s *Server) HandleRequestHeaders(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest) (*extProcPb.ProcessingResponse, utils.User, int64, string) { - klog.Info("\n\n") + klog.Info("\n") klog.InfoS("-- In RequestHeaders processing ...", "requestID", requestID) var username string var user utils.User @@ -482,15 +482,19 @@ func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req * }}}, err.Error()), complete } else if len(res.Model) == 0 { - err = ErrorUnknownResponse - klog.ErrorS(err, "unexpected response", "requestID", requestID, "responseBody", string(b.ResponseBody.GetBody())) + msg := ErrorUnknownResponse.Error() + responseBodyContent := string(b.ResponseBody.GetBody()) + if len(responseBodyContent) != 0 { + msg = responseBodyContent + } + klog.ErrorS(err, "unexpected response", "requestID", requestID, "responseBody", responseBodyContent) complete = true return generateErrorResponse( envoyTypePb.StatusCode_InternalServerError, []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ Key: HeaderErrorResponseUnknown, RawValue: []byte("true"), }}}, - err.Error()), complete + msg), complete } // Do not overwrite model, res can be empty. usage = res.Usage @@ -663,6 +667,14 @@ func validateRoutingStrategy(routingStrategy string) bool { } func generateErrorResponse(statusCode envoyTypePb.StatusCode, headers []*configPb.HeaderValueOption, body string) *extProcPb.ProcessingResponse { + // Set the Content-Type header to application/json + headers = append(headers, &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: "Content-Type", + Value: "application/json", + }, + }) + return &extProcPb.ProcessingResponse{ Response: &extProcPb.ProcessingResponse_ImmediateResponse{ ImmediateResponse: &extProcPb.ImmediateResponse{ @@ -672,7 +684,7 @@ func generateErrorResponse(statusCode envoyTypePb.StatusCode, headers []*configP Headers: &extProcPb.HeaderMutation{ SetHeaders: headers, }, - Body: body, + Body: generateErrorMessage(body, int(statusCode)), }, }, } @@ -719,3 +731,15 @@ func GetRoutingStrategy(headers []*configPb.HeaderValue) (string, bool) { return routingStrategy, routingStrategyEnabled } + +// generateErrorMessage constructs a JSON error message using fmt.Sprintf +func generateErrorMessage(message string, code int) string { + errorStruct := map[string]interface{}{ + "error": map[string]interface{}{ + "message": message, + "code": code, + }, + } + jsonData, _ := json.Marshal(errorStruct) + return string(jsonData) +} diff --git a/test/e2e/e2e_test.go b/test/e2e/e2e_test.go index 57cd15f4..e29811c0 100644 --- a/test/e2e/e2e_test.go +++ b/test/e2e/e2e_test.go @@ -18,6 +18,8 @@ package e2e import ( "context" + "errors" + "net/http" "testing" "github.com/openai/openai-go" @@ -43,52 +45,70 @@ func TestBaseModelInference(t *testing.T) { Model: openai.F(modelName), }) if err != nil { - t.Error("chat completions failed", err) + t.Fatalf("chat completions failed: %v", err) } + assert.Equal(t, modelName, chatCompletion.Model) + assert.NotEmpty(t, chatCompletion.Choices, "chat completion has no choices returned") + assert.NotNil(t, chatCompletion.Choices[0].Message.Content, "chat completion has no message returned") } func TestBaseModelInferenceFailures(t *testing.T) { - // error on invalid api key - client := createOpenAIClient(baseURL, "fake-api-key") - _, err := client.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{ - Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ - openai.UserMessage("Say this is a test"), - }), - Model: openai.F(modelName), - }) - if err == nil { - t.Error("500 Internal Server Error expected for invalid api-key") + testCases := []struct { + name string + apiKey string + modelName string + routingStrategy string + expectErrCode int + }{ + { + name: "Invalid API Key", + apiKey: "fake-api-key", + modelName: modelName, + // TODO: it is supposed to be 401. Let's handle such case and fix this. + expectErrCode: 500, + }, + { + name: "Invalid Model Name", + apiKey: apiKey, + modelName: "fake-model-name", + expectErrCode: 400, + }, + { + name: "Invalid Routing Strategy", + apiKey: apiKey, + modelName: modelName, + routingStrategy: "invalid-routing-strategy", + expectErrCode: 400, + }, } - // error on invalid model name - client = createOpenAIClient(baseURL, apiKey) - _, err = client.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{ - Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ - openai.UserMessage("Say this is a test"), - }), - Model: openai.F("fake-model-name"), - }) - assert.Contains(t, err.Error(), "400 Bad Request") - if err == nil { - t.Error("400 Bad Request expected for invalid api-key") - } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var client *openai.Client + if tc.routingStrategy != "" { + var dst *http.Response + client = createOpenAIClientWithRoutingStrategy(baseURL, tc.apiKey, + tc.routingStrategy, option.WithResponseInto(&dst)) + } else { + client = createOpenAIClient(baseURL, tc.apiKey) + } - // invalid routing strategy - client = openai.NewClient( - option.WithBaseURL(baseURL), - option.WithAPIKey(apiKey), - option.WithHeader("routing-strategy", "invalid-routing-strategy"), - ) - client.Options = append(client.Options, option.WithHeader("routing-strategy", "invalid-routing-strategy")) - _, err = client.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{ - Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ - openai.UserMessage("Say this is a test"), - }), - Model: openai.F(modelName), - }) - if err == nil { - t.Error("400 Bad Request expected for invalid routing-strategy") + _, err := client.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{ + Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("Say this is a test"), + }), + Model: openai.F(tc.modelName), + }) + + assert.Error(t, err) + var apiErr *openai.Error + if !errors.As(err, &apiErr) { + t.Fatalf("Error is not an APIError: %+v", err) + } + if assert.ErrorAs(t, err, &apiErr) { + assert.Equal(t, apiErr.StatusCode, tc.expectErrCode) + } + }) } - assert.Contains(t, err.Error(), "400 Bad Request") } diff --git a/test/e2e/util.go b/test/e2e/util.go index a3879c1a..0ffbd04f 100644 --- a/test/e2e/util.go +++ b/test/e2e/util.go @@ -83,6 +83,7 @@ func createOpenAIClient(baseURL, apiKey string) *openai.Client { r.URL.Path = "/v1" + r.URL.Path return mn(r) }), + option.WithMaxRetries(0), ) } @@ -96,6 +97,7 @@ func createOpenAIClientWithRoutingStrategy(baseURL, apiKey, routingStrategy stri return mn(r) }), option.WithHeader("routing-strategy", routingStrategy), + option.WithMaxRetries(0), respOpt, ) } diff --git a/test/run-e2e-tests.sh b/test/run-e2e-tests.sh index 6f8a6b6c..3860a6af 100755 --- a/test/run-e2e-tests.sh +++ b/test/run-e2e-tests.sh @@ -66,7 +66,7 @@ if [ -n "$INSTALL_AIBRIX" ]; then cd ../.. kubectl port-forward svc/llama2-7b 8000:8000 & - kubectl -n envoy-gateway-system port-forward service/envoy-aibrix-system-aibrix-eg-903790dc 8888:80 & + kubectl -n envoy-gateway-system port-forward service/envoy-aibrix-system-aibrix-eg-903790dc 8888:80 & function cleanup { echo "Cleaning up..."