diff --git a/Dockerfile.local b/Dockerfile.local index c9774ac..f91e761 100644 --- a/Dockerfile.local +++ b/Dockerfile.local @@ -1,10 +1,9 @@ -FROM intel/intel-optimized-pytorch:2.3.0-serving-cpu +FROM pytorch/pytorch:2.4.0-cuda12.4-cudnn9-runtime WORKDIR /app COPY requirements.*.txt /app RUN pip install --no-cache-dir -r requirements.base.txt RUN pip install --no-cache-dir -r requirements.local.txt COPY . /app EXPOSE 8000 -ENV NAME "Fact-check API" WORKDIR /app/src CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/Dockerfile.remote b/Dockerfile.remote index a744bd1..478fd1a 100644 --- a/Dockerfile.remote +++ b/Dockerfile.remote @@ -4,6 +4,5 @@ COPY requirements.base.txt /app RUN pip install --no-cache-dir -r requirements.base.txt COPY . /app EXPOSE 8000 -ENV NAME "Fact-check API" WORKDIR /app/src CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/src/modules/verdict.py b/src/modules/verdict.py index 319aa04..75ecf96 100644 --- a/src/modules/verdict.py +++ b/src/modules/verdict.py @@ -26,12 +26,11 @@ class GenerateSearchQuery(dspy.Signature): - does different InputField name other than answer compateble with dspy evaluate """ class Verdict(dspy.Module): - def __init__(self, retrieve, passages_per_hop=3, max_hops=3): + def __init__(self, passages_per_hop=3, max_hops=3): super().__init__() # self.generate_query = dspy.ChainOfThought(GenerateSearchQuery) # IMPORTANT: solves error `list index out of range` self.generate_query = [dspy.ChainOfThought(GenerateSearchQuery) for _ in range(max_hops)] - self.retrieve = retrieve - self.retrieve.k = passages_per_hop + self.retrieve = dspy.Retrieve(k=passages_per_hop) self.generate_verdict = dspy.ChainOfThought(CheckStatementFaithfulness) self.max_hops = max_hops @@ -39,7 +38,7 @@ def forward(self, statement): context = [] for hop in range(self.max_hops): query = self.generate_query[hop](context=context, statement=statement).query - passages = self.retrieve(query=query, text_only=True) + passages = self.retrieve(query).passages context = deduplicate(context + passages) verdict = self.generate_verdict(context=context, statement=statement) diff --git a/src/pipeline/verdict_citation.py b/src/pipeline/verdict_citation.py index db36b71..5decb33 100644 --- a/src/pipeline/verdict_citation.py +++ b/src/pipeline/verdict_citation.py @@ -17,13 +17,16 @@ def __init__( ): self.retrieve = LlamaIndexRM(docs=docs) - # loading compiled Verdict - self.context_verdict = Verdict(retrieve=self.retrieve) - optimizer_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../optimizers/verdict_MIPROv2.json") - self.context_verdict.load(optimizer_path) - def get(self, statement): - rep = self.context_verdict(statement) + with dspy.context(rm=self.retrieve): + self.context_verdict = Verdict() + + # loading compiled Verdict + optimizer_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../optimizers/verdict_MIPROv2.json") + self.context_verdict.load(optimizer_path) + + rep = self.context_verdict(statement) + context = rep.context verdict = rep.answer