Skip to content

Commit

Permalink
Enhance dynamic lora adapter support for auth enabled scenario (#571)
Browse files Browse the repository at this point in the history
* Support auth for bearer tokens in mocked app

* Get credentials from Lora CR and append to HTTP header

* Add lora api-key integration testing scripts

* Fix the lint issue

---------

Signed-off-by: Jiaxin Shan <[email protected]>
  • Loading branch information
Jeffwan authored Jan 17, 2025
1 parent 0e9dd75 commit a931b76
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 4 deletions.
31 changes: 31 additions & 0 deletions development/app/app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from flask import Flask, request, Response, jsonify
from flask_httpauth import HTTPTokenAuth
from functools import wraps
from werkzeug import serving
import random
import re
Expand Down Expand Up @@ -45,6 +47,30 @@
tokenizer = None
simulator: Optional[Simulator] = None

# Extract the api_key argument and prepare for authentication
api_key = None
try:
index = sys.argv.index("--api_key")
if index + 1 < len(sys.argv):
api_key = sys.argv[index + 1]
except ValueError:
pass

auth = HTTPTokenAuth(scheme='Bearer')


@auth.verify_token
def verify_token(token):
if api_key is None:
return True
return token == api_key


@auth.error_handler
def auth_error(status):
return jsonify({"error": "Unauthorized"}), 401


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -151,6 +177,7 @@ def log_request(self, *args, **kwargs):


@app.route('/v1/models', methods=['GET'])
@auth.login_required
def get_models():
return jsonify({
"object": "list",
Expand All @@ -159,6 +186,7 @@ def get_models():


@app.route('/v1/load_lora_adapter', methods=['POST'])
@auth.login_required
def load_model():
lora_name = request.json.get('lora_name')
# Check if the model already exists
Expand All @@ -179,6 +207,7 @@ def load_model():


@app.route('/v1/unload_lora_adapter', methods=['POST'])
@auth.login_required
def unload_model():
model_id = request.json.get('lora_name')
global models
Expand All @@ -187,6 +216,7 @@ def unload_model():


@app.route('/v1/completions', methods=['POST'])
@auth.login_required
def completion():
try:
prompt = request.json.get('prompt')
Expand Down Expand Up @@ -249,6 +279,7 @@ def completion():


@app.route('/v1/chat/completions', methods=['POST'])
@auth.login_required
def chat_completions():
try:
messages = request.json.get('messages')
Expand Down
22 changes: 22 additions & 0 deletions development/app/config/mock/api-key-patch.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: mock-llama2-7b
spec:
replicas: 3
selector:
matchLabels:
adapter.model.aibrix.ai/enabled: "true"
model.aibrix.ai/name: "llama2-7b"
app: "mock-llama2-7b"
template:
spec:
serviceAccountName: mocked-app-sa
containers:
- name: llm-engine
image: aibrix/vllm-mock:nightly
command:
- python3
- app.py
- --api_key
- test-key-1234567890
4 changes: 4 additions & 0 deletions development/app/config/mock/kustomization.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
resources:
- ../templates/deployment
- components.yaml

# enable following patch when we test lora + api-key
patches:
- path: api-key-patch.yaml
1 change: 1 addition & 0 deletions development/app/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
flask
Flask-HTTPAuth
kubernetes
numpy
pandas
Expand Down
17 changes: 17 additions & 0 deletions development/tutorials/lora/model_adapter_api_key.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
apiVersion: model.aibrix.ai/v1alpha1
kind: ModelAdapter
metadata:
name: text2sql-lora-2
namespace: default
labels:
model.aibrix.ai/name: "text2sql-lora-2"
model.aibrix.ai/port: "8000"
spec:
baseModel: llama2-7b
podSelector:
matchLabels:
model.aibrix.ai/name: llama2-7b
artifactURL: huggingface://yard1/llama-2-7b-sql-lora-test
additionalConfig:
api-key: test-key-1234567890
schedulerName: default
27 changes: 23 additions & 4 deletions pkg/controller/modeladapter/modeladapter_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ func (r *ModelAdapterReconciler) reconcileLoading(ctx context.Context, instance
}

// Check if the model is already loaded
exists, err = r.modelAdapterExists(host, instance.Name)
exists, err = r.modelAdapterExists(host, instance)
if err != nil {
return err
}
Expand All @@ -546,10 +546,21 @@ func (r *ModelAdapterReconciler) reconcileLoading(ctx context.Context, instance
}

// Separate method to check if the model already exists
func (r *ModelAdapterReconciler) modelAdapterExists(host, modelName string) (bool, error) {
func (r *ModelAdapterReconciler) modelAdapterExists(host string, instance *modelv1alpha1.ModelAdapter) (bool, error) {
// TODO: /v1/models is the vllm entrypoints, let's support multiple engine in future
url := fmt.Sprintf("%s/v1/models", host)
resp, err := http.Get(url)

req, err := http.NewRequest("GET", url, nil)
if err != nil {
return false, err
}
// Check if "api-key" exists in the map and set the Authorization header accordingly
if token, ok := instance.Spec.AdditionalConfig["api-key"]; ok {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
}

c := &http.Client{}
resp, err := c.Do(req)
if err != nil {
return false, err
}
Expand Down Expand Up @@ -579,7 +590,7 @@ func (r *ModelAdapterReconciler) modelAdapterExists(host, modelName string) (boo
if !ok {
continue
}
if model["id"] == modelName {
if model["id"] == instance.Name {
return true, nil
}
}
Expand Down Expand Up @@ -615,6 +626,10 @@ func (r *ModelAdapterReconciler) loadModelAdapter(host string, instance *modelv1
return err
}
req.Header.Set("Content-Type", "application/json")
// Check if "api-key" exists in the map and set the Authorization header accordingly
if token, ok := instance.Spec.AdditionalConfig["api-key"]; ok {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
}

client := &http.Client{}
resp, err := client.Do(req)
Expand Down Expand Up @@ -680,6 +695,10 @@ func (r *ModelAdapterReconciler) unloadModelAdapter(instance *modelv1alpha1.Mode
return err
}
req.Header.Set("Content-Type", "application/json")
// Check if "api-key" exists in the map and set the Authorization header accordingly
if token, ok := instance.Spec.AdditionalConfig["api-key"]; ok {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
}

httpClient := &http.Client{}
resp, err := httpClient.Do(req)
Expand Down

0 comments on commit a931b76

Please sign in to comment.