Skip to content

Commit

Permalink
Add custom cache and interface for model adapter scheduling (#100)
Browse files Browse the repository at this point in the history
* Add custom CRD clientset

* nit

* Add custom cache

* test

* test

* nit

* clean up .DS_Store files

* add interface

* fix lint errors

* revert controller name and tag

---------

Co-authored-by: varungupta <[email protected]>
  • Loading branch information
varungup90 and varungupta authored Aug 30, 2024
1 parent 246dbf1 commit 4a53415
Show file tree
Hide file tree
Showing 9 changed files with 372 additions and 4 deletions.
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ COPY cmd/main.go cmd/main.go
COPY api/ api/
COPY pkg/controller/ pkg/controller/
COPY pkg/utils/ pkg/utils/
COPY pkg/cache/ pkg/cache/
COPY pkg/client/ pkg/client/

# Build
# the GOARCH has not a default value to allow the binary be built according to the host where the command
Expand Down
6 changes: 6 additions & 0 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import (
autoscalingv1alpha1 "github.com/aibrix/aibrix/api/autoscaling/v1alpha1"
modelv1alpha1 "github.com/aibrix/aibrix/api/model/v1alpha1"
orchestrationv1alpha1 "github.com/aibrix/aibrix/api/orchestration/v1alpha1"
"github.com/aibrix/aibrix/pkg/cache"
"github.com/aibrix/aibrix/pkg/controller"
//+kubebuilder:scaffold:imports
)
Expand Down Expand Up @@ -161,6 +162,11 @@ func main() {
os.Exit(1)
}

setupLog.Info("starting cache")
stopCh := make(chan struct{})
defer close(stopCh)
cache.NewCache(stopCh)

// Kind controller registration is encapsulated inside the pkg/controller/controller.go
// So here we can use more clean registration flow and there's no need to change logics in future.
if err = controller.SetupWithManager(mgr); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorial/lora/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ curl -X POST http://localhost:8000/v1/load_lora_adapter \

```
# check available models
curl http://localhost:8000/v1/models
curl http://localhost:8000/v1/models | jq .
```

4. Unload Model
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ require (
k8s.io/apimachinery v0.29.2
k8s.io/client-go v0.29.2
k8s.io/code-generator v0.29.2
k8s.io/klog v0.2.0
k8s.io/klog/v2 v2.110.1
k8s.io/kube-openapi v0.0.0-20231010175941-2dd684a91f00
k8s.io/utils v0.0.0-20230726121419-3b25d923346b
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ k8s.io/component-base v0.29.2 h1:lpiLyuvPA9yV1aQwGLENYyK7n/8t6l3nn3zAtFTJYe8=
k8s.io/component-base v0.29.2/go.mod h1:BfB3SLrefbZXiBfbM+2H1dlat21Uewg/5qtKOl8degM=
k8s.io/gengo v0.0.0-20230829151522-9cce18d56c01 h1:pWEwq4Asjm4vjW7vcsmijwBhOr1/shsbSYiWXmNGlks=
k8s.io/gengo v0.0.0-20230829151522-9cce18d56c01/go.mod h1:FiNAH4ZV3gBg2Kwh89tzAEV2be7d5xI0vBa/VySYy3E=
k8s.io/klog v0.2.0 h1:0ElL0OHzF3N+OhoJTL0uca20SxtYt4X4+bzHeqrB83c=
k8s.io/klog v0.2.0/go.mod h1:Gq+BEi5rUBO/HRz0bTSXDUcqjScdoY3a9IHpCEIOOfk=
k8s.io/klog/v2 v2.2.0/go.mod h1:Od+F08eJP+W3HUb4pSrPpgp9DGU4GzlpG/TmITuYh/Y=
k8s.io/klog/v2 v2.110.1 h1:U/Af64HJf7FcwMcXyKm2RPM22WZzyR7OSpYj5tg3cL0=
k8s.io/klog/v2 v2.110.1/go.mod h1:YGtd1984u+GgbuZ7e08/yBuAfKLSO0+uR1Fhi6ExXjo=
Expand Down
258 changes: 258 additions & 0 deletions pkg/cache/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
/*
Copyright 2024 The Aibrix Team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package cache

import (
"errors"
"fmt"
"log"
"strings"
"sync"

crdinformers "github.com/aibrix/aibrix/pkg/client/informers/externalversions"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/client-go/informers"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/cache"
"k8s.io/client-go/tools/clientcmd"
"k8s.io/klog"

modelv1alpha1 "github.com/aibrix/aibrix/api/model/v1alpha1"
v1alpha1 "github.com/aibrix/aibrix/pkg/client/clientset/versioned"
v1alpha1scheme "github.com/aibrix/aibrix/pkg/client/clientset/versioned/scheme"
"k8s.io/client-go/kubernetes/scheme"
)

var once sync.Once

// type global
type Cache struct {
mu sync.RWMutex
initialized bool
pods map[string]*v1.Pod
modelAdapterToPodMapping map[string][]string
podToModelAdapterMapping map[string]map[string]struct{}
}

var (
instance Cache
kubeconfig string
)

func GetCache() (*Cache, error) {
if !instance.initialized {
return nil, errors.New("cache is not initialized")
}
return &instance, nil
}

func NewCache(stopCh <-chan struct{}) *Cache {
once.Do(func() {
var config *rest.Config
var err error

if kubeconfig == "" {
log.Printf("using in-cluster configuration")
config, err = rest.InClusterConfig()
} else {
log.Printf("using configuration from '%s'", kubeconfig)
config, err = clientcmd.BuildConfigFromFlags("", kubeconfig)
}

if err != nil {
panic(err)
}

if err := v1alpha1scheme.AddToScheme(scheme.Scheme); err != nil {
panic(err)
}

k8sClientSet, err := kubernetes.NewForConfig(config)
if err != nil {
panic(err)
}

crdClientSet, err := v1alpha1.NewForConfig(config)
if err != nil {
panic(err)
}

factory := informers.NewSharedInformerFactoryWithOptions(k8sClientSet, 0)
crdFactory := crdinformers.NewSharedInformerFactoryWithOptions(crdClientSet, 0)

podInformer := factory.Core().V1().Pods().Informer()
modeInformer := crdFactory.Model().V1alpha1().ModelAdapters().Informer()

defer runtime.HandleCrash()
factory.Start(stopCh)
crdFactory.Start(stopCh)

// factory.WaitForCacheSync(stopCh)
// crdFactory.WaitForCacheSync(stopCh)

if !cache.WaitForCacheSync(stopCh, podInformer.HasSynced, modeInformer.HasSynced) {
runtime.HandleError(fmt.Errorf("timed out waiting for caches to sync"))
return
}

instance = Cache{
initialized: true,
pods: map[string]*v1.Pod{},
modelAdapterToPodMapping: map[string][]string{},
podToModelAdapterMapping: map[string]map[string]struct{}{},
}

if _, err := podInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{
AddFunc: instance.addPod,
UpdateFunc: instance.updatePod,
DeleteFunc: instance.deletePod,
}); err != nil {
panic(err)
}

if _, err = modeInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{
AddFunc: instance.addModel,
UpdateFunc: instance.updateModel,
DeleteFunc: instance.deleteModel,
}); err != nil {
panic(err)
}
})

return &instance
}

func (c *Cache) addPod(obj interface{}) {
c.mu.Lock()
defer c.mu.Unlock()

pod := obj.(*v1.Pod)
c.pods[pod.Name] = pod
c.podToModelAdapterMapping[pod.Name] = map[string]struct{}{}
klog.Infof("POD CREATED: %s/%s", pod.Namespace, pod.Name)
}

func (c *Cache) updatePod(oldObj interface{}, newObj interface{}) {
c.mu.Lock()
defer c.mu.Unlock()

oldPod := oldObj.(*v1.Pod)
newPod := newObj.(*v1.Pod)
klog.Infof("POD UPDATED. %s/%s %s", oldPod.Namespace, oldPod.Name, newPod.Status.Phase)
}

func (c *Cache) deletePod(obj interface{}) {
c.mu.Lock()
defer c.mu.Unlock()

pod := obj.(*v1.Pod)
delete(c.pods, pod.Name)
klog.Infof("POD DELETED: %s/%s", pod.Namespace, pod.Name)
}

func (c *Cache) addModel(obj interface{}) {
c.mu.Lock()
defer c.mu.Unlock()

model := obj.(*modelv1alpha1.ModelAdapter)
c.modelAdapterToPodMapping[model.Name] = model.Status.Instances
c.addModelAdapterMapping(model)

klog.Infof("MODELADAPTER CREATED: %s/%s", model.Namespace, model.Name)
}

func (c *Cache) updateModel(oldObj interface{}, newObj interface{}) {
c.mu.Lock()
defer c.mu.Unlock()

oldModel := oldObj.(*modelv1alpha1.ModelAdapter)
newModel := newObj.(*modelv1alpha1.ModelAdapter)
c.modelAdapterToPodMapping[newModel.Name] = newModel.Status.Instances
c.deleteModelAdapterMapping(oldModel)
c.addModelAdapterMapping(newModel)

klog.Infof("MODELADAPTER UPDATED. %s/%s %s", oldModel.Namespace, oldModel.Name, newModel.Status.Phase)
}

func (c *Cache) deleteModel(obj interface{}) {
c.mu.Lock()
defer c.mu.Unlock()

model := obj.(*modelv1alpha1.ModelAdapter)
delete(c.modelAdapterToPodMapping, model.Name)
c.deleteModelAdapterMapping(model)

klog.Infof("MODELADAPTER DELETED: %s/%s", model.Namespace, model.Name)
}

func (c *Cache) addModelAdapterMapping(model *modelv1alpha1.ModelAdapter) {
for _, pod := range model.Status.Instances {
models, ok := c.podToModelAdapterMapping[pod]
if !ok {
c.podToModelAdapterMapping[pod] = map[string]struct{}{
model.Name: {},
}
continue
}

models[model.Name] = struct{}{}
c.podToModelAdapterMapping[pod] = models
}
}

func (c *Cache) deleteModelAdapterMapping(model *modelv1alpha1.ModelAdapter) {
for _, pod := range model.Status.Instances {
modelAdapters := c.podToModelAdapterMapping[pod]
delete(modelAdapters, model.Name)
c.podToModelAdapterMapping[pod] = modelAdapters
}
}

func (c *Cache) debugInfo() {
for model, instances := range c.modelAdapterToPodMapping {
klog.Infof("modelName: %s, instances: %v", model, instances)
}

for pod, models := range c.podToModelAdapterMapping {
if !strings.HasPrefix(pod, "llama") {
continue
}

modelsArr := []string{}
for m := range models {
modelsArr = append(modelsArr, m)
}

klog.Infof("podName: %s, modelAdapters: %v", pod, modelsArr)
}
}

func (c *Cache) GetPods() map[string]*v1.Pod {
c.mu.RLock()
defer c.mu.RUnlock()

return c.pods
}

func (c *Cache) GetPodToModelAdapterMapping() map[string]map[string]struct{} {
c.mu.RLock()
defer c.mu.RUnlock()

return c.podToModelAdapterMapping
}
17 changes: 14 additions & 3 deletions pkg/controller/modeladapter/modeladapter_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import (
"time"

modelv1alpha1 "github.com/aibrix/aibrix/api/model/v1alpha1"
"github.com/aibrix/aibrix/pkg/cache"
"github.com/aibrix/aibrix/pkg/controller/modeladapter/scheduling"
corev1 "k8s.io/api/core/v1"
discoveryv1 "k8s.io/api/discovery/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
Expand Down Expand Up @@ -112,13 +114,21 @@ func newReconciler(mgr manager.Manager) (reconcile.Reconciler, error) {
eventBroadcaster.StartRecordingToSink(&clientv1core.EventSinkImpl{Interface: k8sClient.CoreV1().Events("")})
recorder := eventBroadcaster.NewRecorder(mgr.GetScheme(), corev1.EventSource{Component: "model-adapter-controller"})

c, err := cache.GetCache()
if err != nil {
klog.Fatal(err.Error())
}

scheduler := scheduling.NewLeastAdapters(c)

reconciler := &ModelAdapterReconciler{
Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
PodLister: podLister,
ServiceLister: serviceLister,
EndpointSliceLister: endpointSliceLister,
Recorder: recorder,
scheduler: scheduler,
}
return reconciler, nil
}
Expand All @@ -142,8 +152,9 @@ var _ reconcile.Reconciler = &ModelAdapterReconciler{}
// ModelAdapterReconciler reconciles a ModelAdapter object
type ModelAdapterReconciler struct {
client.Client
Scheme *runtime.Scheme
Recorder record.EventRecorder
Scheme *runtime.Scheme
Recorder record.EventRecorder
scheduler scheduling.Scheduler
// PodLister is able to list/get pods from a shared informer's cache store
PodLister corelisters.PodLister
// ServiceLister is able to list/get services from a shared informer's cache store
Expand Down Expand Up @@ -394,7 +405,7 @@ func (r *ModelAdapterReconciler) schedulePod(ctx context.Context, instance *mode
// TODO: let's build the scheduling algorithm later
// we should also fetch <pod, list<lora>> mappings later.

return &podList.Items[0], nil // Returning the first Pod for simplicity
return r.scheduler.SelectPod(ctx, podList.Items)
}

// GetEnvKey retrieves the value of the environment variable named by the key.
Expand Down
Loading

0 comments on commit 4a53415

Please sign in to comment.