diff --git a/.github/workflows/installation-tests.yml b/.github/workflows/installation-tests.yml index 48fb6926..c8ac1521 100644 --- a/.github/workflows/installation-tests.yml +++ b/.github/workflows/installation-tests.yml @@ -9,6 +9,7 @@ on: - 'build/container/**' - 'config/**' - 'pkg/**' + - 'test/**' - 'cmd/**' - 'python/**' - 'Makefile' @@ -21,6 +22,7 @@ on: - 'build/container/**' - 'config/**' - 'pkg/**' + - 'test/**' - 'cmd/**' - 'python/**' - 'Makefile' diff --git a/test/e2e/e2e_test.go b/test/e2e/e2e_test.go index e29811c0..65185e47 100644 --- a/test/e2e/e2e_test.go +++ b/test/e2e/e2e_test.go @@ -27,17 +27,10 @@ import ( "github.com/stretchr/testify/assert" ) -const ( - baseURL = "http://localhost:8888" - apiKey = "test-key-1234567890" - modelName = "llama2-7b" - namespace = "aibrix-system" -) - func TestBaseModelInference(t *testing.T) { initializeClient(context.Background(), t) - client := createOpenAIClient(baseURL, apiKey) + client := createOpenAIClient(gatewayURL, apiKey) chatCompletion, err := client.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{ Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ openai.UserMessage("Say this is a test"), @@ -88,10 +81,10 @@ func TestBaseModelInferenceFailures(t *testing.T) { var client *openai.Client if tc.routingStrategy != "" { var dst *http.Response - client = createOpenAIClientWithRoutingStrategy(baseURL, tc.apiKey, + client = createOpenAIClientWithRoutingStrategy(gatewayURL, tc.apiKey, tc.routingStrategy, option.WithResponseInto(&dst)) } else { - client = createOpenAIClient(baseURL, tc.apiKey) + client = createOpenAIClient(gatewayURL, tc.apiKey) } _, err := client.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{ diff --git a/test/e2e/model_adapter_test.go b/test/e2e/model_adapter_test.go new file mode 100644 index 00000000..71046142 --- /dev/null +++ b/test/e2e/model_adapter_test.go @@ -0,0 +1,93 @@ +package e2e + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + modelv1alpha1 "github.com/vllm-project/aibrix/api/model/v1alpha1" + v1alpha1 "github.com/vllm-project/aibrix/pkg/client/clientset/versioned" + apierrors "k8s.io/apimachinery/pkg/api/errors" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/wait" +) + +const ( + loraName = "text2sql-lora-2" +) + +func TestModelAdapter(t *testing.T) { + adapter := createModelAdapterConfig("text2sql-lora-2", "llama2-7b") + k8sClient, v1alpha1Client := initializeClient(context.Background(), t) + + t.Cleanup(func() { + assert.NoError(t, v1alpha1Client.ModelV1alpha1().ModelAdapters("default").Delete(context.Background(), adapter.Name, v1.DeleteOptions{})) + wait.PollImmediate(1*time.Second, 30*time.Second, + func() (done bool, err error) { + adapter, err = v1alpha1Client.ModelV1alpha1().ModelAdapters("default").Get(context.Background(), adapter.Name, v1.GetOptions{}) + if apierrors.IsNotFound(err) { + return true, nil + } + return false, nil + }) + }) + + // create model adapter + fmt.Println("creating model adapter") + adapter, err := v1alpha1Client.ModelV1alpha1().ModelAdapters("default").Create(context.Background(), adapter, v1.CreateOptions{}) + assert.NoError(t, err) + adapter = validateModelAdapter(t, v1alpha1Client, adapter.Name) + oldPod := adapter.Status.Instances[0] + + // delete pod and ensure model adapter is rescheduled + fmt.Println("deleting pod instance to force model adapter rescheduling") + assert.NoError(t, k8sClient.CoreV1().Pods("default").Delete(context.Background(), oldPod, v1.DeleteOptions{})) + time.Sleep(3 * time.Second) + adapter = validateModelAdapter(t, v1alpha1Client, adapter.Name) + newPod := adapter.Status.Instances[0] + + assert.NotEqual(t, newPod, oldPod, "ensure old and new pods are different") + + // run inference for model adapter + validateInference(t, loraName) +} + +func createModelAdapterConfig(name, model string) *modelv1alpha1.ModelAdapter { + return &modelv1alpha1.ModelAdapter{ + ObjectMeta: v1.ObjectMeta{ + Name: name, + Labels: map[string]string{ + "model.aibrix.ai/name": name, + "model.aibrix.ai/port": "8000", + }, + }, + Spec: modelv1alpha1.ModelAdapterSpec{ + BaseModel: &model, + PodSelector: &v1.LabelSelector{ + MatchLabels: map[string]string{ + "model.aibrix.ai/name": model, + }, + }, + ArtifactURL: "huggingface://yard1/llama-2-7b-sql-lora-test", + AdditionalConfig: map[string]string{ + "api-key": "test-key-1234567890", + }, + }, + } +} + +func validateModelAdapter(t *testing.T, client *v1alpha1.Clientset, name string) *modelv1alpha1.ModelAdapter { + var adapter *modelv1alpha1.ModelAdapter + wait.PollImmediate(1*time.Second, 30*time.Second, + func() (done bool, err error) { + adapter, err = client.ModelV1alpha1().ModelAdapters("default").Get(context.Background(), name, v1.GetOptions{}) + if err != nil || adapter.Status.Phase != modelv1alpha1.ModelAdapterRunning { + return false, nil + } + return true, nil + }) + assert.True(t, len(adapter.Status.Instances) > 0, "model adapter scheduled on atleast one pod") + return adapter +} diff --git a/test/e2e/routing_strategy_test.go b/test/e2e/routing_strategy_test.go index cf8db9cb..ac48649a 100644 --- a/test/e2e/routing_strategy_test.go +++ b/test/e2e/routing_strategy_test.go @@ -56,7 +56,7 @@ func TestPrefixCacheModelInference(t *testing.T) { func getTargetPodFromChatCompletion(t *testing.T, message string) string { var dst *http.Response - client := createOpenAIClientWithRoutingStrategy(baseURL, apiKey, "prefix-cache", option.WithResponseInto(&dst)) + client := createOpenAIClientWithRoutingStrategy(gatewayURL, apiKey, "prefix-cache", option.WithResponseInto(&dst)) chatCompletion, err := client.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{ Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ diff --git a/test/e2e/util.go b/test/e2e/util.go index 0f75500a..4505ddde 100644 --- a/test/e2e/util.go +++ b/test/e2e/util.go @@ -24,6 +24,7 @@ import ( "github.com/openai/openai-go" "github.com/openai/openai-go/option" + "github.com/stretchr/testify/assert" v1alpha1 "github.com/vllm-project/aibrix/pkg/client/clientset/versioned" crdinformers "github.com/vllm-project/aibrix/pkg/client/informers/externalversions" "k8s.io/apimachinery/pkg/util/runtime" @@ -35,6 +36,14 @@ import ( "k8s.io/klog/v2" ) +const ( + gatewayURL = "http://localhost:8888" + engineURL = "http://localhost:8000" + apiKey = "test-key-1234567890" + modelName = "llama2-7b" + namespace = "aibrix-system" +) + func initializeClient(ctx context.Context, t *testing.T) (*kubernetes.Clientset, *v1alpha1.Clientset) { var err error var config *rest.Config @@ -101,3 +110,23 @@ func createOpenAIClientWithRoutingStrategy(baseURL, apiKey, routingStrategy stri respOpt, ) } + +func validateInference(t *testing.T, modelName string) { + client := createOpenAIClient(gatewayURL, apiKey) + validateInferenceWithClient(t, client, modelName) +} + +func validateInferenceWithClient(t *testing.T, client *openai.Client, modelName string) { + chatCompletion, err := client.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{ + Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("Say this is a test"), + }), + Model: openai.F(openai.ChatModel(modelName)), + }) + if err != nil { + t.Fatalf("chat completions failed : %v", err) + } + assert.Equal(t, modelName, chatCompletion.Model) + assert.NotEmpty(t, chatCompletion.Choices, "chat completion has no choices returned") + assert.NotNil(t, chatCompletion.Choices[0].Message.Content, "chat completion has no message returned") +}