Skip to content

Commit 9897f0e

Browse files
committed
move to modal.com for embeddings as well
1 parent 1561210 commit 9897f0e

File tree

3 files changed

+184
-6
lines changed

3 files changed

+184
-6
lines changed

document_qa/custom_embeddings.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from typing import List
2+
import requests
3+
from langchain_core.embeddings import Embeddings
4+
5+
6+
class ModalEmbeddings(Embeddings):
7+
def __init__(self, url: str, model_name: str, api_key: str = None):
8+
self.url = url
9+
self.model_name = model_name
10+
self.api_key = api_key
11+
12+
def embed(self, text: List[str]) -> List[List[str]]:
13+
# We remove newlines from the text to avoid issues with the embedding model.
14+
cleaned_text = [t.replace("\n", " ") for t in text]
15+
16+
payload = {'text': "\n".join(cleaned_text)}
17+
18+
headers = {}
19+
if self.api_key:
20+
headers = {'x-api-key': self.api_key}
21+
22+
response = requests.post(
23+
self.url,
24+
data=payload,
25+
files=[],
26+
headers=headers
27+
)
28+
response.raise_for_status()
29+
30+
# print(response.text)
31+
return response.json()
32+
33+
def embed_documents(self, text: List[str]) -> List[List[str]]:
34+
"""
35+
Embed a list of documents using the embedding model.
36+
"""
37+
return self.embed(text)
38+
39+
def embed_query(self, text: str) -> List[str]:
40+
"""
41+
Embed a query
42+
"""
43+
return self.embed([text])[0]
44+
45+
def get_model_name(self) -> str:
46+
return self.model_name
47+
48+
49+
if __name__ == "__main__":
50+
embeds = ModalEmbeddings(
51+
url="https://lfoppiano--intfloat-multilingual-e5-large-instruct-embed-5da184.modal.run/",
52+
model_name="intfloat/multilingual-e5-large-instruct"
53+
)
54+
55+
print(embeds.embed(
56+
["We are surrounded by stupid kids",
57+
"We are interested in the future of AI"]
58+
))
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import os
2+
from typing import Annotated, List
3+
from fastapi import Request, HTTPException, Form
4+
5+
import modal
6+
import torch
7+
import torch.nn.functional as F
8+
from torch import Tensor
9+
from transformers import AutoTokenizer, AutoModel
10+
11+
image = (
12+
modal.Image.debian_slim(python_version="3.11")
13+
.pip_install(
14+
"transformers",
15+
"huggingface_hub[hf_transfer]==0.26.2",
16+
"flashinfer-python==0.2.0.post2", # pinning, very unstable
17+
"fastapi[standard]",
18+
extra_index_url="https://flashinfer.ai/whl/cu124/torch2.5",
19+
)
20+
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) # faster model transfers
21+
)
22+
23+
MODELS_DIR = "/llamas"
24+
MODEL_NAME = "intfloat/multilingual-e5-large-instruct"
25+
MODEL_REVISION = "84344a23ee1820ac951bc365f1e91d094a911763"
26+
27+
hf_cache_vol = modal.Volume.from_name("huggingface-cache", create_if_missing=True)
28+
vllm_cache_vol = modal.Volume.from_name("vllm-cache", create_if_missing=True)
29+
30+
app = modal.App("intfloat-multilingual-e5-large-instruct-embeddings")
31+
32+
33+
def get_device():
34+
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
35+
36+
def load_model():
37+
print("Loading model...")
38+
device = get_device()
39+
print(f"Using device: {device}")
40+
41+
tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large-instruct')
42+
model = AutoModel.from_pretrained('intfloat/multilingual-e5-large-instruct').to(device)
43+
print("Model loaded successfully.")
44+
45+
return tokenizer, model, device
46+
47+
48+
N_GPU = 1
49+
MINUTES = 60 # seconds
50+
VLLM_PORT = 8000
51+
52+
53+
def average_pool(last_hidden_states: Tensor,
54+
attention_mask: Tensor) -> Tensor:
55+
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
56+
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
57+
58+
59+
@app.function(
60+
image=image,
61+
gpu=f"L40S:{N_GPU}",
62+
# gpu=f"A10G:{N_GPU}",
63+
# how long should we stay up with no requests?
64+
scaledown_window=3 * MINUTES,
65+
volumes={
66+
"/root/.cache/huggingface": hf_cache_vol,
67+
"/root/.cache/vllm": vllm_cache_vol,
68+
},
69+
secrets=[modal.Secret.from_name("document-qa-embedding-key")]
70+
)
71+
@modal.concurrent(
72+
max_inputs=5
73+
) # how many requests can one replica handle? tune carefully!
74+
@modal.fastapi_endpoint(method="POST")
75+
def embed(request: Request, text: Annotated[str, Form()]):
76+
api_key = request.headers.get("x-api-key")
77+
expected_key = os.environ["API_KEY"]
78+
79+
if api_key != expected_key:
80+
raise HTTPException(status_code=401, detail="Unauthorized")
81+
82+
83+
texts = [t for t in text.split("\n") if t.strip()]
84+
if not texts:
85+
return []
86+
87+
tokenizer, model, device = load_model()
88+
model.eval()
89+
90+
print(f"Start embedding {len(texts)} texts")
91+
try:
92+
with torch.no_grad():
93+
# Move inputs to the same device as model
94+
batch_dict = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
95+
batch_dict = {k: v.to(device) for k, v in batch_dict.items()}
96+
97+
# Forward pass
98+
outputs = model(**batch_dict)
99+
100+
# Process embeddings
101+
embeddings = average_pool(
102+
outputs.last_hidden_state,
103+
batch_dict['attention_mask']
104+
)
105+
embeddings = F.normalize(embeddings, p=2, dim=1)
106+
107+
# Move to CPU and convert to list for serialization
108+
embeddings = embeddings.cpu().numpy().tolist()
109+
110+
print("Finished embedding texts.")
111+
return embeddings
112+
113+
except RuntimeError as e:
114+
print(f"Error during embedding: {str(e)}")
115+
if "CUDA out of memory" in str(e):
116+
print("CUDA out of memory error. Try reducing batch size or using a smaller model.")
117+
raise

streamlit_app.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
import dotenv
77
from grobid_quantities.quantities import QuantitiesAPI
88
from langchain.memory import ConversationBufferMemory
9-
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpointEmbeddings
109
from langchain_openai import ChatOpenAI
1110
from streamlit_pdf_viewer import pdf_viewer
1211

12+
from document_qa.custom_embeddings import ModalEmbeddings
1313
from document_qa.ner_client_generic import NERClientGeneric
1414

1515
dotenv.load_dotenv(override=True)
@@ -19,11 +19,11 @@
1919
from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_text_with_annotations
2020

2121
API_MODELS = {
22-
"microsoft/Phi-4-mini-instruct": os.environ["MODAL_1_URL"]
22+
"microsoft/Phi-4-mini-instruct": os.environ["LM_URL"]
2323
}
2424

2525
API_EMBEDDINGS = {
26-
'intfloat/multilingual-e5-large-instruct': 'intfloat/multilingual-e5-large-instruct'
26+
'intfloat/multilingual-e5-large-instruct-modal': os.environ['EMBEDS_URL']
2727
}
2828

2929
if 'rqa' not in st.session_state:
@@ -112,6 +112,7 @@ def new_file():
112112
st.session_state['loaded_embeddings'] = None
113113
st.session_state['doc_id'] = None
114114
st.session_state['uploaded'] = True
115+
st.session_state['annotations'] = []
115116
if st.session_state['memory']:
116117
st.session_state['memory'].clear()
117118

@@ -133,8 +134,10 @@ def init_qa(model_name, embeddings_name):
133134
api_key=os.environ.get('API_KEY')
134135
)
135136

136-
embeddings = HuggingFaceEndpointEmbeddings(
137-
repo_id=API_EMBEDDINGS[embeddings_name]
137+
embeddings = ModalEmbeddings(
138+
url=API_EMBEDDINGS[embeddings_name],
139+
model_name=embeddings_name,
140+
api_key=os.environ.get('EMBEDS_API_KEY')
138141
)
139142

140143
storage = DataStorage(embeddings)
@@ -195,7 +198,7 @@ def play_old_messages(container):
195198
st.markdown("Upload a scientific article in PDF, ask questions, get insights.")
196199
st.markdown(
197200
":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa?tab=readme-ov-file#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: ")
198-
st.markdown("Powered by [Huggingface](https://huggingface.co) and [Modal.com](https://modal.com/)")
201+
st.markdown("LM and Embeddings are powered by [Modal.com](https://modal.com/)")
199202

200203
st.divider()
201204
st.session_state['model'] = model = st.selectbox(

0 commit comments

Comments
 (0)