From ba2ad67b4196af8213c2501e7e70f29a9f6a6fce Mon Sep 17 00:00:00 2001 From: etwk <48991073+etwk@users.noreply.github.com> Date: Tue, 13 Aug 2024 15:14:22 +0000 Subject: [PATCH] add DSPy pipeline --- Dockerfile.local | 3 +- Dockerfile.remote | 4 +- README.md | 4 + docs/changelog.md | 4 + ...ements.remote.txt => requirements.base.txt | 1 + requirements.local.txt | 9 +- src/dspy_modules.py | 99 +++++++++ src/index.py | 113 ----------- src/llm.py | 122 ----------- src/main.py | 5 +- src/optimizers/verdict_MIPROv2.json | 111 ++++++++++ src/pipeline.py | 13 ++ src/retrieve.py | 191 ++++++++++++++++++ src/utils.py | 29 ++- 14 files changed, 457 insertions(+), 251 deletions(-) rename requirements.remote.txt => requirements.base.txt (93%) create mode 100644 src/dspy_modules.py delete mode 100644 src/index.py create mode 100644 src/optimizers/verdict_MIPROv2.json create mode 100644 src/pipeline.py create mode 100644 src/retrieve.py diff --git a/Dockerfile.local b/Dockerfile.local index 425ae6d..c9774ac 100644 --- a/Dockerfile.local +++ b/Dockerfile.local @@ -1,6 +1,7 @@ FROM intel/intel-optimized-pytorch:2.3.0-serving-cpu WORKDIR /app -COPY requirements.local.txt /app +COPY requirements.*.txt /app +RUN pip install --no-cache-dir -r requirements.base.txt RUN pip install --no-cache-dir -r requirements.local.txt COPY . /app EXPOSE 8000 diff --git a/Dockerfile.remote b/Dockerfile.remote index a7f8198..a744bd1 100644 --- a/Dockerfile.remote +++ b/Dockerfile.remote @@ -1,7 +1,7 @@ FROM python:3.11-slim-bookworm WORKDIR /app -COPY requirements.remote.txt /app -RUN pip install --no-cache-dir -r requirements.remote.txt +COPY requirements.base.txt /app +RUN pip install --no-cache-dir -r requirements.base.txt COPY . /app EXPOSE 8000 ENV NAME "Fact-check API" diff --git a/README.md b/README.md index 2dd3724..6b5909a 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,10 @@ Contexts Retrieval - [ ] Retrieve the latest info when facts might change +### pipeline +DSPy: +- [ ] make dspy.settings apply to sessions only in order to support multiple retrieve index + ### Toolchain - [ ] Evaluate MLOps pipeline - https://kitops.ml diff --git a/docs/changelog.md b/docs/changelog.md index d57ee5d..d9752d9 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,3 +1,7 @@ ## application 2024/8/3: - Change from AutoGen to plain OpenAI, since AutoGen AssistantAgent adds system role which are not compateble with Gemma 2 + vllm. + +## pipeline +2024/8/13: + - Introduce DSPy to replace the get verdict part, with multi-step reasoning. \ No newline at end of file diff --git a/requirements.remote.txt b/requirements.base.txt similarity index 93% rename from requirements.remote.txt rename to requirements.base.txt index fab4dae..2cb77d2 100644 --- a/requirements.remote.txt +++ b/requirements.base.txt @@ -1,4 +1,5 @@ aiohttp +dspy-ai fastapi llama-index llama-index-embeddings-ollama diff --git a/requirements.local.txt b/requirements.local.txt index 51d735b..69dc005 100644 --- a/requirements.local.txt +++ b/requirements.local.txt @@ -1,8 +1 @@ -aiohttp -fastapi -llama-index -llama-index-embeddings-huggingface -llama-index-embeddings-ollama -llama-index-postprocessor-jinaai-rerank -openai -uvicorn +llama-index-embeddings-huggingface \ No newline at end of file diff --git a/src/dspy_modules.py b/src/dspy_modules.py new file mode 100644 index 0000000..07d230e --- /dev/null +++ b/src/dspy_modules.py @@ -0,0 +1,99 @@ +import dspy +from dsp.utils import deduplicate + +from retrieve import LlamaIndexRM +from settings import settings + +llm = dspy.OpenAI(model=settings.LLM_MODEL_NAME, api_base=f"{settings.OPENAI_BASE_URL}/", max_tokens=200, stop='\n\n') +dspy.settings.configure(lm=llm) + +class CheckStatementFaithfulness(dspy.Signature): + """Verify that the statement is based on the provided context.""" + context = dspy.InputField(desc="facts here are assumed to be true") + statement = dspy.InputField() + verdict = dspy.OutputField(desc="True/False/Irrelevant indicating if statement is faithful to context") + +class GenerateSearchQuery(dspy.Signature): + """Write a simple search query that will help retrieve info related to the statement.""" + context = dspy.InputField(desc="may contain relevant facts") + statement = dspy.InputField() + query = dspy.OutputField() + +class GenerateCitedParagraph(dspy.Signature): + """Generate a paragraph with citations.""" + context = dspy.InputField(desc="may contain relevant facts") + statement = dspy.InputField() + verdict = dspy.InputField() + paragraph = dspy.OutputField(desc="includes citations") + +""" +SimplifiedBaleen module +Avoid unnecessary content in module cause MIPROv2 optimizer will analize modules. + +Args: + retrieve: dspy.Retrieve + +To-do: + - retrieve latest facts + - remove some contexts incase token reaches to max + - does different InputField name other than answer compateble with dspy evaluate +""" +class ContextVerdict(dspy.Module): + def __init__(self, retrieve, passages_per_hop=3, max_hops=3): + super().__init__() + # self.generate_query = dspy.ChainOfThought(GenerateSearchQuery) # IMPORTANT: solves error `list index out of range` + self.generate_query = [dspy.ChainOfThought(GenerateSearchQuery) for _ in range(max_hops)] + self.retrieve = retrieve + self.retrieve.k = passages_per_hop + self.generate_verdict = dspy.ChainOfThought(CheckStatementFaithfulness) + self.max_hops = max_hops + + def forward(self, statement): + context = [] + for hop in range(self.max_hops): + query = self.generate_query[hop](context=context, statement=statement).query + passages = self.retrieve(query=query, text_only=True) + context = deduplicate(context + passages) + + verdict = self.generate_verdict(context=context, statement=statement) + pred = dspy.Prediction(answer=verdict.verdict, rationale=verdict.rationale, context=context) + return pred + +"""Generate citation from context and verdict""" +class Citation(dspy.Module): + def __init__(self): + super().__init__() + self.generate_cited_paragraph = dspy.ChainOfThought(GenerateCitedParagraph) + + def forward(self, statement, context, verdict): + citation = self.generate_cited_paragraph(context=context, statement=statement, verdict=verdict) + pred = dspy.Prediction(verdict=verdict, citation=citation.paragraph, context=context) + return pred + +""" +Get both verdict and citation. + +Args: + retrieve: dspy.Retrieve +""" +class VerdictCitation(): + def __init__( + self, + docs, + ): + self.retrieve = LlamaIndexRM(docs=docs) + + # loading compiled ContextVerdict + self.context_verdict = ContextVerdict(retrieve=self.retrieve) + self.context_verdict.load("./optimizers/verdict_MIPROv2.json") + + def get(self, statement): + rep = self.context_verdict(statement) + context = rep.context + verdict = rep.answer + + rep = Citation()(statement=statement, context=context, verdict=verdict) + citation = rep.citation + + return verdict, citation + \ No newline at end of file diff --git a/src/index.py b/src/index.py deleted file mode 100644 index 0e504b3..0000000 --- a/src/index.py +++ /dev/null @@ -1,113 +0,0 @@ -""" -llama-index -""" -import os - -from llama_index.core import ( - Document, - ServiceContext, - Settings, - StorageContext, - VectorStoreIndex, - load_index_from_storage, -) -from llama_index.core.node_parser import HierarchicalNodeParser, get_leaf_nodes -from llama_index.core.retrievers import AutoMergingRetriever -from llama_index.core.indices.postprocessor import SentenceTransformerRerank -from llama_index.core.query_engine import RetrieverQueryEngine -from llama_index.core.llms import MockLLM - -Settings.llm = MockLLM(max_tokens=256) # retrieve only, do not use LLM for synthesize - -from settings import settings - -import llama_index.postprocessor.jinaai_rerank.base as jinaai_rerank # todo: shall we lock package version? -jinaai_rerank.API_URL = settings.RERANK_BASE_URL + "/rerank" # switch to on-premise - -# todo: high lantency between client and the ollama embedding server will slow down embedding a lot -from llama_index.embeddings.ollama import OllamaEmbedding - -# todo: improve embedding performance -if settings.EMBEDDING_MODEL_DEPLOY == "local": - embed_model="local:" + settings.EMBEDDING_MODEL_NAME -else: - embed_model = OllamaEmbedding( - model_name=settings.EMBEDDING_MODEL_NAME, - base_url=os.environ.get("OLLAMA_BASE_URL"), # todo: any other configs here? - ) -Settings.embed_model = embed_model - -def nodes2list(nodes): - nodes_list = [] - for node in nodes: - _sub = { - 'id': node.node_id, - 'score': node.score, - 'text': node.get_content().strip(), - 'metadata': node.metadata, - } - nodes_list.append(_sub) - return nodes_list - -class Index(): - def build_automerging_index( - self, - documents, - chunk_sizes=None, - ): - chunk_sizes = chunk_sizes or [2048, 512, 128] - - node_parser = HierarchicalNodeParser.from_defaults(chunk_sizes=chunk_sizes) - nodes = node_parser.get_nodes_from_documents(documents) - leaf_nodes = get_leaf_nodes(nodes) - - storage_context = StorageContext.from_defaults() - storage_context.docstore.add_documents(nodes) - - automerging_index = VectorStoreIndex( - leaf_nodes, storage_context=storage_context - ) - return automerging_index - - def get_automerging_query_engine( - self, - automerging_index, - similarity_top_k=12, - rerank_top_n=6, - ): - base_retriever = automerging_index.as_retriever(similarity_top_k=similarity_top_k) - retriever = AutoMergingRetriever( - base_retriever, automerging_index.storage_context, verbose=True - ) - - if settings.RERANK_MODEL_DEPLOY == "local": - rerank = SentenceTransformerRerank( - top_n=rerank_top_n, model=settings.RERANK_MODEL_NAME, - ) # todo: add support `trust_remote_code=True` - else: - rerank = jinaai_rerank.JinaRerank(api_key='', top_n=rerank_top_n, model=settings.RERANK_MODEL_NAME) - - auto_merging_engine = RetrieverQueryEngine.from_args( - retriever, node_postprocessors=[rerank] - ) - - return auto_merging_engine - - def get_contexts(self, statement, keywords, text, metadata): - """ - Get list of contexts. - - Todo: resources re-use for multiple run. - """ - document = Document(text=text, metadata=metadata) - index = self.build_automerging_index( - [document], - chunk_sizes=settings.RAG_CHUNK_SIZES, - ) # todo: will it better to use retriever directly? - - query_engine = self.get_automerging_query_engine(index, similarity_top_k=16, rerank_top_n=3) - query = f"{keywords} | {statement}" # todo: better way - auto_merging_response = query_engine.query(query) - contexts = nodes2list(auto_merging_response.source_nodes) - - return contexts \ No newline at end of file diff --git a/src/llm.py b/src/llm.py index e236d3c..201dc97 100644 --- a/src/llm.py +++ b/src/llm.py @@ -3,7 +3,6 @@ import logging import utils -from index import Index from settings import settings """ @@ -75,124 +74,3 @@ def get_context_prompt(contexts): {_text} ```""" return prompt - -def get_verdict_single(statement, contexts): - # This prompt allows model to use their own knowedge - system_message = '''You are a helpful AI assistant. -Solve tasks using your fact-check skills. -You will be given a statement followed by context. -Use the contexts and facts you know to check if the statements are true, false, irrelevant. -Provide detailed reason for your verdict, ensuring that each reason is supported by corresponding facts. -Provide the response as JSON with the structure:{verdict, reason}''' - - prompt = f'''{system_message} - -``` -Statement: -{statement} -``` - -{get_context_prompt(contexts)}''' - - reply = get_llm_reply(prompt) - logging.debug(f"Verdict reply from LLM: {reply}") - verdict = utils.llm2json(reply) - - return verdict - -def get_verdict_summary(verdicts, statement): - """ - Calculate and summarize the verdicts of multiple sources. - Introduce some weights: - - total: the number of total verdicts - - valid: the number of verdicts in the desired categories - - winning: the count of the winning verdict - - and count of verdicts of each desiered categories - """ - - weight_total = 0 - weight_valid = 0 - sum_score = 0 - sum_reason = { - "true": {"reason": [], "weight": 0}, - "false": {"reason": [], "weight": 0}, - "irrelevant": {"reason": [], "weight": 0}, - } - - for verdict in verdicts: - weight_total += 1 - v = verdict['verdict'] - if v in sum_reason: - weight_valid += 1 - reason = f"{verdict['reason']}\nsource url: {verdict['url']}\nsource title: {verdict['title']}\n\n" - sum_reason[v]['reason'].append(reason) - sum_reason[v]['weight'] += 1 - if v == 'true': - sum_score += 1 - elif v == 'false': - sum_score -= 1 - - if sum_score > 0: - verdict = "true" - elif sum_score < 0: - verdict = "false" - else: - verdict = "irrelevant" - - reason = ''.join(sum_reason[verdict]['reason']) - if not reason: - raise Exception("No reason found after summary") - - weights = {"total": weight_total, "valid": weight_valid, "winning": sum_reason[verdict]['weight']} - for key in sum_reason.keys(): - weights[key] = sum_reason[key]['weight'] - - return {"verdict": verdict, "reason": reason, "weights": weights, "statement": statement} - -def get_verdict(statement, keywords, search_json): - """ - Get verdit from every one of the context sources - """ - verdicts = [] - - def process_result(result): - content = utils.clear_md_links(result.get('content')) - metadata = { - "url": result.get('url'), - "title": result.get('title'), - } - - try: - contexts = Index().get_contexts(statement, keywords, content, metadata) - except Exception as e: - logging.warning(f"Getting contexts failed: {e}") - return None - - try: - verdict = get_verdict_single(statement, contexts) - except Exception as e: - logging.warning(f"Getting verdit failed: {e}") - return None - - verdict = { - "verdict": verdict.get('verdict', '').lower(), - "reason": verdict.get('reason'), - "url": result.get('url'), - "title": result.get('title'), - } - - logging.debug(f"single source verdict: {verdict}") - return verdict - - with concurrent.futures.ThreadPoolExecutor(max_workers=12) as executor: - # Use a list comprehension to submit tasks and collect results - future_to_result = {executor.submit(process_result, result): result for result in search_json['data']} - - for future in concurrent.futures.as_completed(future_to_result): - verdict = future.result() - if verdict is not None: - verdicts.append(verdict) - - summary = get_verdict_summary(verdicts, statement) - return summary - \ No newline at end of file diff --git a/src/main.py b/src/main.py index 698b51e..7665015 100644 --- a/src/main.py +++ b/src/main.py @@ -4,8 +4,7 @@ from fastapi.responses import Response, JSONResponse, HTMLResponse, PlainTextResponse, FileResponse import logging -import llm, utils -from index import Index +import llm, utils, pipeline logging.basicConfig( level=logging.INFO, @@ -39,7 +38,7 @@ async def fact_check(input): fail_search = True continue logger.info(f"head of search results: {json.dumps(search)[0:500]}") - verdict = await run_in_threadpool(llm.get_verdict, statement, keywords, search) + verdict = await run_in_threadpool(pipeline.get_verdict, search_json=search, statement=statement) if not verdict: continue logger.info(f"final verdict: {verdict}") diff --git a/src/optimizers/verdict_MIPROv2.json b/src/optimizers/verdict_MIPROv2.json new file mode 100644 index 0000000..531f39f --- /dev/null +++ b/src/optimizers/verdict_MIPROv2.json @@ -0,0 +1,111 @@ +{ + "generate_query[0]": { + "lm": null, + "traces": [], + "train": [], + "demos": [ + { + "statement": "Rod Strickland is the godfather of an NBA player who won an NBA championship with the Cavaliers in 2017.", + "answer": "False" + }, + { + "statement": "Olivia Diaz is a member of the upper house of the Nevada legislature.", + "answer": "False" + } + ], + "signature_instructions": "Write a simple search query that will help retrieve info related to the statement.", + "signature_prefix": "Query:", + "extended_signature_instructions": "Given a statement and a context, generate a search query that will help retrieve information relevant to the statement within the context. Make sure the query includes key entities and concepts from both the statement and the context to effectively narrow down the search results.", + "extended_signature_prefix": "Query:" + }, + "generate_query[1]": { + "lm": null, + "traces": [], + "train": [], + "demos": [ + { + "augmented": true, + "context": [ + "G.o.d discography | This is the discography of the South Korean pop music group g.o.d who are currently managed by SidusHQ. Debuting in 1999, the five-member group has released eight albums, two of which have sold over a million copies each. They discontinued group activities following the release of their seventh album and parted ways to begin their solo careers. Their eighth and most recent full album was released in July 2014 to mark their reunion and the 15th anniversary of their debut.", + "Sakura Gakuin discography | Japanese idol group Sakura Gakuin has released seven studio albums, one compilation album, nine video albums, thirteen singles, and twenty music videos. Seven more singles and eight more music videos have been released by sub-units, including those by the band Babymetal released prior to March 2013. Studio albums are released annually, under the supertitle \"Sakura Gakuin [Year] Nendo\" (\u3055\u304f\u3089\u5b66\u9662[Year]\u5e74\u5ea6 , lit. \"Cherry Blossom Academy School Year [Year]\") .", + "I-II-III (Icon of Coil Albums) | I-II-III are three compilation albums released by Icon of Coil in 2006. Each released separately as re-issues of their first 3 albums ('Serenity is the Devil', 'The Soul is in the Software' & remix album 'One Nation Under Beat'), 2 singles ('Shallow Nation' & 'Access And Amplify') and EP ('Seren EP'), which were all either out of print, or hard to find." + ], + "statement": "Since 2021, the K-pop group (G)I-DLE has released more music albums than the girl group BLACKPINK.", + "rationale": "We need to find the number of albums released by (G)I-DLE and BLACKPINK since 2021.", + "query": "number of albums released by (G)I-DLE since 2021 AND number of albums released by BLACKPINK since 2021" + }, + { + "augmented": true, + "context": [ + "Louise Bille-Brahe | Louise Bille-Brahe (1830-1910) was a Danish courtier; \"Overhofmesterinde\" (Mistress of the Robes) to the queen of Denmark, Louise of Hesse-Kassel, from 1888 to 1898, and to the next queen of Denmark, Louise of Sweden, from 1906 to 1910.", + "Alexandra of Denmark | Alexandra of Denmark (Alexandra Caroline Marie Charlotte Louise Julia; 1 December 1844 \u2013 20 November 1925) was Queen of the United Kingdom of Great Britain and Ireland and Empress of India as the wife of King-Emperor Edward VII.", + "Princess Feodora of Denmark | Princess Feodora of Denmark (Feodora Louise Caroline-Mathilde Viktoria Alexandra Frederikke Johanne) (3 July 1910 \u2013 17 March 1975) was a Danish princess as a daughter of Prince Harald of Denmark and granddaughter of Frederick VIII of Denmark." + ], + "statement": "Louise Bille-Brahe was a Danish courtier to the wife of King Frederick VIII.", + "rationale": "We need to find out who was the wife of King Frederick VIII. Then we need to find out if Louise Bille-Brahe was a courtier to her.", + "query": "wife of King Frederick VIII and Louise Bille-Brahe" + } + ], + "signature_instructions": "Write a simple search query that will help retrieve info related to the statement.", + "signature_prefix": "Query:", + "extended_signature_instructions": "Considering the context provided, generate a search query that aims to verify the truthfulness of the statement. The query should focus on identifying information relevant to the statement's claims and potential supporting evidence within the context.", + "extended_signature_prefix": "Query:" + }, + "generate_query[2]": { + "lm": null, + "traces": [], + "train": [], + "demos": [ + { + "augmented": true, + "context": [ + "Robin Perutz | Robin Perutz FRS (born December 1949 in Cambridge), son of the Nobel Prize winner Max Perutz, is a professor of Inorganic Chemistry at the University of York, where he was formerly head of department.", + "Robin Koontz | Robin Michal Koontz (born July 29, 1954) is an American author and illustrator of picture books and early readers for children as well as non-fiction for middle school readers. Her books are published in English, Spanish, and Indonesian. Many of her titles have been reviewed in School Library Journal, Kirkus Reviews, and the CLCD (Children's Literature Comprehension Database).", + "Robin Weisman | Robin Weisman (born March 14, 1984) is a former American child actress. She acted between 1990 and 1994.", + "Max Perutz | Max Ferdinand Perutz (19 May 1914 \u2013 6 February 2002) was an Austrian-born British molecular biologist, who shared the 1962 Nobel Prize for Chemistry with John Kendrew, for their studies of the structures of haemoglobin and myoglobin. He went on to win the Royal Medal of the Royal Society in 1971 and the Copley Medal in 1979. At Cambridge he founded and chaired (1962\u201379) The Medical Research Council Laboratory of Molecular Biology, fourteen of whose scientists have won Nobel Prizes. Perutz's contributions to molecular biology in Cambridge are documented in \"The History of the University of Cambridge: Volume 4 (1870 to 1990)\" published by the Cambridge University Press in 1992.", + "Max Grosskreutz | Max Octavius Grosskreutz (born 27 April 1906 in Proserpine, Queensland- died 20 September 1994) was an Australian speedway rider who finished third in the Star Riders' Championship in 1935, the forerunner to the Speedway World Championship which began a year later in 1936." + ], + "statement": "The father of Robin Perutz was born in 1914.", + "rationale": "We need to find information about Robin Perutz's father. We know from the context that Robin Perutz's father is Max Perutz. We also know from the context that Max Perutz was born in 1914.", + "query": "Max Perutz birth year" + }, + { + "statement": "In 2021, the Cuban government announced plans to gradually phase out the CUC and unify the country's currency.", + "answer": "irrelevant" + } + ], + "signature_instructions": "Write a simple search query that will help retrieve info related to the statement.", + "signature_prefix": "Query:", + "extended_signature_instructions": "Given a statement about a person and their profession, generate a search query that will help retrieve information to determine if the statement is true, focusing on confirming the person's profession.", + "extended_signature_prefix": "Query:" + }, + "retrieve": { + "k": 3 + }, + "generate_verdict": { + "lm": null, + "traces": [], + "train": [], + "demos": [ + { + "augmented": true, + "context": [ + "2017 MTV Video Music Awards | The 2017 MTV Video Music Awards were held on August 27, 2017 at The Forum in Inglewood, California, honoring music videos released between June 25, 2016 and June 23, 2017. It was hosted by Katy Perry. The 34th annual award show aired live from the venue for the second time in its history. The music video for Taylor Swift's song \"Look What You Made Me Do\" premiered during the broadcast. Lil Yachty co-hosted the pre-show with Terrence J, Charlamagne Tha God, and MTV News' Gaby Wilson. It was broadcast across various Viacom networks and their related apps.", + "2013 MTV Video Music Awards | The 2013 MTV Video Music Awards were held on August 25, 2013 at the Barclays Center in Brooklyn, New York. Marking the 30th installment of the award show, they were the first to be held in New York City not to use a venue within the borough of Manhattan. Nominations were announced on July 17, 2013. Leading the nominees were Justin Timberlake and Macklemore & Ryan Lewis with six, followed by Bruno Mars, Miley Cyrus, and Robin Thicke with four. Justin Timberlake was the big winner on the night with four awards, including Video of the Year for \"Mirrors\" and the Michael Jackson Vanguard Award. Macklemore & Ryan Lewis, Bruno Mars and Taylor Swift were also among the winners of the night. The ceremony drew a total of 10.1 million viewers.", + "2015 MTV Video Music Awards | The 2015 MTV Video Music Awards were held on August 30, 2015. The 32nd installment of the event was held at the Microsoft Theater in Los Angeles, California, and hosted by Miley Cyrus. Taylor Swift led the nominations with a total of ten, followed by Ed Sheeran, who had six., bringing his total number of mentions to 13. Swift's \"Wildest Dreams\" music video premiered during the pre-show. Cyrus also announced and released her studio album \"Miley Cyrus & Her Dead Petz\", right after her performance at the end of the show. During his acceptance speech, Kanye West announced that he would be running for the 2020 U.S. Presidential Election. Taylor Swift won the most awards with four, including Video of the Year and Best Female Video. The VMA trophies were redesigned by Jeremy Scott." + ], + "statement": "The 2022 MTV Video Music Awards were held at the Prudential Center in Newark, New Jersey.", + "rationale": "The provided context describes the MTV Video Music Awards from 2013, 2015, and 2017. There is no information about the 2022 MTV Video Music Awards.", + "verdict": "Irrelevant" + }, + { + "statement": "The island that housed a federal prison from 1934 to 1963 is called Alcatraz.", + "answer": "True" + } + ], + "signature_instructions": "Verify that the statement is based on the provided context.", + "signature_prefix": "Verdict:", + "extended_signature_instructions": "Using the given context, determine whether the statement is factually supported, refuted by the context, or irrelevant to the information provided. Justify your response by identifying specific details from the context that support or contradict the statement.", + "extended_signature_prefix": "Verdict:" + } +} \ No newline at end of file diff --git a/src/pipeline.py b/src/pipeline.py new file mode 100644 index 0000000..38c08d1 --- /dev/null +++ b/src/pipeline.py @@ -0,0 +1,13 @@ +import utils +from dspy_modules import VerdictCitation + +def get_verdict(search_json, statement): + docs = utils.search_json_to_docs(search_json) + rep = VerdictCitation(docs=docs).get(statement=statement) + + return { + "verdict": rep[0], + "citation": rep[1], + "statement": statement, + } + \ No newline at end of file diff --git a/src/retrieve.py b/src/retrieve.py new file mode 100644 index 0000000..982bc54 --- /dev/null +++ b/src/retrieve.py @@ -0,0 +1,191 @@ +""" +LlamaIndexCustomRetriever +""" + +import os, logging +from typing import Optional + +from llama_index.core import ( + Document, + ServiceContext, + Settings, + StorageContext, + VectorStoreIndex, + load_index_from_storage, +) +from llama_index.core.node_parser import HierarchicalNodeParser, get_leaf_nodes +from llama_index.core.retrievers import AutoMergingRetriever +from llama_index.core.indices.postprocessor import SentenceTransformerRerank +from llama_index.core.query_engine import RetrieverQueryEngine +from llama_index.core.llms import MockLLM + +Settings.llm = MockLLM(max_tokens=256) # retrieve only, do not use LLM for synthesize + +import utils +from settings import settings + +import llama_index.postprocessor.jinaai_rerank.base as jinaai_rerank # todo: shall we lock package version? +jinaai_rerank.API_URL = settings.RERANK_BASE_URL + "/rerank" # switch to on-premise + +# todo: high lantency between client and the ollama embedding server will slow down embedding a lot +from llama_index.embeddings.ollama import OllamaEmbedding + +# todo: improve embedding performance +if settings.EMBEDDING_MODEL_DEPLOY == "local": + embed_model="local:" + settings.EMBEDDING_MODEL_NAME +else: + embed_model = OllamaEmbedding( + model_name=settings.EMBEDDING_MODEL_NAME, + base_url=os.environ.get("OLLAMA_BASE_URL"), # todo: any other configs here? + ) +Settings.embed_model = embed_model + +class LlamaIndexCustomRetriever(): + def __init__( + self, + docs = None, + similarity_top_k: Optional[int] = 6, + ): + self.similarity_top_k = similarity_top_k + if docs: + self.build_index(docs) + + def build_automerging_index( + self, + documents, + chunk_sizes=[2048, 512, 128], + ): + node_parser = HierarchicalNodeParser.from_defaults(chunk_sizes=chunk_sizes) + nodes = node_parser.get_nodes_from_documents(documents) + leaf_nodes = get_leaf_nodes(nodes) + + storage_context = StorageContext.from_defaults() + storage_context.docstore.add_documents(nodes) + + automerging_index = VectorStoreIndex( + leaf_nodes, storage_context=storage_context + ) + return automerging_index + + def get_automerging_query_engine( + self, + automerging_index, + similarity_top_k=6, + rerank_top_n=3, + ): + base_retriever = automerging_index.as_retriever(similarity_top_k=similarity_top_k) + retriever = AutoMergingRetriever( + base_retriever, automerging_index.storage_context, verbose=True + ) + + if settings.RERANK_MODEL_DEPLOY == "local": + rerank = SentenceTransformerRerank( + top_n=rerank_top_n, model=settings.RERANK_MODEL_NAME, + ) # TODO: add support `trust_remote_code=True` + else: + rerank = jinaai_rerank.JinaRerank(api_key='', top_n=rerank_top_n, model=settings.RERANK_MODEL_NAME) + + auto_merging_engine = RetrieverQueryEngine.from_args( + retriever, node_postprocessors=[rerank] + ) + + return auto_merging_engine + + def build_index(self, docs): + """Initiate index or build a new one.""" + + if docs: + index = self.build_automerging_index( + docs, + chunk_sizes=settings.RAG_CHUNK_SIZES, + ) # TODO: try to retrieve directly + self.index = index + + def retrieve(self, query): + query_engine = self.get_automerging_query_engine( + self.index, + similarity_top_k=self.similarity_top_k * 3, + rerank_top_n=self.similarity_top_k + ) + self.query_engine = query_engine + auto_merging_response = self.query_engine.query(query) + contexts = utils.llama_index_nodes_to_list(auto_merging_response.source_nodes) + return contexts + +import dspy + +NO_TOP_K_WARNING = "The underlying LlamaIndex retriever does not support top k retrieval. Ignoring k value." + +class LlamaIndexRM(dspy.Retrieve): + """Implements a retriever which wraps over a LlamaIndex retriever. + + This is done to bridge LlamaIndex and DSPy and allow the various retrieval + abstractions in LlamaIndex to be used in DSPy. + + Args: + retriever (LlamaIndexCustomRetriever): A LlamaIndex retriever object - text based only + k (int): Optional; the number of examples to retrieve (similarity_top_k) + docs (list): list of documents for building index + + Returns: + DSPy RM Object - this is a retriever object that can be used in DSPy + """ + + retriever: LlamaIndexCustomRetriever + + def __init__( + self, + docs, + k: Optional[int] = None, + ): + self.retriever = LlamaIndexCustomRetriever(docs=docs) + + if k: + self.k = k + + @property + def k(self) -> Optional[int]: + """Get similarity top k of retriever.""" + if not hasattr(self.retriever, "similarity_top_k"): + logging.warning(NO_TOP_K_WARNING) + return None + + return self.retriever.similarity_top_k + + @k.setter + def k(self, k: int) -> None: + """Set similarity top k of retriever.""" + if hasattr(self.retriever, "similarity_top_k"): + self.retriever.similarity_top_k = k + else: + logging.warning(NO_TOP_K_WARNING) + + def forward(self, query: str, k: Optional[int] = None, text_only = False) -> list[dspy.Example]: + """Forward function for the LI retriever. + + This is the function that is called to retrieve the top k examples for a given query. + Top k is set via the setter similarity_top_k or at LI instantiation. + + Args: + query (str): The query to retrieve examples for + k (int): Optional; the number of examples to retrieve (similarity_top_k) + + If the underlying LI retriever does not have the property similarity_top_k, k will be ignored. + + Returns: + List[dspy.Example]: A list of examples retrieved by the retriever + """ + if k: + self.k = k + + raw = self.retriever.retrieve(query) + + if text_only: + rep = [result['text'] for result in raw] + else: + # change key text to long_text as required by DSPy + rep = [ + dspy.Example(**{'long_text': result.pop('text', None), **result}) + for result in raw + ] + return rep \ No newline at end of file diff --git a/src/utils.py b/src/utils.py index 4335f7d..ad190c2 100644 --- a/src/utils.py +++ b/src/utils.py @@ -2,6 +2,7 @@ import aiohttp import itertools import logging +from llama_index.core import Document from settings import settings @@ -85,8 +86,7 @@ def generate_report_markdown(input_text, verdicts): markdown.append(f"### Statement {i}\n") markdown.append(f"**Statement**: {verdict['statement']}\n") markdown.append(f"**Verdict**: `{verdict['verdict']}`\n") - markdown.append(f"**Weight**: {verdict['weights']['winning']} out of {verdict['weights']['valid']} ({verdict['weights']['irrelevant']} irrelevant)\n") - markdown.append(f"**Reason**:\n\n{verdict['reason']}\n") + markdown.append(f"**Citation**:\n\n{verdict['citation']}\n") markdown_str = "\n".join(markdown) return markdown_str @@ -149,3 +149,28 @@ async def get_status(): "stack": stack } return status + +def llama_index_nodes_to_list(nodes): + nodes_list = [] + for node in nodes: + _sub = { + 'id': node.node_id, + 'score': node.score, + 'text': node.get_content().strip(), + 'metadata': node.metadata, + } + nodes_list.append(_sub) + return nodes_list + +def search_json_to_docs(search_json): + """Search JSON results to Llama-Index documents""" + documents = [] + for result in search_json['data']: + content = clear_md_links(result.get('content')) + metadata = { + "url": result.get('url'), + "title": result.get('title'), + } + document = Document(text=content, metadata=metadata) + documents.append(document) + return documents \ No newline at end of file