From 7e507003ed56a68172804db4601eed1196ba20e4 Mon Sep 17 00:00:00 2001 From: etwk <48991073+etwk@users.noreply.github.com> Date: Thu, 15 Aug 2024 23:10:38 +0000 Subject: [PATCH 1/4] update README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index c4072d0..8d8bd98 100644 --- a/README.md +++ b/README.md @@ -106,6 +106,7 @@ DSPy: ### Reports - [ ] AI-generated misinformation ### Factcheck +- https://www.snopes.com - https://www.bmi.bund.de/SharedDocs/schwerpunkte/EN/disinformation/examples-of-russian-disinformation-and-the-facts.html ### Resources #### Inference From ed8db5034041426d775af60da5117ee3a690f50d Mon Sep 17 00:00:00 2001 From: etwk <48991073+etwk@users.noreply.github.com> Date: Thu, 15 Aug 2024 23:11:44 +0000 Subject: [PATCH 2/4] change variable name RAG_CHUNK_SIZES to INDEX_CHUNK_SIZES --- .env | 4 ++-- src/retrieve.py | 2 +- src/settings.py | 8 +++----- src/utils.py | 2 +- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/.env b/.env index 1abf8eb..a4e95ee 100644 --- a/.env +++ b/.env @@ -1,6 +1,7 @@ EMBEDDING_API_KEY=ollama:abc EMBEDDING_MODEL_DEPLOY=api EMBEDDING_MODEL_NAME=jina/jina-embeddings-v2-base-en +INDEX_CHUNK_SIZES=[2048, 512, 128] LLM_MODEL_NAME=google/gemma-2-27b-it OLLAMA_BASE_URL=http://ollama:11434 OPENAI_API_KEY=sk-proj-aaaaaaaaaaaaaaaaa @@ -10,5 +11,4 @@ RERANK_MODEL_DEPLOY=local RERANK_MODEL_NAME=BAAI/bge-reranker-v2-m3 RERANK_BASE_URL=http://xinference:9997/v1 SEARCH_BASE_URL=https://s.jina.ai -THREAD_BUILD_INDEX=12 -RAG_CHUNK_SIZES=[4096, 1024, 256] \ No newline at end of file +THREAD_BUILD_INDEX=12 \ No newline at end of file diff --git a/src/retrieve.py b/src/retrieve.py index 31ad6cf..3f16c64 100644 --- a/src/retrieve.py +++ b/src/retrieve.py @@ -132,7 +132,7 @@ def build_index(self, docs): if docs: self.index, self.storage_context = self.build_automerging_index( docs, - chunk_sizes=settings.RAG_CHUNK_SIZES, + chunk_sizes=settings.INDEX_CHUNK_SIZES, ) # TODO: try to retrieve directly def retrieve(self, query): diff --git a/src/settings.py b/src/settings.py index 1481a1d..9a052b9 100644 --- a/src/settings.py +++ b/src/settings.py @@ -16,13 +16,11 @@ def __init__(self): self.EMBEDDING_MODEL_DEPLOY = os.environ.get("EMBEDDING_MODEL_DEPLOY") or "local" self.RERANK_MODEL_DEPLOY = os.environ.get("RERANK_MODEL_DEPLOY") or "local" - # set RAG chunk sizes - self.RAG_CHUNK_SIZES = [1024, 256] - _chunk_sizes = os.environ.get("RAG_CHUNK_SIZES") + # set Index chunk sizes try: - self.RAG_CHUNK_SIZES = ast.literal_eval(_chunk_sizes) + self.INDEX_CHUNK_SIZES = ast.literal_eval(os.environ.get("INDEX_CHUNK_SIZES")) except: - pass + self.INDEX_CHUNK_SIZES = [1024, 256] # threads self.THREAD_BUILD_INDEX = int(os.environ.get("THREAD_BUILD_INDEX", 12)) diff --git a/src/utils.py b/src/utils.py index ad190c2..e6268bd 100644 --- a/src/utils.py +++ b/src/utils.py @@ -137,7 +137,7 @@ async def get_stack(): "LLM model": settings.LLM_MODEL_NAME, "Embedding model": settings.EMBEDDING_MODEL_NAME, "Rerank model": settings.RERANK_MODEL_NAME, - "RAG chunk sizes": settings.RAG_CHUNK_SIZES, + "Index chunk sizes": settings.INDEX_CHUNK_SIZES, "Embedding deploy mode": settings.EMBEDDING_MODEL_DEPLOY, "Rerank deploy mode": settings.RERANK_MODEL_DEPLOY, } From 148b4c6de2822d19e0f59807e8c32858c8a15e4c Mon Sep 17 00:00:00 2001 From: etwk <48991073+etwk@users.noreply.github.com> Date: Fri, 16 Aug 2024 08:28:17 +0000 Subject: [PATCH 3/4] change all LLM calling to DSPy predict --- src/llm.py | 76 --------------------- src/main.py | 12 ++-- src/modules/__init__.py | 14 ++++ src/modules/citation.py | 20 ++++++ src/{ => modules}/ollama_embedding.py | 0 src/{ => modules}/retrieve.py | 4 +- src/modules/search_query.py | 18 +++++ src/modules/statements.py | 24 +++++++ src/{dspy_modules.py => modules/verdict.py} | 56 +-------------- src/pipeline.py | 13 ---- src/pipeline/__init__.py | 2 + src/pipeline/common.py | 36 ++++++++++ src/pipeline/verdict_citation.py | 31 +++++++++ src/utils.py | 15 +--- 14 files changed, 155 insertions(+), 166 deletions(-) delete mode 100644 src/llm.py create mode 100644 src/modules/__init__.py create mode 100644 src/modules/citation.py rename src/{ => modules}/ollama_embedding.py (100%) rename src/{ => modules}/retrieve.py (98%) create mode 100644 src/modules/search_query.py create mode 100644 src/modules/statements.py rename src/{dspy_modules.py => modules/verdict.py} (53%) delete mode 100644 src/pipeline.py create mode 100644 src/pipeline/__init__.py create mode 100644 src/pipeline/common.py create mode 100644 src/pipeline/verdict_citation.py diff --git a/src/llm.py b/src/llm.py deleted file mode 100644 index 201dc97..0000000 --- a/src/llm.py +++ /dev/null @@ -1,76 +0,0 @@ -from openai import OpenAI -import concurrent.futures -import logging - -import utils -from settings import settings - -""" -About models: - - Gemma 2 does not support system rule -""" - -llm_client = OpenAI( - base_url=settings.OPENAI_BASE_URL, - api_key="token", -) - -def get_llm_reply(prompt, temperature=0): - completion = llm_client.chat.completions.create( - model=settings.LLM_MODEL_NAME, - messages=[ - {"role": "user", "content": prompt} - ], - temperature=temperature, - ) - return completion.choices[0].message.content - -""" -Get list of statements from input. -""" -def get_statements(input): - system_message = '''You are a helpful AI assistant. -Solve tasks using your fact extraction skills. -Extract key statements from the given content. -Provide in array format as response only.''' - - prompt = f'''{system_message} -``` -Content: -{input} -```''' - - reply = get_llm_reply(prompt) - logging.debug(f"get_statements LLM reply: {reply}") - return utils.llm2list(reply) - -""" -Get search keywords from statements -""" -def get_search_keywords(statement): - system_message = '''You are a helpful AI assistant. -Solve tasks using your searching skills. -Generate search keyword used for fact check on the given statement. -Include only the keyword in your response.''' - - prompt = f'''{system_message} -``` -Statement: -{statement} -```''' - reply = get_llm_reply(prompt) - return reply.strip() - -def get_context_prompt(contexts): - prompt = "Context:" - - for ind, node in enumerate(contexts): - _text = node.get('text') - if not _text: - continue - prompt = f"""{prompt} - -``` -{_text} -```""" - return prompt diff --git a/src/main.py b/src/main.py index 7665015..1c45c18 100644 --- a/src/main.py +++ b/src/main.py @@ -4,7 +4,7 @@ from fastapi.responses import Response, JSONResponse, HTMLResponse, PlainTextResponse, FileResponse import logging -import llm, utils, pipeline +import utils, pipeline logging.basicConfig( level=logging.INFO, @@ -18,7 +18,7 @@ async def fact_check(input): status = 500 logger.info(f"Fact checking: {input}") - statements = await run_in_threadpool(llm.get_statements, input) + statements = await run_in_threadpool(pipeline.get_statements, input) logger.info(f"statements: {statements}") if not statements: raise HTTPException(status_code=status, detail="No statements found") @@ -29,11 +29,11 @@ async def fact_check(input): if not statement: continue logger.info(f"statement: {statement}") - keywords = await run_in_threadpool(llm.get_search_keywords, statement) - if not keywords: + query = await run_in_threadpool(pipeline.get_search_query, statement) + if not query: continue - logger.info(f"keywords: {keywords}") - search = await utils.search(keywords) + logger.info(f"search query: {query}") + search = await utils.search(query) if not search: fail_search = True continue diff --git a/src/modules/__init__.py b/src/modules/__init__.py new file mode 100644 index 0000000..dcd5395 --- /dev/null +++ b/src/modules/__init__.py @@ -0,0 +1,14 @@ +import dspy + +from settings import settings + +# set DSPy default language model +llm = dspy.OpenAI(model=settings.LLM_MODEL_NAME, api_base=f"{settings.OPENAI_BASE_URL}/", max_tokens=200, stop='\n\n') +dspy.settings.configure(lm=llm) + +from .citation import Citation +from .ollama_embedding import OllamaEmbedding +from .retrieve import LlamaIndexRM +from .search_query import SearchQuery +from .statements import Statements +from .verdict import Verdict \ No newline at end of file diff --git a/src/modules/citation.py b/src/modules/citation.py new file mode 100644 index 0000000..d29595a --- /dev/null +++ b/src/modules/citation.py @@ -0,0 +1,20 @@ +import dspy + +# TODO: citation needs higher token limits +class GenerateCitedParagraph(dspy.Signature): + """Generate a paragraph with citations.""" + context = dspy.InputField(desc="may contain relevant facts") + statement = dspy.InputField() + verdict = dspy.InputField() + paragraph = dspy.OutputField(desc="includes citations") + +"""Generate citation from context and verdict""" +class Citation(dspy.Module): + def __init__(self): + super().__init__() + self.generate_cited_paragraph = dspy.ChainOfThought(GenerateCitedParagraph) + + def forward(self, statement, context, verdict): + citation = self.generate_cited_paragraph(context=context, statement=statement, verdict=verdict) + pred = dspy.Prediction(verdict=verdict, citation=citation.paragraph, context=context) + return pred diff --git a/src/ollama_embedding.py b/src/modules/ollama_embedding.py similarity index 100% rename from src/ollama_embedding.py rename to src/modules/ollama_embedding.py diff --git a/src/retrieve.py b/src/modules/retrieve.py similarity index 98% rename from src/retrieve.py rename to src/modules/retrieve.py index 3f16c64..c42c925 100644 --- a/src/retrieve.py +++ b/src/modules/retrieve.py @@ -8,11 +8,9 @@ from llama_index.core import ( Document, - ServiceContext, Settings, StorageContext, VectorStoreIndex, - load_index_from_storage, ) from llama_index.core.node_parser import HierarchicalNodeParser, get_leaf_nodes from llama_index.core.retrievers import AutoMergingRetriever @@ -29,7 +27,7 @@ jinaai_rerank.API_URL = settings.RERANK_BASE_URL + "/rerank" # switch to on-premise # todo: high lantency between client and the ollama embedding server will slow down embedding a lot -from ollama_embedding import OllamaEmbedding +from . import OllamaEmbedding # todo: improve embedding performance if settings.EMBEDDING_MODEL_DEPLOY == "local": diff --git a/src/modules/search_query.py b/src/modules/search_query.py new file mode 100644 index 0000000..76067aa --- /dev/null +++ b/src/modules/search_query.py @@ -0,0 +1,18 @@ +import dspy +import logging + +"""Notes: LLM will choose a direction based on known facts""" +class GenerateSearchEngineQuery(dspy.Signature): + """Write a search engine query that will help retrieve info related to the statement.""" + statement = dspy.InputField() + query = dspy.OutputField() + +class SearchQuery(dspy.Module): + def __init__(self): + super().__init__() + self.generate_query = dspy.ChainOfThought(GenerateSearchEngineQuery) + + def forward(self, statement): + query = self.generate_query(statement=statement) + logging.info(f"DSPy CoT search query: {query}") + return query.query \ No newline at end of file diff --git a/src/modules/statements.py b/src/modules/statements.py new file mode 100644 index 0000000..d76d20a --- /dev/null +++ b/src/modules/statements.py @@ -0,0 +1,24 @@ +import dspy +import logging +from pydantic import BaseModel, Field +from typing import List + +# references: https://github.com/weaviate/recipes/blob/main/integrations/llm-frameworks/dspy/4.Structured-Outputs-with-DSPy.ipynb +class Output(BaseModel): + statements: List = Field(description="A list of key statements") + +# TODO: test consistency especially when content contains false claims +class GenerateStatements(dspy.Signature): + """Extract the original statements from given content without fact check.""" + content: str = dspy.InputField(desc="The content to summarize") + output: Output = dspy.OutputField() + +class Statements(dspy.Module): + def __init__(self): + super().__init__() + self.generate_statements = dspy.TypedChainOfThought(GenerateStatements, max_retries=6) + + def forward(self, content): + statements = self.generate_statements(content=content) + logging.info(f"DSPy CoT statements: {statements}") + return statements.output.statements \ No newline at end of file diff --git a/src/dspy_modules.py b/src/modules/verdict.py similarity index 53% rename from src/dspy_modules.py rename to src/modules/verdict.py index 3df8a27..319aa04 100644 --- a/src/dspy_modules.py +++ b/src/modules/verdict.py @@ -1,12 +1,6 @@ import dspy from dsp.utils import deduplicate -from retrieve import LlamaIndexRM -from settings import settings - -llm = dspy.OpenAI(model=settings.LLM_MODEL_NAME, api_base=f"{settings.OPENAI_BASE_URL}/", max_tokens=200, stop='\n\n') -dspy.settings.configure(lm=llm) - class CheckStatementFaithfulness(dspy.Signature): """Verify that the statement is based on the provided context.""" context = dspy.InputField(desc="facts here are assumed to be true") @@ -19,14 +13,6 @@ class GenerateSearchQuery(dspy.Signature): statement = dspy.InputField() query = dspy.OutputField() -# TODO: citation needs higher token limits -class GenerateCitedParagraph(dspy.Signature): - """Generate a paragraph with citations.""" - context = dspy.InputField(desc="may contain relevant facts") - statement = dspy.InputField() - verdict = dspy.InputField() - paragraph = dspy.OutputField(desc="includes citations") - """ SimplifiedBaleen module Avoid unnecessary content in module cause MIPROv2 optimizer will analize modules. @@ -39,7 +25,7 @@ class GenerateCitedParagraph(dspy.Signature): - remove some contexts incase token reaches to max - does different InputField name other than answer compateble with dspy evaluate """ -class ContextVerdict(dspy.Module): +class Verdict(dspy.Module): def __init__(self, retrieve, passages_per_hop=3, max_hops=3): super().__init__() # self.generate_query = dspy.ChainOfThought(GenerateSearchQuery) # IMPORTANT: solves error `list index out of range` @@ -59,42 +45,4 @@ def forward(self, statement): verdict = self.generate_verdict(context=context, statement=statement) pred = dspy.Prediction(answer=verdict.verdict, rationale=verdict.rationale, context=context) return pred - -"""Generate citation from context and verdict""" -class Citation(dspy.Module): - def __init__(self): - super().__init__() - self.generate_cited_paragraph = dspy.ChainOfThought(GenerateCitedParagraph) - - def forward(self, statement, context, verdict): - citation = self.generate_cited_paragraph(context=context, statement=statement, verdict=verdict) - pred = dspy.Prediction(verdict=verdict, citation=citation.paragraph, context=context) - return pred - -""" -Get both verdict and citation. - -Args: - retrieve: dspy.Retrieve -""" -class VerdictCitation(): - def __init__( - self, - docs, - ): - self.retrieve = LlamaIndexRM(docs=docs) - - # loading compiled ContextVerdict - self.context_verdict = ContextVerdict(retrieve=self.retrieve) - self.context_verdict.load("./optimizers/verdict_MIPROv2.json") - - def get(self, statement): - rep = self.context_verdict(statement) - context = rep.context - verdict = rep.answer - - rep = Citation()(statement=statement, context=context, verdict=verdict) - citation = rep.citation - - return verdict, citation - \ No newline at end of file + \ No newline at end of file diff --git a/src/pipeline.py b/src/pipeline.py deleted file mode 100644 index 38c08d1..0000000 --- a/src/pipeline.py +++ /dev/null @@ -1,13 +0,0 @@ -import utils -from dspy_modules import VerdictCitation - -def get_verdict(search_json, statement): - docs = utils.search_json_to_docs(search_json) - rep = VerdictCitation(docs=docs).get(statement=statement) - - return { - "verdict": rep[0], - "citation": rep[1], - "statement": statement, - } - \ No newline at end of file diff --git a/src/pipeline/__init__.py b/src/pipeline/__init__.py new file mode 100644 index 0000000..79298b2 --- /dev/null +++ b/src/pipeline/__init__.py @@ -0,0 +1,2 @@ +from .common import get_search_query, get_statements, get_verdict +from .verdict_citation import VerdictCitation diff --git a/src/pipeline/common.py b/src/pipeline/common.py new file mode 100644 index 0000000..9ed3e4a --- /dev/null +++ b/src/pipeline/common.py @@ -0,0 +1,36 @@ +import logging +import utils +from modules import SearchQuery, Statements +from .verdict_citation import VerdictCitation + +def get_statements(content): + """Get list of statements from a text string""" + try: + statements = Statements()(content=content) + except Exception as e: + logging.error(f"Getting statements failed: {e}") + statements = [] + + return statements + +def get_search_query(statement): + """Get search query from one statement""" + + try: + query = SearchQuery()(statement=statement) + except Exception as e: + logging.error(f"Getting search query from statement '{statement}' failed: {e}") + query = "" + + return query + +def get_verdict(search_json, statement): + docs = utils.search_json_to_docs(search_json) + rep = VerdictCitation(docs=docs).get(statement=statement) + + return { + "verdict": rep[0], + "citation": rep[1], + "statement": statement, + } + \ No newline at end of file diff --git a/src/pipeline/verdict_citation.py b/src/pipeline/verdict_citation.py new file mode 100644 index 0000000..ed65924 --- /dev/null +++ b/src/pipeline/verdict_citation.py @@ -0,0 +1,31 @@ +import os +from modules import Citation, LlamaIndexRM, Verdict + +""" +Get both verdict and citation. + +Args: + retrieve: dspy.Retrieve +""" +class VerdictCitation(): + def __init__( + self, + docs, + ): + self.retrieve = LlamaIndexRM(docs=docs) + + # loading compiled Verdict + self.context_verdict = Verdict(retrieve=self.retrieve) + optimizer_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../optimizers/verdict_MIPROv2.json") + self.context_verdict.load(optimizer_path) + + def get(self, statement): + rep = self.context_verdict(statement) + context = rep.context + verdict = rep.answer + + rep = Citation()(statement=statement, context=context, verdict=verdict) + citation = rep.citation + + return verdict, citation + diff --git a/src/utils.py b/src/utils.py index e6268bd..7dc6864 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,4 +1,4 @@ -import re, json, ast +import re, json import aiohttp import itertools import logging @@ -6,19 +6,6 @@ from settings import settings -def llm2list(text): - list_obj = [] - try: - # find the first pair of square brackets and their content - match = re.search(r'\[.*?\]', text, re.DOTALL) - - if match: - list_obj = ast.literal_eval(match.group()) - except Exception as e: - logging.warning(f"Failed convert LLM response to list: {e}") - pass - return list_obj - def llm2json(text): json_object = {} try: From 4e9326e4fe69c4d35e1f455fd9fa2edcd48eedbd Mon Sep 17 00:00:00 2001 From: etwk <48991073+etwk@users.noreply.github.com> Date: Fri, 16 Aug 2024 08:40:44 +0000 Subject: [PATCH 4/4] increase token limit for generate citation --- src/modules/__init__.py | 3 +++ src/pipeline/verdict_citation.py | 11 ++++++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/modules/__init__.py b/src/modules/__init__.py index dcd5395..b04b2df 100644 --- a/src/modules/__init__.py +++ b/src/modules/__init__.py @@ -6,6 +6,9 @@ llm = dspy.OpenAI(model=settings.LLM_MODEL_NAME, api_base=f"{settings.OPENAI_BASE_URL}/", max_tokens=200, stop='\n\n') dspy.settings.configure(lm=llm) +# LM with higher token limits +llm_long = dspy.OpenAI(model=settings.LLM_MODEL_NAME, api_base=f"{settings.OPENAI_BASE_URL}/", max_tokens=500, stop='\n\n') + from .citation import Citation from .ollama_embedding import OllamaEmbedding from .retrieve import LlamaIndexRM diff --git a/src/pipeline/verdict_citation.py b/src/pipeline/verdict_citation.py index ed65924..db36b71 100644 --- a/src/pipeline/verdict_citation.py +++ b/src/pipeline/verdict_citation.py @@ -1,4 +1,7 @@ import os +import dspy + +from modules import llm_long from modules import Citation, LlamaIndexRM, Verdict """ @@ -23,9 +26,11 @@ def get(self, statement): rep = self.context_verdict(statement) context = rep.context verdict = rep.answer - - rep = Citation()(statement=statement, context=context, verdict=verdict) - citation = rep.citation + + # Use the LLM with higher token limit for citation generation call + with dspy.context(lm=llm_long): + rep = Citation()(statement=statement, context=context, verdict=verdict) + citation = rep.citation return verdict, citation