diff --git a/config/gateway/gateway.yaml b/config/gateway/gateway.yaml index 418882f0..aef47be5 100644 --- a/config/gateway/gateway.yaml +++ b/config/gateway/gateway.yaml @@ -33,9 +33,10 @@ spec: - name: aibrix-gateway-plugins port: 50052 processingMode: - request: {} + request: + body: Buffered response: - body: Streamed + body: Buffered --- apiVersion: gateway.envoyproxy.io/v1alpha1 kind: EnvoyPatchPolicy diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index d5c81307..b95d777f 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -139,7 +139,7 @@ func (c *Cache) addPod(obj interface{}) { c.pods[pod.Name] = pod c.addPodAndModelMapping(pod.Name, modelName) - klog.Infof("POD CREATED: %s/%s", pod.Namespace, pod.Name) + klog.V(4).Infof("POD CREATED: %s/%s", pod.Namespace, pod.Name) c.debugInfo() } @@ -159,9 +159,11 @@ func (c *Cache) updatePod(oldObj interface{}, newObj interface{}) { return } + delete(c.pods, oldPod.Name) + c.pods[newPod.Name] = newPod c.deletePodAndModelMapping(oldPod.Name, oldModelName) c.addPodAndModelMapping(newPod.Name, newModelName) - klog.Infof("POD UPDATED. %s/%s %s", oldPod.Namespace, oldPod.Name, newPod.Status.Phase) + klog.V(4).Infof("POD UPDATED. %s/%s %s", newPod.Namespace, newPod.Name, newPod.Status.Phase) c.debugInfo() } @@ -177,7 +179,7 @@ func (c *Cache) deletePod(obj interface{}) { delete(c.pods, pod.Name) c.deletePodAndModelMapping(pod.Name, modelName) - klog.Infof("POD DELETED: %s/%s", pod.Namespace, pod.Name) + klog.V(4).Infof("POD DELETED: %s/%s", pod.Namespace, pod.Name) c.debugInfo() } @@ -190,7 +192,7 @@ func (c *Cache) addModelAdapter(obj interface{}) { c.addPodAndModelMapping(pod, model.Name) } - klog.Infof("MODELADAPTER CREATED: %s/%s", model.Namespace, model.Name) + klog.V(4).Infof("MODELADAPTER CREATED: %s/%s", model.Namespace, model.Name) c.debugInfo() } @@ -209,7 +211,7 @@ func (c *Cache) updateModelAdapter(oldObj interface{}, newObj interface{}) { c.addPodAndModelMapping(pod, newModel.Name) } - klog.Infof("MODELADAPTER UPDATED. %s/%s %s", oldModel.Namespace, oldModel.Name, newModel.Status.Phase) + klog.V(4).Infof("MODELADAPTER UPDATED. %s/%s %s", oldModel.Namespace, oldModel.Name, newModel.Status.Phase) c.debugInfo() } @@ -222,7 +224,7 @@ func (c *Cache) deleteModelAdapter(obj interface{}) { c.deletePodAndModelMapping(pod, model.Name) } - klog.Infof("MODELADAPTER DELETED: %s/%s", model.Namespace, model.Name) + klog.V(4).Infof("MODELADAPTER DELETED: %s/%s", model.Namespace, model.Name) c.debugInfo() } @@ -261,21 +263,21 @@ func (c *Cache) deletePodAndModelMapping(podName, modelName string) { func (c *Cache) debugInfo() { for _, pod := range c.pods { - klog.Info(pod.Name) + klog.V(4).Infof("pod: %s, podIP: %v", pod.Name, pod.Status.PodIP) } for podName, models := range c.podToModelMapping { var modelList string for modelName := range models { modelList += modelName + " " } - klog.Infof("pod: %s, models: %s", podName, modelList) + klog.V(4).Infof("pod: %s, models: %s", podName, modelList) } for modelName, pods := range c.modelToPodMapping { var podList string for podName := range pods { podList += podName + " " } - klog.Infof("model: %s, pods: %s", modelName, podList) + klog.V(4).Infof("model: %s, pods: %s", modelName, podList) } } diff --git a/pkg/plugins/gateway/gateway.go b/pkg/plugins/gateway/gateway.go index 0fe3aab5..ae6da224 100644 --- a/pkg/plugins/gateway/gateway.go +++ b/pkg/plugins/gateway/gateway.go @@ -25,6 +25,7 @@ import ( "strings" "time" + "github.com/google/uuid" "github.com/redis/go-redis/v9" openai "github.com/sashabaranov/go-openai" "google.golang.org/grpc/codes" @@ -38,7 +39,6 @@ import ( routing "github.com/aibrix/aibrix/pkg/plugins/gateway/routing_algorithms" "github.com/aibrix/aibrix/pkg/utils" configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" - filterPb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/ext_proc/v3" extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3" healthPb "google.golang.org/grpc/health/grpc_health_v1" @@ -93,6 +93,7 @@ func (s *HealthServer) Watch(in *healthPb.HealthCheckRequest, srv healthPb.Healt func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { var user, targetPodIP string ctx := srv.Context() + requestID := uuid.New().String() for { select { @@ -113,7 +114,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { switch v := req.Request.(type) { case *extProcPb.ProcessingRequest_RequestHeaders: - resp, user, targetPodIP = s.HandleRequestHeaders(ctx, req) + resp, user, targetPodIP = s.HandleRequestHeaders(ctx, requestID, req) case *extProcPb.ProcessingRequest_RequestBody: resp = s.HandleRequestBody(req, targetPodIP) @@ -122,7 +123,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { resp = s.HandleResponseHeaders(req, targetPodIP) case *extProcPb.ProcessingRequest_ResponseBody: - resp = s.HandleResponseBody(ctx, req, user, targetPodIP) + resp = s.HandleResponseBody(ctx, requestID, req, user, targetPodIP) default: log.Printf("Unknown Request type %+v\n", v) @@ -134,7 +135,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { } } -func (s *Server) HandleRequestHeaders(ctx context.Context, req *extProcPb.ProcessingRequest) (*extProcPb.ProcessingResponse, string, string) { +func (s *Server) HandleRequestHeaders(ctx context.Context, reqeustID string, req *extProcPb.ProcessingRequest) (*extProcPb.ProcessingResponse, string, string) { log.Println("--- In RequestHeaders processing ...") var username, model, routingStrategy, targetPodIP string r := req.Request @@ -182,6 +183,17 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, req *extProcPb.Proces fmt.Sprintf("pre query: error on checking rpm: %v", err.Error())), username, targetPodIP } + rpm, code, err := s.incrRPM(ctx, username) + if err != nil { + return generateErrorResponse( + code, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: "x-error-update-rpm", RawValue: []byte("true"), + }}}, + fmt.Sprintf("pre query: error on updating rpm: %v", err.Error())), username, targetPodIP + } + klog.Infof("RequestStart %s: RPM: %v for user: %v", reqeustID, rpm, user.Name) + code, err = s.checkTPM(ctx, username, user.Tpm) if err != nil { return generateErrorResponse( @@ -192,12 +204,20 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, req *extProcPb.Proces fmt.Sprintf("pre query: error on checking tpm: %v", err.Error())), username, targetPodIP } - headers := []*configPb.HeaderValueOption{{ - Header: &configPb.HeaderValue{ - Key: "x-went-into-req-headers", - RawValue: []byte("true"), + headers := []*configPb.HeaderValueOption{ + { + Header: &configPb.HeaderValue{ + Key: "x-went-into-req-headers", + RawValue: []byte("true"), + }, }, - }} + { + Header: &configPb.HeaderValue{ + Key: "x-updated-rpm", + RawValue: []byte(fmt.Sprintf("%d", rpm)), + }, + }, + } if routingStrategy != "" { pods, err := s.cache.GetPodsForModel(model) if len(pods) == 0 || err != nil { @@ -226,7 +246,7 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, req *extProcPb.Proces }, }) podRequestCounter := s.cache.IncrPodRequestCount(fmt.Sprintf("%v_REQUEST_COUNT", targetPodIP)) - klog.Infof("RequestStart: SelectedTargetPodIP: %s, PodRequestCount: %v", targetPodIP, podRequestCounter) + klog.Infof("RequestStart %s: SelectedTargetPodIP: %s, PodRequestCount: %v", reqeustID, targetPodIP, podRequestCounter) } resp := &extProcPb.ProcessingResponse{ @@ -240,10 +260,6 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, req *extProcPb.Proces }, }, }, - ModeOverride: &filterPb.ProcessingMode{ - ResponseHeaderMode: filterPb.ProcessingMode_SEND, - RequestBodyMode: filterPb.ProcessingMode_NONE, - }, } return resp, username, targetPodIP @@ -310,8 +326,8 @@ func (s *Server) HandleResponseHeaders(req *extProcPb.ProcessingRequest, targetP } } -func (s *Server) HandleResponseBody(ctx context.Context, req *extProcPb.ProcessingRequest, user string, targetPodIP string) *extProcPb.ProcessingResponse { - log.Println("--- In ResponseBody processing") +func (s *Server) HandleResponseBody(ctx context.Context, reqeustID string, req *extProcPb.ProcessingRequest, user string, targetPodIP string) *extProcPb.ProcessingResponse { + klog.Infof("--- In ResponseBody processing %s", reqeustID) r := req.Request b := r.(*extProcPb.ProcessingRequest_ResponseBody) @@ -319,7 +335,7 @@ func (s *Server) HandleResponseBody(ctx context.Context, req *extProcPb.Processi defer func() { if targetPodIP != "" { podRequestCounter := s.cache.DecrPodRequestCount(fmt.Sprintf("%v_REQUEST_COUNT", targetPodIP)) - klog.Infof("RequestEnd: SelectedTargetPodIP: %s, PodRequestCount: %v", targetPodIP, podRequestCounter) + klog.Infof("RequestEnd %s: SelectedTargetPodIP: %s, PodRequestCount: %v", reqeustID, targetPodIP, podRequestCounter) } }() @@ -333,15 +349,6 @@ func (s *Server) HandleResponseBody(ctx context.Context, req *extProcPb.Processi err.Error()) } - rpm, err := s.ratelimiter.Incr(ctx, fmt.Sprintf("%v_RPM_CURRENT", user), 1) - if err != nil { - return generateErrorResponse( - envoyTypePb.StatusCode_InternalServerError, - []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ - Key: "x-error-update-rpm", RawValue: []byte("true"), - }}}, - fmt.Sprintf("post query: error on updating rpm: %v", err.Error())) - } tpm, err := s.ratelimiter.Incr(ctx, fmt.Sprintf("%v_TPM_CURRENT", user), int64(res.Usage.TotalTokens)) if err != nil { return generateErrorResponse( @@ -351,14 +358,14 @@ func (s *Server) HandleResponseBody(ctx context.Context, req *extProcPb.Processi }}}, fmt.Sprintf("post query: error on updating tpm: %v", err.Error())) } - klog.Infof("Updated RPM: %v, TPM: %v for user: %v", rpm, tpm, user) + klog.Infof("RequestEnd %s: TPM: %v for user: %v", reqeustID, tpm, user) if targetPodIP != "" { podTpm, err := s.ratelimiter.Incr(ctx, fmt.Sprintf("%v_THROUGHPUT", targetPodIP), int64(res.Usage.TotalTokens)) if err != nil { klog.Error(err) } else { - klog.Infof("RequestEnd: SelectedTargetPodIP: %s, PodThroughput: %v", targetPodIP, podTpm) + klog.Infof("RequestEnd %s: SelectedTargetPodIP: %s, PodThroughput: %v", reqeustID, targetPodIP, podTpm) } } @@ -368,12 +375,6 @@ func (s *Server) HandleResponseBody(ctx context.Context, req *extProcPb.Processi Response: &extProcPb.CommonResponse{ HeaderMutation: &extProcPb.HeaderMutation{ SetHeaders: []*configPb.HeaderValueOption{ - { - Header: &configPb.HeaderValue{ - Key: "x-updated-rpm", - RawValue: []byte(fmt.Sprintf("%d", rpm)), - }, - }, { Header: &configPb.HeaderValue{ Key: "x-updated-tpm", @@ -404,6 +405,16 @@ func (s *Server) checkRPM(ctx context.Context, user string, rpmLimit int64) (env return envoyTypePb.StatusCode_OK, nil } +func (s *Server) incrRPM(ctx context.Context, user string) (int64, envoyTypePb.StatusCode, error) { + rpm, err := s.ratelimiter.Incr(ctx, fmt.Sprintf("%v_RPM_CURRENT", user), 1) + if err != nil { + return rpm, envoyTypePb.StatusCode_InternalServerError, err + } + + klog.Infof("Updated RPM: %v for user: %v", rpm, user) + return rpm, envoyTypePb.StatusCode_OK, nil +} + func (s *Server) checkTPM(ctx context.Context, user string, tpmLimit int64) (envoyTypePb.StatusCode, error) { tpmCurrent, err := s.ratelimiter.Get(ctx, fmt.Sprintf("%v_TPM_CURRENT", user)) if err != nil {