From 9750179e787c86deb71e0136eea3ada70b1d6c0f Mon Sep 17 00:00:00 2001 From: etwk <48991073+etwk@users.noreply.github.com> Date: Mon, 19 Aug 2024 08:46:17 +0000 Subject: [PATCH 1/6] lock some pip packages --- requirements.base.txt | 6 +++--- requirements.local.txt | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/requirements.base.txt b/requirements.base.txt index 5809d9e..934ee32 100644 --- a/requirements.base.txt +++ b/requirements.base.txt @@ -1,7 +1,7 @@ aiohttp -dspy-ai +dspy-ai==2.4.13 fastapi -llama-index -llama-index-postprocessor-jinaai-rerank +llama-index==0.10.65 +llama-index-postprocessor-jinaai-rerank==0.1.7 openai uvicorn \ No newline at end of file diff --git a/requirements.local.txt b/requirements.local.txt index 69dc005..d78c31a 100644 --- a/requirements.local.txt +++ b/requirements.local.txt @@ -1 +1 @@ -llama-index-embeddings-huggingface \ No newline at end of file +llama-index-embeddings-huggingface==0.2.3 \ No newline at end of file From ac38f59f995d82450d7e219b0480d3e9b8f2567a Mon Sep 17 00:00:00 2001 From: etwk <48991073+etwk@users.noreply.github.com> Date: Mon, 19 Aug 2024 11:55:43 +0000 Subject: [PATCH 2/6] update docker compose --- docker-compose.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index cc61cba..5558b52 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,13 +1,13 @@ services: check: - build: - dockerfile: Dockerfile + image: ittia/check:remote container_name: check env_file: - ./infra/env.d/check ports: - 8000:8000 restart: always + ollama: image: ollama/ollama container_name: ollama @@ -24,7 +24,7 @@ services: capabilities: [gpu] restart: always - # infinity supports embedding and rerank models, v2 version supports serving multiple models + # Infinity supports embedding and rerank models, v2 version supports serving multiple models infinity: image: michaelf34/infinity:latest container_name: infinity From 2204ad27d1be0aaba3212f336c3bda937b370faf Mon Sep 17 00:00:00 2001 From: etwk <48991073+etwk@users.noreply.github.com> Date: Mon, 19 Aug 2024 11:56:30 +0000 Subject: [PATCH 3/6] add retry to search --- src/main.py | 10 ++++++++-- src/modules/__init__.py | 1 + src/modules/search.py | 30 ++++++++++++++++++++++++++++++ src/utils.py | 32 ++++---------------------------- 4 files changed, 43 insertions(+), 30 deletions(-) create mode 100644 src/modules/search.py diff --git a/src/main.py b/src/main.py index 1c45c18..aab6967 100644 --- a/src/main.py +++ b/src/main.py @@ -5,6 +5,7 @@ import logging import utils, pipeline +from modules import Search logging.basicConfig( level=logging.INFO, @@ -33,10 +34,15 @@ async def fact_check(input): if not query: continue logger.info(f"search query: {query}") - search = await utils.search(query) - if not search: + + # searching + try: + search = await Search(query) + except Exception as e: fail_search = True + logger.error(f"Search '{query}' failed: {e}") continue + logger.info(f"head of search results: {json.dumps(search)[0:500]}") verdict = await run_in_threadpool(pipeline.get_verdict, search_json=search, statement=statement) if not verdict: diff --git a/src/modules/__init__.py b/src/modules/__init__.py index b04b2df..14b304b 100644 --- a/src/modules/__init__.py +++ b/src/modules/__init__.py @@ -12,6 +12,7 @@ from .citation import Citation from .ollama_embedding import OllamaEmbedding from .retrieve import LlamaIndexRM +from .search import Search from .search_query import SearchQuery from .statements import Statements from .verdict import Verdict \ No newline at end of file diff --git a/src/modules/search.py b/src/modules/search.py new file mode 100644 index 0000000..2bf3b2b --- /dev/null +++ b/src/modules/search.py @@ -0,0 +1,30 @@ +import aiohttp +from tenacity import retry, stop_after_attempt, wait_fixed + +from settings import settings +import utils + +@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.1), before_sleep=utils.retry_log_warning, reraise=True) +async def Search(keywords): + """ + Search and get a list of websites content. + + Todo: + - Enhance response clear. + """ + constructed_url = settings.SEARCH_BASE_URL + '/' + keywords + + headers = { + "Accept": "application/json" + } + + async with aiohttp.ClientSession() as session: + async with session.get(constructed_url, headers=headers) as response: + rep = await response.json() + if not rep: + raise Exception(f"Search '{keywords}' result empty") + rep_code = rep.get('code') + if rep_code != 200: + raise Exception(f"Search '{keywords}' response code: {rep_code}") + + return rep \ No newline at end of file diff --git a/src/utils.py b/src/utils.py index 7dc6864..4c28df4 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,5 +1,4 @@ import re, json -import aiohttp import itertools import logging from llama_index.core import Document @@ -17,32 +16,6 @@ def llm2json(text): logging.warning(f"Failed convert LLM response to JSON: {e}") pass return json_object - -async def search(keywords): - """ - Search and get a list of websites content. - - Todo: - - Enhance response clear. - """ - constructed_url = settings.SEARCH_BASE_URL + '/' + keywords - - headers = { - "Accept": "application/json" - } - - async with aiohttp.ClientSession() as session: - try: - async with session.get(constructed_url, headers=headers) as response: - rep = await response.json() - rep_code = rep.get('code') - if rep_code != 200: - raise Exception(f"Search response code: {rep_code}") - except Exception as e: - logging.error(f"Search '{keywords}' failed: {e}") - rep = {} - - return rep def clear_md_links(text): """ @@ -160,4 +133,7 @@ def search_json_to_docs(search_json): } document = Document(text=content, metadata=metadata) documents.append(document) - return documents \ No newline at end of file + return documents + +def retry_log_warning(retry_state): + logging.warning(f"Retrying attempt {retry_state.attempt_number} due to: {retry_state.outcome.exception()}") From 68ad0cdf78d41195ff52673197ac0d7eda13f468 Mon Sep 17 00:00:00 2001 From: etwk <48991073+etwk@users.noreply.github.com> Date: Mon, 19 Aug 2024 13:31:02 +0000 Subject: [PATCH 4/6] add more retry --- src/main.py | 36 ++++++++++++++++++++++++------------ src/pipeline/common.py | 24 ++++++++++-------------- 2 files changed, 34 insertions(+), 26 deletions(-) diff --git a/src/main.py b/src/main.py index aab6967..d2791d3 100644 --- a/src/main.py +++ b/src/main.py @@ -19,9 +19,13 @@ async def fact_check(input): status = 500 logger.info(f"Fact checking: {input}") - statements = await run_in_threadpool(pipeline.get_statements, input) - logger.info(f"statements: {statements}") - if not statements: + + # get list of statements + try: + statements = await run_in_threadpool(pipeline.get_statements, input) + logger.info(f"statements: {statements}") + except Exception as e: + logger.error(f"Get statements failed: {e}") raise HTTPException(status_code=status, detail="No statements found") verdicts = [] @@ -29,25 +33,33 @@ async def fact_check(input): for statement in statements: if not statement: continue - logger.info(f"statement: {statement}") - query = await run_in_threadpool(pipeline.get_search_query, statement) - if not query: + logger.info(f"Statement: {statement}") + + # get search query + try: + query = await run_in_threadpool(pipeline.get_search_query, statement) + logger.info(f"Search query: {query}") + except Exception as e: + logger.error(f"Getting search query from statement '{statement}' failed: {e}") continue - logger.info(f"search query: {query}") # searching try: search = await Search(query) + logger.info(f"Head of search results: {json.dumps(search)[0:500]}") except Exception as e: fail_search = True logger.error(f"Search '{query}' failed: {e}") continue - - logger.info(f"head of search results: {json.dumps(search)[0:500]}") - verdict = await run_in_threadpool(pipeline.get_verdict, search_json=search, statement=statement) - if not verdict: + + # get verdict + try: + verdict = await run_in_threadpool(pipeline.get_verdict, search_json=search, statement=statement) + logger.info(f"Verdict: {verdict}") + except Exception as e: + logger.error(f"Getting verdict for statement {statement} failed: {e}") continue - logger.info(f"final verdict: {verdict}") + verdicts.append(verdict) if not verdicts: diff --git a/src/pipeline/common.py b/src/pipeline/common.py index 9ed3e4a..5c8e5a3 100644 --- a/src/pipeline/common.py +++ b/src/pipeline/common.py @@ -1,29 +1,25 @@ import logging import utils +from tenacity import retry, stop_after_attempt, wait_fixed + from modules import SearchQuery, Statements from .verdict_citation import VerdictCitation +# Statements has retry set already, do not retry here def get_statements(content): """Get list of statements from a text string""" - try: - statements = Statements()(content=content) - except Exception as e: - logging.error(f"Getting statements failed: {e}") - statements = [] - + + statements = Statements()(content=content) return statements +@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.1), before_sleep=utils.retry_log_warning, reraise=True) def get_search_query(statement): """Get search query from one statement""" - - try: - query = SearchQuery()(statement=statement) - except Exception as e: - logging.error(f"Getting search query from statement '{statement}' failed: {e}") - query = "" - - return query + query = SearchQuery()(statement=statement) + return query + +@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.1), before_sleep=utils.retry_log_warning, reraise=True) def get_verdict(search_json, statement): docs = utils.search_json_to_docs(search_json) rep = VerdictCitation(docs=docs).get(statement=statement) From 0675dc411027d2081dd3e0fa455bdf516d45be4c Mon Sep 17 00:00:00 2001 From: etwk <48991073+etwk@users.noreply.github.com> Date: Tue, 20 Aug 2024 00:40:33 +0000 Subject: [PATCH 5/6] update documents metadata process --- src/utils.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/utils.py b/src/utils.py index 4c28df4..092a974 100644 --- a/src/utils.py +++ b/src/utils.py @@ -123,15 +123,22 @@ def llama_index_nodes_to_list(nodes): return nodes_list def search_json_to_docs(search_json): - """Search JSON results to Llama-Index documents""" + """ + Search JSON results to Llama-Index documents + + Do not add metadata for now + cause LlamaIndex uses `node.get_content(metadata_mode=MetadataMode.EMBED)` which addeds metadata to text for generate embeddings + + TODO: pr to llama-index for metadata_mode setting + """ documents = [] for result in search_json['data']: content = clear_md_links(result.get('content')) - metadata = { - "url": result.get('url'), - "title": result.get('title'), - } - document = Document(text=content, metadata=metadata) + # metadata = { + # "url": result.get('url'), + # "title": result.get('title'), + # } + document = Document(text=content) # metadata=metadata documents.append(document) return documents From 42ec7aa0b9a24975d08fae80dc1304a73be7ea34 Mon Sep 17 00:00:00 2001 From: etwk <48991073+etwk@users.noreply.github.com> Date: Tue, 20 Aug 2024 13:53:31 +0000 Subject: [PATCH 6/6] integrate Infinity embedding, add batch to Ollama embedding, retrieve use async --- infra/env.d/check | 1 - requirements.base.txt | 1 + src/_types.py | 22 ++++ src/integrations/__init__.py | 2 + src/integrations/infinity_embedding.py | 132 ++++++++++++++++++++++ src/integrations/ollama_embedding.py | 133 +++++++++++++++++++++++ src/modules/ollama_embedding.py | 145 ------------------------- src/modules/retrieve.py | 61 +++-------- src/settings.py | 3 - 9 files changed, 305 insertions(+), 195 deletions(-) create mode 100644 src/_types.py create mode 100644 src/integrations/__init__.py create mode 100644 src/integrations/infinity_embedding.py create mode 100644 src/integrations/ollama_embedding.py delete mode 100644 src/modules/ollama_embedding.py diff --git a/infra/env.d/check b/infra/env.d/check index 8a2069f..6dbc4b6 100644 --- a/infra/env.d/check +++ b/infra/env.d/check @@ -3,7 +3,6 @@ EMBEDDING_BASE_URL=http://ollama:11434 EMBEDDING_MODEL_DEPLOY=api EMBEDDING_MODEL_NAME=jina/jina-embeddings-v2-base-en INDEX_CHUNK_SIZES=[2048, 512, 128] -THREAD_BUILD_INDEX=12 LLM_MODEL_NAME=google/gemma-2-27b-it OPENAI_API_KEY= diff --git a/requirements.base.txt b/requirements.base.txt index 934ee32..4c231da 100644 --- a/requirements.base.txt +++ b/requirements.base.txt @@ -1,6 +1,7 @@ aiohttp dspy-ai==2.4.13 fastapi +httpx[http2] llama-index==0.10.65 llama-index-postprocessor-jinaai-rerank==0.1.7 openai diff --git a/src/_types.py b/src/_types.py new file mode 100644 index 0000000..81f7fd2 --- /dev/null +++ b/src/_types.py @@ -0,0 +1,22 @@ +import json + +# reference: https://github.com/ollama/ollama-python/blob/main/ollama/_types.py +class ResponseError(Exception): + """ + Common class for response errors. + """ + + def __init__(self, error: str, status_code: int = -1): + try: + # try to parse content as JSON and extract 'error' + # fallback to raw content if JSON parsing fails + error = json.loads(error).get('error', error) + except json.JSONDecodeError: + ... + + super().__init__(error) + self.error = error + 'Reason for the error.' + + self.status_code = status_code + 'HTTP status code of the response.' \ No newline at end of file diff --git a/src/integrations/__init__.py b/src/integrations/__init__.py new file mode 100644 index 0000000..2276b56 --- /dev/null +++ b/src/integrations/__init__.py @@ -0,0 +1,2 @@ +from .infinity_embedding import InfinityEmbedding +from .ollama_embedding import OllamaEmbedding \ No newline at end of file diff --git a/src/integrations/infinity_embedding.py b/src/integrations/infinity_embedding.py new file mode 100644 index 0000000..6599b04 --- /dev/null +++ b/src/integrations/infinity_embedding.py @@ -0,0 +1,132 @@ +# reference: https://github.com/ollama/ollama-python/blob/main/ollama/_client.py + +import os +import httpx +from typing import Any, List +from llama_index.core.base.embeddings.base import BaseEmbedding +from llama_index.core.bridge.pydantic import PrivateAttr +from tenacity import retry, stop_after_attempt, wait_fixed + +import utils +from _types import ResponseError + +DEFAULT_INFINITY_BASE_URL = "http://localhost:7997" + +class InfinityEmbedding(BaseEmbedding): + """Class for Infinity embeddings. + + Using retry here cause one failed request could crash the whole embedding process. + + Args: + api_key (str): Server API key. + model_name (str): Model for embedding. + base_url (str): Infinity url. Defaults to http://localhost:7997. + """ + + _aclient: httpx.AsyncClient = PrivateAttr() + _client: httpx.Client = PrivateAttr() + _settings: dict = PrivateAttr() + _url: str = PrivateAttr() + + def __init__( + self, + model_name: str, + api_key: str = "key", + base_url: str = DEFAULT_INFINITY_BASE_URL, + http2: bool = True, + follow_redirects: bool = True, + timeout: Any = None, + **kwargs: Any, + ) -> None: + super().__init__( + model_name=model_name, + **kwargs, + ) + + self._settings = { + 'follow_redirects': follow_redirects, + 'headers': { + 'Content-Type': 'application/json', + 'Accept': 'application/json', + 'Authorization': f"Bearer {api_key}", + }, + 'http2': http2, + 'timeout': timeout, + } + + self._url = os.path.join(base_url, "embeddings") + + @classmethod + def class_name(cls) -> str: + return "InfinityEmbedding" + + def _get_client(self, _async: bool = False): + """Set and return httpx sync or async client""" + if _async: + if not hasattr(self, "_aclient"): + self._aclient = httpx.AsyncClient(**self._settings) + return self._aclient + else: + if not hasattr(self, "_client"): + self._client = httpx.Client(**self._settings) + return self._client + + def _process_response(self, response: httpx.Response) -> List[List[float]]: + embeddings = [item['embedding'] for item in response.json()['data']] + return embeddings + + @retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5), before_sleep=utils.retry_log_warning, reraise=True) + def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: + """Get text embeddings.""" + client = self._get_client() + response = client.request( + 'POST', + self._url, + json={ + "input": texts, + "model": self.model_name, + }, + ) + + try: + response.raise_for_status() + except httpx.HTTPStatusError as e: + raise ResponseError(e.response.text, e.response.status_code) from None + + return self._process_response(response) + + @retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5), before_sleep=utils.retry_log_warning, reraise=True) + async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: + """Asynchronously get text embeddings.""" + client = self._get_client(_async=True) + response = await client.request( + 'POST', + self._url, + json={ + "input": texts, + "model": self.model_name, + }, + ) + + try: + response.raise_for_status() + except httpx.HTTPStatusError as e: + raise ResponseError(e.response.text, e.response.status_code) from None + + return self._process_response(response) + + def _get_query_embedding(self, query: str) -> List[float]: + """Get query embedding.""" + return self._get_text_embeddings([query])[0] + + async def _aget_query_embedding(self, query: str) -> List[float]: + """The asynchronous version of _get_query_embedding.""" + return await self._aget_text_embeddings([query])[0] + + def _get_text_embedding(self, text: str) -> List[float]: + """Get text embedding.""" + return self._get_text_embeddings([text])[0] + + async def _aget_text_embedding(self, text: str) -> List[float]: + """Asynchronously get text embedding.""" + return await self._aget_text_embeddings([text])[0] \ No newline at end of file diff --git a/src/integrations/ollama_embedding.py b/src/integrations/ollama_embedding.py new file mode 100644 index 0000000..418a332 --- /dev/null +++ b/src/integrations/ollama_embedding.py @@ -0,0 +1,133 @@ +# reference: https://github.com/ollama/ollama-python/blob/main/ollama/_client.py + +import os +import httpx +from typing import Any, List +from llama_index.core.base.embeddings.base import BaseEmbedding +from llama_index.core.bridge.pydantic import PrivateAttr +from tenacity import retry, stop_after_attempt, wait_fixed + +import utils +from _types import ResponseError + +DEFAULT_OLLAMA_BASE_URL = "http://localhost:11434" + +class OllamaEmbedding(BaseEmbedding): + """Class for Ollama embeddings. + + Using retry here cause one failed request could crash the whole embedding process. + + Args: + api_key (str): Server API key. + model_name (str): Model for embedding. + base_url (str): Ollama url. Defaults to http://localhost:7997. + """ + + _aclient: httpx.AsyncClient = PrivateAttr() + _client: httpx.Client = PrivateAttr() + _settings: dict = PrivateAttr() + _url: str = PrivateAttr() + + def __init__( + self, + model_name: str, + api_key: str = "key", + base_url: str = DEFAULT_OLLAMA_BASE_URL, + http2: bool = True, + follow_redirects: bool = True, + timeout: Any = None, + **kwargs: Any, + ) -> None: + super().__init__( + model_name=model_name, + **kwargs, + ) + + self._settings = { + 'follow_redirects': follow_redirects, + 'headers': { + 'Content-Type': 'application/json', + 'Accept': 'application/json', + 'Authorization': f"Bearer {api_key}", + }, + 'http2': http2, + 'timeout': timeout, + } + + self._url = os.path.join(base_url, "api/embed") + + @classmethod + def class_name(cls) -> str: + return "OllamaEmbedding" + + def _get_client(self, _async: bool = False): + """Set and return httpx sync or async client""" + if _async: + if not hasattr(self, "_aclient"): + self._aclient = httpx.AsyncClient(**self._settings) + return self._aclient + else: + if not hasattr(self, "_client"): + self._client = httpx.Client(**self._settings) + return self._client + + def _process_response(self, response: httpx.Response) -> List[List[float]]: + embeddings = response.json()['embeddings'] + return embeddings + + @retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5), before_sleep=utils.retry_log_warning, reraise=True) + def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: + """Get text embeddings.""" + client = self._get_client() + response = client.request( + 'POST', + self._url, + json={ + "input": texts, + "model": self.model_name, + }, + ) + + try: + response.raise_for_status() + except httpx.HTTPStatusError as e: + raise ResponseError(e.response.text, e.response.status_code) from None + + return self._process_response(response) + + # TODO: debug `Event loop is closed` + @retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5), before_sleep=utils.retry_log_warning, reraise=True) + async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: + """Asynchronously get text embeddings.""" + client = self._get_client(_async=True) + response = await client.request( + 'POST', + self._url, + json={ + "input": texts, + "model": self.model_name, + }, + ) + + try: + response.raise_for_status() + except httpx.HTTPStatusError as e: + raise ResponseError(e.response.text, e.response.status_code) from None + + return self._process_response(response) + + def _get_query_embedding(self, query: str) -> List[float]: + """Get query embedding.""" + return self._get_text_embeddings([query])[0] + + async def _aget_query_embedding(self, query: str) -> List[float]: + """The asynchronous version of _get_query_embedding.""" + return await self._aget_text_embeddings([query])[0] + + def _get_text_embedding(self, text: str) -> List[float]: + """Get text embedding.""" + return self._get_text_embeddings([text])[0] + + async def _aget_text_embedding(self, text: str) -> List[float]: + """Asynchronously get text embedding.""" + return await self._aget_text_embeddings([text])[0] \ No newline at end of file diff --git a/src/modules/ollama_embedding.py b/src/modules/ollama_embedding.py deleted file mode 100644 index 20ad812..0000000 --- a/src/modules/ollama_embedding.py +++ /dev/null @@ -1,145 +0,0 @@ -""" -Added API key support. - -Source: https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/embeddings/llama-index-embeddings-ollama/llama_index/embeddings/ollama/base.py -""" - -import httpx -from typing import Any, Dict, List, Optional - -from llama_index.core.base.embeddings.base import BaseEmbedding -from llama_index.core.bridge.pydantic import Field -from llama_index.core.callbacks.base import CallbackManager -from llama_index.core.constants import DEFAULT_EMBED_BATCH_SIZE - -from settings import settings - -class OllamaEmbedding(BaseEmbedding): - """Class for Ollama embeddings.""" - - base_url: str = Field(description="Base url the model is hosted by Ollama") - model_name: str = Field(description="The Ollama model to use.") - embed_batch_size: int = Field( - default=DEFAULT_EMBED_BATCH_SIZE, - description="The batch size for embedding calls.", - gt=0, - lte=2048, - ) - ollama_additional_kwargs: Dict[str, Any] = Field( - default_factory=dict, description="Additional kwargs for the Ollama API." - ) - - def __init__( - self, - model_name: str, - base_url: str = "http://localhost:11434", - embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE, - ollama_additional_kwargs: Optional[Dict[str, Any]] = None, - callback_manager: Optional[CallbackManager] = None, - **kwargs: Any, - ) -> None: - super().__init__( - model_name=model_name, - base_url=base_url, - embed_batch_size=embed_batch_size, - ollama_additional_kwargs=ollama_additional_kwargs or {}, - callback_manager=callback_manager, - ) - - @classmethod - def class_name(cls) -> str: - return "OllamaEmbedding" - - def _get_query_embedding(self, query: str) -> List[float]: - """Get query embedding.""" - return self.get_general_text_embedding(query) - - async def _aget_query_embedding(self, query: str) -> List[float]: - """The asynchronous version of _get_query_embedding.""" - return await self.aget_general_text_embedding(query) - - def _get_text_embedding(self, text: str) -> List[float]: - """Get text embedding.""" - return self.get_general_text_embedding(text) - - async def _aget_text_embedding(self, text: str) -> List[float]: - """Asynchronously get text embedding.""" - return await self.aget_general_text_embedding(text) - - def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: - """Get text embeddings.""" - embeddings_list: List[List[float]] = [] - for text in texts: - embeddings = self.get_general_text_embedding(text) - embeddings_list.append(embeddings) - - return embeddings_list - - async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: - """Asynchronously get text embeddings.""" - return self._aget_text_embeddings(texts) - - def get_general_text_embedding(self, prompt: str) -> List[float]: - """Get Ollama embedding.""" - try: - import requests - except ImportError: - raise ImportError( - "Could not import requests library." - "Please install requests with `pip install requests`" - ) - - ollama_request_body = { - "prompt": prompt, - "model": self.model_name, - "options": self.ollama_additional_kwargs, - } - - response = requests.post( - url=f"{self.base_url}/api/embeddings", - headers={"Content-Type": "application/json", "Authorization": f"Bearer {settings.EMBEDDING_API_KEY}"}, - json=ollama_request_body, - ) - response.encoding = "utf-8" - if response.status_code != 200: - optional_detail = response.json().get("error") - raise ValueError( - f"Ollama call failed with status code {response.status_code}." - f" Details: {optional_detail}" - ) - - try: - return response.json()["embedding"] - except requests.exceptions.JSONDecodeError as e: - raise ValueError( - f"Error raised for Ollama Call: {e}.\nResponse: {response.text}" - ) - - async def aget_general_text_embedding(self, prompt: str) -> List[float]: - """Asynchronously get Ollama embedding.""" - async with httpx.AsyncClient() as client: - ollama_request_body = { - "prompt": prompt, - "model": self.model_name, - "options": self.ollama_additional_kwargs, - } - - response = await client.post( - url=f"{self.base_url}/api/embeddings", - headers={"Content-Type": "application/json", "Authorization": f"Bearer {settings.EMBEDDING_API_KEY}"}, - json=ollama_request_body, - ) - response.encoding = "utf-8" - if response.status_code != 200: - optional_detail = response.json().get("error") - raise ValueError( - f"Ollama call failed with status code {response.status_code}." - f" Details: {optional_detail}" - ) - - try: - return response.json()["embedding"] - except httpx.HTTPStatusError as e: - raise ValueError( - f"Error raised for Ollama Call: {e}.\nResponse: {response.text}" - ) \ No newline at end of file diff --git a/src/modules/retrieve.py b/src/modules/retrieve.py index 35bc0df..abeb133 100644 --- a/src/modules/retrieve.py +++ b/src/modules/retrieve.py @@ -25,18 +25,19 @@ from llama_index.postprocessor.jinaai_rerank import JinaRerank -# todo: high lantency between client and the ollama embedding server will slow down embedding a lot -from . import OllamaEmbedding +from integrations import OllamaEmbedding # todo: improve embedding performance if settings.EMBEDDING_MODEL_DEPLOY == "local": embed_model="local:" + settings.EMBEDDING_MODEL_NAME else: - # TODO: debug Ollama embedding with chunk size [4096, 2048, 1024] compare to local embed_model = OllamaEmbedding( - model_name=settings.EMBEDDING_MODEL_NAME, + api_key=settings.EMBEDDING_API_KEY, base_url=settings.EMBEDDING_BASE_URL, + embed_batch_size=32, # TODO: what's the best batch size for Ollama + model_name=settings.EMBEDDING_MODEL_NAME, ) + Settings.embed_model = embed_model class LlamaIndexCustomRetriever(): @@ -48,29 +49,6 @@ def __init__( self.similarity_top_k = similarity_top_k if docs: self.build_index(docs) - - def create_index(self, nodes): - storage_context = StorageContext.from_defaults() - storage_context.docstore.add_documents(nodes) - leaf_nodes = get_leaf_nodes(nodes) - return VectorStoreIndex(leaf_nodes, storage_context=storage_context) - - def merge_index(self, indexs): - """ - Args: - - indexs: list of indexs - """ - nodes = [] - for index in indexs: - vector_store_dict = index.storage_context.vector_store.to_dict() - embedding_dict = vector_store_dict['embedding_dict'] - for doc_id, node in index.storage_context.docstore.docs.items(): - # necessary to avoid re-calc of embeddings - node.embedding = embedding_dict[doc_id] - nodes.append(node) - - merged_index = VectorStoreIndex(nodes=nodes) - return merged_index def build_automerging_index( self, @@ -78,36 +56,27 @@ def build_automerging_index( chunk_sizes=[2048, 512, 128], ): node_parser = HierarchicalNodeParser.from_defaults(chunk_sizes=chunk_sizes) - nodes = node_parser.get_nodes_from_documents(documents) - leaf_nodes = get_leaf_nodes(nodes) + self.nodes = node_parser.get_nodes_from_documents(documents) + leaf_nodes = get_leaf_nodes(self.nodes) storage_context = StorageContext.from_defaults() - storage_context.docstore.add_documents(nodes) + storage_context.docstore.add_documents(self.nodes) - leaf_indexs = [] - - # TODO: better concurrency, possibly async - with concurrent.futures.ThreadPoolExecutor(max_workers=settings.THREAD_BUILD_INDEX) as executor: - future_to_index = {executor.submit(self.create_index, [_node]): _node for _node in leaf_nodes} - - for future in concurrent.futures.as_completed(future_to_index): - index = future.result() - leaf_indexs.append(index) - - automerging_index = self.merge_index(leaf_indexs) + automerging_index = VectorStoreIndex( + leaf_nodes, storage_context=storage_context, use_async=True + ) - return automerging_index, storage_context + return automerging_index def get_automerging_query_engine( self, automerging_index, - storage_context, similarity_top_k=6, rerank_top_n=3, ): base_retriever = automerging_index.as_retriever(similarity_top_k=similarity_top_k) retriever = AutoMergingRetriever( - base_retriever, storage_context, verbose=True + base_retriever, automerging_index.storage_context, verbose=True ) # TODO: load model files at app start @@ -133,16 +102,16 @@ def build_index(self, docs): """Initiate index or build a new one.""" if docs: - self.index, self.storage_context = self.build_automerging_index( + self.index = self.build_automerging_index( docs, chunk_sizes=settings.INDEX_CHUNK_SIZES, ) # TODO: try to retrieve directly def retrieve(self, query): + # TODO: get query engine performance costs rerank_top_n=self.similarity_top_k query_engine = self.get_automerging_query_engine( automerging_index=self.index, - storage_context=self.storage_context, similarity_top_k=rerank_top_n * 3, rerank_top_n=rerank_top_n ) diff --git a/src/settings.py b/src/settings.py index 294de5d..286f66a 100644 --- a/src/settings.py +++ b/src/settings.py @@ -23,9 +23,6 @@ def __init__(self): except: self.INDEX_CHUNK_SIZES = [1024, 256] - # threads - self.THREAD_BUILD_INDEX = int(os.environ.get("THREAD_BUILD_INDEX", 12)) - # keys self.EMBEDDING_API_KEY = os.environ.get("EMBEDDING_API_KEY") or "" self.RERANK_API_KEY = os.environ.get("RERANK_API_KEY") or ""