-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10 from ittia-research/dev
change all LLM calling to DSPy, increase citation token limit
- Loading branch information
Showing
17 changed files
with
171 additions
and
175 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import dspy | ||
|
||
from settings import settings | ||
|
||
# set DSPy default language model | ||
llm = dspy.OpenAI(model=settings.LLM_MODEL_NAME, api_base=f"{settings.OPENAI_BASE_URL}/", max_tokens=200, stop='\n\n') | ||
dspy.settings.configure(lm=llm) | ||
|
||
# LM with higher token limits | ||
llm_long = dspy.OpenAI(model=settings.LLM_MODEL_NAME, api_base=f"{settings.OPENAI_BASE_URL}/", max_tokens=500, stop='\n\n') | ||
|
||
from .citation import Citation | ||
from .ollama_embedding import OllamaEmbedding | ||
from .retrieve import LlamaIndexRM | ||
from .search_query import SearchQuery | ||
from .statements import Statements | ||
from .verdict import Verdict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import dspy | ||
|
||
# TODO: citation needs higher token limits | ||
class GenerateCitedParagraph(dspy.Signature): | ||
"""Generate a paragraph with citations.""" | ||
context = dspy.InputField(desc="may contain relevant facts") | ||
statement = dspy.InputField() | ||
verdict = dspy.InputField() | ||
paragraph = dspy.OutputField(desc="includes citations") | ||
|
||
"""Generate citation from context and verdict""" | ||
class Citation(dspy.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.generate_cited_paragraph = dspy.ChainOfThought(GenerateCitedParagraph) | ||
|
||
def forward(self, statement, context, verdict): | ||
citation = self.generate_cited_paragraph(context=context, statement=statement, verdict=verdict) | ||
pred = dspy.Prediction(verdict=verdict, citation=citation.paragraph, context=context) | ||
return pred |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import dspy | ||
import logging | ||
|
||
"""Notes: LLM will choose a direction based on known facts""" | ||
class GenerateSearchEngineQuery(dspy.Signature): | ||
"""Write a search engine query that will help retrieve info related to the statement.""" | ||
statement = dspy.InputField() | ||
query = dspy.OutputField() | ||
|
||
class SearchQuery(dspy.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.generate_query = dspy.ChainOfThought(GenerateSearchEngineQuery) | ||
|
||
def forward(self, statement): | ||
query = self.generate_query(statement=statement) | ||
logging.info(f"DSPy CoT search query: {query}") | ||
return query.query |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import dspy | ||
import logging | ||
from pydantic import BaseModel, Field | ||
from typing import List | ||
|
||
# references: https://github.com/weaviate/recipes/blob/main/integrations/llm-frameworks/dspy/4.Structured-Outputs-with-DSPy.ipynb | ||
class Output(BaseModel): | ||
statements: List = Field(description="A list of key statements") | ||
|
||
# TODO: test consistency especially when content contains false claims | ||
class GenerateStatements(dspy.Signature): | ||
"""Extract the original statements from given content without fact check.""" | ||
content: str = dspy.InputField(desc="The content to summarize") | ||
output: Output = dspy.OutputField() | ||
|
||
class Statements(dspy.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.generate_statements = dspy.TypedChainOfThought(GenerateStatements, max_retries=6) | ||
|
||
def forward(self, content): | ||
statements = self.generate_statements(content=content) | ||
logging.info(f"DSPy CoT statements: {statements}") | ||
return statements.output.statements |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .common import get_search_query, get_statements, get_verdict | ||
from .verdict_citation import VerdictCitation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import logging | ||
import utils | ||
from modules import SearchQuery, Statements | ||
from .verdict_citation import VerdictCitation | ||
|
||
def get_statements(content): | ||
"""Get list of statements from a text string""" | ||
try: | ||
statements = Statements()(content=content) | ||
except Exception as e: | ||
logging.error(f"Getting statements failed: {e}") | ||
statements = [] | ||
|
||
return statements | ||
|
||
def get_search_query(statement): | ||
"""Get search query from one statement""" | ||
|
||
try: | ||
query = SearchQuery()(statement=statement) | ||
except Exception as e: | ||
logging.error(f"Getting search query from statement '{statement}' failed: {e}") | ||
query = "" | ||
|
||
return query | ||
|
||
def get_verdict(search_json, statement): | ||
docs = utils.search_json_to_docs(search_json) | ||
rep = VerdictCitation(docs=docs).get(statement=statement) | ||
|
||
return { | ||
"verdict": rep[0], | ||
"citation": rep[1], | ||
"statement": statement, | ||
} | ||
|
Oops, something went wrong.