Skip to content

Commit

Permalink
Merge pull request #17 from ittia-research/dev
Browse files Browse the repository at this point in the history
change to multi-sources mode
  • Loading branch information
etwk authored Aug 26, 2024
2 parents fe03620 + 405070b commit 724950c
Show file tree
Hide file tree
Showing 20 changed files with 701 additions and 207 deletions.
10 changes: 8 additions & 2 deletions Dockerfile.local
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
FROM pytorch/pytorch:2.4.0-cuda12.4-cudnn9-runtime

WORKDIR /app

COPY requirements.*.txt /app
RUN pip install --no-cache-dir -r requirements.base.txt
RUN pip install --no-cache-dir -r requirements.local.txt
COPY . /app
EXPOSE 8000

WORKDIR /app/src

COPY ./src .

EXPOSE 8000

CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
10 changes: 8 additions & 2 deletions Dockerfile.remote
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
FROM python:3.11-slim-bookworm

WORKDIR /app

COPY requirements.base.txt /app
RUN pip install --no-cache-dir -r requirements.base.txt
COPY . /app
EXPOSE 8000

WORKDIR /app/src

COPY ./src .

EXPOSE 8000

CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ LLM

Embedding:
- [ ] chunk size optimize
- [ ] Ollama embedding performance

Contexts
- [ ] Filter out non-related contexts before send for verdict
Expand All @@ -66,9 +65,15 @@ Retrieval
### pipeline
DSPy:
- [ ] make dspy.settings apply to sessions only in order to support multiple retrieve index
- [ ] choose the right LLM temperature
- [ ] better training datasets

### Retrival
- [ ] Better retrival solution: high performance, concurrency, multiple index, index editable.
- [ ] Getting more sources when needed.

### Verdict
- [ ] Set final verdict standards.

### Toolchain
- [ ] Evaluate MLOps pipeline
Expand Down
314 changes: 260 additions & 54 deletions datasets/HotPotQA/HotPotQA_statement_verdict.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ services:
check:
image: ittia/check:remote
container_name: check
volumes:
- /data/cache:/data/cache
env_file:
- ./infra/env.d/check
ports:
Expand Down
2 changes: 2 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,7 @@
- Change from AutoGen to plain OpenAI, since AutoGen AssistantAgent adds system role which are not compateble with Gemma 2 + vllm.

## pipeline
2024/8/26:
- Changed to multi-sources mode (divide sources based on hostname), instead of use all web search results as one single source.
2024/8/13:
- Introduce DSPy to replace the get verdict part, with multi-step reasoning.
15 changes: 12 additions & 3 deletions infra/env.d/check
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
CONCURRENCY_VERDICT=8

DSP_CACHEBOOL=True
DSP_CACHEDIR=/data/cache

EMBEDDING_API_KEY=<CHANGE_ME>
EMBEDDING_BASE_URL=http://ollama:11434
EMBEDDING_BASE_URL=http://infinity:7997
EMBEDDING_BATCH_SIZE=1024
EMBEDDING_MODEL_DEPLOY=api
EMBEDDING_MODEL_NAME=jina/jina-embeddings-v2-base-en
EMBEDDING_MODEL_NAME=jinaai/jina-embeddings-v2-base-en
EMBEDDING_SERVER_TYPE=infinity
INDEX_CHUNK_SIZES=[2048, 512, 128]

LLM_MODEL_NAME=google/gemma-2-27b-it
LLM_MODEL_NAME=mistralai/Mistral-Nemo-Instruct-2407
OPENAI_API_KEY=<CHANGE_ME>
OPENAI_BASE_URL=http://localhost:8000/v1

OPTIMIZER_FILE_NAME=verdict_MIPROv2.json

RERANK_API_KEY=<CHANGE_ME>
RERANK_BASE_URL=http://infinity:7997
RERANK_MODEL_DEPLOY=api
Expand Down
2 changes: 1 addition & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ async def fact_check(input):
verdict = await run_in_threadpool(pipeline.get_verdict, search_json=search, statement=statement)
logger.info(f"Verdict: {verdict}")
except Exception as e:
logger.error(f"Getting verdict for statement {statement} failed: {e}")
logger.error(f"Getting verdict for statement '{statement}' failed: {e}")
continue

verdicts.append(verdict)
Expand Down
4 changes: 2 additions & 2 deletions src/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
llm_long = dspy.OpenAI(model=settings.LLM_MODEL_NAME, api_base=f"{settings.OPENAI_BASE_URL}/", max_tokens=500, stop='\n\n')

from .citation import Citation
from .context_verdict import ContextVerdict
from .retrieve import LlamaIndexRM
from .search import Search
from .search_query import SearchQuery
from .statements import Statements
from .verdict import Verdict
from .statements import Statements
7 changes: 3 additions & 4 deletions src/modules/citation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import dspy

# TODO: citation needs higher token limits
class GenerateCitedParagraph(dspy.Signature):
"""Generate a paragraph with citations."""
context = dspy.InputField(desc="may contain relevant facts")
context = dspy.InputField(desc="May contain relevant facts.")
statement = dspy.InputField()
verdict = dspy.InputField()
paragraph = dspy.OutputField(desc="includes citations")

paragraph = dspy.OutputField(desc="Includes citations.")
"""Generate citation from context and verdict"""
class Citation(dspy.Module):
def __init__(self):
Expand Down
61 changes: 61 additions & 0 deletions src/modules/context_verdict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import dspy
import re
from dsp.utils import deduplicate

class CheckStatement(dspy.Signature):
"""Verify the statement based on the provided context."""
context = dspy.InputField(desc="Facts here are assumed to be true.")
statement = dspy.InputField()
verdict = dspy.OutputField(desc=("In order,"
" `False` if the context directly negates the statement,"
" `True` if it directly supports the statement,"
" else `Irrelevant`."))

"""
LM sometimes reply additional words after the verdict, this function address the issue.
"""
def extract_verdict(input):
# Extract the first word
match = re.match(r'\s*(\w+)', input)
if match:
first_word = match.group(1)
if first_word.lower() in ['false', 'true', 'irrelevant']:
# Return verdict with the first letter capitalized
return first_word.capitalize()
# If no in the verdict list, return the input directly
return input

class GenerateSearchQuery(dspy.Signature):
"""Write a search query that will help retrieve additional info related to the statement."""
context = dspy.InputField(desc="Existing context.")
statement = dspy.InputField()
query = dspy.OutputField()

"""
SimplifiedBaleen module
Avoid unnecessary content in module cause MIPROv2 optimizer will analize modules.
To-do:
- retrieve latest facts
- query results might stays the same in hops: better retrieval
"""
class ContextVerdict(dspy.Module):
def __init__(self, passages_per_hop=3, max_hops=3):
super().__init__()
# self.generate_query = dspy.ChainOfThought(GenerateSearchQuery) # IMPORTANT: solves error `list index out of range`
self.generate_query = [dspy.ChainOfThought(GenerateSearchQuery) for _ in range(max_hops)]
self.retrieve = dspy.Retrieve(k=passages_per_hop)
self.generate_verdict = dspy.ChainOfThought(CheckStatement)
self.max_hops = max_hops

def forward(self, statement):
context = []
for hop in range(self.max_hops):
query = self.generate_query[hop](context=context, statement=statement).query
passages = self.retrieve(query).passages
context = deduplicate(context + passages)

_verdict_predict = self.generate_verdict(context=context, statement=statement)
verdict = extract_verdict(_verdict_predict.verdict)
pred = dspy.Prediction(answer=verdict, rationale=_verdict_predict.rationale, context=context)
return pred
38 changes: 22 additions & 16 deletions src/modules/retrieve.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
"""
LlamaIndexCustomRetriever
"""

import logging
import concurrent.futures
import logging
from typing import Optional

from llama_index.core import (
Expand All @@ -25,18 +21,27 @@

from llama_index.postprocessor.jinaai_rerank import JinaRerank

from integrations import OllamaEmbedding
from integrations import InfinityEmbedding

# todo: improve embedding performance
if settings.EMBEDDING_MODEL_DEPLOY == "local":
embed_model="local:" + settings.EMBEDDING_MODEL_NAME
else:
embed_model = OllamaEmbedding(
api_key=settings.EMBEDDING_API_KEY,
base_url=settings.EMBEDDING_BASE_URL,
embed_batch_size=32, # TODO: what's the best batch size for Ollama
model_name=settings.EMBEDDING_MODEL_NAME,
)
else: # TODO: best batch size
if settings.EMBEDDING_SERVER_TYPE == "infinity":
embed_model = InfinityEmbedding(
api_key=settings.EMBEDDING_API_KEY,
base_url=settings.EMBEDDING_BASE_URL,
embed_batch_size=settings.EMBEDDING_BATCH_SIZE,
model_name=settings.EMBEDDING_MODEL_NAME,
)
elif settings.EMBEDDING_SERVER_TYPE == "ollama":
embed_model = OllamaEmbedding(
api_key=settings.EMBEDDING_API_KEY,
base_url=settings.EMBEDDING_BASE_URL,
embed_batch_size=settings.EMBEDDING_BATCH_SIZE,
model_name=settings.EMBEDDING_MODEL_NAME,
)
else:
raise ValueError(f"Embedding server {settings.EMBEDDING_SERVER_TYPE} not supported")

Settings.embed_model = embed_model

Expand All @@ -61,9 +66,10 @@ def build_automerging_index(

storage_context = StorageContext.from_defaults()
storage_context.docstore.add_documents(self.nodes)


# TODO: enable async
automerging_index = VectorStoreIndex(
leaf_nodes, storage_context=storage_context, use_async=True
leaf_nodes, storage_context=storage_context, use_async=False
)

return automerging_index
Expand Down
6 changes: 5 additions & 1 deletion src/modules/statements.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
class Output(BaseModel):
statements: List = Field(description="A list of key statements")

# TODO: test consistency especially when content contains false claims
"""
TODO:
- correct statements format: time, etc.
- test consistency especially when content contains false claims
"""
class GenerateStatements(dspy.Signature):
"""Extract the original statements from given content without fact check."""
content: str = dspy.InputField(desc="The content to summarize")
Expand Down
47 changes: 0 additions & 47 deletions src/modules/verdict.py

This file was deleted.

16 changes: 16 additions & 0 deletions src/optimizers/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
## Optimizer Card
verdict_MIPROv2:
- datasets:
- size: train 1002, val 500, test 498
- source: generated from HotPotQA
- quality: low
- optimizer: MIPROv2
- compile:
- model: mistralai/Mistral-Nemo-Instruct-2407
- init_temperature: 1
- num_candidates: 20
- num_batches: 120
- max_bootstrapped_demos: 4
- max_labeled_demos: 4
- version:
- dspy-ai==2.4.13
Loading

0 comments on commit 724950c

Please sign in to comment.