From 9e11187afd9949dd111e9f52e6fe4449dde11d1e Mon Sep 17 00:00:00 2001 From: etwk <48991073+etwk@users.noreply.github.com> Date: Thu, 15 Aug 2024 01:02:03 +0000 Subject: [PATCH] add api key support to LlamaIndex OllamaEmbedding --- .env | 1 + requirements.base.txt | 1 - src/dspy_modules.py | 1 + src/ollama_embedding.py | 145 ++++++++++++++++++++++++++++++++++++++++ src/retrieve.py | 5 +- src/settings.py | 3 + 6 files changed, 153 insertions(+), 3 deletions(-) create mode 100644 src/ollama_embedding.py diff --git a/.env b/.env index 5f92aed..1abf8eb 100644 --- a/.env +++ b/.env @@ -1,3 +1,4 @@ +EMBEDDING_API_KEY=ollama:abc EMBEDDING_MODEL_DEPLOY=api EMBEDDING_MODEL_NAME=jina/jina-embeddings-v2-base-en LLM_MODEL_NAME=google/gemma-2-27b-it diff --git a/requirements.base.txt b/requirements.base.txt index 2cb77d2..5809d9e 100644 --- a/requirements.base.txt +++ b/requirements.base.txt @@ -2,7 +2,6 @@ aiohttp dspy-ai fastapi llama-index -llama-index-embeddings-ollama llama-index-postprocessor-jinaai-rerank openai uvicorn \ No newline at end of file diff --git a/src/dspy_modules.py b/src/dspy_modules.py index 07d230e..3df8a27 100644 --- a/src/dspy_modules.py +++ b/src/dspy_modules.py @@ -19,6 +19,7 @@ class GenerateSearchQuery(dspy.Signature): statement = dspy.InputField() query = dspy.OutputField() +# TODO: citation needs higher token limits class GenerateCitedParagraph(dspy.Signature): """Generate a paragraph with citations.""" context = dspy.InputField(desc="may contain relevant facts") diff --git a/src/ollama_embedding.py b/src/ollama_embedding.py new file mode 100644 index 0000000..20ad812 --- /dev/null +++ b/src/ollama_embedding.py @@ -0,0 +1,145 @@ +""" +Added API key support. + +Source: https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/embeddings/llama-index-embeddings-ollama/llama_index/embeddings/ollama/base.py +""" + +import httpx +from typing import Any, Dict, List, Optional + +from llama_index.core.base.embeddings.base import BaseEmbedding +from llama_index.core.bridge.pydantic import Field +from llama_index.core.callbacks.base import CallbackManager +from llama_index.core.constants import DEFAULT_EMBED_BATCH_SIZE + +from settings import settings + +class OllamaEmbedding(BaseEmbedding): + """Class for Ollama embeddings.""" + + base_url: str = Field(description="Base url the model is hosted by Ollama") + model_name: str = Field(description="The Ollama model to use.") + embed_batch_size: int = Field( + default=DEFAULT_EMBED_BATCH_SIZE, + description="The batch size for embedding calls.", + gt=0, + lte=2048, + ) + ollama_additional_kwargs: Dict[str, Any] = Field( + default_factory=dict, description="Additional kwargs for the Ollama API." + ) + + def __init__( + self, + model_name: str, + base_url: str = "http://localhost:11434", + embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE, + ollama_additional_kwargs: Optional[Dict[str, Any]] = None, + callback_manager: Optional[CallbackManager] = None, + **kwargs: Any, + ) -> None: + super().__init__( + model_name=model_name, + base_url=base_url, + embed_batch_size=embed_batch_size, + ollama_additional_kwargs=ollama_additional_kwargs or {}, + callback_manager=callback_manager, + ) + + @classmethod + def class_name(cls) -> str: + return "OllamaEmbedding" + + def _get_query_embedding(self, query: str) -> List[float]: + """Get query embedding.""" + return self.get_general_text_embedding(query) + + async def _aget_query_embedding(self, query: str) -> List[float]: + """The asynchronous version of _get_query_embedding.""" + return await self.aget_general_text_embedding(query) + + def _get_text_embedding(self, text: str) -> List[float]: + """Get text embedding.""" + return self.get_general_text_embedding(text) + + async def _aget_text_embedding(self, text: str) -> List[float]: + """Asynchronously get text embedding.""" + return await self.aget_general_text_embedding(text) + + def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: + """Get text embeddings.""" + embeddings_list: List[List[float]] = [] + for text in texts: + embeddings = self.get_general_text_embedding(text) + embeddings_list.append(embeddings) + + return embeddings_list + + async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: + """Asynchronously get text embeddings.""" + return self._aget_text_embeddings(texts) + + def get_general_text_embedding(self, prompt: str) -> List[float]: + """Get Ollama embedding.""" + try: + import requests + except ImportError: + raise ImportError( + "Could not import requests library." + "Please install requests with `pip install requests`" + ) + + ollama_request_body = { + "prompt": prompt, + "model": self.model_name, + "options": self.ollama_additional_kwargs, + } + + response = requests.post( + url=f"{self.base_url}/api/embeddings", + headers={"Content-Type": "application/json", "Authorization": f"Bearer {settings.EMBEDDING_API_KEY}"}, + json=ollama_request_body, + ) + response.encoding = "utf-8" + if response.status_code != 200: + optional_detail = response.json().get("error") + raise ValueError( + f"Ollama call failed with status code {response.status_code}." + f" Details: {optional_detail}" + ) + + try: + return response.json()["embedding"] + except requests.exceptions.JSONDecodeError as e: + raise ValueError( + f"Error raised for Ollama Call: {e}.\nResponse: {response.text}" + ) + + async def aget_general_text_embedding(self, prompt: str) -> List[float]: + """Asynchronously get Ollama embedding.""" + async with httpx.AsyncClient() as client: + ollama_request_body = { + "prompt": prompt, + "model": self.model_name, + "options": self.ollama_additional_kwargs, + } + + response = await client.post( + url=f"{self.base_url}/api/embeddings", + headers={"Content-Type": "application/json", "Authorization": f"Bearer {settings.EMBEDDING_API_KEY}"}, + json=ollama_request_body, + ) + response.encoding = "utf-8" + if response.status_code != 200: + optional_detail = response.json().get("error") + raise ValueError( + f"Ollama call failed with status code {response.status_code}." + f" Details: {optional_detail}" + ) + + try: + return response.json()["embedding"] + except httpx.HTTPStatusError as e: + raise ValueError( + f"Error raised for Ollama Call: {e}.\nResponse: {response.text}" + ) \ No newline at end of file diff --git a/src/retrieve.py b/src/retrieve.py index 9de6c1d..31ad6cf 100644 --- a/src/retrieve.py +++ b/src/retrieve.py @@ -29,7 +29,7 @@ jinaai_rerank.API_URL = settings.RERANK_BASE_URL + "/rerank" # switch to on-premise # todo: high lantency between client and the ollama embedding server will slow down embedding a lot -from llama_index.embeddings.ollama import OllamaEmbedding +from ollama_embedding import OllamaEmbedding # todo: improve embedding performance if settings.EMBEDDING_MODEL_DEPLOY == "local": @@ -111,7 +111,8 @@ def get_automerging_query_engine( retriever = AutoMergingRetriever( base_retriever, storage_context, verbose=True ) - + + # TODO: load model files at app start if settings.RERANK_MODEL_DEPLOY == "local": rerank = SentenceTransformerRerank( top_n=rerank_top_n, model=settings.RERANK_MODEL_NAME, diff --git a/src/settings.py b/src/settings.py index 7f9794e..1481a1d 100644 --- a/src/settings.py +++ b/src/settings.py @@ -26,5 +26,8 @@ def __init__(self): # threads self.THREAD_BUILD_INDEX = int(os.environ.get("THREAD_BUILD_INDEX", 12)) + + # keys + self.EMBEDDING_API_KEY = os.environ.get("EMBEDDING_API_KEY") or "" settings = Settings() \ No newline at end of file