From 93487bb29c4b2c365e5830b1aa1baed6578eb58e Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Thu, 12 Oct 2023 14:29:51 -0700 Subject: [PATCH 1/6] paper qa --- papermage/predictors/paper_qa_predictor.py | 124 ++++++++++++++++++ .../predictors/utils_paper_qa/__init__.py | 0 .../predictors/utils_paper_qa/hashing.py | 82 ++++++++++++ pyproject.toml | 4 + 4 files changed, 210 insertions(+) create mode 100644 papermage/predictors/paper_qa_predictor.py create mode 100644 papermage/predictors/utils_paper_qa/__init__.py create mode 100644 papermage/predictors/utils_paper_qa/hashing.py diff --git a/papermage/predictors/paper_qa_predictor.py b/papermage/predictors/paper_qa_predictor.py new file mode 100644 index 0000000..7f528ee --- /dev/null +++ b/papermage/predictors/paper_qa_predictor.py @@ -0,0 +1,124 @@ +from typing import List, Set + +import numpy as np +import pysbd +from tokreate import CallAction, ParseAction + +from papermage.magelib import Document, Entity, ParagraphsFieldName, SentencesFieldName +from papermage.predictors import BasePredictor + +from .utils_paper_qa.hashing import create_hash, int_to_bin, similarity + +FULL_DOC_QA_ATTRIBUTED_PROMPT = """\ +Answer a question using the provided scientific paper. +Your response should be a JSON object with the following fields: + - answer: The answer to the question. The answer should use concise language, but be comprehensive. Only provide answers that are objectively supported by the text in paper. + - excerpts: A list of one or more *EXACT* text spans extracted from the paper that support the answer. Return between at most ten spans, and no more that 800 words. Make sure to cover all aspects of the answer above. +If there is no answer, return an empty dictionary, i.e., `{}`. + +Paper: +{{ full_text }} + +Given the information above, please answer the question: "{{ question }}".""" + +FULL_DOC_QA_SYSTEM_JSON_PROMPT = """\ +You are a helpful research assistant, answering questions about scientific papers accurately and concisely. +You ONLY respond to questions that have an objective answer, and return an empty response for subjective requests. +You always return a valid JSON object to each user request.""" + + +class PaperQaPredictor(BasePredictor): + def __init__(self, model_name: str = "gpt-3.5-turbo-16k", max_tokens: int = 2048): + self.call = CallAction( + prompt=FULL_DOC_QA_ATTRIBUTED_PROMPT, + system=FULL_DOC_QA_SYSTEM_JSON_PROMPT, + model=model_name, + parameters={"max_tokens": max_tokens}, + ) >> ParseAction(name="json_parser", parser="json.loads") + self.sentencizer = pysbd.Segmenter(language="en", clean=False) + + @property + def REQUIRED_DOCUMENT_FIELDS(self) -> List[str]: + return [ParagraphsFieldName, SentencesFieldName] + + def merge_adjacent_sentences(self, locs: List[int], slack: int = 1, max_len: int = 3) -> List[List[int]]: + """Merge adjacent sentences that are within a certain distance of each other.""" + seen: Set[int] = set() + grouped: List[List[int]] = [] + end_pos = len(locs) - 1 + for i in range(len(locs)): + init_pos = locs[i] + if init_pos in seen: + grouped.append([]) + elif i < end_pos: + for j in range(i + 1, len(locs)): + curr_pos, prev_pos = locs[j], locs[j - 1] + if (curr_pos - prev_pos) > slack or (curr_pos - init_pos) >= max_len or j == end_pos: + new_pos = list(range(init_pos, prev_pos + 1)) + grouped.append(new_pos) + seen.update(new_pos) + break + else: + grouped.append([init_pos]) + seen.add(init_pos) + + return grouped + + def predict(self, doc: Document, *args, **kwargs) -> List[Entity]: + # TODO: fix base predictor so that it can handle questions + self._doc_field_checker(doc) + return self._predict(doc, *args, **kwargs) + + def _predict(self, doc: Document, question: str) -> List[Entity]: # type: ignore + full_text = "" + sentences_vecs = [] + sentences_ids = [] + + # build full text representation and hash sentences for similarity + for i, paragraph in enumerate(doc.paragraphs): # type: ignore + header = getattr(paragraph.metadata, "header_text", None) + for j, sent in enumerate(paragraph.sentences): + if j == 0 and (header := getattr(sent.metadata, "header_text", header)): + # add a header if it exists + full_text += f"\n\n## {header}\n" + else: + full_text += "\n\n" + + full_text += f"{sent.text.strip()} " + sentences_vecs.append(int_to_bin(create_hash(sent.text))) + sentences_ids.append((i, j)) + full_text.strip() + sentences_array = np.vstack(sentences_vecs) + + *_, output = self.call.run(full_text=full_text, question=question) + parsed_output = output.state["json_parser"] + + if not parsed_output: + # the model could not answer the question + return [] + + excerpts = [s for e in parsed_output.get("excerpts", []) for s in self.sentencizer.segment(e)] + + if not excerpts: + # the model could not find supporting evidence + return [] + + encoded_context = np.vstack([int_to_bin(create_hash(e)) for e in excerpts]) + similarities = similarity(queries=encoded_context, targets=sentences_array) + grouped_locs = self.merge_adjacent_sentences(np.argmax(similarities, axis=1)) + + extracted_excerpts: List[Entity] = [] + for locs in grouped_locs: + if len(locs) == 0: + # nothing matched for this group + continue + + ids = [sentences_ids[loc] for loc in locs] + sents = [doc.paragraphs[i].sentences[j] for i, j in ids] # type: ignore + for sent in sents: + matched_sent = Entity.from_json(sent.to_json()) + matched_sent.metadata.score = np.max(similarities[:, locs]).tolist() + matched_sent.metadata.answer = parsed_output["answer"] + extracted_excerpts.append(matched_sent) + + return extracted_excerpts diff --git a/papermage/predictors/utils_paper_qa/__init__.py b/papermage/predictors/utils_paper_qa/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/papermage/predictors/utils_paper_qa/hashing.py b/papermage/predictors/utils_paper_qa/hashing.py new file mode 100644 index 0000000..e67f718 --- /dev/null +++ b/papermage/predictors/utils_paper_qa/hashing.py @@ -0,0 +1,82 @@ +import re +from functools import lru_cache +from hashlib import blake2b +from typing import List + +import numpy as np + +MIN_WORD_LENGTH = 4 +HASH_BITS = 64 + + +def int_to_bin(i: int) -> np.ndarray: + return np.array(list(map(int, np.binary_repr(i, width=64))), dtype=np.int8) + + +def bin_to_int(b: np.ndarray) -> int: + return int("".join(map(str, b)), 2) + + +def int_to_hex(i: int) -> str: + return hex(i)[2:] + + +def hex_to_int(h: str) -> int: + return int(h, 16) + + +def generate_tokens(s: str, min_tokens: int = HASH_BITS, min_length: int = MIN_WORD_LENGTH) -> List[str]: + """Generate tokens and sample a subset if you have enough of minimum length.""" + tokens = [t.strip() for t in re.findall(r" ?[^.\s,!?…。,、।۔،]+", s)] + lengths = [len(t) for t in tokens] + for n in range(min_length, -1, -1): + if sum(lt for lt in lengths if lt > n) >= min_tokens: + return [t for t in tokens if len(t) > n] + return tokens + + +def _hash_token(token: str) -> np.ndarray: + token_bytes = blake2b(token.encode(), digest_size=8).digest() + return np.frombuffer(token_bytes, dtype=">u8").astype(np.int64) + + +@lru_cache(maxsize=2**14) +def hash_token(token: str, hash_bits: int = HASH_BITS): + # we use blake2b because it's faster than md5, sha1, and sha3 + h = _hash_token(token=token) + + # initialize an array of -1 + bit_array = -1 * np.ones(hash_bits, dtype=np.int64) + + # find the indices where condition h & (1 << np.arange(hash_bits)) is true + indices = np.where(h & (1 << np.arange(hash_bits, dtype=np.int64))) + + # set those indices to 1 + bit_array[indices] = 1 + + # return the bit array + return bit_array + + +def create_hash(s: str, hash_bits: int = HASH_BITS) -> int: + tokens = generate_tokens(s=s, min_tokens=hash_bits) + v = np.zeros(hash_bits, dtype=np.int64) + + for t in tokens: + v = v + hash_token(token=t, hash_bits=hash_bits) + + # Get the indices where v[i] is >= 0 + indices, *_ = np.where(v >= 0) + + # Calculate fingerprint using bitwise shift and summation (equivalent to sum(2**indices)) + fingerprint = np.sum(1 << indices.astype(np.uint64)) + + return fingerprint.tolist() + + +def similarity(targets: np.ndarray, queries: np.ndarray) -> np.ndarray: + if len(targets.shape) == 1: + targets = targets.reshape(1, -1) + if len(queries.shape) == 1: + queries = queries.reshape(1, -1) + return 1.0 - np.sum(queries[:, np.newaxis, :] != targets[np.newaxis, :, :], axis=-1) / queries.shape[-1] diff --git a/pyproject.toml b/pyproject.toml index ed6515e..05c30a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,6 +115,10 @@ dev = [ visualizers = [ 'layoutparser==0.3.4' ] +qa = [ + 'tokreate', + 'pysbd' +] predictors = [ 'thefuzz[speedup]', 'scikit-learn>=1.3.0', From cabbb2b197534e1b72c13150725132e85b019fac Mon Sep 17 00:00:00 2001 From: kyleclo Date: Wed, 29 Nov 2023 20:56:47 -0800 Subject: [PATCH 2/6] adding import --- papermage/predictors/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/papermage/predictors/__init__.py b/papermage/predictors/__init__.py index ae62f67..1fbe446 100644 --- a/papermage/predictors/__init__.py +++ b/papermage/predictors/__init__.py @@ -2,6 +2,7 @@ from papermage.predictors.base_predictors.hf_predictors import HFBIOTaggerPredictor from papermage.predictors.block_predictors import LPEffDetPubLayNetBlockPredictor from papermage.predictors.formula_predictors import LPEffDetFormulaPredictor +from papermage.predictors.paper_qa_predictors import PaperQaPredictor from papermage.predictors.sentence_predictors import PysbdSentencePredictor from papermage.predictors.span_qa_predictors import APISpanQAPredictor from papermage.predictors.token_predictors import HFWhitspaceTokenPredictor @@ -18,4 +19,5 @@ "LPEffDetFormulaPredictor", "APISpanQAPredictor", "BasePredictor", + "PaperQaPredictor", ] From 4a5d04f155230497e2ab80cd46ba7d915a806e54 Mon Sep 17 00:00:00 2001 From: kyleclo Date: Wed, 29 Nov 2023 20:57:24 -0800 Subject: [PATCH 3/6] add args kwargs to base predictor predict --- papermage/predictors/base_predictors/base_predictor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/papermage/predictors/base_predictors/base_predictor.py b/papermage/predictors/base_predictors/base_predictor.py index 5ee0f17..df88f85 100644 --- a/papermage/predictors/base_predictors/base_predictor.py +++ b/papermage/predictors/base_predictors/base_predictor.py @@ -31,13 +31,13 @@ def _doc_field_checker(self, doc: Document) -> None: field in doc.layers ), f"The input Document object {doc} doesn't contain the required field {field}" - def predict(self, doc: Document) -> List[Entity]: + def predict(self, doc: Document, *args, **kwargs) -> List[Entity]: """For all the predictors, the input is a document object, and the output is a list of annotations. """ self._doc_field_checker(doc) - return self._predict(doc=doc) + return self._predict(doc=doc, *args, **kwargs) @abstractmethod - def _predict(self, doc: Document) -> List[Entity]: + def _predict(self, doc: Document, *args, **kwargs) -> List[Entity]: raise NotImplementedError From e5879d2fe76a8acff5a8d82c088fcc483126768d Mon Sep 17 00:00:00 2001 From: kyleclo Date: Wed, 29 Nov 2023 20:57:45 -0800 Subject: [PATCH 4/6] add args kwargs to predict method --- papermage/predictors/base_predictors/lp_predictors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/papermage/predictors/base_predictors/lp_predictors.py b/papermage/predictors/base_predictors/lp_predictors.py index 9d48b32..6ddb6c0 100644 --- a/papermage/predictors/base_predictors/lp_predictors.py +++ b/papermage/predictors/base_predictors/lp_predictors.py @@ -100,7 +100,7 @@ def postprocess(self, model_outputs: lp.Layout, page_index: int, image: Image) - for block in model_outputs ] - def _predict(self, doc: Document) -> List[Entity]: + def _predict(self, doc: Document, *args, **kwargs) -> List[Entity]: """Returns a list of Entities for the detected layouts for all pages Args: From e7f772c3ff10c475499032bc984e0e4219e5707b Mon Sep 17 00:00:00 2001 From: kyleclo Date: Wed, 29 Nov 2023 20:59:24 -0800 Subject: [PATCH 5/6] create new space for predictor utils --- papermage/predictors/utils/__init__.py | 0 papermage/predictors/utils/paper_qa_utils.py | 88 +++++++++++ papermage/predictors/utils/vila_utils.py | 149 +++++++++++++++++++ papermage/predictors/vila_predictors.py | 139 +---------------- 4 files changed, 243 insertions(+), 133 deletions(-) create mode 100644 papermage/predictors/utils/__init__.py create mode 100644 papermage/predictors/utils/paper_qa_utils.py create mode 100644 papermage/predictors/utils/vila_utils.py diff --git a/papermage/predictors/utils/__init__.py b/papermage/predictors/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/papermage/predictors/utils/paper_qa_utils.py b/papermage/predictors/utils/paper_qa_utils.py new file mode 100644 index 0000000..97f947b --- /dev/null +++ b/papermage/predictors/utils/paper_qa_utils.py @@ -0,0 +1,88 @@ +""" + +@soldni + +""" + +import re +from functools import lru_cache +from hashlib import blake2b +from typing import List + +import numpy as np + +MIN_WORD_LENGTH = 4 +HASH_BITS = 64 + + +def int_to_bin(i: int) -> np.ndarray: + return np.array(list(map(int, np.binary_repr(i, width=64))), dtype=np.int8) + + +def bin_to_int(b: np.ndarray) -> int: + return int("".join(map(str, b)), 2) + + +def int_to_hex(i: int) -> str: + return hex(i)[2:] + + +def hex_to_int(h: str) -> int: + return int(h, 16) + + +def generate_tokens(s: str, min_tokens: int = HASH_BITS, min_length: int = MIN_WORD_LENGTH) -> List[str]: + """Generate tokens and sample a subset if you have enough of minimum length.""" + tokens = [t.strip() for t in re.findall(r" ?[^.\s,!?…。,、।۔،]+", s)] + lengths = [len(t) for t in tokens] + for n in range(min_length, -1, -1): + if sum(lt for lt in lengths if lt > n) >= min_tokens: + return [t for t in tokens if len(t) > n] + return tokens + + +def _hash_token(token: str) -> np.ndarray: + token_bytes = blake2b(token.encode(), digest_size=8).digest() + return np.frombuffer(token_bytes, dtype=">u8").astype(np.int64) + + +@lru_cache(maxsize=2**14) +def hash_token(token: str, hash_bits: int = HASH_BITS): + # we use blake2b because it's faster than md5, sha1, and sha3 + h = _hash_token(token=token) + + # initialize an array of -1 + bit_array = -1 * np.ones(hash_bits, dtype=np.int64) + + # find the indices where condition h & (1 << np.arange(hash_bits)) is true + indices = np.where(h & (1 << np.arange(hash_bits, dtype=np.int64))) + + # set those indices to 1 + bit_array[indices] = 1 + + # return the bit array + return bit_array + + +def create_hash(s: str, hash_bits: int = HASH_BITS) -> int: + tokens = generate_tokens(s=s, min_tokens=hash_bits) + v = np.zeros(hash_bits, dtype=np.int64) + + for t in tokens: + v = v + hash_token(token=t, hash_bits=hash_bits) + + # Get the indices where v[i] is >= 0 + indices, *_ = np.where(v >= 0) + + # Calculate fingerprint using bitwise shift and summation (equivalent to sum(2**indices)) + fingerprint = np.sum(1 << indices.astype(np.uint64)) + + return fingerprint.tolist() + + +def similarity(targets: np.ndarray, queries: np.ndarray) -> np.ndarray: + if len(targets.shape) == 1: + targets = targets.reshape(1, -1) + if len(queries.shape) == 1: + queries = queries.reshape(1, -1) + return 1.0 - np.sum(queries[:, np.newaxis, :] != targets[np.newaxis, :, :], axis=-1) / queries.shape[-1] diff --git a/papermage/predictors/utils/vila_utils.py b/papermage/predictors/utils/vila_utils.py new file mode 100644 index 0000000..650e0ec --- /dev/null +++ b/papermage/predictors/utils/vila_utils.py @@ -0,0 +1,149 @@ +""" + +@shannons + +""" + +import inspect +import itertools +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +from papermage.magelib import ( + BlocksFieldName, + Document, + Entity, + Metadata, + PagesFieldName, + RowsFieldName, + Span, + TokensFieldName, +) + + +# util +def columns_used_in_model_inputs(model): + signature = inspect.signature(model.forward) + signature_columns = list(signature.parameters.keys()) + return signature_columns + + +# util +def normalize_bbox( + bbox, + page_width, + page_height, + target_width, + target_height, +): + """ + Normalize bounding box to the target size. + """ + + x1, y1, x2, y2 = bbox + + # Right now only execute this for only "large" PDFs + # TODO: Change it for all PDFs + if page_width > target_width or page_height > target_height: + x1 = float(x1) / page_width * target_width + x2 = float(x2) / page_width * target_width + y1 = float(y1) / page_height * target_height + y2 = float(y2) / page_height * target_height + + return (x1, y1, x2, y2) + + +# util +def shift_index_sequence_to_zero_start(sequence): + """ + Shift a sequence to start at 0. + """ + sequence_start = min(sequence) + return [i - sequence_start for i in sequence] + + +# util +def get_visual_group_id(token: Entity, field_name: str, defaults=-1) -> int: + if not hasattr(token, field_name): + return defaults + field_value = getattr(token, field_name) + if len(field_value) == 0 or field_value[0].id is None: + return defaults + return field_value[0].id + + +# util +def convert_document_page_to_pdf_dict(doc: Document, page_width: int, page_height: int) -> Dict[str, List]: + """Convert a document to a dictionary of the form: + { + 'words': ['word1', 'word2', ...], + 'bbox': [[x1, y1, x2, y2], [x1, y1, x2, y2], ...], + 'block_ids': [0, 0, 0, 1 ...], + 'line_ids': [0, 1, 1, 2 ...], + 'labels': [0, 0, 0, 1 ...], # could be empty + } + + Args: + document (Document): + The input document object + page_width (int): + Typically the transformer model requires to use + the absolute coordinates for encoding the coordinates. + Set the correspnding page_width and page_height to convert the + relative coordinates to the absolute coordinates. + page_height (int): + Typically the transformer model requires to use + the absolute coordinates for encoding the coordinates. + Set the correspnding page_width and page_height to convert the + relative coordinates to the absolute coordinates. + + Returns: + Dict[str, List]: The pdf_dict object + """ + + token_data = [ + ( + token.symbols_from_spans[0], # words + token.boxes[0].to_absolute(page_width=page_width, page_height=page_height).xy_coordinates, # bbox + get_visual_group_id(token, RowsFieldName, -1), # line_ids + get_visual_group_id(token, BlocksFieldName, -1), # block_ids + ) + for token in doc.tokens + ] + + words, bbox, line_ids, block_ids = (list(l) for l in zip(*token_data)) + line_ids = shift_index_sequence_to_zero_start(line_ids) + block_ids = shift_index_sequence_to_zero_start(block_ids) + + labels = [None] * len(words) + # TODO: We provide an empty label list. + + return { + "words": words, + "bbox": bbox, + "block_ids": block_ids, + "line_ids": line_ids, + "labels": labels, + } + + +# util +def convert_sequence_tagging_to_spans( + token_prediction_sequence: List, +) -> List[Tuple[int, int, int]]: + """For a sequence of token predictions, convert them to spans + of consecutive same predictions. + + Args: + token_prediction_sequence (List) + + Returns: + List[Tuple[int, int, int]]: A list of (start, end, label) + of consecutive prediction of the same label. + """ + prev_len = 0 + spans = [] + for gp, seq in itertools.groupby(token_prediction_sequence): + cur_len = len(list(seq)) + spans.append((prev_len, prev_len + cur_len, gp)) + prev_len = prev_len + cur_len + return spans diff --git a/papermage/predictors/vila_predictors.py b/papermage/predictors/vila_predictors.py index 76eaeaf..6ab0a64 100644 --- a/papermage/predictors/vila_predictors.py +++ b/papermage/predictors/vila_predictors.py @@ -11,13 +11,10 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" -import inspect -import itertools + from abc import abstractmethod from typing import Any, Dict, List, Optional, Sequence, Tuple, Union -import torch -from tqdm import tqdm from vila.predictors import LayoutIndicatorPDFPredictor, SimplePDFPredictor from papermage.magelib import ( @@ -32,6 +29,11 @@ ) from papermage.predictors import BasePredictor +from .utils.vila_utils import ( + convert_document_page_to_pdf_dict, + convert_sequence_tagging_to_spans, +) + # Two constants for the constraining the size of the page for # inputs to the model. # TODO: Move this to somewhere else. @@ -59,135 +61,6 @@ ] -# util -def columns_used_in_model_inputs(model): - signature = inspect.signature(model.forward) - signature_columns = list(signature.parameters.keys()) - return signature_columns - - -# util -def normalize_bbox( - bbox, - page_width, - page_height, - target_width, - target_height, -): - """ - Normalize bounding box to the target size. - """ - - x1, y1, x2, y2 = bbox - - # Right now only execute this for only "large" PDFs - # TODO: Change it for all PDFs - if page_width > target_width or page_height > target_height: - x1 = float(x1) / page_width * target_width - x2 = float(x2) / page_width * target_width - y1 = float(y1) / page_height * target_height - y2 = float(y2) / page_height * target_height - - return (x1, y1, x2, y2) - - -# util -def shift_index_sequence_to_zero_start(sequence): - """ - Shift a sequence to start at 0. - """ - sequence_start = min(sequence) - return [i - sequence_start for i in sequence] - - -# util -def get_visual_group_id(token: Entity, field_name: str, defaults=-1) -> int: - if not hasattr(token, field_name): - return defaults - field_value = getattr(token, field_name) - if len(field_value) == 0 or field_value[0].id is None: - return defaults - return field_value[0].id - - -# util -def convert_document_page_to_pdf_dict(doc: Document, page_width: int, page_height: int) -> Dict[str, List]: - """Convert a document to a dictionary of the form: - { - 'words': ['word1', 'word2', ...], - 'bbox': [[x1, y1, x2, y2], [x1, y1, x2, y2], ...], - 'block_ids': [0, 0, 0, 1 ...], - 'line_ids': [0, 1, 1, 2 ...], - 'labels': [0, 0, 0, 1 ...], # could be empty - } - - Args: - document (Document): - The input document object - page_width (int): - Typically the transformer model requires to use - the absolute coordinates for encoding the coordinates. - Set the correspnding page_width and page_height to convert the - relative coordinates to the absolute coordinates. - page_height (int): - Typically the transformer model requires to use - the absolute coordinates for encoding the coordinates. - Set the correspnding page_width and page_height to convert the - relative coordinates to the absolute coordinates. - - Returns: - Dict[str, List]: The pdf_dict object - """ - - token_data = [ - ( - token.symbols_from_spans[0], # words - token.boxes[0].to_absolute(page_width=page_width, page_height=page_height).xy_coordinates, # bbox - get_visual_group_id(token, RowsFieldName, -1), # line_ids - get_visual_group_id(token, BlocksFieldName, -1), # block_ids - ) - for token in doc.tokens - ] - - words, bbox, line_ids, block_ids = (list(l) for l in zip(*token_data)) - line_ids = shift_index_sequence_to_zero_start(line_ids) - block_ids = shift_index_sequence_to_zero_start(block_ids) - - labels = [None] * len(words) - # TODO: We provide an empty label list. - - return { - "words": words, - "bbox": bbox, - "block_ids": block_ids, - "line_ids": line_ids, - "labels": labels, - } - - -# util -def convert_sequence_tagging_to_spans( - token_prediction_sequence: List, -) -> List[Tuple[int, int, int]]: - """For a sequence of token predictions, convert them to spans - of consecutive same predictions. - - Args: - token_prediction_sequence (List) - - Returns: - List[Tuple[int, int, int]]: A list of (start, end, label) - of consecutive prediction of the same label. - """ - prev_len = 0 - spans = [] - for gp, seq in itertools.groupby(token_prediction_sequence): - cur_len = len(list(seq)) - spans.append((prev_len, prev_len + cur_len, gp)) - prev_len = prev_len + cur_len - return spans - - class BaseSinglePageTokenClassificationPredictor(BasePredictor): @property def REQUIRED_DOCUMENT_FIELDS(self) -> List[str]: From 818fdd199b2637f70feb8da001e9902e98070359 Mon Sep 17 00:00:00 2001 From: kyleclo Date: Wed, 29 Nov 2023 21:00:12 -0800 Subject: [PATCH 6/6] placeholders for paper qa predictors --- papermage/predictors/paper_qa_predictors.py | 137 ++++++++++++++++++ papermage/predictors/paragraph_predictors.py | 22 +++ .../test_paper_qa_predictors.py | 37 +++++ 3 files changed, 196 insertions(+) create mode 100644 papermage/predictors/paper_qa_predictors.py create mode 100644 papermage/predictors/paragraph_predictors.py create mode 100644 tests/test_predictors/test_paper_qa_predictors.py diff --git a/papermage/predictors/paper_qa_predictors.py b/papermage/predictors/paper_qa_predictors.py new file mode 100644 index 0000000..9f90ab7 --- /dev/null +++ b/papermage/predictors/paper_qa_predictors.py @@ -0,0 +1,137 @@ +""" + +@soldni + +""" + +from typing import List, Set + +import numpy as np +import pysbd +from tokreate import CallAction, ParseAction + +from papermage.magelib import ( + BlocksFieldName, + Document, + Entity, + ParagraphsFieldName, + SentencesFieldName, +) +from papermage.predictors import BasePredictor + +from .utils.paper_qa_utils import create_hash, int_to_bin, similarity + +FULL_DOC_QA_ATTRIBUTED_PROMPT = """\ +Answer a question using the provided scientific paper. +Your response should be a JSON object with the following fields: + - answer: The answer to the question. The answer should use concise language, but be comprehensive. Only provide answers that are objectively supported by the text in paper. + - excerpts: A list of one or more *EXACT* text spans extracted from the paper that support the answer. Return between at most ten spans, and no more that 800 words. Make sure to cover all aspects of the answer above. +If there is no answer, return an empty dictionary, i.e., `{}`. + +Paper: +{{ full_text }} + +Given the information above, please answer the question: "{{ question }}".""" + +FULL_DOC_QA_SYSTEM_JSON_PROMPT = """\ +You are a helpful research assistant, answering questions about scientific papers accurately and concisely. +You ONLY respond to questions that have an objective answer, and return an empty response for subjective requests. +You always return a valid JSON object to each user request.""" + + +class PaperQaPredictor(BasePredictor): + def __init__(self, model_name: str = "gpt-3.5-turbo-16k", max_tokens: int = 2048): + self.call = CallAction( + prompt=FULL_DOC_QA_ATTRIBUTED_PROMPT, + system=FULL_DOC_QA_SYSTEM_JSON_PROMPT, + model=model_name, + parameters={"max_tokens": max_tokens}, + ) >> ParseAction(name="json_parser", parser="json.loads") + self.sentencizer = pysbd.Segmenter(language="en", clean=False) + + @property + def REQUIRED_DOCUMENT_FIELDS(self) -> List[str]: + return [BlocksFieldName, SentencesFieldName] + + def merge_adjacent_sentences(self, locs: List[int], slack: int = 1, max_len: int = 3) -> List[List[int]]: + """Merge adjacent sentences that are within a certain distance of each other.""" + seen: Set[int] = set() + grouped: List[List[int]] = [] + end_pos = len(locs) - 1 + for i in range(len(locs)): + init_pos = locs[i] + if init_pos in seen: + grouped.append([]) + elif i < end_pos: + for j in range(i + 1, len(locs)): + curr_pos, prev_pos = locs[j], locs[j - 1] + if (curr_pos - prev_pos) > slack or (curr_pos - init_pos) >= max_len or j == end_pos: + new_pos = list(range(init_pos, prev_pos + 1)) + grouped.append(new_pos) + seen.update(new_pos) + break + else: + grouped.append([init_pos]) + seen.add(init_pos) + + return grouped + + def _hash_sentences(self, doc: Document) -> List[str]: # type: ignore + """Hash sentences in a document.""" + for i, sent in enumerate(doc.sentences): + sent.metadata.hash = create_hash(sent.text) + sent.metadata.index = i + + def _predict(self, doc: Document, question: str) -> List[Entity]: # type: ignore + full_text = "" + sentences_vecs = [] + sentences_ids = [] + + # build full text representation and hash sentences for similarity + for i, paragraph in enumerate(doc.paragraphs): # type: ignore + header = getattr(paragraph.metadata, "header_text", None) + for j, sent in enumerate(paragraph.sentences): + if j == 0 and (header := getattr(sent.metadata, "header_text", header)): + # add a header if it exists + full_text += f"\n\n## {header}\n" + else: + full_text += "\n\n" + + full_text += f"{sent.text.strip()} " + sentences_vecs.append(int_to_bin(create_hash(sent.text))) + sentences_ids.append((i, j)) + full_text.strip() + sentences_array = np.vstack(sentences_vecs) + + *_, output = self.call.run(full_text=full_text, question=question) + parsed_output = output.state["json_parser"] + + if not parsed_output: + # the model could not answer the question + return [] + + excerpts = [s for e in parsed_output.get("excerpts", []) for s in self.sentencizer.segment(e)] + + if not excerpts: + # the model could not find supporting evidence + return [] + + encoded_context = np.vstack([int_to_bin(create_hash(e)) for e in excerpts]) + similarities = similarity(queries=encoded_context, targets=sentences_array) + grouped_locs = self.merge_adjacent_sentences(np.argmax(similarities, axis=1)) + + extracted_excerpts: List[Entity] = [] + for locs in grouped_locs: + if len(locs) == 0: + # nothing matched for this group + continue + + ids = [sentences_ids[loc] for loc in locs] + sents = [doc.paragraphs[i].sentences[j] for i, j in ids] # type: ignore + for sent in sents: + matched_sent = Entity.from_json(sent.to_json()) + matched_sent.metadata.score = np.max(similarities[:, locs]).tolist() + matched_sent.metadata.answer = parsed_output["answer"] + extracted_excerpts.append(matched_sent) + + return extracted_excerpts diff --git a/papermage/predictors/paragraph_predictors.py b/papermage/predictors/paragraph_predictors.py new file mode 100644 index 0000000..4bd8d33 --- /dev/null +++ b/papermage/predictors/paragraph_predictors.py @@ -0,0 +1,22 @@ +""" + +Paragraph Predictor + +@kylel + +""" + +import re +from typing import List, Tuple + +from papermage.magelib import ( + BlocksFieldName, + Document, + Entity, + PagesFieldName, + ParagraphsFieldName, + RowsFieldName, + Span, + TokensFieldName, +) +from papermage.predictors import BasePredictor diff --git a/tests/test_predictors/test_paper_qa_predictors.py b/tests/test_predictors/test_paper_qa_predictors.py new file mode 100644 index 0000000..5070b9f --- /dev/null +++ b/tests/test_predictors/test_paper_qa_predictors.py @@ -0,0 +1,37 @@ +""" + +@kylel + +""" + + +import json +import os +import pathlib +import unittest + +from papermage.magelib import Document, Entity, Metadata, Span +from papermage.predictors import PaperQaPredictor + + +class TestPaperQAPredictor(unittest.TestCase): + def setUp(self): + self.fixture_path = pathlib.Path(__file__).parent.parent / "fixtures" + with open(self.fixture_path / "2304.02623v1.json", "r") as f: + test_doc_json = json.load(f) + self.doc = Document.from_json(doc_json=test_doc_json) + + self.paper_qa_predictor = PaperQaPredictor() + + self.using_github_actions = ( + "USING_GITHUB_ACTIONS" in os.environ and os.environ["USING_GITHUB_ACTIONS"] == "true" + ) + + def test_merge_adjacent_sentences(self): + locs = [0, 1, 2, 3] + result = self.paper_qa_predictor.merge_adjacent_sentences(locs=locs, slack=1, max_len=3) + self.assertListEqual(result, [[0, 1, 2], [], [], [3]]) + + locs = [0, 1, 3, 5, 8] + result = self.paper_qa_predictor.merge_adjacent_sentences(locs=locs, slack=2, max_len=3) + self.assertListEqual(result, [[0, 1, 2], [], [], [3]])