Skip to content

Commit

Permalink
Fix concurrency issue with gateway RPM plugin (#244)
Browse files Browse the repository at this point in the history
* Fix concurrency issue with gateway RPM plugin

* update new pod to fix IP issue

* update log level
  • Loading branch information
varungup90 authored Sep 27, 2024
1 parent 2ce3b56 commit 0df904d
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 45 deletions.
5 changes: 3 additions & 2 deletions config/gateway/gateway.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ spec:
- name: aibrix-gateway-plugins
port: 50052
processingMode:
request: {}
request:
body: Buffered
response:
body: Streamed
body: Buffered
---
apiVersion: gateway.envoyproxy.io/v1alpha1
kind: EnvoyPatchPolicy
Expand Down
20 changes: 11 additions & 9 deletions pkg/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func (c *Cache) addPod(obj interface{}) {

c.pods[pod.Name] = pod
c.addPodAndModelMapping(pod.Name, modelName)
klog.Infof("POD CREATED: %s/%s", pod.Namespace, pod.Name)
klog.V(4).Infof("POD CREATED: %s/%s", pod.Namespace, pod.Name)
c.debugInfo()
}

Expand All @@ -159,9 +159,11 @@ func (c *Cache) updatePod(oldObj interface{}, newObj interface{}) {
return
}

delete(c.pods, oldPod.Name)
c.pods[newPod.Name] = newPod
c.deletePodAndModelMapping(oldPod.Name, oldModelName)
c.addPodAndModelMapping(newPod.Name, newModelName)
klog.Infof("POD UPDATED. %s/%s %s", oldPod.Namespace, oldPod.Name, newPod.Status.Phase)
klog.V(4).Infof("POD UPDATED. %s/%s %s", newPod.Namespace, newPod.Name, newPod.Status.Phase)
c.debugInfo()
}

Expand All @@ -177,7 +179,7 @@ func (c *Cache) deletePod(obj interface{}) {

delete(c.pods, pod.Name)
c.deletePodAndModelMapping(pod.Name, modelName)
klog.Infof("POD DELETED: %s/%s", pod.Namespace, pod.Name)
klog.V(4).Infof("POD DELETED: %s/%s", pod.Namespace, pod.Name)
c.debugInfo()
}

Expand All @@ -190,7 +192,7 @@ func (c *Cache) addModelAdapter(obj interface{}) {
c.addPodAndModelMapping(pod, model.Name)
}

klog.Infof("MODELADAPTER CREATED: %s/%s", model.Namespace, model.Name)
klog.V(4).Infof("MODELADAPTER CREATED: %s/%s", model.Namespace, model.Name)
c.debugInfo()
}

Expand All @@ -209,7 +211,7 @@ func (c *Cache) updateModelAdapter(oldObj interface{}, newObj interface{}) {
c.addPodAndModelMapping(pod, newModel.Name)
}

klog.Infof("MODELADAPTER UPDATED. %s/%s %s", oldModel.Namespace, oldModel.Name, newModel.Status.Phase)
klog.V(4).Infof("MODELADAPTER UPDATED. %s/%s %s", oldModel.Namespace, oldModel.Name, newModel.Status.Phase)
c.debugInfo()
}

Expand All @@ -222,7 +224,7 @@ func (c *Cache) deleteModelAdapter(obj interface{}) {
c.deletePodAndModelMapping(pod, model.Name)
}

klog.Infof("MODELADAPTER DELETED: %s/%s", model.Namespace, model.Name)
klog.V(4).Infof("MODELADAPTER DELETED: %s/%s", model.Namespace, model.Name)
c.debugInfo()
}

Expand Down Expand Up @@ -261,21 +263,21 @@ func (c *Cache) deletePodAndModelMapping(podName, modelName string) {

func (c *Cache) debugInfo() {
for _, pod := range c.pods {
klog.Info(pod.Name)
klog.V(4).Infof("pod: %s, podIP: %v", pod.Name, pod.Status.PodIP)
}
for podName, models := range c.podToModelMapping {
var modelList string
for modelName := range models {
modelList += modelName + " "
}
klog.Infof("pod: %s, models: %s", podName, modelList)
klog.V(4).Infof("pod: %s, models: %s", podName, modelList)
}
for modelName, pods := range c.modelToPodMapping {
var podList string
for podName := range pods {
podList += podName + " "
}
klog.Infof("model: %s, pods: %s", modelName, podList)
klog.V(4).Infof("model: %s, pods: %s", modelName, podList)
}
}

Expand Down
79 changes: 45 additions & 34 deletions pkg/plugins/gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"strings"
"time"

"github.com/google/uuid"
"github.com/redis/go-redis/v9"
openai "github.com/sashabaranov/go-openai"
"google.golang.org/grpc/codes"
Expand All @@ -38,7 +39,6 @@ import (
routing "github.com/aibrix/aibrix/pkg/plugins/gateway/routing_algorithms"
"github.com/aibrix/aibrix/pkg/utils"
configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
filterPb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/ext_proc/v3"
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3"
healthPb "google.golang.org/grpc/health/grpc_health_v1"
Expand Down Expand Up @@ -93,6 +93,7 @@ func (s *HealthServer) Watch(in *healthPb.HealthCheckRequest, srv healthPb.Healt
func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
var user, targetPodIP string
ctx := srv.Context()
requestID := uuid.New().String()

for {
select {
Expand All @@ -113,7 +114,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
switch v := req.Request.(type) {

case *extProcPb.ProcessingRequest_RequestHeaders:
resp, user, targetPodIP = s.HandleRequestHeaders(ctx, req)
resp, user, targetPodIP = s.HandleRequestHeaders(ctx, requestID, req)

case *extProcPb.ProcessingRequest_RequestBody:
resp = s.HandleRequestBody(req, targetPodIP)
Expand All @@ -122,7 +123,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
resp = s.HandleResponseHeaders(req, targetPodIP)

case *extProcPb.ProcessingRequest_ResponseBody:
resp = s.HandleResponseBody(ctx, req, user, targetPodIP)
resp = s.HandleResponseBody(ctx, requestID, req, user, targetPodIP)

default:
log.Printf("Unknown Request type %+v\n", v)
Expand All @@ -134,7 +135,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
}
}

func (s *Server) HandleRequestHeaders(ctx context.Context, req *extProcPb.ProcessingRequest) (*extProcPb.ProcessingResponse, string, string) {
func (s *Server) HandleRequestHeaders(ctx context.Context, reqeustID string, req *extProcPb.ProcessingRequest) (*extProcPb.ProcessingResponse, string, string) {
log.Println("--- In RequestHeaders processing ...")
var username, model, routingStrategy, targetPodIP string
r := req.Request
Expand Down Expand Up @@ -182,6 +183,17 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, req *extProcPb.Proces
fmt.Sprintf("pre query: error on checking rpm: %v", err.Error())), username, targetPodIP
}

rpm, code, err := s.incrRPM(ctx, username)
if err != nil {
return generateErrorResponse(
code,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: "x-error-update-rpm", RawValue: []byte("true"),
}}},
fmt.Sprintf("pre query: error on updating rpm: %v", err.Error())), username, targetPodIP
}
klog.Infof("RequestStart %s: RPM: %v for user: %v", reqeustID, rpm, user.Name)

code, err = s.checkTPM(ctx, username, user.Tpm)
if err != nil {
return generateErrorResponse(
Expand All @@ -192,12 +204,20 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, req *extProcPb.Proces
fmt.Sprintf("pre query: error on checking tpm: %v", err.Error())), username, targetPodIP
}

headers := []*configPb.HeaderValueOption{{
Header: &configPb.HeaderValue{
Key: "x-went-into-req-headers",
RawValue: []byte("true"),
headers := []*configPb.HeaderValueOption{
{
Header: &configPb.HeaderValue{
Key: "x-went-into-req-headers",
RawValue: []byte("true"),
},
},
}}
{
Header: &configPb.HeaderValue{
Key: "x-updated-rpm",
RawValue: []byte(fmt.Sprintf("%d", rpm)),
},
},
}
if routingStrategy != "" {
pods, err := s.cache.GetPodsForModel(model)
if len(pods) == 0 || err != nil {
Expand Down Expand Up @@ -226,7 +246,7 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, req *extProcPb.Proces
},
})
podRequestCounter := s.cache.IncrPodRequestCount(fmt.Sprintf("%v_REQUEST_COUNT", targetPodIP))
klog.Infof("RequestStart: SelectedTargetPodIP: %s, PodRequestCount: %v", targetPodIP, podRequestCounter)
klog.Infof("RequestStart %s: SelectedTargetPodIP: %s, PodRequestCount: %v", reqeustID, targetPodIP, podRequestCounter)
}

resp := &extProcPb.ProcessingResponse{
Expand All @@ -240,10 +260,6 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, req *extProcPb.Proces
},
},
},
ModeOverride: &filterPb.ProcessingMode{
ResponseHeaderMode: filterPb.ProcessingMode_SEND,
RequestBodyMode: filterPb.ProcessingMode_NONE,
},
}

return resp, username, targetPodIP
Expand Down Expand Up @@ -310,16 +326,16 @@ func (s *Server) HandleResponseHeaders(req *extProcPb.ProcessingRequest, targetP
}
}

func (s *Server) HandleResponseBody(ctx context.Context, req *extProcPb.ProcessingRequest, user string, targetPodIP string) *extProcPb.ProcessingResponse {
log.Println("--- In ResponseBody processing")
func (s *Server) HandleResponseBody(ctx context.Context, reqeustID string, req *extProcPb.ProcessingRequest, user string, targetPodIP string) *extProcPb.ProcessingResponse {
klog.Infof("--- In ResponseBody processing %s", reqeustID)

r := req.Request
b := r.(*extProcPb.ProcessingRequest_ResponseBody)

defer func() {
if targetPodIP != "" {
podRequestCounter := s.cache.DecrPodRequestCount(fmt.Sprintf("%v_REQUEST_COUNT", targetPodIP))
klog.Infof("RequestEnd: SelectedTargetPodIP: %s, PodRequestCount: %v", targetPodIP, podRequestCounter)
klog.Infof("RequestEnd %s: SelectedTargetPodIP: %s, PodRequestCount: %v", reqeustID, targetPodIP, podRequestCounter)
}
}()

Expand All @@ -333,15 +349,6 @@ func (s *Server) HandleResponseBody(ctx context.Context, req *extProcPb.Processi
err.Error())
}

rpm, err := s.ratelimiter.Incr(ctx, fmt.Sprintf("%v_RPM_CURRENT", user), 1)
if err != nil {
return generateErrorResponse(
envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: "x-error-update-rpm", RawValue: []byte("true"),
}}},
fmt.Sprintf("post query: error on updating rpm: %v", err.Error()))
}
tpm, err := s.ratelimiter.Incr(ctx, fmt.Sprintf("%v_TPM_CURRENT", user), int64(res.Usage.TotalTokens))
if err != nil {
return generateErrorResponse(
Expand All @@ -351,14 +358,14 @@ func (s *Server) HandleResponseBody(ctx context.Context, req *extProcPb.Processi
}}},
fmt.Sprintf("post query: error on updating tpm: %v", err.Error()))
}
klog.Infof("Updated RPM: %v, TPM: %v for user: %v", rpm, tpm, user)
klog.Infof("RequestEnd %s: TPM: %v for user: %v", reqeustID, tpm, user)

if targetPodIP != "" {
podTpm, err := s.ratelimiter.Incr(ctx, fmt.Sprintf("%v_THROUGHPUT", targetPodIP), int64(res.Usage.TotalTokens))
if err != nil {
klog.Error(err)
} else {
klog.Infof("RequestEnd: SelectedTargetPodIP: %s, PodThroughput: %v", targetPodIP, podTpm)
klog.Infof("RequestEnd %s: SelectedTargetPodIP: %s, PodThroughput: %v", reqeustID, targetPodIP, podTpm)
}
}

Expand All @@ -368,12 +375,6 @@ func (s *Server) HandleResponseBody(ctx context.Context, req *extProcPb.Processi
Response: &extProcPb.CommonResponse{
HeaderMutation: &extProcPb.HeaderMutation{
SetHeaders: []*configPb.HeaderValueOption{
{
Header: &configPb.HeaderValue{
Key: "x-updated-rpm",
RawValue: []byte(fmt.Sprintf("%d", rpm)),
},
},
{
Header: &configPb.HeaderValue{
Key: "x-updated-tpm",
Expand Down Expand Up @@ -404,6 +405,16 @@ func (s *Server) checkRPM(ctx context.Context, user string, rpmLimit int64) (env
return envoyTypePb.StatusCode_OK, nil
}

func (s *Server) incrRPM(ctx context.Context, user string) (int64, envoyTypePb.StatusCode, error) {
rpm, err := s.ratelimiter.Incr(ctx, fmt.Sprintf("%v_RPM_CURRENT", user), 1)
if err != nil {
return rpm, envoyTypePb.StatusCode_InternalServerError, err
}

klog.Infof("Updated RPM: %v for user: %v", rpm, user)
return rpm, envoyTypePb.StatusCode_OK, nil
}

func (s *Server) checkTPM(ctx context.Context, user string, tpmLimit int64) (envoyTypePb.StatusCode, error) {
tpmCurrent, err := s.ratelimiter.Get(ctx, fmt.Sprintf("%v_TPM_CURRENT", user))
if err != nil {
Expand Down

0 comments on commit 0df904d

Please sign in to comment.