From 1bc20e1b297ab8a384bba590e181d68bf62c24f3 Mon Sep 17 00:00:00 2001 From: Varun Gupta Date: Wed, 18 Sep 2024 17:14:20 -0700 Subject: [PATCH] Add routing for model adapter (#183) * Add routing for model adapter * nit: logging * nit: gateway error response code refactoring * code review comments * add/delete httproute for model adapter * nit --------- Co-authored-by: varungupta --- config/gateway/gateway.yaml | 15 +- docs/development/app/README.md | 2 +- docs/tutorial/lora/README.md | 33 ++- docs/tutorial/lora/model_adapter.yaml | 22 ++ pkg/cache/cache.go | 196 +++++++++----- .../modeladapter/scheduling/leastadapters.go | 17 +- .../modelrouter/modelrouter_controller.go | 65 ++++- pkg/plugins/gateway/gateway.go | 239 +++++++----------- .../routing_algorithms/least_request.go | 2 +- .../gateway/routing_algorithms/random.go | 11 +- .../gateway/routing_algorithms/router.go | 2 +- .../gateway/routing_algorithms/throughput.go | 2 +- 12 files changed, 367 insertions(+), 239 deletions(-) diff --git a/config/gateway/gateway.yaml b/config/gateway/gateway.yaml index ff4e4ba7..418882f0 100644 --- a/config/gateway/gateway.yaml +++ b/config/gateway/gateway.yaml @@ -58,11 +58,16 @@ spec: name: original_route match: prefix: "/" - # headers: - # # update ip address as needed and in production this config is not needed as backend will derive the pod ip - # - name: "target-pod" - # string_match: - # exact: "10.244.1.3:8000" + headers: + - name: "routing-strategy" + string_match: + exact: "random" + - name: "least-request" + string_match: + exact: "random" + - name: "routing-strategy" + string_match: + exact: "throughput" route: cluster: original_destination_cluster timeout: 1000s # Increase route timeout diff --git a/docs/development/app/README.md b/docs/development/app/README.md index 950dcdae..e9c3f081 100644 --- a/docs/development/app/README.md +++ b/docs/development/app/README.md @@ -84,7 +84,7 @@ curl -v http://localhost:8888/v1/chat/completions \ "model": "llama2-70b", "messages": [{"role": "user", "content": "Say this is a test!"}], "temperature": 0.7 - }' + }' & # least-request based for i in {1..10}; do diff --git a/docs/tutorial/lora/README.md b/docs/tutorial/lora/README.md index 4509bcc5..25dfbd7b 100644 --- a/docs/tutorial/lora/README.md +++ b/docs/tutorial/lora/README.md @@ -12,7 +12,7 @@ docker build -t aibrix/vllm-mock:nightly -f Dockerfile . 2. Deploy mocked model image ```shell -kubectl apply -f deployment.yaml +kubectl apply -f docs/development/app/deployment.yaml ``` 3. Load models @@ -43,7 +43,7 @@ Verified! The model is loaded and unloaded successfully and pod annotations are 5. Deploy the controller and apply the `model_adapter.yaml` ``` -kubectl apply -f model_adapter.yaml +kubectl apply -f docs/tutorial/lora/model_adapter.yaml ``` @@ -70,4 +70,33 @@ curl https://localhost:8000/v1/completions \ "max_tokens": 7, "temperature": 0 }' +``` + +# request via gateway without routing strategy +```shell +curl -v http://localhost:8888/v1/chat/completions \ + -H "user: your-user-name" \ + -H "model: lora-1" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer any_key" \ + -d '{ + "model": "lora-1", + "messages": [{"role": "user", "content": "Say this is a test!"}], + "temperature": 0.7 + }' +``` + +# request via gateway with routing strategy +```shell +curl -v http://localhost:8888/v1/chat/completions \ + -H "user: your-user-name" \ + -H "model: lora-1" \ + -H "routing-strategy: least-request" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer any_key" \ + -d '{ + "model": "lora-1", + "messages": [{"role": "user", "content": "Say this is a test!"}], + "temperature": 0.7 + }' ``` \ No newline at end of file diff --git a/docs/tutorial/lora/model_adapter.yaml b/docs/tutorial/lora/model_adapter.yaml index 389cc22b..3dba328e 100644 --- a/docs/tutorial/lora/model_adapter.yaml +++ b/docs/tutorial/lora/model_adapter.yaml @@ -3,6 +3,9 @@ kind: ModelAdapter metadata: name: lora-1 namespace: aibrix-system + labels: + model.aibrix.ai: "lora-1" + model.aibrix.ai/port: "8000" spec: baseModel: llama2-70b podSelector: @@ -10,3 +13,22 @@ spec: model.aibrix.ai: llama2-70b artifactURL: huggingface://yard1/llama-2-7b-sql-lora-test schedulerName: default +# --- +# # for test-purpose, if need to create HTTPRoute object manually +# apiVersion: gateway.networking.k8s.io/v1 +# kind: HTTPRoute +# metadata: +# name: lora-1-router +# namespace: aibrix-system +# spec: +# parentRefs: +# - name: aibrix-eg +# rules: +# - matches: +# - headers: +# - type: Exact +# name: model +# value: lora-1 +# backendRefs: +# - name: lora-1 +# port: 8000 \ No newline at end of file diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index f3f37b80..d5c81307 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -19,7 +19,6 @@ package cache import ( "errors" "fmt" - "strings" "sync" crdinformers "github.com/aibrix/aibrix/pkg/client/informers/externalversions" @@ -41,18 +40,22 @@ var once sync.Once // type global type Cache struct { - mu sync.RWMutex - initialized bool - pods map[string]*v1.Pod - modelAdapterToPodMapping map[string][]string - podToModelAdapterMapping map[string]map[string]struct{} - podRequestTracker map[string]int + mu sync.RWMutex + initialized bool + pods map[string]*v1.Pod + podToModelMapping map[string]map[string]struct{} // pod_name: map[model_name]struct{} + modelToPodMapping map[string]map[string]*v1.Pod // model_name: map[pod_name]*v1.Pod + podRequestTracker map[string]int } var ( instance Cache ) +const ( + modelIdentifier = "model.aibrix.ai" +) + func GetCache() (*Cache, error) { if !instance.initialized { return nil, errors.New("cache is not initialized") @@ -81,7 +84,7 @@ func NewCache(config *rest.Config, stopCh <-chan struct{}) *Cache { crdFactory := crdinformers.NewSharedInformerFactoryWithOptions(crdClientSet, 0) podInformer := factory.Core().V1().Pods().Informer() - modeInformer := crdFactory.Model().V1alpha1().ModelAdapters().Informer() + modelInformer := crdFactory.Model().V1alpha1().ModelAdapters().Informer() defer runtime.HandleCrash() factory.Start(stopCh) @@ -90,17 +93,17 @@ func NewCache(config *rest.Config, stopCh <-chan struct{}) *Cache { // factory.WaitForCacheSync(stopCh) // crdFactory.WaitForCacheSync(stopCh) - if !cache.WaitForCacheSync(stopCh, podInformer.HasSynced, modeInformer.HasSynced) { + if !cache.WaitForCacheSync(stopCh, podInformer.HasSynced, modelInformer.HasSynced) { runtime.HandleError(fmt.Errorf("timed out waiting for caches to sync")) return } instance = Cache{ - initialized: true, - pods: map[string]*v1.Pod{}, - modelAdapterToPodMapping: map[string][]string{}, - podToModelAdapterMapping: map[string]map[string]struct{}{}, - podRequestTracker: map[string]int{}, + initialized: true, + pods: map[string]*v1.Pod{}, + podToModelMapping: map[string]map[string]struct{}{}, + modelToPodMapping: map[string]map[string]*v1.Pod{}, + podRequestTracker: map[string]int{}, } if _, err := podInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{ @@ -111,10 +114,10 @@ func NewCache(config *rest.Config, stopCh <-chan struct{}) *Cache { panic(err) } - if _, err = modeInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{ - AddFunc: instance.addModel, - UpdateFunc: instance.updateModel, - DeleteFunc: instance.deleteModel, + if _, err = modelInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{ + AddFunc: instance.addModelAdapter, + UpdateFunc: instance.updateModelAdapter, + DeleteFunc: instance.deleteModelAdapter, }); err != nil { panic(err) } @@ -128,9 +131,16 @@ func (c *Cache) addPod(obj interface{}) { defer c.mu.Unlock() pod := obj.(*v1.Pod) + // only track pods with model deployments + modelName, ok := pod.Labels[modelIdentifier] + if !ok { + return + } + c.pods[pod.Name] = pod - c.podToModelAdapterMapping[pod.Name] = map[string]struct{}{} + c.addPodAndModelMapping(pod.Name, modelName) klog.Infof("POD CREATED: %s/%s", pod.Namespace, pod.Name) + c.debugInfo() } func (c *Cache) updatePod(oldObj interface{}, newObj interface{}) { @@ -138,8 +148,21 @@ func (c *Cache) updatePod(oldObj interface{}, newObj interface{}) { defer c.mu.Unlock() oldPod := oldObj.(*v1.Pod) + oldModelName, ok := oldPod.Labels[modelIdentifier] + if !ok { + return + } + newPod := newObj.(*v1.Pod) + newModelName, ok := oldPod.Labels[modelIdentifier] + if !ok { + return + } + + c.deletePodAndModelMapping(oldPod.Name, oldModelName) + c.addPodAndModelMapping(newPod.Name, newModelName) klog.Infof("POD UPDATED. %s/%s %s", oldPod.Namespace, oldPod.Name, newPod.Status.Phase) + c.debugInfo() } func (c *Cache) deletePod(obj interface{}) { @@ -147,85 +170,125 @@ func (c *Cache) deletePod(obj interface{}) { defer c.mu.Unlock() pod := obj.(*v1.Pod) + modelName, ok := pod.Labels[modelIdentifier] + if !ok { + return + } + delete(c.pods, pod.Name) + c.deletePodAndModelMapping(pod.Name, modelName) klog.Infof("POD DELETED: %s/%s", pod.Namespace, pod.Name) + c.debugInfo() } -func (c *Cache) addModel(obj interface{}) { +func (c *Cache) addModelAdapter(obj interface{}) { c.mu.Lock() defer c.mu.Unlock() model := obj.(*modelv1alpha1.ModelAdapter) - c.modelAdapterToPodMapping[model.Name] = model.Status.Instances - c.addModelAdapterMapping(model) + for _, pod := range model.Status.Instances { + c.addPodAndModelMapping(pod, model.Name) + } klog.Infof("MODELADAPTER CREATED: %s/%s", model.Namespace, model.Name) + c.debugInfo() } -func (c *Cache) updateModel(oldObj interface{}, newObj interface{}) { +func (c *Cache) updateModelAdapter(oldObj interface{}, newObj interface{}) { c.mu.Lock() defer c.mu.Unlock() oldModel := oldObj.(*modelv1alpha1.ModelAdapter) newModel := newObj.(*modelv1alpha1.ModelAdapter) - c.modelAdapterToPodMapping[newModel.Name] = newModel.Status.Instances - c.deleteModelAdapterMapping(oldModel) - c.addModelAdapterMapping(newModel) + + for _, pod := range oldModel.Status.Instances { + c.deletePodAndModelMapping(pod, oldModel.Name) + } + + for _, pod := range newModel.Status.Instances { + c.addPodAndModelMapping(pod, newModel.Name) + } klog.Infof("MODELADAPTER UPDATED. %s/%s %s", oldModel.Namespace, oldModel.Name, newModel.Status.Phase) + c.debugInfo() } -func (c *Cache) deleteModel(obj interface{}) { +func (c *Cache) deleteModelAdapter(obj interface{}) { c.mu.Lock() defer c.mu.Unlock() model := obj.(*modelv1alpha1.ModelAdapter) - delete(c.modelAdapterToPodMapping, model.Name) - c.deleteModelAdapterMapping(model) + for _, pod := range model.Status.Instances { + c.deletePodAndModelMapping(pod, model.Name) + } klog.Infof("MODELADAPTER DELETED: %s/%s", model.Namespace, model.Name) + c.debugInfo() } -func (c *Cache) addModelAdapterMapping(model *modelv1alpha1.ModelAdapter) { - for _, pod := range model.Status.Instances { - models, ok := c.podToModelAdapterMapping[pod] - if !ok { - c.podToModelAdapterMapping[pod] = map[string]struct{}{ - model.Name: {}, - } - continue +func (c *Cache) addPodAndModelMapping(podName, modelName string) { + pod, ok := c.pods[podName] + if !ok { + klog.Errorf("pod %s does not exist in internal-cache", podName) + return + } + + models, ok := c.podToModelMapping[podName] + if !ok { + c.podToModelMapping[podName] = map[string]struct{}{ + modelName: {}, } + } else { + models[modelName] = struct{}{} + c.podToModelMapping[podName] = models + } - models[model.Name] = struct{}{} - c.podToModelAdapterMapping[pod] = models + pods, ok := c.modelToPodMapping[modelName] + if !ok { + c.modelToPodMapping[modelName] = map[string]*v1.Pod{ + podName: pod, + } + } else { + pods[podName] = pod + c.modelToPodMapping[modelName] = pods } } -func (c *Cache) deleteModelAdapterMapping(model *modelv1alpha1.ModelAdapter) { - for _, pod := range model.Status.Instances { - modelAdapters := c.podToModelAdapterMapping[pod] - delete(modelAdapters, model.Name) - c.podToModelAdapterMapping[pod] = modelAdapters - } +func (c *Cache) deletePodAndModelMapping(podName, modelName string) { + delete(c.podToModelMapping, podName) + delete(c.modelToPodMapping, modelName) } func (c *Cache) debugInfo() { - for model, instances := range c.modelAdapterToPodMapping { - klog.Infof("modelName: %s, instances: %v", model, instances) + for _, pod := range c.pods { + klog.Info(pod.Name) } - - for pod, models := range c.podToModelAdapterMapping { - if !strings.HasPrefix(pod, "llama") { - continue + for podName, models := range c.podToModelMapping { + var modelList string + for modelName := range models { + modelList += modelName + " " } - - modelsArr := []string{} - for m := range models { - modelsArr = append(modelsArr, m) + klog.Infof("pod: %s, models: %s", podName, modelList) + } + for modelName, pods := range c.modelToPodMapping { + var podList string + for podName := range pods { + podList += podName + " " } + klog.Infof("model: %s, pods: %s", modelName, podList) + } +} + +func (c *Cache) GetPod(podName string) (*v1.Pod, error) { + c.mu.RLock() + defer c.mu.RUnlock() - klog.Infof("podName: %s, modelAdapters: %v", pod, modelsArr) + pod, ok := c.pods[podName] + if !ok { + return nil, fmt.Errorf("pod does not exist in the cache: %s", podName) } + + return pod, nil } func (c *Cache) GetPods() map[string]*v1.Pod { @@ -235,11 +298,28 @@ func (c *Cache) GetPods() map[string]*v1.Pod { return c.pods } -func (c *Cache) GetPodToModelAdapterMapping() map[string]map[string]struct{} { +func (c *Cache) GetPodsForModel(modelName string) (map[string]*v1.Pod, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + podsMap, ok := c.modelToPodMapping[modelName] + if !ok { + return nil, fmt.Errorf("model does not exist in the cache: %s", modelName) + } + + return podsMap, nil +} + +func (c *Cache) GetModelsForPod(podName string) (map[string]struct{}, error) { c.mu.RLock() defer c.mu.RUnlock() - return c.podToModelAdapterMapping + models, ok := c.podToModelMapping[podName] + if !ok { + return nil, fmt.Errorf("pod does not exist in the cache: %s", podName) + } + + return models, nil } func (c *Cache) IncrPodRequestCount(podName string) int { diff --git a/pkg/controller/modeladapter/scheduling/leastadapters.go b/pkg/controller/modeladapter/scheduling/leastadapters.go index f3cdf148..6b289cd1 100644 --- a/pkg/controller/modeladapter/scheduling/leastadapters.go +++ b/pkg/controller/modeladapter/scheduling/leastadapters.go @@ -18,7 +18,6 @@ package scheduling import ( "context" - "errors" "math" "github.com/aibrix/aibrix/pkg/cache" @@ -37,24 +36,20 @@ func NewLeastAdapters(c *cache.Cache) Scheduler { } func (r leastAdapters) SelectPod(ctx context.Context, pods []v1.Pod) (*v1.Pod, error) { - modelAdapterCountMin := math.MaxInt selectedPod := v1.Pod{} - podMap := r.cache.GetPods() - podToModelAdapterMapping := r.cache.GetPodToModelAdapterMapping() + modelAdapterCountMin := math.MaxInt for _, pod := range pods { - if _, ok := podMap[pod.Name]; !ok { - return nil, errors.New("pod not found in the cache") + models, err := r.cache.GetModelsForPod(pod.Name) + if err != nil { + return nil, err } - - modelAdapters := podToModelAdapterMapping[pod.Name] - if len(modelAdapters) < modelAdapterCountMin { + if len(models) < modelAdapterCountMin { selectedPod = pod - modelAdapterCountMin = len(modelAdapters) + modelAdapterCountMin = len(models) } } klog.Infof("pod selected with least model adapters: %s", selectedPod.Name) - return &selectedPod, nil } diff --git a/pkg/controller/modelrouter/modelrouter_controller.go b/pkg/controller/modelrouter/modelrouter_controller.go index a51a1a85..7d84471b 100644 --- a/pkg/controller/modelrouter/modelrouter_controller.go +++ b/pkg/controller/modelrouter/modelrouter_controller.go @@ -30,6 +30,8 @@ import ( "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/manager" + + modelv1alpha1 "github.com/aibrix/aibrix/api/model/v1alpha1" gatewayv1 "sigs.k8s.io/gateway-api/apis/v1" ) @@ -54,15 +56,29 @@ func Add(mgr manager.Manager) error { return err } + modelInformer, err := cacher.GetInformer(context.TODO(), &modelv1alpha1.ModelAdapter{}) + if err != nil { + return err + } + utilruntime.Must(gatewayv1.AddToScheme(mgr.GetClient().Scheme())) modelRouter := &ModelRouter{ Client: mgr.GetClient(), } + _, err = deploymentInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{ AddFunc: modelRouter.addModel, DeleteFunc: modelRouter.deleteModel, }) + if err != nil { + return err + } + + _, err = modelInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{ + AddFunc: modelRouter.addModelAdapter, + DeleteFunc: modelRouter.deleteModelAdapter, + }) return err } @@ -74,19 +90,39 @@ type ModelRouter struct { func (m *ModelRouter) addModel(obj interface{}) { deployment := obj.(*appsv1.Deployment) + m.createHTTPRoute(deployment.Namespace, deployment.Labels) +} + +func (m *ModelRouter) deleteModel(obj interface{}) { + deployment := obj.(*appsv1.Deployment) + m.deleteHTTPRoute(deployment.Namespace, deployment.Labels) +} + +func (m *ModelRouter) addModelAdapter(obj interface{}) { + modelAdapter := obj.(*modelv1alpha1.ModelAdapter) + m.createHTTPRoute(modelAdapter.Namespace, modelAdapter.Labels) +} - modelName, ok := deployment.Labels[modelIdentifier] +func (m *ModelRouter) deleteModelAdapter(obj interface{}) { + modelAdapter := obj.(*modelv1alpha1.ModelAdapter) + m.deleteHTTPRoute(modelAdapter.Namespace, modelAdapter.Labels) +} + +func (m *ModelRouter) createHTTPRoute(namespace string, labels map[string]string) { + modelName, ok := labels[modelIdentifier] if !ok { - fmt.Printf("deployment %s does not have a model, labels: %s\n", deployment.Name, deployment.Labels) return } - modelPort, _ := strconv.ParseInt(deployment.Labels[modelPortIdentifier], 10, 32) + modelPort, err := strconv.ParseInt(labels[modelPortIdentifier], 10, 32) + if err != nil { + return + } httpRoute := gatewayv1.HTTPRoute{ ObjectMeta: metav1.ObjectMeta{ Name: fmt.Sprintf("%s-router", modelName), - Namespace: deployment.Namespace, + Namespace: namespace, }, Spec: gatewayv1.HTTPRouteSpec{ CommonRouteSpec: gatewayv1.CommonRouteSpec{ @@ -124,26 +160,27 @@ func (m *ModelRouter) addModel(obj interface{}) { }, }, } - err := m.Client.Create(context.Background(), &httpRoute) - klog.Errorln(err) + if err := m.Client.Create(context.Background(), &httpRoute); err != nil { + klog.Errorln(err) + } + klog.Infof("httproute: %v created for model: %v", httpRoute.Name, modelName) } -func (m *ModelRouter) deleteModel(obj interface{}) { - deployment := obj.(*appsv1.Deployment) - - modelName, ok := deployment.Labels[modelIdentifier] +func (m *ModelRouter) deleteHTTPRoute(namespace string, labels map[string]string) { + modelName, ok := labels[modelIdentifier] if !ok { - fmt.Printf("deployment %s does not have a model, labels: %s\n", deployment.Name, deployment.Labels) return } httpRoute := gatewayv1.HTTPRoute{ ObjectMeta: metav1.ObjectMeta{ Name: fmt.Sprintf("%s-router", modelName), - Namespace: deployment.Namespace, + Namespace: namespace, }, } - err := m.Client.Delete(context.Background(), &httpRoute) - klog.Errorln(err) + if err := m.Client.Delete(context.Background(), &httpRoute); err != nil { + klog.Errorln(err) + } + klog.Infof("httproute: %v deleted for model: %v", httpRoute.Name, modelName) } diff --git a/pkg/plugins/gateway/gateway.go b/pkg/plugins/gateway/gateway.go index 8165e529..0fe3aab5 100644 --- a/pkg/plugins/gateway/gateway.go +++ b/pkg/plugins/gateway/gateway.go @@ -29,8 +29,7 @@ import ( openai "github.com/sashabaranov/go-openai" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - corev1 "k8s.io/api/core/v1" - v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + v1 "k8s.io/api/core/v1" "k8s.io/client-go/kubernetes" "k8s.io/klog" @@ -158,9 +157,12 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, req *extProcPb.Proces user, err := utils.GetUser(utils.User{Name: username}, s.redisClient) if err != nil { - // TODO: return immediate response - klog.Infoln("user does not exists") - return nil, username, targetPodIP + return generateErrorResponse( + envoyTypePb.StatusCode_Forbidden, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: "x-user-missing", RawValue: []byte("true"), + }}}, + fmt.Sprintf("pre query: username is missing: %v", err.Error())), username, targetPodIP } if user.Rpm == 0 { @@ -172,92 +174,22 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, req *extProcPb.Proces code, err := s.checkRPM(ctx, username, user.Rpm) if err != nil { - return &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_ImmediateResponse{ - ImmediateResponse: &extProcPb.ImmediateResponse{ - Status: &envoyTypePb.HttpStatus{ - Code: code, - }, - Details: err.Error(), - Headers: &extProcPb.HeaderMutation{ - SetHeaders: []*configPb.HeaderValueOption{ - { - Header: &configPb.HeaderValue{ - Key: "x-rpm-exceeded", - RawValue: []byte("true"), - }, - }, - }, - }, - }, - }, - }, username, targetPodIP + return generateErrorResponse( + code, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: "x-rpm-exceeded", RawValue: []byte("true"), + }}}, + fmt.Sprintf("pre query: error on checking rpm: %v", err.Error())), username, targetPodIP } code, err = s.checkTPM(ctx, username, user.Tpm) if err != nil { - return &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_ImmediateResponse{ - ImmediateResponse: &extProcPb.ImmediateResponse{ - Status: &envoyTypePb.HttpStatus{ - Code: code, - }, - Details: err.Error(), - Headers: &extProcPb.HeaderMutation{ - SetHeaders: []*configPb.HeaderValueOption{ - { - Header: &configPb.HeaderValue{ - Key: "x-tpm-exceeded", - RawValue: []byte("true"), - }, - }, - }, - }, - }, - }, - }, username, targetPodIP - } - - pods, err := s.client.CoreV1().Pods(utils.NAMESPACE).List(ctx, v1.ListOptions{ - LabelSelector: fmt.Sprintf("model.aibrix.ai=%s", model), - }) - if err != nil { - klog.Error(err) - return &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_ImmediateResponse{ - ImmediateResponse: &extProcPb.ImmediateResponse{ - Status: &envoyTypePb.HttpStatus{ - Code: code, - }, - Details: err.Error(), - Headers: &extProcPb.HeaderMutation{ - SetHeaders: []*configPb.HeaderValueOption{ - { - Header: &configPb.HeaderValue{ - Key: "x-routing-error", - RawValue: []byte("true"), - }, - }, - }, - }, - }, - }, - }, username, targetPodIP - } - - targetPodIP, err = s.SelectTargetPod(ctx, routingStrategy, pods.Items) - if err != nil { - return &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_ImmediateResponse{ - ImmediateResponse: &extProcPb.ImmediateResponse{ - Status: &envoyTypePb.HttpStatus{ - Code: envoyTypePb.StatusCode_InternalServerError, - }, - Details: err.Error(), - Body: "error on selecting target pod", - }, - }, - }, username, targetPodIP + return generateErrorResponse( + code, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: "x-tpm-exceeded", RawValue: []byte("true"), + }}}, + fmt.Sprintf("pre query: error on checking tpm: %v", err.Error())), username, targetPodIP } headers := []*configPb.HeaderValueOption{{ @@ -266,14 +198,33 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, req *extProcPb.Proces RawValue: []byte("true"), }, }} - if targetPodIP != "" { + if routingStrategy != "" { + pods, err := s.cache.GetPodsForModel(model) + if len(pods) == 0 || err != nil { + return generateErrorResponse( + code, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: "x-no-model-deployment", RawValue: []byte("true"), + }}}, + fmt.Sprintf("pre query: no models are deployed: %v", err.Error())), username, targetPodIP + } + + targetPodIP, err = s.selectTargetPod(ctx, routingStrategy, pods) + if targetPodIP == "" || err != nil { + return generateErrorResponse( + code, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: "x-select-target-pod", RawValue: []byte("true"), + }}}, + fmt.Sprintf("pre query: error on selecting target pod: %v", err.Error())), username, targetPodIP + } + headers = append(headers, &configPb.HeaderValueOption{ Header: &configPb.HeaderValue{ Key: "target-pod", RawValue: []byte(targetPodIP), }, }) - podRequestCounter := s.cache.IncrPodRequestCount(fmt.Sprintf("%v_REQUEST_COUNT", targetPodIP)) klog.Infof("RequestStart: SelectedTargetPodIP: %s, PodRequestCount: %v", targetPodIP, podRequestCounter) } @@ -330,25 +281,27 @@ func (s *Server) HandleRequestBody(req *extProcPb.ProcessingRequest, targetPodIP func (s *Server) HandleResponseHeaders(req *extProcPb.ProcessingRequest, targetPodIP string) *extProcPb.ProcessingResponse { log.Println("--- In ResponseHeaders processing") + headers := []*configPb.HeaderValueOption{{ + Header: &configPb.HeaderValue{ + Key: "x-went-into-resp-headers", + RawValue: []byte("true"), + }, + }} + if targetPodIP != "" { + headers = append(headers, &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: "target-pod", + RawValue: []byte(targetPodIP), + }, + }) + } + return &extProcPb.ProcessingResponse{ Response: &extProcPb.ProcessingResponse_ResponseHeaders{ ResponseHeaders: &extProcPb.HeadersResponse{ Response: &extProcPb.CommonResponse{ HeaderMutation: &extProcPb.HeaderMutation{ - SetHeaders: []*configPb.HeaderValueOption{ - { - Header: &configPb.HeaderValue{ - Key: "x-went-into-resp-headers", - RawValue: []byte("true"), - }, - }, - { - Header: &configPb.HeaderValue{ - Key: "target-pod", - RawValue: []byte(targetPodIP), - }, - }, - }, + SetHeaders: headers, }, ClearRouteCache: true, }, @@ -363,54 +316,44 @@ func (s *Server) HandleResponseBody(ctx context.Context, req *extProcPb.Processi r := req.Request b := r.(*extProcPb.ProcessingRequest_ResponseBody) + defer func() { + if targetPodIP != "" { + podRequestCounter := s.cache.DecrPodRequestCount(fmt.Sprintf("%v_REQUEST_COUNT", targetPodIP)) + klog.Infof("RequestEnd: SelectedTargetPodIP: %s, PodRequestCount: %v", targetPodIP, podRequestCounter) + } + }() + var res openai.CompletionResponse if err := json.Unmarshal(b.ResponseBody.Body, &res); err != nil { - return &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_ImmediateResponse{ - ImmediateResponse: &extProcPb.ImmediateResponse{ - Status: &envoyTypePb.HttpStatus{ - Code: envoyTypePb.StatusCode_InternalServerError, - }, - Details: err.Error(), - }, - }, - } + return generateErrorResponse( + envoyTypePb.StatusCode_InternalServerError, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: "x-error-response-unmarshal", RawValue: []byte("true"), + }}}, + err.Error()) } rpm, err := s.ratelimiter.Incr(ctx, fmt.Sprintf("%v_RPM_CURRENT", user), 1) if err != nil { - return &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_ImmediateResponse{ - ImmediateResponse: &extProcPb.ImmediateResponse{ - Status: &envoyTypePb.HttpStatus{ - Code: envoyTypePb.StatusCode_InternalServerError, - }, - Details: err.Error(), - Body: "post query: error on updating rpm", - }, - }, - } + return generateErrorResponse( + envoyTypePb.StatusCode_InternalServerError, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: "x-error-update-rpm", RawValue: []byte("true"), + }}}, + fmt.Sprintf("post query: error on updating rpm: %v", err.Error())) } tpm, err := s.ratelimiter.Incr(ctx, fmt.Sprintf("%v_TPM_CURRENT", user), int64(res.Usage.TotalTokens)) if err != nil { - return &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_ImmediateResponse{ - ImmediateResponse: &extProcPb.ImmediateResponse{ - Status: &envoyTypePb.HttpStatus{ - Code: envoyTypePb.StatusCode_InternalServerError, - }, - Details: err.Error(), - Body: "post query: error on updating tpm", - }, - }, - } + return generateErrorResponse( + envoyTypePb.StatusCode_InternalServerError, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: "x-error-update-tpm", RawValue: []byte("true"), + }}}, + fmt.Sprintf("post query: error on updating tpm: %v", err.Error())) } klog.Infof("Updated RPM: %v, TPM: %v for user: %v", rpm, tpm, user) if targetPodIP != "" { - podRequestCounter := s.cache.DecrPodRequestCount(fmt.Sprintf("%v_REQUEST_COUNT", targetPodIP)) - klog.Infof("RequestEnd: SelectedTargetPodIP: %s, PodRequestCount: %v", targetPodIP, podRequestCounter) - podTpm, err := s.ratelimiter.Incr(ctx, fmt.Sprintf("%v_THROUGHPUT", targetPodIP), int64(res.Usage.TotalTokens)) if err != nil { klog.Error(err) @@ -477,11 +420,9 @@ func (s *Server) checkTPM(ctx context.Context, user string, tpmLimit int64) (env return envoyTypePb.StatusCode_OK, nil } -func (s *Server) SelectTargetPod(ctx context.Context, routingStrategy string, pods []corev1.Pod) (string, error) { +func (s *Server) selectTargetPod(ctx context.Context, routingStrategy string, pods map[string]*v1.Pod) (string, error) { var route routing.Router switch routingStrategy { - case "random": - route = s.routers[routingStrategy] case "least-request": route = s.routers[routingStrategy] case "throughput": @@ -492,3 +433,19 @@ func (s *Server) SelectTargetPod(ctx context.Context, routingStrategy string, po return route.Get(ctx, pods) } + +func generateErrorResponse(statusCode envoyTypePb.StatusCode, headers []*configPb.HeaderValueOption, body string) *extProcPb.ProcessingResponse { + return &extProcPb.ProcessingResponse{ + Response: &extProcPb.ProcessingResponse_ImmediateResponse{ + ImmediateResponse: &extProcPb.ImmediateResponse{ + Status: &envoyTypePb.HttpStatus{ + Code: statusCode, + }, + Headers: &extProcPb.HeaderMutation{ + SetHeaders: headers, + }, + Body: body, + }, + }, + } +} diff --git a/pkg/plugins/gateway/routing_algorithms/least_request.go b/pkg/plugins/gateway/routing_algorithms/least_request.go index 36cacd55..786d1753 100644 --- a/pkg/plugins/gateway/routing_algorithms/least_request.go +++ b/pkg/plugins/gateway/routing_algorithms/least_request.go @@ -44,7 +44,7 @@ func NewLeastRequestRouter(ratelimiter ratelimiter.AccountRateLimiter) Router { } } -func (r leastRequestRouter) Get(ctx context.Context, pods []v1.Pod) (string, error) { +func (r leastRequestRouter) Get(ctx context.Context, pods map[string]*v1.Pod) (string, error) { var targetPodIP string minCount := math.MaxInt podRequestCounts := r.cache.GetPodRequestCount() diff --git a/pkg/plugins/gateway/routing_algorithms/random.go b/pkg/plugins/gateway/routing_algorithms/random.go index 220575ba..3f74c20e 100644 --- a/pkg/plugins/gateway/routing_algorithms/random.go +++ b/pkg/plugins/gateway/routing_algorithms/random.go @@ -18,7 +18,6 @@ package routingalgorithms import ( "context" - "math/rand/v2" v1 "k8s.io/api/core/v1" ) @@ -30,7 +29,11 @@ func NewRandomRouter() Router { return randomRouter{} } -func (r randomRouter) Get(ctx context.Context, pods []v1.Pod) (string, error) { - pod := pods[rand.IntN(3)] - return pod.Status.PodIP + ":8000", nil // TODO (varun): remove static port +func (r randomRouter) Get(ctx context.Context, pods map[string]*v1.Pod) (string, error) { + var selectedPod *v1.Pod + for _, pod := range pods { + selectedPod = pod + } + + return selectedPod.Status.PodIP + ":8000", nil // TODO (varun): remove static port } diff --git a/pkg/plugins/gateway/routing_algorithms/router.go b/pkg/plugins/gateway/routing_algorithms/router.go index 354a5c58..b6ee67b3 100644 --- a/pkg/plugins/gateway/routing_algorithms/router.go +++ b/pkg/plugins/gateway/routing_algorithms/router.go @@ -25,5 +25,5 @@ import ( type Router interface { // Returns the target pod // TODO (varun): replace with cache util package which can watch on pods - Get(ctx context.Context, pods []v1.Pod) (string, error) + Get(ctx context.Context, pods map[string]*v1.Pod) (string, error) } diff --git a/pkg/plugins/gateway/routing_algorithms/throughput.go b/pkg/plugins/gateway/routing_algorithms/throughput.go index 15db9502..186cb0b5 100644 --- a/pkg/plugins/gateway/routing_algorithms/throughput.go +++ b/pkg/plugins/gateway/routing_algorithms/throughput.go @@ -36,7 +36,7 @@ func NewThroughputRouter(ratelimiter ratelimiter.AccountRateLimiter) Router { } } -func (r throughputRouter) Get(ctx context.Context, pods []v1.Pod) (string, error) { +func (r throughputRouter) Get(ctx context.Context, pods map[string]*v1.Pod) (string, error) { var targetPodIP string minCount := math.MaxInt