diff --git a/.env b/.env index 22018fa..76a413b 100644 --- a/.env +++ b/.env @@ -1,10 +1,11 @@ +EMBEDDING_MODEL_DEPLOY=api EMBEDDING_MODEL_NAME=jinaai/jina-embeddings-v2-base-en LLM_MODEL_NAME=google/gemma-2-27b-it OLLAMA_BASE_URL=http://ollama:11434 OPENAI_API_KEY=sk-proj-aaaaaaaaaaaaaaaaa OPENAI_BASE_URL=http://localhost:8000/v1 PROJECT_HOSTING_BASE_URL=http://127.0.0.1:8000 -RAG_MODEL_DEPLOY=local +RERANK_MODEL_DEPLOY=local RERANK_MODEL_NAME=BAAI/bge-reranker-v2-m3 RERANK_BASE_URL=http://xinference:9997/v1 SEARCH_BASE_URL=https://s.jina.ai diff --git a/docs/experience.md b/docs/experience.md new file mode 100644 index 0000000..a8f7cb9 --- /dev/null +++ b/docs/experience.md @@ -0,0 +1,3 @@ +## Prompt +- If there is unclear reasoning or logic in a prompt, which might have a huge impact on LLM's reasoning ability. + diff --git a/src/index.py b/src/index.py index ef82bcb..cb03b24 100644 --- a/src/index.py +++ b/src/index.py @@ -17,7 +17,7 @@ from llama_index.core.query_engine import RetrieverQueryEngine from llama_index.core.llms import MockLLM -Settings.llm = MockLLM() # retrieve only, do not use LLM for synthesize +Settings.llm = MockLLM(max_tokens=256) # retrieve only, do not use LLM for synthesize from settings import settings @@ -27,57 +27,6 @@ # todo: high lantency between client and the ollama embedding server will slow down embedding a lot from llama_index.embeddings.ollama import OllamaEmbedding -def build_automerging_index( - documents, - chunk_sizes=None, -): - chunk_sizes = chunk_sizes or [2048, 512, 128] - - if settings.RAG_MODEL_DEPLOY == "local": - embed_model="local:" + settings.EMBEDDING_MODEL_NAME - else: - embed_model = OllamaEmbedding( - model_name=settings.EMBEDDING_MODEL_NAME, - base_url=os.environ.get("OLLAMA_BASE_URL"), # todo: any other configs here? - ) - - node_parser = HierarchicalNodeParser.from_defaults(chunk_sizes=chunk_sizes) - nodes = node_parser.get_nodes_from_documents(documents) - leaf_nodes = get_leaf_nodes(nodes) - merging_context = ServiceContext.from_defaults( - embed_model=embed_model, - ) - storage_context = StorageContext.from_defaults() - storage_context.docstore.add_documents(nodes) - - automerging_index = VectorStoreIndex( - leaf_nodes, storage_context=storage_context, service_context=merging_context - ) - return automerging_index - -def get_automerging_query_engine( - automerging_index, - similarity_top_k=12, - rerank_top_n=6, -): - base_retriever = automerging_index.as_retriever(similarity_top_k=similarity_top_k) - retriever = AutoMergingRetriever( - base_retriever, automerging_index.storage_context, verbose=True - ) - - if settings.RAG_MODEL_DEPLOY == "local": - rerank = SentenceTransformerRerank( - top_n=rerank_top_n, model=settings.RERANK_MODEL_NAME, - ) # todo: add support `trust_remote_code=True` - else: - rerank = jinaai_rerank.JinaRerank(api_key='', top_n=rerank_top_n, model=settings.RERANK_MODEL_NAME) - - auto_merging_engine = RetrieverQueryEngine.from_args( - retriever, node_postprocessors=[rerank] - ) - - return auto_merging_engine - def nodes2list(nodes): nodes_list = [] for ind, source_node in enumerate(nodes): @@ -87,22 +36,77 @@ def nodes2list(nodes): _sub['text'] = source_node.node.get_content().strip() nodes_list.append(_sub) return nodes_list - -def get_contexts(statement, keywords, text): - """ - Get list of contexts. - Todo: resources re-use for multiple run. - """ - document = Document(text=text) - index = build_automerging_index( - [document], - chunk_sizes=settings.RAG_CHUNK_SIZES, - ) # todo: will it better to use retriever directly? - - query_engine = get_automerging_query_engine(index, similarity_top_k=16) - query = f"{keywords} | {statement}" # todo: better way - auto_merging_response = query_engine.query(query) - contexts = nodes2list(auto_merging_response.source_nodes) +class Index(): + def build_automerging_index( + self, + documents, + chunk_sizes=None, + ): + chunk_sizes = chunk_sizes or [2048, 512, 128] - return contexts \ No newline at end of file + # todo: improve embedding performance + if settings.EMBEDDING_MODEL_DEPLOY == "local": + embed_model="local:" + settings.EMBEDDING_MODEL_NAME + else: + embed_model = OllamaEmbedding( + model_name=settings.EMBEDDING_MODEL_NAME, + base_url=os.environ.get("OLLAMA_BASE_URL"), # todo: any other configs here? + ) + + node_parser = HierarchicalNodeParser.from_defaults(chunk_sizes=chunk_sizes) + nodes = node_parser.get_nodes_from_documents(documents) + leaf_nodes = get_leaf_nodes(nodes) + merging_context = ServiceContext.from_defaults( + embed_model=embed_model, + ) + storage_context = StorageContext.from_defaults() + storage_context.docstore.add_documents(nodes) + + automerging_index = VectorStoreIndex( + leaf_nodes, storage_context=storage_context, service_context=merging_context + ) + return automerging_index + + def get_automerging_query_engine( + self, + automerging_index, + similarity_top_k=12, + rerank_top_n=6, + ): + base_retriever = automerging_index.as_retriever(similarity_top_k=similarity_top_k) + retriever = AutoMergingRetriever( + base_retriever, automerging_index.storage_context, verbose=True + ) + + if settings.RERANK_MODEL_DEPLOY == "local": + rerank = SentenceTransformerRerank( + top_n=rerank_top_n, model=settings.RERANK_MODEL_NAME, + ) # todo: add support `trust_remote_code=True` + else: + rerank = jinaai_rerank.JinaRerank(api_key='', top_n=rerank_top_n, model=settings.RERANK_MODEL_NAME) + + auto_merging_engine = RetrieverQueryEngine.from_args( + retriever, node_postprocessors=[rerank] + ) + + return auto_merging_engine + + def get_contexts(self, statement, keywords, text): + """ + Get list of contexts. + + Todo: resources re-use for multiple run. + """ + document = Document(text=text) + index = self.build_automerging_index( + [document], + chunk_sizes=settings.RAG_CHUNK_SIZES, + ) # todo: will it better to use retriever directly? + + query_engine = self.get_automerging_query_engine(index, similarity_top_k=16, rerank_top_n=3) + query = f"{keywords} | {statement}" # todo: better way + auto_merging_response = query_engine.query(query) + contexts = nodes2list(auto_merging_response.source_nodes) + + return contexts \ No newline at end of file diff --git a/src/llm.py b/src/llm.py index d406e41..51eb2e0 100644 --- a/src/llm.py +++ b/src/llm.py @@ -1,7 +1,9 @@ from openai import OpenAI +import concurrent.futures import logging import utils +from index import Index from settings import settings """ @@ -60,37 +62,133 @@ def get_search_keywords(statement): reply = get_llm_reply(prompt) return reply.strip() -def get_verdict(statement, contexts): +def get_context_prompt(contexts): + prompt = "Context:" + + for ind, node in enumerate(contexts): + _text = node.get('text') + if not _text: + continue + prompt = f"""{prompt} + +``` +{_text} +```""" + return prompt + +def get_verdict_single(statement, contexts): # This prompt allows model to use their own knowedge system_message = '''You are a helpful AI assistant. Solve tasks using your fact-check skills. -You will be given a statement followed by some contexts. -Use the contexts and facts you know to check if the statements are true, false, or uncheckable. -Ignore contexts that are irrelevant or stale. -Provide detailed reasons for your verdict, ensuring that each reason is supported by corresponding facts. -Be thorough in your explanations, avoiding any duplication of information. +You will be given a statement followed by context. +Use the contexts and facts you know to check if the statements are true, false, irrelevant. +Provide detailed reason for your verdict, ensuring that each reason is supported by corresponding facts. Provide the response as JSON with the structure:{verdict, reason}''' prompt = f'''{system_message} + ``` Statement: {statement} ``` -Contexts:''' - - for ind, node in enumerate(contexts): - _text = node.get('text') - if not _text: - continue - prompt = f"""{prompt} - ``` - Context {ind + 1}: - {_text} - ```""" - + +{get_context_prompt(contexts)}''' + reply = get_llm_reply(prompt) - logging.info(f"Verdict reply from LLM: {reply}") + logging.debug(f"Verdict reply from LLM: {reply}") verdict = utils.llm2json(reply) - if verdict: - verdict['statement'] = statement + return verdict + +def get_verdict_summary(verdicts, statement): + """ + Calculate and summarize the verdicts of multiple sources. + Introduce some weights: + - total: the number of total verdicts + - valid: the number of verdicts in the desired categories + - winning: the count of the winning verdict + - and count of verdicts of each desiered categories + """ + + weight_total = 0 + weight_valid = 0 + sum_score = 0 + sum_reason = { + "true": {"reason": [], "weight": 0}, + "false": {"reason": [], "weight": 0}, + "irrelevant": {"reason": [], "weight": 0}, + } + + for verdict in verdicts: + weight_total += 1 + v = verdict['verdict'] + if v in sum_reason: + weight_valid += 1 + reason = f"{verdict['reason']}\nsource url: {verdict['url']}\nsource title: {verdict['title']}\n\n" + sum_reason[v]['reason'].append(reason) + sum_reason[v]['weight'] += 1 + if v == 'true': + sum_score += 1 + elif v == 'false': + sum_score -= 1 + + if sum_score > 0: + verdict = "true" + elif sum_score < 0: + verdict = "false" + else: + verdict = "irrelevant" + + reason = ''.join(sum_reason[verdict]['reason']) + if not reason: + raise Exception("No reason found after summary") + + weights = {"total": weight_total, "valid": weight_valid, "winning": sum_reason[verdict]['weight']} + for key in sum_reason.keys(): + weights[key] = sum_reason[key]['weight'] + + return {"verdict": verdict, "reason": reason, "weights": weights, "statement": statement} + +def get_verdict(statement, keywords, search_json): + """ + Get verdit from every one of the context sources + """ + verdicts = [] + + def process_result(result): + content = utils.clear_md_links(result.get('content')) + + try: + contexts = Index().get_contexts(statement, keywords, content) + except Exception as e: + logging.warning(f"Getting contexts failed: {e}") + return None + + try: + verdict = get_verdict_single(statement, contexts) + except Exception as e: + logging.warning(f"Getting verdit failed: {e}") + return None + + verdict = { + "verdict": verdict.get('verdict', '').lower(), + "reason": verdict.get('reason'), + "url": result.get('url'), + "title": result.get('title'), + } + + logging.debug(f"single source verdict: {verdict}") + return verdict + + with concurrent.futures.ThreadPoolExecutor(max_workers=12) as executor: + # Use a list comprehension to submit tasks and collect results + future_to_result = {executor.submit(process_result, result): result for result in search_json['data']} + + for future in concurrent.futures.as_completed(future_to_result): + verdict = future.result() + if verdict is not None: + verdicts.append(verdict) + + summary = get_verdict_summary(verdicts, statement) + return summary + \ No newline at end of file diff --git a/src/main.py b/src/main.py index 8218439..698b51e 100644 --- a/src/main.py +++ b/src/main.py @@ -1,9 +1,11 @@ +import json from fastapi import FastAPI, HTTPException, Request from fastapi.concurrency import run_in_threadpool from fastapi.responses import Response, JSONResponse, HTMLResponse, PlainTextResponse, FileResponse import logging -import llm, index, utils +import llm, utils +from index import Index logging.basicConfig( level=logging.INFO, @@ -36,15 +38,11 @@ async def fact_check(input): if not search: fail_search = True continue - logger.info(f"head of search results: {search[0:200]}") - contexts = await run_in_threadpool(index.get_contexts, statement, keywords, search) - if not contexts: - continue - logger.info(f"contexts: {contexts}") - verdict = await run_in_threadpool(llm.get_verdict, statement, contexts) + logger.info(f"head of search results: {json.dumps(search)[0:500]}") + verdict = await run_in_threadpool(llm.get_verdict, statement, keywords, search) if not verdict: continue - logger.info(f"verdict: {verdict}") + logger.info(f"final verdict: {verdict}") verdicts.append(verdict) if not verdicts: diff --git a/src/settings.py b/src/settings.py index 41e3227..e4ad452 100644 --- a/src/settings.py +++ b/src/settings.py @@ -12,11 +12,12 @@ def __init__(self): self.PROJECT_HOSTING_BASE_URL = os.environ.get("PROJECT_HOSTING_BASE_URL") or "https://check.ittia.net" self.SEARCH_BASE_URL = os.environ.get("SEARCH_BASE_URL") or "https://s.jina.ai" - # set RAG model deploy mode - self.RAG_MODEL_DEPLOY = os.environ.get("RAG_MODEL_DEPLOY") or "local" + # set RAG models deploy mode + self.EMBEDDING_MODEL_DEPLOY = os.environ.get("EMBEDDING_MODEL_DEPLOY") or "local" + self.RERANK_MODEL_DEPLOY = os.environ.get("RERANK_MODEL_DEPLOY") or "local" # set RAG chunk sizes - self.RAG_CHUNK_SIZES = [4096, 1024, 256] + self.RAG_CHUNK_SIZES = [1024, 256] _chunk_sizes = os.environ.get("RAG_CHUNK_SIZES") try: self.RAG_CHUNK_SIZES = ast.literal_eval(_chunk_sizes) diff --git a/src/utils.py b/src/utils.py index 4b7fd57..b8bb071 100644 --- a/src/utils.py +++ b/src/utils.py @@ -32,7 +32,7 @@ def llm2json(text): async def search(keywords): """ - Constructs a URL from given keywords and search via Jina Reader. + Search and get a list of websites content. Todo: - Enhance response clear. @@ -46,17 +46,15 @@ async def search(keywords): async with aiohttp.ClientSession() as session: try: async with session.get(constructed_url, headers=headers) as response: - response_data = await response.json() - response_code = response_data.get('code') - if response_code != 200: - raise Exception(f"Search response code: {response_code}") - text = "\n\n".join([doc['content'] for doc in response_data['data']]) - result = clear_md_links(text) + rep = await response.json() + rep_code = rep.get('code') + if rep_code != 200: + raise Exception(f"Search response code: {rep_code}") except Exception as e: logging.error(f"Search '{keywords}' failed: {e}") - result = '' + rep = {} - return result + return rep def clear_md_links(text): """ @@ -87,76 +85,12 @@ def generate_report_markdown(input_text, verdicts): markdown.append(f"### Statement {i}\n") markdown.append(f"**Statement**: {verdict['statement']}\n") markdown.append(f"**Verdict**: `{verdict['verdict']}`\n") - markdown.append(f"**Reason**: {verdict['reason']}\n") + markdown.append(f"**Weight**: {verdict['weights']['winning']} out of {verdict['weights']['valid']} ({verdict['weights']['irrelevant']} irrelevant)\n") + markdown.append(f"**Reason**:\n\n{verdict['reason']}\n") markdown_str = "\n".join(markdown) return markdown_str -def generate_report_html(input_text, verdicts): - html = [] - - # Add basic HTML structure and styles - html.append(""" - - -
- - - - """) - - # Add original input - html.append("" + input_text + "\n") - - # Add verdicts - html.append("
{verdict["verdict"]}