Skip to content

Commit

Permalink
add api key support to LlamaIndex OllamaEmbedding
Browse files Browse the repository at this point in the history
  • Loading branch information
etwk committed Aug 15, 2024
1 parent 6d0389c commit 9e11187
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 3 deletions.
1 change: 1 addition & 0 deletions .env
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
EMBEDDING_API_KEY=ollama:abc
EMBEDDING_MODEL_DEPLOY=api
EMBEDDING_MODEL_NAME=jina/jina-embeddings-v2-base-en
LLM_MODEL_NAME=google/gemma-2-27b-it
Expand Down
1 change: 0 additions & 1 deletion requirements.base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ aiohttp
dspy-ai
fastapi
llama-index
llama-index-embeddings-ollama
llama-index-postprocessor-jinaai-rerank
openai
uvicorn
1 change: 1 addition & 0 deletions src/dspy_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class GenerateSearchQuery(dspy.Signature):
statement = dspy.InputField()
query = dspy.OutputField()

# TODO: citation needs higher token limits
class GenerateCitedParagraph(dspy.Signature):
"""Generate a paragraph with citations."""
context = dspy.InputField(desc="may contain relevant facts")
Expand Down
145 changes: 145 additions & 0 deletions src/ollama_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""
Added API key support.
Source: https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/embeddings/llama-index-embeddings-ollama/llama_index/embeddings/ollama/base.py
"""

import httpx
from typing import Any, Dict, List, Optional

from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.bridge.pydantic import Field
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.constants import DEFAULT_EMBED_BATCH_SIZE

from settings import settings

class OllamaEmbedding(BaseEmbedding):
"""Class for Ollama embeddings."""

base_url: str = Field(description="Base url the model is hosted by Ollama")
model_name: str = Field(description="The Ollama model to use.")
embed_batch_size: int = Field(
default=DEFAULT_EMBED_BATCH_SIZE,
description="The batch size for embedding calls.",
gt=0,
lte=2048,
)
ollama_additional_kwargs: Dict[str, Any] = Field(
default_factory=dict, description="Additional kwargs for the Ollama API."
)

def __init__(
self,
model_name: str,
base_url: str = "http://localhost:11434",
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
ollama_additional_kwargs: Optional[Dict[str, Any]] = None,
callback_manager: Optional[CallbackManager] = None,
**kwargs: Any,
) -> None:
super().__init__(
model_name=model_name,
base_url=base_url,
embed_batch_size=embed_batch_size,
ollama_additional_kwargs=ollama_additional_kwargs or {},
callback_manager=callback_manager,
)

@classmethod
def class_name(cls) -> str:
return "OllamaEmbedding"

def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
return self.get_general_text_embedding(query)

async def _aget_query_embedding(self, query: str) -> List[float]:
"""The asynchronous version of _get_query_embedding."""
return await self.aget_general_text_embedding(query)

def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
return self.get_general_text_embedding(text)

async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding."""
return await self.aget_general_text_embedding(text)

def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings."""
embeddings_list: List[List[float]] = []
for text in texts:
embeddings = self.get_general_text_embedding(text)
embeddings_list.append(embeddings)

return embeddings_list

async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Asynchronously get text embeddings."""
return self._aget_text_embeddings(texts)

def get_general_text_embedding(self, prompt: str) -> List[float]:
"""Get Ollama embedding."""
try:
import requests
except ImportError:
raise ImportError(
"Could not import requests library."
"Please install requests with `pip install requests`"
)

ollama_request_body = {
"prompt": prompt,
"model": self.model_name,
"options": self.ollama_additional_kwargs,
}

response = requests.post(
url=f"{self.base_url}/api/embeddings",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {settings.EMBEDDING_API_KEY}"},
json=ollama_request_body,
)
response.encoding = "utf-8"
if response.status_code != 200:
optional_detail = response.json().get("error")
raise ValueError(
f"Ollama call failed with status code {response.status_code}."
f" Details: {optional_detail}"
)

try:
return response.json()["embedding"]
except requests.exceptions.JSONDecodeError as e:
raise ValueError(
f"Error raised for Ollama Call: {e}.\nResponse: {response.text}"
)

async def aget_general_text_embedding(self, prompt: str) -> List[float]:
"""Asynchronously get Ollama embedding."""
async with httpx.AsyncClient() as client:
ollama_request_body = {
"prompt": prompt,
"model": self.model_name,
"options": self.ollama_additional_kwargs,
}

response = await client.post(
url=f"{self.base_url}/api/embeddings",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {settings.EMBEDDING_API_KEY}"},
json=ollama_request_body,
)
response.encoding = "utf-8"
if response.status_code != 200:
optional_detail = response.json().get("error")
raise ValueError(
f"Ollama call failed with status code {response.status_code}."
f" Details: {optional_detail}"
)

try:
return response.json()["embedding"]
except httpx.HTTPStatusError as e:
raise ValueError(
f"Error raised for Ollama Call: {e}.\nResponse: {response.text}"
)
5 changes: 3 additions & 2 deletions src/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
jinaai_rerank.API_URL = settings.RERANK_BASE_URL + "/rerank" # switch to on-premise

# todo: high lantency between client and the ollama embedding server will slow down embedding a lot
from llama_index.embeddings.ollama import OllamaEmbedding
from ollama_embedding import OllamaEmbedding

# todo: improve embedding performance
if settings.EMBEDDING_MODEL_DEPLOY == "local":
Expand Down Expand Up @@ -111,7 +111,8 @@ def get_automerging_query_engine(
retriever = AutoMergingRetriever(
base_retriever, storage_context, verbose=True
)


# TODO: load model files at app start
if settings.RERANK_MODEL_DEPLOY == "local":
rerank = SentenceTransformerRerank(
top_n=rerank_top_n, model=settings.RERANK_MODEL_NAME,
Expand Down
3 changes: 3 additions & 0 deletions src/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,8 @@ def __init__(self):

# threads
self.THREAD_BUILD_INDEX = int(os.environ.get("THREAD_BUILD_INDEX", 12))

# keys
self.EMBEDDING_API_KEY = os.environ.get("EMBEDDING_API_KEY") or ""

settings = Settings()

0 comments on commit 9e11187

Please sign in to comment.