diff --git a/papermage/predictors/__init__.py b/papermage/predictors/__init__.py index ae62f67..1fbe446 100644 --- a/papermage/predictors/__init__.py +++ b/papermage/predictors/__init__.py @@ -2,6 +2,7 @@ from papermage.predictors.base_predictors.hf_predictors import HFBIOTaggerPredictor from papermage.predictors.block_predictors import LPEffDetPubLayNetBlockPredictor from papermage.predictors.formula_predictors import LPEffDetFormulaPredictor +from papermage.predictors.paper_qa_predictors import PaperQaPredictor from papermage.predictors.sentence_predictors import PysbdSentencePredictor from papermage.predictors.span_qa_predictors import APISpanQAPredictor from papermage.predictors.token_predictors import HFWhitspaceTokenPredictor @@ -18,4 +19,5 @@ "LPEffDetFormulaPredictor", "APISpanQAPredictor", "BasePredictor", + "PaperQaPredictor", ] diff --git a/papermage/predictors/base_predictors/base_predictor.py b/papermage/predictors/base_predictors/base_predictor.py index 5ee0f17..df88f85 100644 --- a/papermage/predictors/base_predictors/base_predictor.py +++ b/papermage/predictors/base_predictors/base_predictor.py @@ -31,13 +31,13 @@ def _doc_field_checker(self, doc: Document) -> None: field in doc.layers ), f"The input Document object {doc} doesn't contain the required field {field}" - def predict(self, doc: Document) -> List[Entity]: + def predict(self, doc: Document, *args, **kwargs) -> List[Entity]: """For all the predictors, the input is a document object, and the output is a list of annotations. """ self._doc_field_checker(doc) - return self._predict(doc=doc) + return self._predict(doc=doc, *args, **kwargs) @abstractmethod - def _predict(self, doc: Document) -> List[Entity]: + def _predict(self, doc: Document, *args, **kwargs) -> List[Entity]: raise NotImplementedError diff --git a/papermage/predictors/base_predictors/lp_predictors.py b/papermage/predictors/base_predictors/lp_predictors.py index 9d48b32..6ddb6c0 100644 --- a/papermage/predictors/base_predictors/lp_predictors.py +++ b/papermage/predictors/base_predictors/lp_predictors.py @@ -100,7 +100,7 @@ def postprocess(self, model_outputs: lp.Layout, page_index: int, image: Image) - for block in model_outputs ] - def _predict(self, doc: Document) -> List[Entity]: + def _predict(self, doc: Document, *args, **kwargs) -> List[Entity]: """Returns a list of Entities for the detected layouts for all pages Args: diff --git a/papermage/predictors/paper_qa_predictor.py b/papermage/predictors/paper_qa_predictor.py new file mode 100644 index 0000000..7f528ee --- /dev/null +++ b/papermage/predictors/paper_qa_predictor.py @@ -0,0 +1,124 @@ +from typing import List, Set + +import numpy as np +import pysbd +from tokreate import CallAction, ParseAction + +from papermage.magelib import Document, Entity, ParagraphsFieldName, SentencesFieldName +from papermage.predictors import BasePredictor + +from .utils_paper_qa.hashing import create_hash, int_to_bin, similarity + +FULL_DOC_QA_ATTRIBUTED_PROMPT = """\ +Answer a question using the provided scientific paper. +Your response should be a JSON object with the following fields: + - answer: The answer to the question. The answer should use concise language, but be comprehensive. Only provide answers that are objectively supported by the text in paper. + - excerpts: A list of one or more *EXACT* text spans extracted from the paper that support the answer. Return between at most ten spans, and no more that 800 words. Make sure to cover all aspects of the answer above. +If there is no answer, return an empty dictionary, i.e., `{}`. + +Paper: +{{ full_text }} + +Given the information above, please answer the question: "{{ question }}".""" + +FULL_DOC_QA_SYSTEM_JSON_PROMPT = """\ +You are a helpful research assistant, answering questions about scientific papers accurately and concisely. +You ONLY respond to questions that have an objective answer, and return an empty response for subjective requests. +You always return a valid JSON object to each user request.""" + + +class PaperQaPredictor(BasePredictor): + def __init__(self, model_name: str = "gpt-3.5-turbo-16k", max_tokens: int = 2048): + self.call = CallAction( + prompt=FULL_DOC_QA_ATTRIBUTED_PROMPT, + system=FULL_DOC_QA_SYSTEM_JSON_PROMPT, + model=model_name, + parameters={"max_tokens": max_tokens}, + ) >> ParseAction(name="json_parser", parser="json.loads") + self.sentencizer = pysbd.Segmenter(language="en", clean=False) + + @property + def REQUIRED_DOCUMENT_FIELDS(self) -> List[str]: + return [ParagraphsFieldName, SentencesFieldName] + + def merge_adjacent_sentences(self, locs: List[int], slack: int = 1, max_len: int = 3) -> List[List[int]]: + """Merge adjacent sentences that are within a certain distance of each other.""" + seen: Set[int] = set() + grouped: List[List[int]] = [] + end_pos = len(locs) - 1 + for i in range(len(locs)): + init_pos = locs[i] + if init_pos in seen: + grouped.append([]) + elif i < end_pos: + for j in range(i + 1, len(locs)): + curr_pos, prev_pos = locs[j], locs[j - 1] + if (curr_pos - prev_pos) > slack or (curr_pos - init_pos) >= max_len or j == end_pos: + new_pos = list(range(init_pos, prev_pos + 1)) + grouped.append(new_pos) + seen.update(new_pos) + break + else: + grouped.append([init_pos]) + seen.add(init_pos) + + return grouped + + def predict(self, doc: Document, *args, **kwargs) -> List[Entity]: + # TODO: fix base predictor so that it can handle questions + self._doc_field_checker(doc) + return self._predict(doc, *args, **kwargs) + + def _predict(self, doc: Document, question: str) -> List[Entity]: # type: ignore + full_text = "" + sentences_vecs = [] + sentences_ids = [] + + # build full text representation and hash sentences for similarity + for i, paragraph in enumerate(doc.paragraphs): # type: ignore + header = getattr(paragraph.metadata, "header_text", None) + for j, sent in enumerate(paragraph.sentences): + if j == 0 and (header := getattr(sent.metadata, "header_text", header)): + # add a header if it exists + full_text += f"\n\n## {header}\n" + else: + full_text += "\n\n" + + full_text += f"{sent.text.strip()} " + sentences_vecs.append(int_to_bin(create_hash(sent.text))) + sentences_ids.append((i, j)) + full_text.strip() + sentences_array = np.vstack(sentences_vecs) + + *_, output = self.call.run(full_text=full_text, question=question) + parsed_output = output.state["json_parser"] + + if not parsed_output: + # the model could not answer the question + return [] + + excerpts = [s for e in parsed_output.get("excerpts", []) for s in self.sentencizer.segment(e)] + + if not excerpts: + # the model could not find supporting evidence + return [] + + encoded_context = np.vstack([int_to_bin(create_hash(e)) for e in excerpts]) + similarities = similarity(queries=encoded_context, targets=sentences_array) + grouped_locs = self.merge_adjacent_sentences(np.argmax(similarities, axis=1)) + + extracted_excerpts: List[Entity] = [] + for locs in grouped_locs: + if len(locs) == 0: + # nothing matched for this group + continue + + ids = [sentences_ids[loc] for loc in locs] + sents = [doc.paragraphs[i].sentences[j] for i, j in ids] # type: ignore + for sent in sents: + matched_sent = Entity.from_json(sent.to_json()) + matched_sent.metadata.score = np.max(similarities[:, locs]).tolist() + matched_sent.metadata.answer = parsed_output["answer"] + extracted_excerpts.append(matched_sent) + + return extracted_excerpts diff --git a/papermage/predictors/paper_qa_predictors.py b/papermage/predictors/paper_qa_predictors.py new file mode 100644 index 0000000..9f90ab7 --- /dev/null +++ b/papermage/predictors/paper_qa_predictors.py @@ -0,0 +1,137 @@ +""" + +@soldni + +""" + +from typing import List, Set + +import numpy as np +import pysbd +from tokreate import CallAction, ParseAction + +from papermage.magelib import ( + BlocksFieldName, + Document, + Entity, + ParagraphsFieldName, + SentencesFieldName, +) +from papermage.predictors import BasePredictor + +from .utils.paper_qa_utils import create_hash, int_to_bin, similarity + +FULL_DOC_QA_ATTRIBUTED_PROMPT = """\ +Answer a question using the provided scientific paper. +Your response should be a JSON object with the following fields: + - answer: The answer to the question. The answer should use concise language, but be comprehensive. Only provide answers that are objectively supported by the text in paper. + - excerpts: A list of one or more *EXACT* text spans extracted from the paper that support the answer. Return between at most ten spans, and no more that 800 words. Make sure to cover all aspects of the answer above. +If there is no answer, return an empty dictionary, i.e., `{}`. + +Paper: +{{ full_text }} + +Given the information above, please answer the question: "{{ question }}".""" + +FULL_DOC_QA_SYSTEM_JSON_PROMPT = """\ +You are a helpful research assistant, answering questions about scientific papers accurately and concisely. +You ONLY respond to questions that have an objective answer, and return an empty response for subjective requests. +You always return a valid JSON object to each user request.""" + + +class PaperQaPredictor(BasePredictor): + def __init__(self, model_name: str = "gpt-3.5-turbo-16k", max_tokens: int = 2048): + self.call = CallAction( + prompt=FULL_DOC_QA_ATTRIBUTED_PROMPT, + system=FULL_DOC_QA_SYSTEM_JSON_PROMPT, + model=model_name, + parameters={"max_tokens": max_tokens}, + ) >> ParseAction(name="json_parser", parser="json.loads") + self.sentencizer = pysbd.Segmenter(language="en", clean=False) + + @property + def REQUIRED_DOCUMENT_FIELDS(self) -> List[str]: + return [BlocksFieldName, SentencesFieldName] + + def merge_adjacent_sentences(self, locs: List[int], slack: int = 1, max_len: int = 3) -> List[List[int]]: + """Merge adjacent sentences that are within a certain distance of each other.""" + seen: Set[int] = set() + grouped: List[List[int]] = [] + end_pos = len(locs) - 1 + for i in range(len(locs)): + init_pos = locs[i] + if init_pos in seen: + grouped.append([]) + elif i < end_pos: + for j in range(i + 1, len(locs)): + curr_pos, prev_pos = locs[j], locs[j - 1] + if (curr_pos - prev_pos) > slack or (curr_pos - init_pos) >= max_len or j == end_pos: + new_pos = list(range(init_pos, prev_pos + 1)) + grouped.append(new_pos) + seen.update(new_pos) + break + else: + grouped.append([init_pos]) + seen.add(init_pos) + + return grouped + + def _hash_sentences(self, doc: Document) -> List[str]: # type: ignore + """Hash sentences in a document.""" + for i, sent in enumerate(doc.sentences): + sent.metadata.hash = create_hash(sent.text) + sent.metadata.index = i + + def _predict(self, doc: Document, question: str) -> List[Entity]: # type: ignore + full_text = "" + sentences_vecs = [] + sentences_ids = [] + + # build full text representation and hash sentences for similarity + for i, paragraph in enumerate(doc.paragraphs): # type: ignore + header = getattr(paragraph.metadata, "header_text", None) + for j, sent in enumerate(paragraph.sentences): + if j == 0 and (header := getattr(sent.metadata, "header_text", header)): + # add a header if it exists + full_text += f"\n\n## {header}\n" + else: + full_text += "\n\n" + + full_text += f"{sent.text.strip()} " + sentences_vecs.append(int_to_bin(create_hash(sent.text))) + sentences_ids.append((i, j)) + full_text.strip() + sentences_array = np.vstack(sentences_vecs) + + *_, output = self.call.run(full_text=full_text, question=question) + parsed_output = output.state["json_parser"] + + if not parsed_output: + # the model could not answer the question + return [] + + excerpts = [s for e in parsed_output.get("excerpts", []) for s in self.sentencizer.segment(e)] + + if not excerpts: + # the model could not find supporting evidence + return [] + + encoded_context = np.vstack([int_to_bin(create_hash(e)) for e in excerpts]) + similarities = similarity(queries=encoded_context, targets=sentences_array) + grouped_locs = self.merge_adjacent_sentences(np.argmax(similarities, axis=1)) + + extracted_excerpts: List[Entity] = [] + for locs in grouped_locs: + if len(locs) == 0: + # nothing matched for this group + continue + + ids = [sentences_ids[loc] for loc in locs] + sents = [doc.paragraphs[i].sentences[j] for i, j in ids] # type: ignore + for sent in sents: + matched_sent = Entity.from_json(sent.to_json()) + matched_sent.metadata.score = np.max(similarities[:, locs]).tolist() + matched_sent.metadata.answer = parsed_output["answer"] + extracted_excerpts.append(matched_sent) + + return extracted_excerpts diff --git a/papermage/predictors/paragraph_predictors.py b/papermage/predictors/paragraph_predictors.py new file mode 100644 index 0000000..4bd8d33 --- /dev/null +++ b/papermage/predictors/paragraph_predictors.py @@ -0,0 +1,22 @@ +""" + +Paragraph Predictor + +@kylel + +""" + +import re +from typing import List, Tuple + +from papermage.magelib import ( + BlocksFieldName, + Document, + Entity, + PagesFieldName, + ParagraphsFieldName, + RowsFieldName, + Span, + TokensFieldName, +) +from papermage.predictors import BasePredictor diff --git a/papermage/predictors/utils/__init__.py b/papermage/predictors/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/papermage/predictors/utils/paper_qa_utils.py b/papermage/predictors/utils/paper_qa_utils.py new file mode 100644 index 0000000..97f947b --- /dev/null +++ b/papermage/predictors/utils/paper_qa_utils.py @@ -0,0 +1,88 @@ +""" + +@soldni + +""" + +import re +from functools import lru_cache +from hashlib import blake2b +from typing import List + +import numpy as np + +MIN_WORD_LENGTH = 4 +HASH_BITS = 64 + + +def int_to_bin(i: int) -> np.ndarray: + return np.array(list(map(int, np.binary_repr(i, width=64))), dtype=np.int8) + + +def bin_to_int(b: np.ndarray) -> int: + return int("".join(map(str, b)), 2) + + +def int_to_hex(i: int) -> str: + return hex(i)[2:] + + +def hex_to_int(h: str) -> int: + return int(h, 16) + + +def generate_tokens(s: str, min_tokens: int = HASH_BITS, min_length: int = MIN_WORD_LENGTH) -> List[str]: + """Generate tokens and sample a subset if you have enough of minimum length.""" + tokens = [t.strip() for t in re.findall(r" ?[^.\s,!?…。,、।۔،]+", s)] + lengths = [len(t) for t in tokens] + for n in range(min_length, -1, -1): + if sum(lt for lt in lengths if lt > n) >= min_tokens: + return [t for t in tokens if len(t) > n] + return tokens + + +def _hash_token(token: str) -> np.ndarray: + token_bytes = blake2b(token.encode(), digest_size=8).digest() + return np.frombuffer(token_bytes, dtype=">u8").astype(np.int64) + + +@lru_cache(maxsize=2**14) +def hash_token(token: str, hash_bits: int = HASH_BITS): + # we use blake2b because it's faster than md5, sha1, and sha3 + h = _hash_token(token=token) + + # initialize an array of -1 + bit_array = -1 * np.ones(hash_bits, dtype=np.int64) + + # find the indices where condition h & (1 << np.arange(hash_bits)) is true + indices = np.where(h & (1 << np.arange(hash_bits, dtype=np.int64))) + + # set those indices to 1 + bit_array[indices] = 1 + + # return the bit array + return bit_array + + +def create_hash(s: str, hash_bits: int = HASH_BITS) -> int: + tokens = generate_tokens(s=s, min_tokens=hash_bits) + v = np.zeros(hash_bits, dtype=np.int64) + + for t in tokens: + v = v + hash_token(token=t, hash_bits=hash_bits) + + # Get the indices where v[i] is >= 0 + indices, *_ = np.where(v >= 0) + + # Calculate fingerprint using bitwise shift and summation (equivalent to sum(2**indices)) + fingerprint = np.sum(1 << indices.astype(np.uint64)) + + return fingerprint.tolist() + + +def similarity(targets: np.ndarray, queries: np.ndarray) -> np.ndarray: + if len(targets.shape) == 1: + targets = targets.reshape(1, -1) + if len(queries.shape) == 1: + queries = queries.reshape(1, -1) + return 1.0 - np.sum(queries[:, np.newaxis, :] != targets[np.newaxis, :, :], axis=-1) / queries.shape[-1] diff --git a/papermage/predictors/utils/vila_utils.py b/papermage/predictors/utils/vila_utils.py new file mode 100644 index 0000000..650e0ec --- /dev/null +++ b/papermage/predictors/utils/vila_utils.py @@ -0,0 +1,149 @@ +""" + +@shannons + +""" + +import inspect +import itertools +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +from papermage.magelib import ( + BlocksFieldName, + Document, + Entity, + Metadata, + PagesFieldName, + RowsFieldName, + Span, + TokensFieldName, +) + + +# util +def columns_used_in_model_inputs(model): + signature = inspect.signature(model.forward) + signature_columns = list(signature.parameters.keys()) + return signature_columns + + +# util +def normalize_bbox( + bbox, + page_width, + page_height, + target_width, + target_height, +): + """ + Normalize bounding box to the target size. + """ + + x1, y1, x2, y2 = bbox + + # Right now only execute this for only "large" PDFs + # TODO: Change it for all PDFs + if page_width > target_width or page_height > target_height: + x1 = float(x1) / page_width * target_width + x2 = float(x2) / page_width * target_width + y1 = float(y1) / page_height * target_height + y2 = float(y2) / page_height * target_height + + return (x1, y1, x2, y2) + + +# util +def shift_index_sequence_to_zero_start(sequence): + """ + Shift a sequence to start at 0. + """ + sequence_start = min(sequence) + return [i - sequence_start for i in sequence] + + +# util +def get_visual_group_id(token: Entity, field_name: str, defaults=-1) -> int: + if not hasattr(token, field_name): + return defaults + field_value = getattr(token, field_name) + if len(field_value) == 0 or field_value[0].id is None: + return defaults + return field_value[0].id + + +# util +def convert_document_page_to_pdf_dict(doc: Document, page_width: int, page_height: int) -> Dict[str, List]: + """Convert a document to a dictionary of the form: + { + 'words': ['word1', 'word2', ...], + 'bbox': [[x1, y1, x2, y2], [x1, y1, x2, y2], ...], + 'block_ids': [0, 0, 0, 1 ...], + 'line_ids': [0, 1, 1, 2 ...], + 'labels': [0, 0, 0, 1 ...], # could be empty + } + + Args: + document (Document): + The input document object + page_width (int): + Typically the transformer model requires to use + the absolute coordinates for encoding the coordinates. + Set the correspnding page_width and page_height to convert the + relative coordinates to the absolute coordinates. + page_height (int): + Typically the transformer model requires to use + the absolute coordinates for encoding the coordinates. + Set the correspnding page_width and page_height to convert the + relative coordinates to the absolute coordinates. + + Returns: + Dict[str, List]: The pdf_dict object + """ + + token_data = [ + ( + token.symbols_from_spans[0], # words + token.boxes[0].to_absolute(page_width=page_width, page_height=page_height).xy_coordinates, # bbox + get_visual_group_id(token, RowsFieldName, -1), # line_ids + get_visual_group_id(token, BlocksFieldName, -1), # block_ids + ) + for token in doc.tokens + ] + + words, bbox, line_ids, block_ids = (list(l) for l in zip(*token_data)) + line_ids = shift_index_sequence_to_zero_start(line_ids) + block_ids = shift_index_sequence_to_zero_start(block_ids) + + labels = [None] * len(words) + # TODO: We provide an empty label list. + + return { + "words": words, + "bbox": bbox, + "block_ids": block_ids, + "line_ids": line_ids, + "labels": labels, + } + + +# util +def convert_sequence_tagging_to_spans( + token_prediction_sequence: List, +) -> List[Tuple[int, int, int]]: + """For a sequence of token predictions, convert them to spans + of consecutive same predictions. + + Args: + token_prediction_sequence (List) + + Returns: + List[Tuple[int, int, int]]: A list of (start, end, label) + of consecutive prediction of the same label. + """ + prev_len = 0 + spans = [] + for gp, seq in itertools.groupby(token_prediction_sequence): + cur_len = len(list(seq)) + spans.append((prev_len, prev_len + cur_len, gp)) + prev_len = prev_len + cur_len + return spans diff --git a/papermage/predictors/utils_paper_qa/__init__.py b/papermage/predictors/utils_paper_qa/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/papermage/predictors/utils_paper_qa/hashing.py b/papermage/predictors/utils_paper_qa/hashing.py new file mode 100644 index 0000000..e67f718 --- /dev/null +++ b/papermage/predictors/utils_paper_qa/hashing.py @@ -0,0 +1,82 @@ +import re +from functools import lru_cache +from hashlib import blake2b +from typing import List + +import numpy as np + +MIN_WORD_LENGTH = 4 +HASH_BITS = 64 + + +def int_to_bin(i: int) -> np.ndarray: + return np.array(list(map(int, np.binary_repr(i, width=64))), dtype=np.int8) + + +def bin_to_int(b: np.ndarray) -> int: + return int("".join(map(str, b)), 2) + + +def int_to_hex(i: int) -> str: + return hex(i)[2:] + + +def hex_to_int(h: str) -> int: + return int(h, 16) + + +def generate_tokens(s: str, min_tokens: int = HASH_BITS, min_length: int = MIN_WORD_LENGTH) -> List[str]: + """Generate tokens and sample a subset if you have enough of minimum length.""" + tokens = [t.strip() for t in re.findall(r" ?[^.\s,!?…。,、।۔،]+", s)] + lengths = [len(t) for t in tokens] + for n in range(min_length, -1, -1): + if sum(lt for lt in lengths if lt > n) >= min_tokens: + return [t for t in tokens if len(t) > n] + return tokens + + +def _hash_token(token: str) -> np.ndarray: + token_bytes = blake2b(token.encode(), digest_size=8).digest() + return np.frombuffer(token_bytes, dtype=">u8").astype(np.int64) + + +@lru_cache(maxsize=2**14) +def hash_token(token: str, hash_bits: int = HASH_BITS): + # we use blake2b because it's faster than md5, sha1, and sha3 + h = _hash_token(token=token) + + # initialize an array of -1 + bit_array = -1 * np.ones(hash_bits, dtype=np.int64) + + # find the indices where condition h & (1 << np.arange(hash_bits)) is true + indices = np.where(h & (1 << np.arange(hash_bits, dtype=np.int64))) + + # set those indices to 1 + bit_array[indices] = 1 + + # return the bit array + return bit_array + + +def create_hash(s: str, hash_bits: int = HASH_BITS) -> int: + tokens = generate_tokens(s=s, min_tokens=hash_bits) + v = np.zeros(hash_bits, dtype=np.int64) + + for t in tokens: + v = v + hash_token(token=t, hash_bits=hash_bits) + + # Get the indices where v[i] is >= 0 + indices, *_ = np.where(v >= 0) + + # Calculate fingerprint using bitwise shift and summation (equivalent to sum(2**indices)) + fingerprint = np.sum(1 << indices.astype(np.uint64)) + + return fingerprint.tolist() + + +def similarity(targets: np.ndarray, queries: np.ndarray) -> np.ndarray: + if len(targets.shape) == 1: + targets = targets.reshape(1, -1) + if len(queries.shape) == 1: + queries = queries.reshape(1, -1) + return 1.0 - np.sum(queries[:, np.newaxis, :] != targets[np.newaxis, :, :], axis=-1) / queries.shape[-1] diff --git a/papermage/predictors/vila_predictors.py b/papermage/predictors/vila_predictors.py index 76eaeaf..6ab0a64 100644 --- a/papermage/predictors/vila_predictors.py +++ b/papermage/predictors/vila_predictors.py @@ -11,13 +11,10 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" -import inspect -import itertools + from abc import abstractmethod from typing import Any, Dict, List, Optional, Sequence, Tuple, Union -import torch -from tqdm import tqdm from vila.predictors import LayoutIndicatorPDFPredictor, SimplePDFPredictor from papermage.magelib import ( @@ -32,6 +29,11 @@ ) from papermage.predictors import BasePredictor +from .utils.vila_utils import ( + convert_document_page_to_pdf_dict, + convert_sequence_tagging_to_spans, +) + # Two constants for the constraining the size of the page for # inputs to the model. # TODO: Move this to somewhere else. @@ -59,135 +61,6 @@ ] -# util -def columns_used_in_model_inputs(model): - signature = inspect.signature(model.forward) - signature_columns = list(signature.parameters.keys()) - return signature_columns - - -# util -def normalize_bbox( - bbox, - page_width, - page_height, - target_width, - target_height, -): - """ - Normalize bounding box to the target size. - """ - - x1, y1, x2, y2 = bbox - - # Right now only execute this for only "large" PDFs - # TODO: Change it for all PDFs - if page_width > target_width or page_height > target_height: - x1 = float(x1) / page_width * target_width - x2 = float(x2) / page_width * target_width - y1 = float(y1) / page_height * target_height - y2 = float(y2) / page_height * target_height - - return (x1, y1, x2, y2) - - -# util -def shift_index_sequence_to_zero_start(sequence): - """ - Shift a sequence to start at 0. - """ - sequence_start = min(sequence) - return [i - sequence_start for i in sequence] - - -# util -def get_visual_group_id(token: Entity, field_name: str, defaults=-1) -> int: - if not hasattr(token, field_name): - return defaults - field_value = getattr(token, field_name) - if len(field_value) == 0 or field_value[0].id is None: - return defaults - return field_value[0].id - - -# util -def convert_document_page_to_pdf_dict(doc: Document, page_width: int, page_height: int) -> Dict[str, List]: - """Convert a document to a dictionary of the form: - { - 'words': ['word1', 'word2', ...], - 'bbox': [[x1, y1, x2, y2], [x1, y1, x2, y2], ...], - 'block_ids': [0, 0, 0, 1 ...], - 'line_ids': [0, 1, 1, 2 ...], - 'labels': [0, 0, 0, 1 ...], # could be empty - } - - Args: - document (Document): - The input document object - page_width (int): - Typically the transformer model requires to use - the absolute coordinates for encoding the coordinates. - Set the correspnding page_width and page_height to convert the - relative coordinates to the absolute coordinates. - page_height (int): - Typically the transformer model requires to use - the absolute coordinates for encoding the coordinates. - Set the correspnding page_width and page_height to convert the - relative coordinates to the absolute coordinates. - - Returns: - Dict[str, List]: The pdf_dict object - """ - - token_data = [ - ( - token.symbols_from_spans[0], # words - token.boxes[0].to_absolute(page_width=page_width, page_height=page_height).xy_coordinates, # bbox - get_visual_group_id(token, RowsFieldName, -1), # line_ids - get_visual_group_id(token, BlocksFieldName, -1), # block_ids - ) - for token in doc.tokens - ] - - words, bbox, line_ids, block_ids = (list(l) for l in zip(*token_data)) - line_ids = shift_index_sequence_to_zero_start(line_ids) - block_ids = shift_index_sequence_to_zero_start(block_ids) - - labels = [None] * len(words) - # TODO: We provide an empty label list. - - return { - "words": words, - "bbox": bbox, - "block_ids": block_ids, - "line_ids": line_ids, - "labels": labels, - } - - -# util -def convert_sequence_tagging_to_spans( - token_prediction_sequence: List, -) -> List[Tuple[int, int, int]]: - """For a sequence of token predictions, convert them to spans - of consecutive same predictions. - - Args: - token_prediction_sequence (List) - - Returns: - List[Tuple[int, int, int]]: A list of (start, end, label) - of consecutive prediction of the same label. - """ - prev_len = 0 - spans = [] - for gp, seq in itertools.groupby(token_prediction_sequence): - cur_len = len(list(seq)) - spans.append((prev_len, prev_len + cur_len, gp)) - prev_len = prev_len + cur_len - return spans - - class BaseSinglePageTokenClassificationPredictor(BasePredictor): @property def REQUIRED_DOCUMENT_FIELDS(self) -> List[str]: diff --git a/pyproject.toml b/pyproject.toml index c88ea3a..9b8f1ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,6 +115,10 @@ dev = [ visualizers = [ 'layoutparser==0.3.4' ] +qa = [ + 'tokreate', + 'pysbd' +] predictors = [ 'thefuzz[speedup]', 'scikit-learn>=1.3.0', diff --git a/tests/test_predictors/test_paper_qa_predictors.py b/tests/test_predictors/test_paper_qa_predictors.py new file mode 100644 index 0000000..5070b9f --- /dev/null +++ b/tests/test_predictors/test_paper_qa_predictors.py @@ -0,0 +1,37 @@ +""" + +@kylel + +""" + + +import json +import os +import pathlib +import unittest + +from papermage.magelib import Document, Entity, Metadata, Span +from papermage.predictors import PaperQaPredictor + + +class TestPaperQAPredictor(unittest.TestCase): + def setUp(self): + self.fixture_path = pathlib.Path(__file__).parent.parent / "fixtures" + with open(self.fixture_path / "2304.02623v1.json", "r") as f: + test_doc_json = json.load(f) + self.doc = Document.from_json(doc_json=test_doc_json) + + self.paper_qa_predictor = PaperQaPredictor() + + self.using_github_actions = ( + "USING_GITHUB_ACTIONS" in os.environ and os.environ["USING_GITHUB_ACTIONS"] == "true" + ) + + def test_merge_adjacent_sentences(self): + locs = [0, 1, 2, 3] + result = self.paper_qa_predictor.merge_adjacent_sentences(locs=locs, slack=1, max_len=3) + self.assertListEqual(result, [[0, 1, 2], [], [], [3]]) + + locs = [0, 1, 3, 5, 8] + result = self.paper_qa_predictor.merge_adjacent_sentences(locs=locs, slack=2, max_len=3) + self.assertListEqual(result, [[0, 1, 2], [], [], [3]])