Skip to content

Commit

Permalink
Merge pull request #4 from ittia-research/dev
Browse files Browse the repository at this point in the history
[Feature] get verdict of each source separately
  • Loading branch information
etwk authored Aug 5, 2024
2 parents efc1809 + b871435 commit 5d0e761
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 178 deletions.
3 changes: 2 additions & 1 deletion .env
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
EMBEDDING_MODEL_DEPLOY=api
EMBEDDING_MODEL_NAME=jinaai/jina-embeddings-v2-base-en
LLM_MODEL_NAME=google/gemma-2-27b-it
OLLAMA_BASE_URL=http://ollama:11434
OPENAI_API_KEY=sk-proj-aaaaaaaaaaaaaaaaa
OPENAI_BASE_URL=http://localhost:8000/v1
PROJECT_HOSTING_BASE_URL=http://127.0.0.1:8000
RAG_MODEL_DEPLOY=local
RERANK_MODEL_DEPLOY=local
RERANK_MODEL_NAME=BAAI/bge-reranker-v2-m3
RERANK_BASE_URL=http://xinference:9997/v1
SEARCH_BASE_URL=https://s.jina.ai
Expand Down
3 changes: 3 additions & 0 deletions docs/experience.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
## Prompt
- If there is unclear reasoning or logic in a prompt, which might have a huge impact on LLM's reasoning ability.

142 changes: 73 additions & 69 deletions src/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.llms import MockLLM

Settings.llm = MockLLM() # retrieve only, do not use LLM for synthesize
Settings.llm = MockLLM(max_tokens=256) # retrieve only, do not use LLM for synthesize

from settings import settings

Expand All @@ -27,57 +27,6 @@
# todo: high lantency between client and the ollama embedding server will slow down embedding a lot
from llama_index.embeddings.ollama import OllamaEmbedding

def build_automerging_index(
documents,
chunk_sizes=None,
):
chunk_sizes = chunk_sizes or [2048, 512, 128]

if settings.RAG_MODEL_DEPLOY == "local":
embed_model="local:" + settings.EMBEDDING_MODEL_NAME
else:
embed_model = OllamaEmbedding(
model_name=settings.EMBEDDING_MODEL_NAME,
base_url=os.environ.get("OLLAMA_BASE_URL"), # todo: any other configs here?
)

node_parser = HierarchicalNodeParser.from_defaults(chunk_sizes=chunk_sizes)
nodes = node_parser.get_nodes_from_documents(documents)
leaf_nodes = get_leaf_nodes(nodes)
merging_context = ServiceContext.from_defaults(
embed_model=embed_model,
)
storage_context = StorageContext.from_defaults()
storage_context.docstore.add_documents(nodes)

automerging_index = VectorStoreIndex(
leaf_nodes, storage_context=storage_context, service_context=merging_context
)
return automerging_index

def get_automerging_query_engine(
automerging_index,
similarity_top_k=12,
rerank_top_n=6,
):
base_retriever = automerging_index.as_retriever(similarity_top_k=similarity_top_k)
retriever = AutoMergingRetriever(
base_retriever, automerging_index.storage_context, verbose=True
)

if settings.RAG_MODEL_DEPLOY == "local":
rerank = SentenceTransformerRerank(
top_n=rerank_top_n, model=settings.RERANK_MODEL_NAME,
) # todo: add support `trust_remote_code=True`
else:
rerank = jinaai_rerank.JinaRerank(api_key='', top_n=rerank_top_n, model=settings.RERANK_MODEL_NAME)

auto_merging_engine = RetrieverQueryEngine.from_args(
retriever, node_postprocessors=[rerank]
)

return auto_merging_engine

def nodes2list(nodes):
nodes_list = []
for ind, source_node in enumerate(nodes):
Expand All @@ -87,22 +36,77 @@ def nodes2list(nodes):
_sub['text'] = source_node.node.get_content().strip()
nodes_list.append(_sub)
return nodes_list

def get_contexts(statement, keywords, text):
"""
Get list of contexts.

Todo: resources re-use for multiple run.
"""
document = Document(text=text)
index = build_automerging_index(
[document],
chunk_sizes=settings.RAG_CHUNK_SIZES,
) # todo: will it better to use retriever directly?

query_engine = get_automerging_query_engine(index, similarity_top_k=16)
query = f"{keywords} | {statement}" # todo: better way
auto_merging_response = query_engine.query(query)
contexts = nodes2list(auto_merging_response.source_nodes)
class Index():
def build_automerging_index(
self,
documents,
chunk_sizes=None,
):
chunk_sizes = chunk_sizes or [2048, 512, 128]

return contexts
# todo: improve embedding performance
if settings.EMBEDDING_MODEL_DEPLOY == "local":
embed_model="local:" + settings.EMBEDDING_MODEL_NAME
else:
embed_model = OllamaEmbedding(
model_name=settings.EMBEDDING_MODEL_NAME,
base_url=os.environ.get("OLLAMA_BASE_URL"), # todo: any other configs here?
)

node_parser = HierarchicalNodeParser.from_defaults(chunk_sizes=chunk_sizes)
nodes = node_parser.get_nodes_from_documents(documents)
leaf_nodes = get_leaf_nodes(nodes)
merging_context = ServiceContext.from_defaults(
embed_model=embed_model,
)
storage_context = StorageContext.from_defaults()
storage_context.docstore.add_documents(nodes)

automerging_index = VectorStoreIndex(
leaf_nodes, storage_context=storage_context, service_context=merging_context
)
return automerging_index

def get_automerging_query_engine(
self,
automerging_index,
similarity_top_k=12,
rerank_top_n=6,
):
base_retriever = automerging_index.as_retriever(similarity_top_k=similarity_top_k)
retriever = AutoMergingRetriever(
base_retriever, automerging_index.storage_context, verbose=True
)

if settings.RERANK_MODEL_DEPLOY == "local":
rerank = SentenceTransformerRerank(
top_n=rerank_top_n, model=settings.RERANK_MODEL_NAME,
) # todo: add support `trust_remote_code=True`
else:
rerank = jinaai_rerank.JinaRerank(api_key='', top_n=rerank_top_n, model=settings.RERANK_MODEL_NAME)

auto_merging_engine = RetrieverQueryEngine.from_args(
retriever, node_postprocessors=[rerank]
)

return auto_merging_engine

def get_contexts(self, statement, keywords, text):
"""
Get list of contexts.
Todo: resources re-use for multiple run.
"""
document = Document(text=text)
index = self.build_automerging_index(
[document],
chunk_sizes=settings.RAG_CHUNK_SIZES,
) # todo: will it better to use retriever directly?

query_engine = self.get_automerging_query_engine(index, similarity_top_k=16, rerank_top_n=3)
query = f"{keywords} | {statement}" # todo: better way
auto_merging_response = query_engine.query(query)
contexts = nodes2list(auto_merging_response.source_nodes)

return contexts
140 changes: 119 additions & 21 deletions src/llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from openai import OpenAI
import concurrent.futures
import logging

import utils
from index import Index
from settings import settings

"""
Expand Down Expand Up @@ -60,37 +62,133 @@ def get_search_keywords(statement):
reply = get_llm_reply(prompt)
return reply.strip()

def get_verdict(statement, contexts):
def get_context_prompt(contexts):
prompt = "Context:"

for ind, node in enumerate(contexts):
_text = node.get('text')
if not _text:
continue
prompt = f"""{prompt}
```
{_text}
```"""
return prompt

def get_verdict_single(statement, contexts):
# This prompt allows model to use their own knowedge
system_message = '''You are a helpful AI assistant.
Solve tasks using your fact-check skills.
You will be given a statement followed by some contexts.
Use the contexts and facts you know to check if the statements are true, false, or uncheckable.
Ignore contexts that are irrelevant or stale.
Provide detailed reasons for your verdict, ensuring that each reason is supported by corresponding facts.
Be thorough in your explanations, avoiding any duplication of information.
You will be given a statement followed by context.
Use the contexts and facts you know to check if the statements are true, false, irrelevant.
Provide detailed reason for your verdict, ensuring that each reason is supported by corresponding facts.
Provide the response as JSON with the structure:{verdict, reason}'''

prompt = f'''{system_message}
```
Statement:
{statement}
```
Contexts:'''

for ind, node in enumerate(contexts):
_text = node.get('text')
if not _text:
continue
prompt = f"""{prompt}
```
Context {ind + 1}:
{_text}
```"""

{get_context_prompt(contexts)}'''

reply = get_llm_reply(prompt)
logging.info(f"Verdict reply from LLM: {reply}")
logging.debug(f"Verdict reply from LLM: {reply}")
verdict = utils.llm2json(reply)
if verdict:
verdict['statement'] = statement

return verdict

def get_verdict_summary(verdicts, statement):
"""
Calculate and summarize the verdicts of multiple sources.
Introduce some weights:
- total: the number of total verdicts
- valid: the number of verdicts in the desired categories
- winning: the count of the winning verdict
- and count of verdicts of each desiered categories
"""

weight_total = 0
weight_valid = 0
sum_score = 0
sum_reason = {
"true": {"reason": [], "weight": 0},
"false": {"reason": [], "weight": 0},
"irrelevant": {"reason": [], "weight": 0},
}

for verdict in verdicts:
weight_total += 1
v = verdict['verdict']
if v in sum_reason:
weight_valid += 1
reason = f"{verdict['reason']}\nsource url: {verdict['url']}\nsource title: {verdict['title']}\n\n"
sum_reason[v]['reason'].append(reason)
sum_reason[v]['weight'] += 1
if v == 'true':
sum_score += 1
elif v == 'false':
sum_score -= 1

if sum_score > 0:
verdict = "true"
elif sum_score < 0:
verdict = "false"
else:
verdict = "irrelevant"

reason = ''.join(sum_reason[verdict]['reason'])
if not reason:
raise Exception("No reason found after summary")

weights = {"total": weight_total, "valid": weight_valid, "winning": sum_reason[verdict]['weight']}
for key in sum_reason.keys():
weights[key] = sum_reason[key]['weight']

return {"verdict": verdict, "reason": reason, "weights": weights, "statement": statement}

def get_verdict(statement, keywords, search_json):
"""
Get verdit from every one of the context sources
"""
verdicts = []

def process_result(result):
content = utils.clear_md_links(result.get('content'))

try:
contexts = Index().get_contexts(statement, keywords, content)
except Exception as e:
logging.warning(f"Getting contexts failed: {e}")
return None

try:
verdict = get_verdict_single(statement, contexts)
except Exception as e:
logging.warning(f"Getting verdit failed: {e}")
return None

verdict = {
"verdict": verdict.get('verdict', '').lower(),
"reason": verdict.get('reason'),
"url": result.get('url'),
"title": result.get('title'),
}

logging.debug(f"single source verdict: {verdict}")
return verdict

with concurrent.futures.ThreadPoolExecutor(max_workers=12) as executor:
# Use a list comprehension to submit tasks and collect results
future_to_result = {executor.submit(process_result, result): result for result in search_json['data']}

for future in concurrent.futures.as_completed(future_to_result):
verdict = future.result()
if verdict is not None:
verdicts.append(verdict)

summary = get_verdict_summary(verdicts, statement)
return summary

14 changes: 6 additions & 8 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import json
from fastapi import FastAPI, HTTPException, Request
from fastapi.concurrency import run_in_threadpool
from fastapi.responses import Response, JSONResponse, HTMLResponse, PlainTextResponse, FileResponse
import logging

import llm, index, utils
import llm, utils
from index import Index

logging.basicConfig(
level=logging.INFO,
Expand Down Expand Up @@ -36,15 +38,11 @@ async def fact_check(input):
if not search:
fail_search = True
continue
logger.info(f"head of search results: {search[0:200]}")
contexts = await run_in_threadpool(index.get_contexts, statement, keywords, search)
if not contexts:
continue
logger.info(f"contexts: {contexts}")
verdict = await run_in_threadpool(llm.get_verdict, statement, contexts)
logger.info(f"head of search results: {json.dumps(search)[0:500]}")
verdict = await run_in_threadpool(llm.get_verdict, statement, keywords, search)
if not verdict:
continue
logger.info(f"verdict: {verdict}")
logger.info(f"final verdict: {verdict}")
verdicts.append(verdict)

if not verdicts:
Expand Down
7 changes: 4 additions & 3 deletions src/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ def __init__(self):
self.PROJECT_HOSTING_BASE_URL = os.environ.get("PROJECT_HOSTING_BASE_URL") or "https://check.ittia.net"
self.SEARCH_BASE_URL = os.environ.get("SEARCH_BASE_URL") or "https://s.jina.ai"

# set RAG model deploy mode
self.RAG_MODEL_DEPLOY = os.environ.get("RAG_MODEL_DEPLOY") or "local"
# set RAG models deploy mode
self.EMBEDDING_MODEL_DEPLOY = os.environ.get("EMBEDDING_MODEL_DEPLOY") or "local"
self.RERANK_MODEL_DEPLOY = os.environ.get("RERANK_MODEL_DEPLOY") or "local"

# set RAG chunk sizes
self.RAG_CHUNK_SIZES = [4096, 1024, 256]
self.RAG_CHUNK_SIZES = [1024, 256]
_chunk_sizes = os.environ.get("RAG_CHUNK_SIZES")
try:
self.RAG_CHUNK_SIZES = ast.literal_eval(_chunk_sizes)
Expand Down
Loading

0 comments on commit 5d0e761

Please sign in to comment.