Skip to content

Commit

Permalink
Merge pull request #10 from ittia-research/dev
Browse files Browse the repository at this point in the history
change all LLM calling to DSPy, increase citation token limit
  • Loading branch information
etwk authored Aug 16, 2024
2 parents 44b7391 + 4e9326e commit cd95cbe
Show file tree
Hide file tree
Showing 17 changed files with 171 additions and 175 deletions.
4 changes: 2 additions & 2 deletions .env
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
EMBEDDING_API_KEY=ollama:abc
EMBEDDING_MODEL_DEPLOY=api
EMBEDDING_MODEL_NAME=jina/jina-embeddings-v2-base-en
INDEX_CHUNK_SIZES=[2048, 512, 128]
LLM_MODEL_NAME=google/gemma-2-27b-it
OLLAMA_BASE_URL=http://ollama:11434
OPENAI_API_KEY=sk-proj-aaaaaaaaaaaaaaaaa
Expand All @@ -10,5 +11,4 @@ RERANK_MODEL_DEPLOY=local
RERANK_MODEL_NAME=BAAI/bge-reranker-v2-m3
RERANK_BASE_URL=http://xinference:9997/v1
SEARCH_BASE_URL=https://s.jina.ai
THREAD_BUILD_INDEX=12
RAG_CHUNK_SIZES=[4096, 1024, 256]
THREAD_BUILD_INDEX=12
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ DSPy:
### Reports
- [ ] AI-generated misinformation
### Factcheck
- https://www.snopes.com
- https://www.bmi.bund.de/SharedDocs/schwerpunkte/EN/disinformation/examples-of-russian-disinformation-and-the-facts.html
### Resources
#### Inference
Expand Down
76 changes: 0 additions & 76 deletions src/llm.py

This file was deleted.

12 changes: 6 additions & 6 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fastapi.responses import Response, JSONResponse, HTMLResponse, PlainTextResponse, FileResponse
import logging

import llm, utils, pipeline
import utils, pipeline

logging.basicConfig(
level=logging.INFO,
Expand All @@ -18,7 +18,7 @@
async def fact_check(input):
status = 500
logger.info(f"Fact checking: {input}")
statements = await run_in_threadpool(llm.get_statements, input)
statements = await run_in_threadpool(pipeline.get_statements, input)
logger.info(f"statements: {statements}")
if not statements:
raise HTTPException(status_code=status, detail="No statements found")
Expand All @@ -29,11 +29,11 @@ async def fact_check(input):
if not statement:
continue
logger.info(f"statement: {statement}")
keywords = await run_in_threadpool(llm.get_search_keywords, statement)
if not keywords:
query = await run_in_threadpool(pipeline.get_search_query, statement)
if not query:
continue
logger.info(f"keywords: {keywords}")
search = await utils.search(keywords)
logger.info(f"search query: {query}")
search = await utils.search(query)
if not search:
fail_search = True
continue
Expand Down
17 changes: 17 additions & 0 deletions src/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import dspy

from settings import settings

# set DSPy default language model
llm = dspy.OpenAI(model=settings.LLM_MODEL_NAME, api_base=f"{settings.OPENAI_BASE_URL}/", max_tokens=200, stop='\n\n')
dspy.settings.configure(lm=llm)

# LM with higher token limits
llm_long = dspy.OpenAI(model=settings.LLM_MODEL_NAME, api_base=f"{settings.OPENAI_BASE_URL}/", max_tokens=500, stop='\n\n')

from .citation import Citation
from .ollama_embedding import OllamaEmbedding
from .retrieve import LlamaIndexRM
from .search_query import SearchQuery
from .statements import Statements
from .verdict import Verdict
20 changes: 20 additions & 0 deletions src/modules/citation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import dspy

# TODO: citation needs higher token limits
class GenerateCitedParagraph(dspy.Signature):
"""Generate a paragraph with citations."""
context = dspy.InputField(desc="may contain relevant facts")
statement = dspy.InputField()
verdict = dspy.InputField()
paragraph = dspy.OutputField(desc="includes citations")

"""Generate citation from context and verdict"""
class Citation(dspy.Module):
def __init__(self):
super().__init__()
self.generate_cited_paragraph = dspy.ChainOfThought(GenerateCitedParagraph)

def forward(self, statement, context, verdict):
citation = self.generate_cited_paragraph(context=context, statement=statement, verdict=verdict)
pred = dspy.Prediction(verdict=verdict, citation=citation.paragraph, context=context)
return pred
File renamed without changes.
6 changes: 2 additions & 4 deletions src/retrieve.py → src/modules/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@

from llama_index.core import (
Document,
ServiceContext,
Settings,
StorageContext,
VectorStoreIndex,
load_index_from_storage,
)
from llama_index.core.node_parser import HierarchicalNodeParser, get_leaf_nodes
from llama_index.core.retrievers import AutoMergingRetriever
Expand All @@ -29,7 +27,7 @@
jinaai_rerank.API_URL = settings.RERANK_BASE_URL + "/rerank" # switch to on-premise

# todo: high lantency between client and the ollama embedding server will slow down embedding a lot
from ollama_embedding import OllamaEmbedding
from . import OllamaEmbedding

# todo: improve embedding performance
if settings.EMBEDDING_MODEL_DEPLOY == "local":
Expand Down Expand Up @@ -132,7 +130,7 @@ def build_index(self, docs):
if docs:
self.index, self.storage_context = self.build_automerging_index(
docs,
chunk_sizes=settings.RAG_CHUNK_SIZES,
chunk_sizes=settings.INDEX_CHUNK_SIZES,
) # TODO: try to retrieve directly

def retrieve(self, query):
Expand Down
18 changes: 18 additions & 0 deletions src/modules/search_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import dspy
import logging

"""Notes: LLM will choose a direction based on known facts"""
class GenerateSearchEngineQuery(dspy.Signature):
"""Write a search engine query that will help retrieve info related to the statement."""
statement = dspy.InputField()
query = dspy.OutputField()

class SearchQuery(dspy.Module):
def __init__(self):
super().__init__()
self.generate_query = dspy.ChainOfThought(GenerateSearchEngineQuery)

def forward(self, statement):
query = self.generate_query(statement=statement)
logging.info(f"DSPy CoT search query: {query}")
return query.query
24 changes: 24 additions & 0 deletions src/modules/statements.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import dspy
import logging
from pydantic import BaseModel, Field
from typing import List

# references: https://github.com/weaviate/recipes/blob/main/integrations/llm-frameworks/dspy/4.Structured-Outputs-with-DSPy.ipynb
class Output(BaseModel):
statements: List = Field(description="A list of key statements")

# TODO: test consistency especially when content contains false claims
class GenerateStatements(dspy.Signature):
"""Extract the original statements from given content without fact check."""
content: str = dspy.InputField(desc="The content to summarize")
output: Output = dspy.OutputField()

class Statements(dspy.Module):
def __init__(self):
super().__init__()
self.generate_statements = dspy.TypedChainOfThought(GenerateStatements, max_retries=6)

def forward(self, content):
statements = self.generate_statements(content=content)
logging.info(f"DSPy CoT statements: {statements}")
return statements.output.statements
56 changes: 2 additions & 54 deletions src/dspy_modules.py → src/modules/verdict.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
import dspy
from dsp.utils import deduplicate

from retrieve import LlamaIndexRM
from settings import settings

llm = dspy.OpenAI(model=settings.LLM_MODEL_NAME, api_base=f"{settings.OPENAI_BASE_URL}/", max_tokens=200, stop='\n\n')
dspy.settings.configure(lm=llm)

class CheckStatementFaithfulness(dspy.Signature):
"""Verify that the statement is based on the provided context."""
context = dspy.InputField(desc="facts here are assumed to be true")
Expand All @@ -19,14 +13,6 @@ class GenerateSearchQuery(dspy.Signature):
statement = dspy.InputField()
query = dspy.OutputField()

# TODO: citation needs higher token limits
class GenerateCitedParagraph(dspy.Signature):
"""Generate a paragraph with citations."""
context = dspy.InputField(desc="may contain relevant facts")
statement = dspy.InputField()
verdict = dspy.InputField()
paragraph = dspy.OutputField(desc="includes citations")

"""
SimplifiedBaleen module
Avoid unnecessary content in module cause MIPROv2 optimizer will analize modules.
Expand All @@ -39,7 +25,7 @@ class GenerateCitedParagraph(dspy.Signature):
- remove some contexts incase token reaches to max
- does different InputField name other than answer compateble with dspy evaluate
"""
class ContextVerdict(dspy.Module):
class Verdict(dspy.Module):
def __init__(self, retrieve, passages_per_hop=3, max_hops=3):
super().__init__()
# self.generate_query = dspy.ChainOfThought(GenerateSearchQuery) # IMPORTANT: solves error `list index out of range`
Expand All @@ -59,42 +45,4 @@ def forward(self, statement):
verdict = self.generate_verdict(context=context, statement=statement)
pred = dspy.Prediction(answer=verdict.verdict, rationale=verdict.rationale, context=context)
return pred

"""Generate citation from context and verdict"""
class Citation(dspy.Module):
def __init__(self):
super().__init__()
self.generate_cited_paragraph = dspy.ChainOfThought(GenerateCitedParagraph)

def forward(self, statement, context, verdict):
citation = self.generate_cited_paragraph(context=context, statement=statement, verdict=verdict)
pred = dspy.Prediction(verdict=verdict, citation=citation.paragraph, context=context)
return pred

"""
Get both verdict and citation.
Args:
retrieve: dspy.Retrieve
"""
class VerdictCitation():
def __init__(
self,
docs,
):
self.retrieve = LlamaIndexRM(docs=docs)

# loading compiled ContextVerdict
self.context_verdict = ContextVerdict(retrieve=self.retrieve)
self.context_verdict.load("./optimizers/verdict_MIPROv2.json")

def get(self, statement):
rep = self.context_verdict(statement)
context = rep.context
verdict = rep.answer

rep = Citation()(statement=statement, context=context, verdict=verdict)
citation = rep.citation

return verdict, citation


13 changes: 0 additions & 13 deletions src/pipeline.py

This file was deleted.

2 changes: 2 additions & 0 deletions src/pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .common import get_search_query, get_statements, get_verdict
from .verdict_citation import VerdictCitation
36 changes: 36 additions & 0 deletions src/pipeline/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import logging
import utils
from modules import SearchQuery, Statements
from .verdict_citation import VerdictCitation

def get_statements(content):
"""Get list of statements from a text string"""
try:
statements = Statements()(content=content)
except Exception as e:
logging.error(f"Getting statements failed: {e}")
statements = []

return statements

def get_search_query(statement):
"""Get search query from one statement"""

try:
query = SearchQuery()(statement=statement)
except Exception as e:
logging.error(f"Getting search query from statement '{statement}' failed: {e}")
query = ""

return query

def get_verdict(search_json, statement):
docs = utils.search_json_to_docs(search_json)
rep = VerdictCitation(docs=docs).get(statement=statement)

return {
"verdict": rep[0],
"citation": rep[1],
"statement": statement,
}

Loading

0 comments on commit cd95cbe

Please sign in to comment.