Skip to content

Commit

Permalink
Merge pull request #11 from ittia-research/dev
Browse files Browse the repository at this point in the history
change base image to CUDA, change to dspy.Retrieve
  • Loading branch information
etwk authored Aug 16, 2024
2 parents cd95cbe + 558379c commit 147b2b7
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 13 deletions.
3 changes: 1 addition & 2 deletions Dockerfile.local
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
FROM intel/intel-optimized-pytorch:2.3.0-serving-cpu
FROM pytorch/pytorch:2.4.0-cuda12.4-cudnn9-runtime
WORKDIR /app
COPY requirements.*.txt /app
RUN pip install --no-cache-dir -r requirements.base.txt
RUN pip install --no-cache-dir -r requirements.local.txt
COPY . /app
EXPOSE 8000
ENV NAME "Fact-check API"
WORKDIR /app/src
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
1 change: 0 additions & 1 deletion Dockerfile.remote
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@ COPY requirements.base.txt /app
RUN pip install --no-cache-dir -r requirements.base.txt
COPY . /app
EXPOSE 8000
ENV NAME "Fact-check API"
WORKDIR /app/src
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
7 changes: 3 additions & 4 deletions src/modules/verdict.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,19 @@ class GenerateSearchQuery(dspy.Signature):
- does different InputField name other than answer compateble with dspy evaluate
"""
class Verdict(dspy.Module):
def __init__(self, retrieve, passages_per_hop=3, max_hops=3):
def __init__(self, passages_per_hop=3, max_hops=3):
super().__init__()
# self.generate_query = dspy.ChainOfThought(GenerateSearchQuery) # IMPORTANT: solves error `list index out of range`
self.generate_query = [dspy.ChainOfThought(GenerateSearchQuery) for _ in range(max_hops)]
self.retrieve = retrieve
self.retrieve.k = passages_per_hop
self.retrieve = dspy.Retrieve(k=passages_per_hop)
self.generate_verdict = dspy.ChainOfThought(CheckStatementFaithfulness)
self.max_hops = max_hops

def forward(self, statement):
context = []
for hop in range(self.max_hops):
query = self.generate_query[hop](context=context, statement=statement).query
passages = self.retrieve(query=query, text_only=True)
passages = self.retrieve(query).passages
context = deduplicate(context + passages)

verdict = self.generate_verdict(context=context, statement=statement)
Expand Down
15 changes: 9 additions & 6 deletions src/pipeline/verdict_citation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@ def __init__(
):
self.retrieve = LlamaIndexRM(docs=docs)

# loading compiled Verdict
self.context_verdict = Verdict(retrieve=self.retrieve)
optimizer_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../optimizers/verdict_MIPROv2.json")
self.context_verdict.load(optimizer_path)

def get(self, statement):
rep = self.context_verdict(statement)
with dspy.context(rm=self.retrieve):
self.context_verdict = Verdict()

# loading compiled Verdict
optimizer_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../optimizers/verdict_MIPROv2.json")
self.context_verdict.load(optimizer_path)

rep = self.context_verdict(statement)

context = rep.context
verdict = rep.answer

Expand Down

0 comments on commit 147b2b7

Please sign in to comment.