Skip to content

Commit e43bedb

Browse files
committed
add asymmetric customized e5 model
Signed-off-by: Fen Qin <[email protected]>
1 parent 423b9b5 commit e43bedb

File tree

8 files changed

+368
-79
lines changed

8 files changed

+368
-79
lines changed
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Common Model Deployment
2+
3+
Shared deployment infrastructure for SageMaker endpoints across different model types.
4+
5+
## Structure
6+
7+
```
8+
examples/
9+
├── common/
10+
│ ├── deploy.py # Shared deployment script
11+
│ └── README.md # This file
12+
├── semantic_highlighting/ # Highlighting models
13+
│ ├── api_types.py # Highlighting-specific types
14+
│ ├── modernbert/
15+
│ └── opensearch-semantic-highlighter/
16+
└── embedding_models/ # Embedding models
17+
├── api_types.py # Embedding-specific types
18+
├── validate.sh # Validation script
19+
└── asymmetric_e5/
20+
```
21+
22+
## Usage
23+
24+
### Deploy a Model
25+
26+
```bash
27+
cd common
28+
python3 deploy.py --model <model_name> [options]
29+
```
30+
31+
The deployment script will:
32+
1. Download the model from HuggingFace
33+
2. Create a model package with inference code
34+
3. Deploy to SageMaker endpoint
35+
4. **Output the endpoint name** for validation
36+
37+
### Validate Deployment
38+
39+
After deployment, you'll see output like:
40+
```
41+
Endpoint deployed successfully: asymmetric-e5-20251113-210834-866f6617
42+
```
43+
44+
Use this endpoint name to validate:
45+
46+
```bash
47+
# For embedding models ONLY
48+
cd ../embedding_models
49+
./validate.sh asymmetric-e5-20251113-210834-866f6617
50+
51+
# For semantic highlighting models
52+
# (no validation script yet - test manually via AWS console or CLI)
53+
```
54+
55+
**Why specify endpoint name?**
56+
- SageMaker generates unique endpoint names with timestamps
57+
- Multiple deployments can exist simultaneously
58+
- Allows testing specific endpoint versions
59+
- Prevents accidental validation of wrong endpoints
60+
61+
**Note:** The `validate.sh` script is specifically designed for embedding models and tests embedding-specific payloads (query/passage embeddings, OpenSearch connector format). Semantic highlighting models require different validation payloads.
62+
63+
### Available Models
64+
65+
**Semantic Highlighting:**
66+
- `opensearch-semantic-highlighter`
67+
- `modernbert`
68+
69+
**Embedding Models:**
70+
- `asymmetric_e5`
71+
72+
### Options
73+
74+
- `--model`: Model to deploy (required)
75+
- `--instance-type`: SageMaker instance type (default: ml.g5.xlarge)
76+
- `--instance-count`: Number of instances (default: 1)
77+
78+
### Examples
79+
80+
```bash
81+
# Deploy asymmetric E5 embedding model
82+
python3 deploy.py --model asymmetric_e5 --instance-type ml.m5.large
83+
84+
# Deploy semantic highlighter
85+
python3 deploy.py --model opensearch-semantic-highlighter
86+
```
87+
88+
## Environment Variables
89+
90+
- `AWS_REGION`: AWS region (default: us-east-1)
91+
- `INSTANCE_TYPE`: Default instance type
92+
- `INSTANCE_COUNT`: Default instance count
93+
94+
## Finding Existing Endpoints
95+
96+
To list existing endpoints:
97+
```bash
98+
aws sagemaker list-endpoints --region us-east-1
99+
```
100+
101+
## API Formats
102+
103+
### Embedding Models (asymmetric_e5)
104+
105+
**Request:**
106+
```json
107+
{
108+
"texts": ["how much protein should a female eat"],
109+
"content_type": "query"
110+
}
111+
```
112+
113+
**Response:**
114+
```json
115+
[[0.21125227, -0.19419950, ...]]
116+
```
117+
118+
### Semantic Highlighting Models
119+
120+
**Request:**
121+
```json
122+
{
123+
"question": "What is the treatment?",
124+
"context": "Traditional treatments include cholinesterase inhibitors."
125+
}
126+
```
127+
128+
**Response:**
129+
```json
130+
{
131+
"highlights": [{"start": 0, "end": 50}],
132+
"processing_time_ms": 22.4,
133+
"device": "cuda"
134+
}
135+
```
136+
137+
## Adding New Models
138+
139+
1. Create model directory under appropriate task type
140+
2. Add `inference.py` and `requirements.txt`
141+
3. Update `MODEL_CONFIGS` in `deploy.py`
142+
4. Ensure proper `api_types.py` exists for the task type

docs/source/examples/semantic_highlighting/deploy.py renamed to docs/source/examples/common/deploy.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,20 @@
2626
'opensearch-semantic-highlighter': {
2727
'model_name': 'opensearch-project/opensearch-semantic-highlighter-v1',
2828
'endpoint_prefix': 'opensearch-semantic-highlighter',
29-
's3_prefix': 'opensearch-semantic-highlighter'
29+
's3_prefix': 'opensearch-semantic-highlighter',
30+
'task_type': 'semantic_highlighting'
3031
},
3132
'modernbert': {
3233
'model_name': 'answerdotai/ModernBERT-base',
3334
'endpoint_prefix': 'modernbert-highlighter',
34-
's3_prefix': 'modernbert-highlighter'
35+
's3_prefix': 'modernbert-highlighter',
36+
'task_type': 'semantic_highlighting'
37+
},
38+
'asymmetric_e5': {
39+
'model_name': 'intfloat/multilingual-e5-small',
40+
'endpoint_prefix': 'asymmetric-e5',
41+
's3_prefix': 'asymmetric-e5',
42+
'task_type': 'embedding_models'
3543
}
3644
}
3745

@@ -52,7 +60,7 @@ def create_sagemaker_role():
5260
sts = boto3.client('sts')
5361
account_id = sts.get_caller_identity()["Account"]
5462
role_name = 'SageMakerExecutionRole'
55-
role_arn = f'arn:aws:iam::{account_id}:role/{role_name}'
63+
role_arn = 'arn:aws:iam::{}:role/{}'.format(account_id, role_name)
5664

5765
try:
5866
iam.get_role(RoleName=role_name)
@@ -117,12 +125,15 @@ def prepare_model_files(model_key):
117125

118126
def create_model_tar(work_dir, model_key):
119127
"""Create model.tar.gz with model files and inference code."""
128+
config = MODEL_CONFIGS[model_key]
129+
task_type = config['task_type']
130+
120131
os.makedirs(f"{work_dir}/code", exist_ok=True)
121132

122133
# Copy model-specific inference code
123-
inference_src = f"{model_key}/inference.py"
124-
requirements_src = f"{model_key}/requirements.txt"
125-
api_types_src = "api_types.py"
134+
inference_src = f"../{task_type}/{model_key}/inference.py"
135+
requirements_src = f"../{task_type}/{model_key}/requirements.txt"
136+
api_types_src = f"../{task_type}/api_types.py"
126137

127138
if not os.path.exists(inference_src):
128139
raise FileNotFoundError(f"Inference code not found: {inference_src}")
@@ -131,7 +142,7 @@ def create_model_tar(work_dir, model_key):
131142
if os.path.exists(requirements_src):
132143
shutil.copy(requirements_src, f"{work_dir}/code/requirements.txt")
133144

134-
# Copy shared API types
145+
# Copy task-specific API types
135146
if os.path.exists(api_types_src):
136147
shutil.copy(api_types_src, f"{work_dir}/code/api_types.py")
137148

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""API type definitions for embedding model inference"""
2+
from typing import TypedDict, List, Union, Optional
3+
4+
5+
class EmbeddingRequest(TypedDict):
6+
"""Single embedding request"""
7+
texts: List[str]
8+
content_type: Optional[str] # "query" or "passage"
9+
10+
11+
class BatchEmbeddingRequest(TypedDict):
12+
"""Batch embedding request (OpenSearch connector format)"""
13+
parameters: EmbeddingRequest
14+
15+
16+
class EmbeddingResponse(TypedDict):
17+
"""Standard embedding response format"""
18+
embeddings: Union[List[float], List[List[float]]]
19+
processing_time_ms: float
20+
device: str
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import os
2+
import sys
3+
import json
4+
import time
5+
import logging
6+
import torch
7+
from transformers import AutoTokenizer, AutoModel
8+
9+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
10+
from api_types import EmbeddingRequest, BatchEmbeddingRequest, EmbeddingResponse
11+
12+
logging.basicConfig(level=logging.INFO)
13+
logger = logging.getLogger(__name__)
14+
15+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16+
logger.info(f"Device: {DEVICE}")
17+
18+
def model_fn(model_dir):
19+
"""Load model and tokenizer"""
20+
model_name = "intfloat/multilingual-e5-small"
21+
tokenizer = AutoTokenizer.from_pretrained(model_name)
22+
model = AutoModel.from_pretrained(model_name).to(DEVICE)
23+
return {"model": model, "tokenizer": tokenizer}
24+
25+
def input_fn(request_body, request_content_type):
26+
"""Parse input and return texts for embedding"""
27+
if request_content_type != "application/json":
28+
raise ValueError(f"Unsupported content type: {request_content_type}")
29+
30+
input_data = json.loads(request_body)
31+
32+
# Handle OpenSearch connector format
33+
if "parameters" in input_data:
34+
params = input_data["parameters"]
35+
texts = params.get("texts", [])
36+
content_type = params.get("content_type")
37+
else:
38+
texts = input_data.get("texts", [])
39+
content_type = input_data.get("content_type")
40+
41+
# Add content type prefix if specified
42+
if content_type:
43+
texts = [f"{content_type}: {text}" for text in texts]
44+
45+
return texts
46+
47+
def predict_fn(input_data, model_dict):
48+
"""Generate embeddings"""
49+
start_time = time.time()
50+
51+
model = model_dict["model"]
52+
tokenizer = model_dict["tokenizer"]
53+
54+
inputs = tokenizer(input_data, padding=True, truncation=True,
55+
return_tensors="pt", max_length=512).to(DEVICE)
56+
57+
with torch.no_grad():
58+
outputs = model(**inputs)
59+
embeddings = outputs.last_hidden_state.mean(dim=1)
60+
61+
processing_time = (time.time() - start_time) * 1000
62+
63+
return {
64+
"embeddings": embeddings.cpu().numpy(),
65+
"processing_time_ms": processing_time,
66+
"device": str(DEVICE)
67+
}
68+
69+
def output_fn(prediction, content_type):
70+
"""Format output for OpenSearch compatibility"""
71+
if content_type != "application/json":
72+
raise ValueError(f"Unsupported content type: {content_type}")
73+
74+
embeddings = prediction["embeddings"]
75+
76+
# Return simple array format for OpenSearch
77+
if len(embeddings.shape) == 2: # Batch
78+
result = [embedding.tolist() for embedding in embeddings]
79+
else: # Single
80+
result = embeddings.tolist()
81+
82+
return json.dumps(result)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
torch>=2.0.0
2+
transformers>=4.28.0

0 commit comments

Comments
 (0)