Skip to content

Commit

Permalink
Add model adapter e2e tests (#701)
Browse files Browse the repository at this point in the history
* Add model adapter e2e tests

* trigger e2e test on change in test package

* add more tests

Signed-off-by: Varun Gupta <[email protected]>

* remove unused util function

Signed-off-by: Varun Gupta <[email protected]>

---------

Signed-off-by: Varun Gupta <[email protected]>
  • Loading branch information
varungup90 authored Feb 21, 2025
1 parent 728c7c4 commit 41126e4
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 11 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/installation-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ on:
- 'build/container/**'
- 'config/**'
- 'pkg/**'
- 'test/**'
- 'cmd/**'
- 'python/**'
- 'Makefile'
Expand All @@ -21,6 +22,7 @@ on:
- 'build/container/**'
- 'config/**'
- 'pkg/**'
- 'test/**'
- 'cmd/**'
- 'python/**'
- 'Makefile'
Expand Down
13 changes: 3 additions & 10 deletions test/e2e/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,10 @@ import (
"github.com/stretchr/testify/assert"
)

const (
baseURL = "http://localhost:8888"
apiKey = "test-key-1234567890"
modelName = "llama2-7b"
namespace = "aibrix-system"
)

func TestBaseModelInference(t *testing.T) {
initializeClient(context.Background(), t)

client := createOpenAIClient(baseURL, apiKey)
client := createOpenAIClient(gatewayURL, apiKey)
chatCompletion, err := client.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{
Messages: openai.F([]openai.ChatCompletionMessageParamUnion{
openai.UserMessage("Say this is a test"),
Expand Down Expand Up @@ -88,10 +81,10 @@ func TestBaseModelInferenceFailures(t *testing.T) {
var client *openai.Client
if tc.routingStrategy != "" {
var dst *http.Response
client = createOpenAIClientWithRoutingStrategy(baseURL, tc.apiKey,
client = createOpenAIClientWithRoutingStrategy(gatewayURL, tc.apiKey,
tc.routingStrategy, option.WithResponseInto(&dst))
} else {
client = createOpenAIClient(baseURL, tc.apiKey)
client = createOpenAIClient(gatewayURL, tc.apiKey)
}

_, err := client.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{
Expand Down
93 changes: 93 additions & 0 deletions test/e2e/model_adapter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package e2e

import (
"context"
"fmt"
"testing"
"time"

"github.com/stretchr/testify/assert"
modelv1alpha1 "github.com/vllm-project/aibrix/api/model/v1alpha1"
v1alpha1 "github.com/vllm-project/aibrix/pkg/client/clientset/versioned"
apierrors "k8s.io/apimachinery/pkg/api/errors"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/wait"
)

const (
loraName = "text2sql-lora-2"
)

func TestModelAdapter(t *testing.T) {
adapter := createModelAdapterConfig("text2sql-lora-2", "llama2-7b")
k8sClient, v1alpha1Client := initializeClient(context.Background(), t)

t.Cleanup(func() {
assert.NoError(t, v1alpha1Client.ModelV1alpha1().ModelAdapters("default").Delete(context.Background(), adapter.Name, v1.DeleteOptions{}))
wait.PollImmediate(1*time.Second, 30*time.Second,
func() (done bool, err error) {
adapter, err = v1alpha1Client.ModelV1alpha1().ModelAdapters("default").Get(context.Background(), adapter.Name, v1.GetOptions{})
if apierrors.IsNotFound(err) {
return true, nil
}
return false, nil
})
})

// create model adapter
fmt.Println("creating model adapter")
adapter, err := v1alpha1Client.ModelV1alpha1().ModelAdapters("default").Create(context.Background(), adapter, v1.CreateOptions{})
assert.NoError(t, err)
adapter = validateModelAdapter(t, v1alpha1Client, adapter.Name)
oldPod := adapter.Status.Instances[0]

// delete pod and ensure model adapter is rescheduled
fmt.Println("deleting pod instance to force model adapter rescheduling")
assert.NoError(t, k8sClient.CoreV1().Pods("default").Delete(context.Background(), oldPod, v1.DeleteOptions{}))
time.Sleep(3 * time.Second)
adapter = validateModelAdapter(t, v1alpha1Client, adapter.Name)
newPod := adapter.Status.Instances[0]

assert.NotEqual(t, newPod, oldPod, "ensure old and new pods are different")

// run inference for model adapter
validateInference(t, loraName)
}

func createModelAdapterConfig(name, model string) *modelv1alpha1.ModelAdapter {
return &modelv1alpha1.ModelAdapter{
ObjectMeta: v1.ObjectMeta{
Name: name,
Labels: map[string]string{
"model.aibrix.ai/name": name,
"model.aibrix.ai/port": "8000",
},
},
Spec: modelv1alpha1.ModelAdapterSpec{
BaseModel: &model,
PodSelector: &v1.LabelSelector{
MatchLabels: map[string]string{
"model.aibrix.ai/name": model,
},
},
ArtifactURL: "huggingface://yard1/llama-2-7b-sql-lora-test",
AdditionalConfig: map[string]string{
"api-key": "test-key-1234567890",
},
},
}
}

func validateModelAdapter(t *testing.T, client *v1alpha1.Clientset, name string) *modelv1alpha1.ModelAdapter {
var adapter *modelv1alpha1.ModelAdapter
wait.PollImmediate(1*time.Second, 30*time.Second,
func() (done bool, err error) {
adapter, err = client.ModelV1alpha1().ModelAdapters("default").Get(context.Background(), name, v1.GetOptions{})
if err != nil || adapter.Status.Phase != modelv1alpha1.ModelAdapterRunning {
return false, nil
}
return true, nil
})
assert.True(t, len(adapter.Status.Instances) > 0, "model adapter scheduled on atleast one pod")
return adapter
}
2 changes: 1 addition & 1 deletion test/e2e/routing_strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func TestPrefixCacheModelInference(t *testing.T) {

func getTargetPodFromChatCompletion(t *testing.T, message string) string {
var dst *http.Response
client := createOpenAIClientWithRoutingStrategy(baseURL, apiKey, "prefix-cache", option.WithResponseInto(&dst))
client := createOpenAIClientWithRoutingStrategy(gatewayURL, apiKey, "prefix-cache", option.WithResponseInto(&dst))

chatCompletion, err := client.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{
Messages: openai.F([]openai.ChatCompletionMessageParamUnion{
Expand Down
29 changes: 29 additions & 0 deletions test/e2e/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (

"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
"github.com/stretchr/testify/assert"
v1alpha1 "github.com/vllm-project/aibrix/pkg/client/clientset/versioned"
crdinformers "github.com/vllm-project/aibrix/pkg/client/informers/externalversions"
"k8s.io/apimachinery/pkg/util/runtime"
Expand All @@ -35,6 +36,14 @@ import (
"k8s.io/klog/v2"
)

const (
gatewayURL = "http://localhost:8888"
engineURL = "http://localhost:8000"
apiKey = "test-key-1234567890"
modelName = "llama2-7b"
namespace = "aibrix-system"
)

func initializeClient(ctx context.Context, t *testing.T) (*kubernetes.Clientset, *v1alpha1.Clientset) {
var err error
var config *rest.Config
Expand Down Expand Up @@ -101,3 +110,23 @@ func createOpenAIClientWithRoutingStrategy(baseURL, apiKey, routingStrategy stri
respOpt,
)
}

func validateInference(t *testing.T, modelName string) {
client := createOpenAIClient(gatewayURL, apiKey)
validateInferenceWithClient(t, client, modelName)
}

func validateInferenceWithClient(t *testing.T, client *openai.Client, modelName string) {
chatCompletion, err := client.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{
Messages: openai.F([]openai.ChatCompletionMessageParamUnion{
openai.UserMessage("Say this is a test"),
}),
Model: openai.F(openai.ChatModel(modelName)),
})
if err != nil {
t.Fatalf("chat completions failed : %v", err)
}
assert.Equal(t, modelName, chatCompletion.Model)
assert.NotEmpty(t, chatCompletion.Choices, "chat completion has no choices returned")
assert.NotNil(t, chatCompletion.Choices[0].Message.Content, "chat completion has no message returned")
}

0 comments on commit 41126e4

Please sign in to comment.