Skip to content

Commit

Permalink
Merge pull request #12 from ittia-research/dev
Browse files Browse the repository at this point in the history
add script to create dataset based on HotPotQA, update infra
  • Loading branch information
etwk authored Aug 19, 2024
2 parents 147b2b7 + 6f2e268 commit 43541bc
Show file tree
Hide file tree
Showing 10 changed files with 459 additions and 54 deletions.
14 changes: 0 additions & 14 deletions .env

This file was deleted.

399 changes: 399 additions & 0 deletions datasets/HotPotQA/HotPotQA_statement_verdict.ipynb

Large diffs are not rendered by default.

22 changes: 12 additions & 10 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ services:
dockerfile: Dockerfile
container_name: check
env_file:
- .env
- ./infra/env.d/check
ports:
- 8000:8000
restart: always
Expand All @@ -23,17 +23,19 @@ services:
count: all
capabilities: [gpu]
restart: always
xinference:
build:
context: infra/xinference
dockerfile: Dockerfile
container_name: xinference

# infinity supports embedding and rerank models, v2 version supports serving multiple models
infinity:
image: michaelf34/infinity:latest
container_name: infinity
ports:
- 9997:9997
- 7997:7997
volumes:
- /data/volumes/xinference:/data
environment:
- XINFERENCE_HOME=/data
- /data/cache/huggingface:/cache/huggingface
env_file:
- ./infra/env.d/infinity
- ./infra/env.d/huggingface
command: ["v2"]
deploy:
resources:
reservations:
Expand Down
19 changes: 19 additions & 0 deletions infra/env.d/check
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
EMBEDDING_API_KEY=<CHANGE_ME>
EMBEDDING_BASE_URL=http://ollama:11434
EMBEDDING_MODEL_DEPLOY=api
EMBEDDING_MODEL_NAME=jina/jina-embeddings-v2-base-en
INDEX_CHUNK_SIZES=[2048, 512, 128]
THREAD_BUILD_INDEX=12

LLM_MODEL_NAME=google/gemma-2-27b-it
OPENAI_API_KEY=<CHANGE_ME>
OPENAI_BASE_URL=http://localhost:8000/v1

RERANK_API_KEY=<CHANGE_ME>
RERANK_BASE_URL=http://infinity:7997
RERANK_MODEL_DEPLOY=api
RERANK_MODEL_NAME=jinaai/jina-reranker-v2-base-multilingual

SEARCH_BASE_URL=https://s.jina.ai

PROJECT_HOSTING_BASE_URL=http://127.0.0.1:8000
3 changes: 3 additions & 0 deletions infra/env.d/huggingface
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
HF_HOME=/cache/huggingface
HUGGING_FACE_HUB_TOKEN=<CHANGE_ME>

4 changes: 4 additions & 0 deletions infra/env.d/infinity
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
INFINITY_API_KEY=<CHANGE_ME>
INFINITY_LOG_LEVEL=trace
INFINITY_MODEL_ID=jinaai/jina-reranker-v2-base-multilingual

8 changes: 0 additions & 8 deletions infra/xinference/Dockerfile

This file was deleted.

14 changes: 0 additions & 14 deletions infra/xinference/init.sh

This file was deleted.

28 changes: 20 additions & 8 deletions src/modules/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
LlamaIndexCustomRetriever
"""

import os, logging
from typing import Optional
import logging
import concurrent.futures
from typing import Optional

from llama_index.core import (
Document,
Expand All @@ -23,8 +23,7 @@
import utils
from settings import settings

import llama_index.postprocessor.jinaai_rerank.base as jinaai_rerank # todo: shall we lock package version?
jinaai_rerank.API_URL = settings.RERANK_BASE_URL + "/rerank" # switch to on-premise
from llama_index.postprocessor.jinaai_rerank import JinaRerank

# todo: high lantency between client and the ollama embedding server will slow down embedding a lot
from . import OllamaEmbedding
Expand All @@ -33,9 +32,10 @@
if settings.EMBEDDING_MODEL_DEPLOY == "local":
embed_model="local:" + settings.EMBEDDING_MODEL_NAME
else:
# TODO: debug Ollama embedding with chunk size [4096, 2048, 1024] compare to local
embed_model = OllamaEmbedding(
model_name=settings.EMBEDDING_MODEL_NAME,
base_url=os.environ.get("OLLAMA_BASE_URL"), # todo: any other configs here?
base_url=settings.EMBEDDING_BASE_URL,
)
Settings.embed_model = embed_model

Expand Down Expand Up @@ -116,7 +116,12 @@ def get_automerging_query_engine(
top_n=rerank_top_n, model=settings.RERANK_MODEL_NAME,
) # TODO: add support `trust_remote_code=True`
else:
rerank = jinaai_rerank.JinaRerank(api_key='', top_n=rerank_top_n, model=settings.RERANK_MODEL_NAME)
rerank = JinaRerank(
base_url = settings.RERANK_BASE_URL,
api_key=settings.RERANK_API_KEY,
top_n=rerank_top_n,
model=settings.RERANK_MODEL_NAME,
)

auto_merging_engine = RetrieverQueryEngine.from_args(
retriever, node_postprocessors=[rerank]
Expand All @@ -134,15 +139,22 @@ def build_index(self, docs):
) # TODO: try to retrieve directly

def retrieve(self, query):
rerank_top_n=self.similarity_top_k
query_engine = self.get_automerging_query_engine(
automerging_index=self.index,
storage_context=self.storage_context,
similarity_top_k=self.similarity_top_k * 3,
rerank_top_n=self.similarity_top_k
similarity_top_k=rerank_top_n * 3,
rerank_top_n=rerank_top_n
)
self.query_engine = query_engine
auto_merging_response = self.query_engine.query(query)
contexts = utils.llama_index_nodes_to_list(auto_merging_response.source_nodes)

# select top_n here because some rerank services does not support the feature
if len(contexts) > rerank_top_n:
contexts.sort(key=lambda x: x['score'], reverse=True) # sort by score in descending order
contexts = contexts[:rerank_top_n]

return contexts

import dspy
Expand Down
2 changes: 2 additions & 0 deletions src/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ def __init__(self):
self.RERANK_MODEL_NAME = os.environ.get("RERANK_MODEL_NAME") or "BAAI/bge-reranker-v2-m3"

self.OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL") or "https://api.openai.com/v1"
self.EMBEDDING_BASE_URL = os.environ.get("EMBEDDING_BASE_URL") or "http://ollama:11434"
self.RERANK_BASE_URL = os.environ.get("RERANK_BASE_URL") or "http://xinference:9997/v1"
self.PROJECT_HOSTING_BASE_URL = os.environ.get("PROJECT_HOSTING_BASE_URL") or "https://check.ittia.net"
self.SEARCH_BASE_URL = os.environ.get("SEARCH_BASE_URL") or "https://s.jina.ai"
Expand All @@ -27,5 +28,6 @@ def __init__(self):

# keys
self.EMBEDDING_API_KEY = os.environ.get("EMBEDDING_API_KEY") or ""
self.RERANK_API_KEY = os.environ.get("RERANK_API_KEY") or ""

settings = Settings()

0 comments on commit 43541bc

Please sign in to comment.