diff --git a/src/api/__init__.py b/src/api/__init__.py new file mode 100644 index 0000000..5991ecf --- /dev/null +++ b/src/api/__init__.py @@ -0,0 +1,2 @@ +from .fetch import FetchUrl +from .search import SearchWeb diff --git a/src/api/fetch.py b/src/api/fetch.py new file mode 100644 index 0000000..daaca20 --- /dev/null +++ b/src/api/fetch.py @@ -0,0 +1,27 @@ +import httpx +import json +from tenacity import retry, stop_after_attempt, wait_fixed + +import utils +from settings import settings + +client = httpx.AsyncClient(http2=True, follow_redirects=True) + +class FetchUrl(): + """Fetch one single url via API fetch endpoint""" + + def __init__(self, url: str): + self.url = url + self.api = settings.SEARCH_BASE_URL + '/fetch' + self.timeout = 120 # api request timeout, set higher cause api backend might need to try a few times + + @retry(stop=stop_after_attempt(3), wait=wait_fixed(0.1), before_sleep=utils.retry_log_warning, reraise=True) + async def get(self): + _data = { + 'url': self.url, + } + response = await client.post(self.api, json=_data, timeout=self.timeout) + _r = response.json() + if _r['status'] != 'ok': + raise Exception(f"Fetch url return status not ok: {self.url}") + return _r['data'] \ No newline at end of file diff --git a/src/api/search.py b/src/api/search.py new file mode 100644 index 0000000..5bef64f --- /dev/null +++ b/src/api/search.py @@ -0,0 +1,58 @@ +import asyncio +import httpx +import json +from tenacity import retry, stop_after_attempt, wait_fixed + +import utils +from settings import settings + +class SearchWeb(): + """ + Web search with a query with session support: + - get more links following the previous searches + - get all links of this session + """ + def __init__(self, query: str): + self.query = query + self.api = settings.SEARCH_BASE_URL + '/search' + self.timeout = 600 # api request timeout, set higher cause search backend might need to try a few times + + self.client = httpx.AsyncClient(http2=True, follow_redirects=True, timeout=self.timeout) + self.urls = [] # all urls got + + """ + Get JSON data from API stream output. + + TODO: + - Is there a more standard way to process streamed JSON? + """ + @retry(stop=stop_after_attempt(3), wait=wait_fixed(0.1), before_sleep=utils.retry_log_warning, reraise=True) + async def get(self, num: int = 10, all: bool = False): + _data = { + 'query': self.query, + 'num': num, # how many more urls to get + 'all': all, + } + async with self.client.stream("POST", self.api, json=_data) as response: + buffer = "" + async for chunk in response.aiter_text(): + if chunk.strip(): # Only process non-empty chunks + buffer += chunk + + # Attempt to load the buffer as JSON + try: + # Keep loading JSON until all data is consumed + while buffer: + # Try to load a complete JSON object + rep, index = json.JSONDecoder().raw_decode(buffer) + _url = rep['url'] + # deduplication + if _url not in self.urls: # TODO: waht if the new one containes same url but better metadata + self.urls.append(_url) + yield rep + + # Remove the processed part from the buffer + buffer = buffer[index:].lstrip() # Remove processed JSON and any leading whitespace + except json.JSONDecodeError: + # If we encounter an error, we may not have a complete JSON object yet + continue # Continue to read more data diff --git a/src/main.py b/src/main.py index 7bdcf12..472b23c 100644 --- a/src/main.py +++ b/src/main.py @@ -17,67 +17,68 @@ app = FastAPI() -""" -Process input string, fact-check and output MARKDOWN -""" -async def fact_check(input): - status = 500 - logger.info(f"Fact checking: {input}") - - # get list of statements - try: - statements = await run_in_threadpool(pipeline.get_statements, input) - logger.info(f"statements: {statements}") - except Exception as e: - logger.error(f"Get statements failed: {e}") - raise HTTPException(status_code=status, detail="No statements found") - - verdicts = [] - fail_search = False - for statement in statements: - if not statement: - continue - logger.info(f"Statement: {statement}") - - # get search query - try: - query = await run_in_threadpool(pipeline.get_search_query, statement) - logger.info(f"Search query: {query}") - except Exception as e: - logger.error(f"Getting search query from statement '{statement}' failed: {e}") - continue - - # searching - try: - search = await Search(query) - logger.info(f"Head of search results: {json.dumps(search)[0:500]}") - except Exception as e: - fail_search = True - logger.error(f"Search '{query}' failed: {e}") - continue - - # get verdict - try: - verdict = await run_in_threadpool(pipeline.get_verdict, search_json=search, statement=statement) - logger.info(f"Verdict: {verdict}") - except Exception as e: - logger.error(f"Getting verdict for statement '{statement}' failed: {e}") - continue +# """ +# Process input string, fact-check and output MARKDOWN +# """ +# async def fact_check(input): +# status = 500 +# logger.info(f"Fact checking: {input}") + +# # get list of statements +# try: +# statements = await run_in_threadpool(pipeline.get_statements, input) +# logger.info(f"statements: {statements}") +# except Exception as e: +# logger.error(f"Get statements failed: {e}") +# raise HTTPException(status_code=status, detail="No statements found") + +# verdicts = [] +# fail_search = False +# for statement in statements: +# if not statement: +# continue +# logger.info(f"Statement: {statement}") + +# # get search query +# try: +# query = await run_in_threadpool(pipeline.get_search_query, statement) +# logger.info(f"Search query: {query}") +# except Exception as e: +# logger.error(f"Getting search query from statement '{statement}' failed: {e}") +# continue + +# # searching +# try: +# search = await Search(query) +# logger.info(f"Head of search results: {json.dumps(search)[0:500]}") +# except Exception as e: +# fail_search = True +# logger.error(f"Search '{query}' failed: {e}") +# continue + +# # get verdict +# try: +# verdict = await run_in_threadpool(pipeline.get_verdict, search_json=search, statement=statement) +# logger.info(f"Verdict: {verdict}") +# except Exception as e: +# logger.error(f"Getting verdict for statement '{statement}' failed: {e}") +# continue - verdicts.append(verdict) +# verdicts.append(verdict) - if not verdicts: - if fail_search: - raise HTTPException(status_code=status, detail="Search not available") - else: - raise HTTPException(status_code=status, detail="No verdicts found") +# if not verdicts: +# if fail_search: +# raise HTTPException(status_code=status, detail="Search not available") +# else: +# raise HTTPException(status_code=status, detail="No verdicts found") - report = utils.generate_report_markdown(input, verdicts) - return report +# report = utils.generate_report_markdown(input, verdicts) +# return report # TODO: multi-stage response async def stream_response(path): - task = asyncio.create_task(fact_check(path)) + union = pipeline.Union(path) + task = asyncio.create_task(union.final()) # Stream response to prevent timeout, return multi-stage reponses elapsed_time = 0 @@ -109,7 +110,8 @@ async def health(): async def status(): _status = utils.get_status() return _status - + +# TODO: integrade error handle with output @app.get("/{path:path}", response_class=PlainTextResponse) async def catch_all(path: str, accept: str = Header(None)): try: diff --git a/src/modules/retrieve.py b/src/modules/retrieve.py index ab78377..d713aa0 100644 --- a/src/modules/retrieve.py +++ b/src/modules/retrieve.py @@ -158,11 +158,12 @@ def __init__( docs, k: Optional[int] = None, ): - self.retriever = LlamaIndexCustomRetriever(docs=docs) - + self.docs = docs if k: self.k = k + self.retriever = LlamaIndexCustomRetriever(docs=self.docs) + @property def k(self) -> Optional[int]: """Get similarity top k of retriever.""" diff --git a/src/pipeline/__init__.py b/src/pipeline/__init__.py index 79298b2..582142e 100644 --- a/src/pipeline/__init__.py +++ b/src/pipeline/__init__.py @@ -1,2 +1,200 @@ -from .common import get_search_query, get_statements, get_verdict -from .verdict_citation import VerdictCitation +import asyncio +import dspy +import logging +import os +from fastapi.concurrency import run_in_threadpool +from tenacity import retry, stop_after_attempt, wait_fixed +from urllib.parse import urlparse + +import utils +from api import FetchUrl, SearchWeb +from modules import SearchQuery, Statements +from modules import llm_long, Citation, LlamaIndexRM, ContextVerdict +from settings import settings + +# loading compiled ContextVerdict +optimizer_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), f"../optimizers/{settings.OPTIMIZER_FILE_NAME}") +context_verdict = ContextVerdict() +context_verdict.load(optimizer_path) + +class Union(): + """ + Run the full cycle from raw input to verdicts of multiple statements. + Keep data in the class. + + TODO: + - Add support of verdict standards. + - Make betetr use of the other data of web search. + """ + + def __init__(self, input: str): + """Avoid run I/O intense functions here to better support async""" + self.input = input # raw input to analize + self.data = {} # contains all intermediate and final data + + async def final(self): + await self.get_statements() + _task = [asyncio.create_task(self._pipe_statement(data_statement)) for data_statement in self.data.values()] + await asyncio.gather(*_task) + + # update reports + _sum = [v['summary'] for v in self.data.values()] + self.reports = utils.generate_report_markdown(self.input, _sum) + + return self.reports + + async def _pipe_statement(self, data_statement): + """ + Pipeline to process single statement. + Get all links to generate hostname mapping before fetch content and generate verdict citation for each hostname(source). + + TODO: + - Make unit works on URL instead of hostname level. + """ + await self.get_search_query(data_statement) + _updated_sources = await self.update_source_map(data_statement['sources'], data_statement['query']) + _task = [asyncio.create_task(self._pipe_source(data_statement['sources'][source], data_statement['statement'])) for source in _updated_sources] + await asyncio.gather(*_task) + + # update summary + self.update_summary(data_statement) + + async def _pipe_source(self, data_source, statement): + """Update docs and then update retriever, verdic, citation""" + + # update docs + _task_docs = [] + for _, data_doc in data_source['docs'].items(): + if not data_doc.get('doc'): # TODO: better way to decide if update doc + _task_docs.append(asyncio.create_task(self.update_doc(data_doc))) + await asyncio.gather(*_task_docs) # finish all docs processing + + # update retriever + docs = [v['doc'] for v in data_source['docs'].values()] + data_source["retriever"] = await run_in_threadpool(LlamaIndexRM, docs=docs) + + # update verdict, citation + await run_in_threadpool(self.update_verdict_citation, data_source, statement) + + # Statements has retry set already, do not retry here + async def get_statements(self): + """Get list of statements from a text string""" + try: + _dspy = Statements() + self.statements = await run_in_threadpool(_dspy, self.input) + except Exception as e: + logging.error(f"Get statements failed: {e}") + self.statements = [] + + if not self.statements: + raise HTTPException(status_code=500, detail="No statements found") + logging.info(f"statements: {self.statements}") + + # add statements to data with order + for i, v in enumerate(self.statements, start=1): + _key = utils.get_md5(v) + self.data.setdefault(_key, {'order': i, 'statement': v, 'sources': {}}) + + @retry(stop=stop_after_attempt(3), wait=wait_fixed(0.1), before_sleep=utils.retry_log_warning, reraise=True) + async def get_search_query(self, data_statement): + """Get search query for one statement and add to the data""" + _dspy = SearchQuery() + data_statement['query'] = await run_in_threadpool(_dspy, data_statement['statement']) + + async def update_source_map(self, data_sources, query): + """ + Update map of sources(web URLs for now), + add to the data and return list of updated sources. + """ + _updated = [] + _search_web = SearchWeb(query=query) + async for url_dict in _search_web.get(): + url = url_dict.get('url') + if not url: # TODO: necessary? + continue + url_hash = utils.get_md5(url) + hostname = urlparse(url).hostname + data_sources.setdefault(hostname, {}).setdefault('docs', {}).update({url_hash: {'url': url}}) + _updated.append(hostname) if hostname not in _updated else None + return _updated + + async def update_doc(self, data_doc): + """Update doc (URL content for now)""" + _rep = await FetchUrl(url=data_doc['url']).get() + data_doc['raw'] = _rep # dict including URL content and metadata, etc. + data_doc['title'] = _rep['title'] + data_doc['doc'] = utils.search_result_to_doc(_rep) # TODO: better process + + @retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5), before_sleep=utils.retry_log_warning, reraise=True) + def update_verdict_citation(self, data_source, statement): + """Update a single source""" + + with dspy.context(rm=data_source['retriever']): + rep = context_verdict(statement) + + context = rep.context + verdict = rep.answer + + # Use the LLM with higher token limit for citation generation call + with dspy.context(lm=llm_long): + rep = Citation()(statement=statement, context=context, verdict=verdict) + citation = rep.citation + + data_source['context'] = context + data_source['verdict'] = verdict + data_source['citation'] = citation + + def update_summary(self, data_statement): + """ + Calculate and summarize the verdicts of multiple sources. + Introduce some weights: + - total: the number of total verdicts + - valid: the number of verdicts in the desired categories + - winning: the count of the winning verdict + - and count of verdicts of each desiered categories + """ + + weight_total = 0 + weight_valid = 0 + sum_score = 0 + sum_citation = { + "true": {"citation": [], "weight": 0}, + "false": {"citation": [], "weight": 0}, + "irrelevant": {"citation": [], "weight": 0}, + } + + for hostname, verdict in data_statement['sources'].items(): + weight_total += 1 + v = verdict['verdict'].lower() + if v in sum_citation: + weight_valid += 1 + citation = f"{verdict['citation']} *source: {hostname}*\n\n" + sum_citation[v]['citation'].append(citation) + sum_citation[v]['weight'] += 1 + if v == 'true': + sum_score += 1 + elif v == 'false': + sum_score -= 1 + + if sum_score > 0: + verdict = "true" + elif sum_score < 0: + verdict = "false" + else: + verdict = "irrelevant" + + citation = ''.join(sum_citation[verdict]['citation']) + if not citation: + raise Exception("No citation found after summarize") + + weights = {"total": weight_total, "valid": weight_valid, "winning": sum_citation[verdict]['weight']} + for key in sum_citation.keys(): + weights[key] = sum_citation[key]['weight'] + + data_statement['summary'] = { + "verdict": verdict, + "citation": citation, + "weights": weights, + "statement": data_statement['statement'], + } + \ No newline at end of file diff --git a/src/pipeline/common.py b/src/pipeline/common.py deleted file mode 100644 index 0bfb468..0000000 --- a/src/pipeline/common.py +++ /dev/null @@ -1,77 +0,0 @@ -import logging -import utils -from tenacity import retry, stop_after_attempt, wait_fixed - -from modules import SearchQuery, Statements -from .verdict_citation import VerdictCitation - -# Statements has retry set already, do not retry here -def get_statements(content): - """Get list of statements from a text string""" - - statements = Statements()(content=content) - return statements - -@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.1), before_sleep=utils.retry_log_warning, reraise=True) -def get_search_query(statement): - """Get search query from one statement""" - - query = SearchQuery()(statement=statement) - return query - -def get_verdict_summary(verdicts_data, statement): - """ - Calculate and summarize the verdicts of multiple sources. - Introduce some weights: - - total: the number of total verdicts - - valid: the number of verdicts in the desired categories - - winning: the count of the winning verdict - - and count of verdicts of each desiered categories - """ - - weight_total = 0 - weight_valid = 0 - sum_score = 0 - sum_citation = { - "true": {"citation": [], "weight": 0}, - "false": {"citation": [], "weight": 0}, - "irrelevant": {"citation": [], "weight": 0}, - } - - for hostname, verdict in verdicts_data.items(): - weight_total += 1 - v = verdict['verdict'].lower() - if v in sum_citation: - weight_valid += 1 - citation = f"{verdict['citation']} *source: {hostname}*\n\n" - sum_citation[v]['citation'].append(citation) - sum_citation[v]['weight'] += 1 - if v == 'true': - sum_score += 1 - elif v == 'false': - sum_score -= 1 - - if sum_score > 0: - verdict = "true" - elif sum_score < 0: - verdict = "false" - else: - verdict = "irrelevant" - - citation = ''.join(sum_citation[verdict]['citation']) - if not citation: - raise Exception("No citation found after summarize") - - weights = {"total": weight_total, "valid": weight_valid, "winning": sum_citation[verdict]['weight']} - for key in sum_citation.keys(): - weights[key] = sum_citation[key]['weight'] - - return {"verdict": verdict, "citation": citation, "weights": weights, "statement": statement} - -def get_verdict(search_json, statement): - _verdict_citation = VerdictCitation(search_json=search_json) - verdicts_data = _verdict_citation.get(statement=statement) - - summary = get_verdict_summary(verdicts_data, statement) - return summary - \ No newline at end of file diff --git a/src/pipeline/verdict_citation.py b/src/pipeline/verdict_citation.py deleted file mode 100644 index ca1de94..0000000 --- a/src/pipeline/verdict_citation.py +++ /dev/null @@ -1,81 +0,0 @@ -import os -import concurrent.futures -import dspy -from llama_index.core import Document -from tenacity import retry, stop_after_attempt, wait_fixed -from urllib.parse import urlparse - -import utils -from settings import settings -from modules import llm_long -from modules import Citation, LlamaIndexRM, ContextVerdict - -# loading compiled ContextVerdict -optimizer_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), f"../optimizers/{settings.OPTIMIZER_FILE_NAME}") -context_verdict = ContextVerdict() -context_verdict.load(optimizer_path) - -""" -All web pages of the same hostname as one source. -For each sources, get verdict and citation seperately. -""" -class VerdictCitation(): - def __init__( - self, - search_json, - ): - raw = search_json['data'] - self.data = {} # main container - self.update_retriever(raw) - - def update_retriever(self, raw): - update_list = [] - - # update doc - for r in raw: - url = r.get('url') - url_hash = utils.get_md5(url) - hostname = urlparse(url).hostname - doc = utils.search_result_to_doc(r) - self.data.setdefault(hostname, {}).setdefault('docs', {}).update({url_hash: {"doc": doc, "raw": r}}) - self.data[hostname]['new'] = True - update_list.append(hostname) if hostname not in update_list else None - - # update retriever - # TODO: what if we have a lot of small retriever to create and RM server latency high - for hostname in update_list: - docs = [v['doc'] for v in self.data[hostname]['docs'].values()] - self.data[hostname]["retriever"] = LlamaIndexRM(docs=docs) - - @retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5), before_sleep=utils.retry_log_warning, reraise=True) - def _update_verdict_citation_single(self, hostname, statement): - with dspy.context(rm=self.data[hostname]['retriever']): - rep = context_verdict(statement) - - context = rep.context - verdict = rep.answer - - # Use the LLM with higher token limit for citation generation call - with dspy.context(lm=llm_long): - rep = Citation()(statement=statement, context=context, verdict=verdict) - citation = rep.citation - - self.data[hostname]['statement'] = statement - self.data[hostname]['context'] = context - self.data[hostname]['verdict'] = verdict - self.data[hostname]['citation'] = citation - self.data[hostname]['new'] = False - - def update_verdict_citation(self, statement): - with concurrent.futures.ThreadPoolExecutor(max_workers=settings.CONCURRENCY_VERDICT) as executor: - futures = [] - for hostname in self.data: - if self.data[hostname].get('new'): - futures.append(executor.submit(self._update_verdict_citation_single, hostname, statement)) - - concurrent.futures.wait(futures) # wait for all futures to complete - - """Get verdict and citation""" - def get(self, statement): - self.update_verdict_citation(statement) - return self.data diff --git a/src/utils.py b/src/utils.py index 75c4f6c..e6842ce 100644 --- a/src/utils.py +++ b/src/utils.py @@ -45,11 +45,13 @@ def generate_report_markdown(input_text, verdicts): # Add verdicts markdown.append("## Fact Check\n") for i, verdict in enumerate(verdicts, start=1): + weights = verdict['weights'] + percentage = calculate_percentage(weights['winning'], weights['valid']) markdown.append(f"### Statement {i}\n") markdown.append(f"**Statement**: {verdict['statement']}\n") - markdown.append(f"**Verdict**: `{verdict['verdict']}`\n") - markdown.append(f"**Weight**: {verdict['weights']['winning']} out of {verdict['weights']['valid']} ({verdict['weights']['irrelevant']} irrelevant)\n") - markdown.append(f"**Citation**:\n\n{verdict['citation']}\n") + markdown.append(f"**Verdict**: `{verdict['verdict'].capitalize()}`\n") + markdown.append(f"**Weight**: {percentage} (false: {weights['false']}, true: {weights['true']}, irrelevant: {weights['irrelevant']})\n") + markdown.append(f"**Citations**:\n\n{verdict['citation']}\n") markdown_str = "\n".join(markdown) return markdown_str @@ -152,6 +154,14 @@ def get_md5(input): md5_hash = hashlib.md5(input.encode()) return md5_hash.hexdigest() +def calculate_percentage(part, whole): + # Check to avoid division by zero + if whole == 0: + return "N/A" + + percentage = round((part / whole) * 100) + return f"{percentage}%" + # generate str for stream def get_stream(stage: str = 'wait', content = None): message = {"stage": stage, "content": content}