diff --git a/pkg/plugins/gateway/gateway.go b/pkg/plugins/gateway/gateway.go index b8a47a0f..ad7e68ab 100644 --- a/pkg/plugins/gateway/gateway.go +++ b/pkg/plugins/gateway/gateway.go @@ -25,6 +25,7 @@ import ( "io" "net/http" "slices" + "strconv" "strings" "sync" "time" @@ -171,8 +172,9 @@ func (s *HealthServer) Watch(in *healthPb.HealthCheckRequest, srv healthPb.Healt func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { var user utils.User var rpm, traceTerm int64 + var respErrorCode int var model, routingStrategy, targetPodIP string - var stream bool + var stream, isRespError bool ctx := srv.Context() requestID := uuid.New().String() completed := false @@ -202,11 +204,16 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { resp, model, targetPodIP, stream, traceTerm = s.HandleRequestBody(ctx, requestID, req, user, routingStrategy) case *extProcPb.ProcessingRequest_ResponseHeaders: - resp = s.HandleResponseHeaders(ctx, requestID, req, targetPodIP) + resp, isRespError, respErrorCode = s.HandleResponseHeaders(ctx, requestID, req, targetPodIP) case *extProcPb.ProcessingRequest_ResponseBody: - resp, completed = s.HandleResponseBody(ctx, requestID, req, user, rpm, model, targetPodIP, stream, traceTerm, completed) - + respBody := req.Request.(*extProcPb.ProcessingRequest_ResponseBody) + if isRespError { + klog.ErrorS(errors.New("request end"), string(respBody.ResponseBody.GetBody()), "requestID", requestID) + generateErrorResponse(envoyTypePb.StatusCode(respErrorCode), nil, string(respBody.ResponseBody.GetBody())) + } else { + resp, completed = s.HandleResponseBody(ctx, requestID, req, user, rpm, model, targetPodIP, stream, traceTerm, completed) + } default: klog.Infof("Unknown Request type %+v\n", v) } @@ -218,7 +225,6 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { } func (s *Server) HandleRequestHeaders(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest) (*extProcPb.ProcessingResponse, utils.User, int64, string) { - klog.Info("\n") klog.InfoS("-- In RequestHeaders processing ...", "requestID", requestID) var username string var user utils.User @@ -358,7 +364,7 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestID string, req *e klog.InfoS("request start", "requestID", requestID, "model", model) } else { message, extErr := getRequestMessage(jsonMap) - if err != nil { + if extErr != nil { return extErr, model, targetPodIP, stream, term } @@ -403,8 +409,9 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestID string, req *e }, model, targetPodIP, stream, term } -func (s *Server) HandleResponseHeaders(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest, targetPodIP string) *extProcPb.ProcessingResponse { +func (s *Server) HandleResponseHeaders(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest, targetPodIP string) (*extProcPb.ProcessingResponse, bool, int) { klog.InfoS("-- In ResponseHeaders processing ...", "requestID", requestID) + b := req.Request.(*extProcPb.ProcessingRequest_ResponseHeaders) headers := []*configPb.HeaderValueOption{{ Header: &configPb.HeaderValue{ @@ -421,6 +428,24 @@ func (s *Server) HandleResponseHeaders(ctx context.Context, requestID string, re }) } + var isProcessingError bool + var processingErrorCode int + for _, headerValue := range b.ResponseHeaders.Headers.Headers { + if headerValue.Key == ":status" { + code, _ := strconv.Atoi(string(headerValue.RawValue)) + if code != 200 { + isProcessingError = true + processingErrorCode = code + } + } + headers = append(headers, &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: headerValue.Key, + RawValue: headerValue.RawValue, + }, + }) + } + return &extProcPb.ProcessingResponse{ Response: &extProcPb.ProcessingResponse_ResponseHeaders{ ResponseHeaders: &extProcPb.HeadersResponse{ @@ -432,7 +457,7 @@ func (s *Server) HandleResponseHeaders(ctx context.Context, requestID string, re }, }, }, - } + }, isProcessingError, processingErrorCode } func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest, user utils.User, rpm int64, model string, targetPodIP string, stream bool, traceTerm int64, hasCompleted bool) (*extProcPb.ProcessingResponse, bool) { diff --git a/test/e2e/e2e_test.go b/test/e2e/e2e_test.go index 65185e47..74d4b82b 100644 --- a/test/e2e/e2e_test.go +++ b/test/e2e/e2e_test.go @@ -55,11 +55,10 @@ func TestBaseModelInferenceFailures(t *testing.T) { expectErrCode int }{ { - name: "Invalid API Key", - apiKey: "fake-api-key", - modelName: modelName, - // TODO: it is supposed to be 401. Let's handle such case and fix this. - expectErrCode: 500, + name: "Invalid API Key", + apiKey: "fake-api-key", + modelName: modelName, + expectErrCode: 401, }, { name: "Invalid Model Name", @@ -100,7 +99,7 @@ func TestBaseModelInferenceFailures(t *testing.T) { t.Fatalf("Error is not an APIError: %+v", err) } if assert.ErrorAs(t, err, &apiErr) { - assert.Equal(t, apiErr.StatusCode, tc.expectErrCode) + assert.Equal(t, tc.expectErrCode, apiErr.StatusCode, t.Name()) } }) } diff --git a/test/e2e/model_adapter_test.go b/test/e2e/model_adapter_test.go index 71046142..c434925b 100644 --- a/test/e2e/model_adapter_test.go +++ b/test/e2e/model_adapter_test.go @@ -1,3 +1,19 @@ +/* +Copyright 2024 The Aibrix Team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package e2e import ( @@ -23,20 +39,24 @@ func TestModelAdapter(t *testing.T) { k8sClient, v1alpha1Client := initializeClient(context.Background(), t) t.Cleanup(func() { - assert.NoError(t, v1alpha1Client.ModelV1alpha1().ModelAdapters("default").Delete(context.Background(), adapter.Name, v1.DeleteOptions{})) - wait.PollImmediate(1*time.Second, 30*time.Second, - func() (done bool, err error) { - adapter, err = v1alpha1Client.ModelV1alpha1().ModelAdapters("default").Get(context.Background(), adapter.Name, v1.GetOptions{}) + assert.NoError(t, v1alpha1Client.ModelV1alpha1().ModelAdapters("default").Delete(context.Background(), + adapter.Name, v1.DeleteOptions{})) + assert.NoError(t, wait.PollUntilContextTimeout(context.Background(), 1*time.Second, 30*time.Second, true, + func(ctx context.Context) (done bool, err error) { + adapter, err = v1alpha1Client.ModelV1alpha1().ModelAdapters("default").Get(context.Background(), + adapter.Name, v1.GetOptions{}) if apierrors.IsNotFound(err) { return true, nil } return false, nil - }) + })) + }) // create model adapter fmt.Println("creating model adapter") - adapter, err := v1alpha1Client.ModelV1alpha1().ModelAdapters("default").Create(context.Background(), adapter, v1.CreateOptions{}) + adapter, err := v1alpha1Client.ModelV1alpha1().ModelAdapters("default").Create(context.Background(), + adapter, v1.CreateOptions{}) assert.NoError(t, err) adapter = validateModelAdapter(t, v1alpha1Client, adapter.Name) oldPod := adapter.Status.Instances[0] @@ -80,14 +100,14 @@ func createModelAdapterConfig(name, model string) *modelv1alpha1.ModelAdapter { func validateModelAdapter(t *testing.T, client *v1alpha1.Clientset, name string) *modelv1alpha1.ModelAdapter { var adapter *modelv1alpha1.ModelAdapter - wait.PollImmediate(1*time.Second, 30*time.Second, - func() (done bool, err error) { + assert.NoError(t, wait.PollUntilContextTimeout(context.Background(), 1*time.Second, 30*time.Second, true, + func(ctx context.Context) (done bool, err error) { adapter, err = client.ModelV1alpha1().ModelAdapters("default").Get(context.Background(), name, v1.GetOptions{}) if err != nil || adapter.Status.Phase != modelv1alpha1.ModelAdapterRunning { return false, nil } return true, nil - }) + })) assert.True(t, len(adapter.Status.Instances) > 0, "model adapter scheduled on atleast one pod") return adapter } diff --git a/test/e2e/util.go b/test/e2e/util.go index 4505ddde..a00e80ae 100644 --- a/test/e2e/util.go +++ b/test/e2e/util.go @@ -121,7 +121,7 @@ func validateInferenceWithClient(t *testing.T, client *openai.Client, modelName Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ openai.UserMessage("Say this is a test"), }), - Model: openai.F(openai.ChatModel(modelName)), + Model: openai.F(modelName), }) if err != nil { t.Fatalf("chat completions failed : %v", err)