Skip to content

Commit

Permalink
Add routing for model adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
varungupta authored and varungupta committed Sep 16, 2024
1 parent 3e74a14 commit 4993f0c
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 35 deletions.
18 changes: 16 additions & 2 deletions docs/tutorial/lora/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ docker build -t aibrix/vllm-mock:nightly -f Dockerfile .

2. Deploy mocked model image
```shell
kubectl apply -f deployment.yaml
kubectl apply -f docs/development/app/deployment.yaml
```

3. Load models
Expand Down Expand Up @@ -43,7 +43,7 @@ Verified! The model is loaded and unloaded successfully and pod annotations are
5. Deploy the controller and apply the `model_adapter.yaml`

```
kubectl apply -f model_adapter.yaml
kubectl apply -f docs/tutorial/lora/model_adapter.yaml
```


Expand All @@ -70,4 +70,18 @@ curl https://localhost:8000/v1/completions \
"max_tokens": 7,
"temperature": 0
}'
```

```shell
curl -v http://localhost:8888/v1/chat/completions \
-H "user: your-user-name" \
-H "model: lora-1" \
-H "routing-strategy: random" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer any_key" \
-d '{
"model": "lora-1",
"messages": [{"role": "user", "content": "Say this is a test!"}],
"temperature": 0.7
}'
```
146 changes: 133 additions & 13 deletions pkg/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,15 @@ var once sync.Once

// type global
type Cache struct {
mu sync.RWMutex
initialized bool
pods map[string]*v1.Pod
modelAdapterToPodMapping map[string][]string
podToModelAdapterMapping map[string]map[string]struct{}
mu sync.RWMutex
initialized bool
pods map[string]*v1.Pod

podToModelMapping map[string]map[string]struct{} // pod_name: map[model_name]struct{}
modelToPodMapping map[string]map[string]*v1.Pod // model_name: map[pod_name]*v1.Pod

modelAdapterToPodMapping map[string][]string // TODO deprecate: model_adapter_name: []pods
podToModelAdapterMapping map[string]map[string]struct{} // TODO deprecate: pod_name: map[model_adapter_name]struct{}
podRequestTracker map[string]int
}

Expand Down Expand Up @@ -96,8 +100,12 @@ func NewCache(config *rest.Config, stopCh <-chan struct{}) *Cache {
}

instance = Cache{
initialized: true,
pods: map[string]*v1.Pod{},
initialized: true,
pods: map[string]*v1.Pod{},

podToModelMapping: map[string]map[string]struct{}{},
modelToPodMapping: map[string]map[string]*v1.Pod{},

modelAdapterToPodMapping: map[string][]string{},
podToModelAdapterMapping: map[string]map[string]struct{}{},
podRequestTracker: map[string]int{},
Expand All @@ -112,9 +120,9 @@ func NewCache(config *rest.Config, stopCh <-chan struct{}) *Cache {
}

if _, err = modeInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{
AddFunc: instance.addModel,
UpdateFunc: instance.updateModel,
DeleteFunc: instance.deleteModel,
AddFunc: instance.addModelAdapter,
UpdateFunc: instance.updateModelAdapter,
DeleteFunc: instance.deleteModelAdapter,
}); err != nil {
panic(err)
}
Expand All @@ -128,62 +136,143 @@ func (c *Cache) addPod(obj interface{}) {
defer c.mu.Unlock()

pod := obj.(*v1.Pod)
// only track pods with model deployments
modelName, ok := pod.Labels[modelv1alpha1.GroupVersion.Group]
if !ok {
return
}

c.pods[pod.Name] = pod
c.addPodAndModelMapping(pod.Name, modelName)
c.podToModelAdapterMapping[pod.Name] = map[string]struct{}{}
klog.Infof("POD CREATED: %s/%s", pod.Namespace, pod.Name)
c.debugInfo()
}

func (c *Cache) updatePod(oldObj interface{}, newObj interface{}) {
c.mu.Lock()
defer c.mu.Unlock()

oldPod := oldObj.(*v1.Pod)
oldModelName, ok := oldPod.Labels[modelv1alpha1.GroupVersion.Group]
if !ok {
return
}

newPod := newObj.(*v1.Pod)
newModelName, ok := oldPod.Labels[modelv1alpha1.GroupVersion.Group]
if !ok {
return
}

c.deletePodAndModelMapping(oldPod.Name, oldModelName)
c.addPodAndModelMapping(newPod.Name, newModelName)
klog.Infof("POD UPDATED. %s/%s %s", oldPod.Namespace, oldPod.Name, newPod.Status.Phase)
c.debugInfo()
}

func (c *Cache) deletePod(obj interface{}) {
c.mu.Lock()
defer c.mu.Unlock()

pod := obj.(*v1.Pod)
modelName, ok := pod.Labels[modelv1alpha1.GroupVersion.Group]
if !ok {
return
}

delete(c.pods, pod.Name)
c.deletePodAndModelMapping(pod.Name, modelName)
klog.Infof("POD DELETED: %s/%s", pod.Namespace, pod.Name)
c.debugInfo()
}

func (c *Cache) addModel(obj interface{}) {
func (c *Cache) addModelAdapter(obj interface{}) {
c.mu.Lock()
defer c.mu.Unlock()

model := obj.(*modelv1alpha1.ModelAdapter)
for _, pod := range model.Status.Instances {
c.addPodAndModelMapping(pod, model.Name)
}

c.modelAdapterToPodMapping[model.Name] = model.Status.Instances
c.addModelAdapterMapping(model)

klog.Infof("MODELADAPTER CREATED: %s/%s", model.Namespace, model.Name)
c.debugInfo()
}

func (c *Cache) updateModel(oldObj interface{}, newObj interface{}) {
func (c *Cache) updateModelAdapter(oldObj interface{}, newObj interface{}) {
c.mu.Lock()
defer c.mu.Unlock()

oldModel := oldObj.(*modelv1alpha1.ModelAdapter)
newModel := newObj.(*modelv1alpha1.ModelAdapter)

for _, pod := range oldModel.Status.Instances {
c.deletePodAndModelMapping(pod, oldModel.Name)
}

for _, pod := range newModel.Status.Instances {
c.addPodAndModelMapping(pod, newModel.Name)
}

c.modelAdapterToPodMapping[newModel.Name] = newModel.Status.Instances
c.deleteModelAdapterMapping(oldModel)
c.addModelAdapterMapping(newModel)

klog.Infof("MODELADAPTER UPDATED. %s/%s %s", oldModel.Namespace, oldModel.Name, newModel.Status.Phase)
c.debugInfo()
}

func (c *Cache) deleteModel(obj interface{}) {
func (c *Cache) deleteModelAdapter(obj interface{}) {
c.mu.Lock()
defer c.mu.Unlock()

model := obj.(*modelv1alpha1.ModelAdapter)
for _, pod := range model.Status.Instances {
c.deletePodAndModelMapping(pod, model.Name)
}

delete(c.modelAdapterToPodMapping, model.Name)
c.deleteModelAdapterMapping(model)

klog.Infof("MODELADAPTER DELETED: %s/%s", model.Namespace, model.Name)
c.debugInfo()
}

func (c *Cache) addPodAndModelMapping(podName, modelName string) {
pod, ok := c.pods[podName]
if !ok {
klog.Errorf("pod %s does not exist in internal-cache", podName)
return
}

models, ok := c.podToModelMapping[podName]
if !ok {
c.podToModelMapping[podName] = map[string]struct{}{
modelName: {},
}
} else {
models[modelName] = struct{}{}
c.podToModelMapping[podName] = models
}

pods, ok := c.modelToPodMapping[modelName]
if !ok {
c.modelToPodMapping[modelName] = map[string]*v1.Pod{
podName: pod,
}
} else {
pods[podName] = pod
c.modelToPodMapping[modelName] = pods
}
}

func (c *Cache) deletePodAndModelMapping(podName, modelName string) {
delete(c.podToModelMapping, podName)
delete(c.modelToPodMapping, modelName)
}

func (c *Cache) addModelAdapterMapping(model *modelv1alpha1.ModelAdapter) {
Expand All @@ -210,6 +299,24 @@ func (c *Cache) deleteModelAdapterMapping(model *modelv1alpha1.ModelAdapter) {
}

func (c *Cache) debugInfo() {
for _, pod := range c.pods {
klog.Info(pod.Name)
}
for podName, models := range c.podToModelMapping {
var modelList string
for modelName := range models {
modelList += modelName + " "
}
klog.Infof("pod: %s, models: %s", podName, modelList)
}
for modelName, pods := range c.modelToPodMapping {
var podList string
for podName := range pods {
podList += podName + " "
}
klog.Infof("model: %s, pods: %s", modelName, podList)
}

for model, instances := range c.modelAdapterToPodMapping {
klog.Infof("modelName: %s, instances: %v", model, instances)
}
Expand All @@ -235,6 +342,19 @@ func (c *Cache) GetPods() map[string]*v1.Pod {
return c.pods
}

func (c *Cache) GetPodsForModel(modelName string) map[string]*v1.Pod {
c.mu.RLock()
defer c.mu.RUnlock()

pods := map[string]*v1.Pod{}
for podName, pod := range c.modelToPodMapping[modelName] {
klog.Info(podName)
pods[podName] = pod
}

return pods
}

func (c *Cache) GetPodToModelAdapterMapping() map[string]map[string]struct{} {
c.mu.RLock()
defer c.mu.RUnlock()
Expand Down
21 changes: 8 additions & 13 deletions pkg/plugins/gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,13 @@ import (
openai "github.com/sashabaranov/go-openai"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
corev1 "k8s.io/api/core/v1"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
v1 "k8s.io/api/core/v1"
"k8s.io/client-go/kubernetes"
"k8s.io/klog"

"github.com/aibrix/aibrix/pkg/cache"
ratelimiter "github.com/aibrix/aibrix/pkg/plugins/gateway/rate_limiter"
routing "github.com/aibrix/aibrix/pkg/plugins/gateway/routing_algorithms"
podutils "github.com/aibrix/aibrix/pkg/utils"
configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
filterPb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/ext_proc/v3"
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
Expand Down Expand Up @@ -201,23 +199,20 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, req *extProcPb.Proces
}, user, targetPodIP
}

pods, err := s.client.CoreV1().Pods(podutils.NAMESPACE).List(ctx, v1.ListOptions{
LabelSelector: fmt.Sprintf("model.aibrix.ai=%s", model),
})
if err != nil {
klog.Error(err)
pods := s.cache.GetPodsForModel(model)
if len(pods) == 0 {
return &extProcPb.ProcessingResponse{
Response: &extProcPb.ProcessingResponse_ImmediateResponse{
ImmediateResponse: &extProcPb.ImmediateResponse{
Status: &envoyTypePb.HttpStatus{
Code: code,
Code: envoyTypePb.StatusCode_ServiceUnavailable,
},
Details: err.Error(),
Details: "no models are deployed",
Headers: &extProcPb.HeaderMutation{
SetHeaders: []*configPb.HeaderValueOption{
{
Header: &configPb.HeaderValue{
Key: "x-routing-error",
Key: "x-no-model-deployment",
RawValue: []byte("true"),
},
},
Expand All @@ -228,7 +223,7 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, req *extProcPb.Proces
}, user, targetPodIP
}

targetPodIP, err = s.SelectTargetPod(ctx, routingStrategy, pods.Items)
targetPodIP, err = s.SelectTargetPod(ctx, routingStrategy, pods)
if err != nil {
return &extProcPb.ProcessingResponse{
Response: &extProcPb.ProcessingResponse_ImmediateResponse{
Expand Down Expand Up @@ -470,7 +465,7 @@ func (s *Server) checkTPM(ctx context.Context, user string) (envoyTypePb.StatusC
return envoyTypePb.StatusCode_OK, nil
}

func (s *Server) SelectTargetPod(ctx context.Context, routingStrategy string, pods []corev1.Pod) (string, error) {
func (s *Server) SelectTargetPod(ctx context.Context, routingStrategy string, pods map[string]*v1.Pod) (string, error) {
var route routing.Router
switch routingStrategy {
case "random":
Expand Down
2 changes: 1 addition & 1 deletion pkg/plugins/gateway/routing_algorithms/least_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func NewLeastRequestRouter(ratelimiter ratelimiter.AccountRateLimiter) Router {
}
}

func (r leastRequestRouter) Get(ctx context.Context, pods []v1.Pod) (string, error) {
func (r leastRequestRouter) Get(ctx context.Context, pods map[string]*v1.Pod) (string, error) {
var targetPodIP string
minCount := math.MaxInt
podRequestCounts := r.cache.GetPodRequestCount()
Expand Down
11 changes: 7 additions & 4 deletions pkg/plugins/gateway/routing_algorithms/random.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package routingalgorithms

import (
"context"
"math/rand/v2"

v1 "k8s.io/api/core/v1"
)
Expand All @@ -30,7 +29,11 @@ func NewRandomRouter() Router {
return randomRouter{}
}

func (r randomRouter) Get(ctx context.Context, pods []v1.Pod) (string, error) {
pod := pods[rand.IntN(3)]
return pod.Status.PodIP + ":8000", nil // TODO (varun): remove static port
func (r randomRouter) Get(ctx context.Context, pods map[string]*v1.Pod) (string, error) {
var selectedPod *v1.Pod
for _, pod := range pods {
selectedPod = pod
}

return selectedPod.Status.PodIP + ":8000", nil // TODO (varun): remove static port
}
2 changes: 1 addition & 1 deletion pkg/plugins/gateway/routing_algorithms/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ import (
type Router interface {
// Returns the target pod
// TODO (varun): replace with cache util package which can watch on pods
Get(ctx context.Context, pods []v1.Pod) (string, error)
Get(ctx context.Context, pods map[string]*v1.Pod) (string, error)
}
2 changes: 1 addition & 1 deletion pkg/plugins/gateway/routing_algorithms/throughput.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func NewThroughputRouter(ratelimiter ratelimiter.AccountRateLimiter) Router {
}
}

func (r throughputRouter) Get(ctx context.Context, pods []v1.Pod) (string, error) {
func (r throughputRouter) Get(ctx context.Context, pods map[string]*v1.Pod) (string, error) {
var targetPodIP string
minCount := math.MaxInt

Expand Down

0 comments on commit 4993f0c

Please sign in to comment.