Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions papermage/predictors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from papermage.predictors.base_predictors.hf_predictors import HFBIOTaggerPredictor
from papermage.predictors.block_predictors import LPEffDetPubLayNetBlockPredictor
from papermage.predictors.formula_predictors import LPEffDetFormulaPredictor
from papermage.predictors.paper_qa_predictors import PaperQaPredictor
from papermage.predictors.sentence_predictors import PysbdSentencePredictor
from papermage.predictors.span_qa_predictors import APISpanQAPredictor
from papermage.predictors.token_predictors import HFWhitspaceTokenPredictor
Expand All @@ -18,4 +19,5 @@
"LPEffDetFormulaPredictor",
"APISpanQAPredictor",
"BasePredictor",
"PaperQaPredictor",
]
6 changes: 3 additions & 3 deletions papermage/predictors/base_predictors/base_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ def _doc_field_checker(self, doc: Document) -> None:
field in doc.layers
), f"The input Document object {doc} doesn't contain the required field {field}"

def predict(self, doc: Document) -> List[Entity]:
def predict(self, doc: Document, *args, **kwargs) -> List[Entity]:
"""For all the predictors, the input is a document object, and
the output is a list of annotations.
"""
self._doc_field_checker(doc)
return self._predict(doc=doc)
return self._predict(doc=doc, *args, **kwargs)

@abstractmethod
def _predict(self, doc: Document) -> List[Entity]:
def _predict(self, doc: Document, *args, **kwargs) -> List[Entity]:
raise NotImplementedError
2 changes: 1 addition & 1 deletion papermage/predictors/base_predictors/lp_predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def postprocess(self, model_outputs: lp.Layout, page_index: int, image: Image) -
for block in model_outputs
]

def _predict(self, doc: Document) -> List[Entity]:
def _predict(self, doc: Document, *args, **kwargs) -> List[Entity]:
"""Returns a list of Entities for the detected layouts for all pages

Args:
Expand Down
124 changes: 124 additions & 0 deletions papermage/predictors/paper_qa_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from typing import List, Set

import numpy as np
import pysbd
from tokreate import CallAction, ParseAction

from papermage.magelib import Document, Entity, ParagraphsFieldName, SentencesFieldName
from papermage.predictors import BasePredictor

from .utils_paper_qa.hashing import create_hash, int_to_bin, similarity

FULL_DOC_QA_ATTRIBUTED_PROMPT = """\
Answer a question using the provided scientific paper.
Your response should be a JSON object with the following fields:
- answer: The answer to the question. The answer should use concise language, but be comprehensive. Only provide answers that are objectively supported by the text in paper.
- excerpts: A list of one or more *EXACT* text spans extracted from the paper that support the answer. Return between at most ten spans, and no more that 800 words. Make sure to cover all aspects of the answer above.
If there is no answer, return an empty dictionary, i.e., `{}`.

Paper:
{{ full_text }}

Given the information above, please answer the question: "{{ question }}"."""

FULL_DOC_QA_SYSTEM_JSON_PROMPT = """\
You are a helpful research assistant, answering questions about scientific papers accurately and concisely.
You ONLY respond to questions that have an objective answer, and return an empty response for subjective requests.
You always return a valid JSON object to each user request."""


class PaperQaPredictor(BasePredictor):
def __init__(self, model_name: str = "gpt-3.5-turbo-16k", max_tokens: int = 2048):
self.call = CallAction(
prompt=FULL_DOC_QA_ATTRIBUTED_PROMPT,
system=FULL_DOC_QA_SYSTEM_JSON_PROMPT,
model=model_name,
parameters={"max_tokens": max_tokens},
) >> ParseAction(name="json_parser", parser="json.loads")
self.sentencizer = pysbd.Segmenter(language="en", clean=False)

@property
def REQUIRED_DOCUMENT_FIELDS(self) -> List[str]:
return [ParagraphsFieldName, SentencesFieldName]

def merge_adjacent_sentences(self, locs: List[int], slack: int = 1, max_len: int = 3) -> List[List[int]]:
"""Merge adjacent sentences that are within a certain distance of each other."""
seen: Set[int] = set()
grouped: List[List[int]] = []
end_pos = len(locs) - 1
for i in range(len(locs)):
init_pos = locs[i]
if init_pos in seen:
grouped.append([])
elif i < end_pos:
for j in range(i + 1, len(locs)):
curr_pos, prev_pos = locs[j], locs[j - 1]
if (curr_pos - prev_pos) > slack or (curr_pos - init_pos) >= max_len or j == end_pos:
new_pos = list(range(init_pos, prev_pos + 1))
grouped.append(new_pos)
seen.update(new_pos)
break
else:
grouped.append([init_pos])
seen.add(init_pos)

return grouped

def predict(self, doc: Document, *args, **kwargs) -> List[Entity]:
# TODO: fix base predictor so that it can handle questions
self._doc_field_checker(doc)
return self._predict(doc, *args, **kwargs)

def _predict(self, doc: Document, question: str) -> List[Entity]: # type: ignore
full_text = ""
sentences_vecs = []
sentences_ids = []

# build full text representation and hash sentences for similarity
for i, paragraph in enumerate(doc.paragraphs): # type: ignore
header = getattr(paragraph.metadata, "header_text", None)
for j, sent in enumerate(paragraph.sentences):
if j == 0 and (header := getattr(sent.metadata, "header_text", header)):
# add a header if it exists
full_text += f"\n\n## {header}\n"
else:
full_text += "\n\n"

full_text += f"{sent.text.strip()} "
sentences_vecs.append(int_to_bin(create_hash(sent.text)))
sentences_ids.append((i, j))
full_text.strip()
sentences_array = np.vstack(sentences_vecs)

*_, output = self.call.run(full_text=full_text, question=question)
parsed_output = output.state["json_parser"]

if not parsed_output:
# the model could not answer the question
return []

excerpts = [s for e in parsed_output.get("excerpts", []) for s in self.sentencizer.segment(e)]

if not excerpts:
# the model could not find supporting evidence
return []

encoded_context = np.vstack([int_to_bin(create_hash(e)) for e in excerpts])
similarities = similarity(queries=encoded_context, targets=sentences_array)
grouped_locs = self.merge_adjacent_sentences(np.argmax(similarities, axis=1))

extracted_excerpts: List[Entity] = []
for locs in grouped_locs:
if len(locs) == 0:
# nothing matched for this group
continue

ids = [sentences_ids[loc] for loc in locs]
sents = [doc.paragraphs[i].sentences[j] for i, j in ids] # type: ignore
for sent in sents:
matched_sent = Entity.from_json(sent.to_json())
matched_sent.metadata.score = np.max(similarities[:, locs]).tolist()
matched_sent.metadata.answer = parsed_output["answer"]
extracted_excerpts.append(matched_sent)

return extracted_excerpts
137 changes: 137 additions & 0 deletions papermage/predictors/paper_qa_predictors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""

@soldni

"""

from typing import List, Set

import numpy as np
import pysbd
from tokreate import CallAction, ParseAction

from papermage.magelib import (
BlocksFieldName,
Document,
Entity,
ParagraphsFieldName,
SentencesFieldName,
)
from papermage.predictors import BasePredictor

from .utils.paper_qa_utils import create_hash, int_to_bin, similarity

FULL_DOC_QA_ATTRIBUTED_PROMPT = """\
Answer a question using the provided scientific paper.
Your response should be a JSON object with the following fields:
- answer: The answer to the question. The answer should use concise language, but be comprehensive. Only provide answers that are objectively supported by the text in paper.
- excerpts: A list of one or more *EXACT* text spans extracted from the paper that support the answer. Return between at most ten spans, and no more that 800 words. Make sure to cover all aspects of the answer above.
If there is no answer, return an empty dictionary, i.e., `{}`.

Paper:
{{ full_text }}

Given the information above, please answer the question: "{{ question }}"."""

FULL_DOC_QA_SYSTEM_JSON_PROMPT = """\
You are a helpful research assistant, answering questions about scientific papers accurately and concisely.
You ONLY respond to questions that have an objective answer, and return an empty response for subjective requests.
You always return a valid JSON object to each user request."""


class PaperQaPredictor(BasePredictor):
def __init__(self, model_name: str = "gpt-3.5-turbo-16k", max_tokens: int = 2048):
self.call = CallAction(
prompt=FULL_DOC_QA_ATTRIBUTED_PROMPT,
system=FULL_DOC_QA_SYSTEM_JSON_PROMPT,
model=model_name,
parameters={"max_tokens": max_tokens},
) >> ParseAction(name="json_parser", parser="json.loads")
self.sentencizer = pysbd.Segmenter(language="en", clean=False)

@property
def REQUIRED_DOCUMENT_FIELDS(self) -> List[str]:
return [BlocksFieldName, SentencesFieldName]

def merge_adjacent_sentences(self, locs: List[int], slack: int = 1, max_len: int = 3) -> List[List[int]]:
"""Merge adjacent sentences that are within a certain distance of each other."""
seen: Set[int] = set()
grouped: List[List[int]] = []
end_pos = len(locs) - 1
for i in range(len(locs)):
init_pos = locs[i]
if init_pos in seen:
grouped.append([])
elif i < end_pos:
for j in range(i + 1, len(locs)):
curr_pos, prev_pos = locs[j], locs[j - 1]
if (curr_pos - prev_pos) > slack or (curr_pos - init_pos) >= max_len or j == end_pos:
new_pos = list(range(init_pos, prev_pos + 1))
grouped.append(new_pos)
seen.update(new_pos)
break
else:
grouped.append([init_pos])
seen.add(init_pos)

return grouped

def _hash_sentences(self, doc: Document) -> List[str]: # type: ignore
"""Hash sentences in a document."""
for i, sent in enumerate(doc.sentences):
sent.metadata.hash = create_hash(sent.text)
sent.metadata.index = i

def _predict(self, doc: Document, question: str) -> List[Entity]: # type: ignore
full_text = ""
sentences_vecs = []
sentences_ids = []

# build full text representation and hash sentences for similarity
for i, paragraph in enumerate(doc.paragraphs): # type: ignore
header = getattr(paragraph.metadata, "header_text", None)
for j, sent in enumerate(paragraph.sentences):
if j == 0 and (header := getattr(sent.metadata, "header_text", header)):
# add a header if it exists
full_text += f"\n\n## {header}\n"
else:
full_text += "\n\n"

full_text += f"{sent.text.strip()} "
sentences_vecs.append(int_to_bin(create_hash(sent.text)))
sentences_ids.append((i, j))
full_text.strip()
sentences_array = np.vstack(sentences_vecs)

*_, output = self.call.run(full_text=full_text, question=question)
parsed_output = output.state["json_parser"]

if not parsed_output:
# the model could not answer the question
return []

excerpts = [s for e in parsed_output.get("excerpts", []) for s in self.sentencizer.segment(e)]

if not excerpts:
# the model could not find supporting evidence
return []

encoded_context = np.vstack([int_to_bin(create_hash(e)) for e in excerpts])
similarities = similarity(queries=encoded_context, targets=sentences_array)
grouped_locs = self.merge_adjacent_sentences(np.argmax(similarities, axis=1))

extracted_excerpts: List[Entity] = []
for locs in grouped_locs:
if len(locs) == 0:
# nothing matched for this group
continue

ids = [sentences_ids[loc] for loc in locs]
sents = [doc.paragraphs[i].sentences[j] for i, j in ids] # type: ignore
for sent in sents:
matched_sent = Entity.from_json(sent.to_json())
matched_sent.metadata.score = np.max(similarities[:, locs]).tolist()
matched_sent.metadata.answer = parsed_output["answer"]
extracted_excerpts.append(matched_sent)

return extracted_excerpts
22 changes: 22 additions & 0 deletions papermage/predictors/paragraph_predictors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""

Paragraph Predictor

@kylel

"""

import re
from typing import List, Tuple

from papermage.magelib import (
BlocksFieldName,
Document,
Entity,
PagesFieldName,
ParagraphsFieldName,
RowsFieldName,
Span,
TokensFieldName,
)
from papermage.predictors import BasePredictor
Empty file.
Loading