Skip to content

Commit

Permalink
Merge pull request #147 from lsorber/main
Browse files Browse the repository at this point in the history
Improve block weighting with uniform and hat functions
  • Loading branch information
markus583 authored Jan 18, 2025
2 parents 159a493 + 16fba36 commit 5902e7e
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 3 deletions.
11 changes: 11 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
# noqa: E501
from wtpsplit import WtP, SaT

def test_weighting():
sat = SaT("sat-3l-sm", ort_providers=["CPUExecutionProvider"])

text = "This is a test sentence This is another test sentence."
splits_default = sat.split(text, threshold=0.25)
splits_uniform = sat.split(text, threshold=0.25, weighting="uniform")
splits_hat = sat.split(text, threshold=0.25, weighting="hat")
expected_splits = ["This is a test sentence ", "This is another test sentence."]
assert splits_default == splits_uniform == splits_hat == expected_splits
assert "".join(splits_default) == text


def test_split_ort():
sat = SaT("sat-3l-sm", ort_providers=["CPUExecutionProvider"])
Expand Down
21 changes: 21 additions & 0 deletions wtpsplit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import warnings
from pathlib import Path
from typing import Literal

# avoid the "None of PyTorch, TensorFlow, etc. have been found" warning.
with contextlib.redirect_stderr(open(os.devnull, "w")):
Expand Down Expand Up @@ -141,6 +142,7 @@ def predict_proba(
block_size: int = 512,
batch_size=32,
pad_last_batch: bool = False,
weighting: Literal["uniform", "hat"] = "uniform",
remove_whitespace_before_inference: bool = False,
outer_batch_size=1000,
return_paragraph_probabilities=False,
Expand All @@ -156,6 +158,7 @@ def predict_proba(
block_size=block_size,
batch_size=batch_size,
pad_last_batch=pad_last_batch,
weighting=weighting,
remove_whitespace_before_inference=remove_whitespace_before_inference,
outer_batch_size=outer_batch_size,
return_paragraph_probabilities=return_paragraph_probabilities,
Expand All @@ -171,6 +174,7 @@ def predict_proba(
block_size=block_size,
batch_size=batch_size,
pad_last_batch=pad_last_batch,
weighting=weighting,
remove_whitespace_before_inference=remove_whitespace_before_inference,
outer_batch_size=outer_batch_size,
return_paragraph_probabilities=return_paragraph_probabilities,
Expand All @@ -186,6 +190,7 @@ def _predict_proba(
block_size: int,
batch_size: int,
pad_last_batch: bool,
weighting: Literal["uniform", "hat"],
remove_whitespace_before_inference: bool,
outer_batch_size: int,
return_paragraph_probabilities: bool,
Expand Down Expand Up @@ -246,6 +251,7 @@ def _predict_proba(
max_block_size=block_size,
batch_size=batch_size,
pad_last_batch=pad_last_batch,
weighting=weighting,
verbose=verbose,
)[0]
else:
Expand Down Expand Up @@ -290,6 +296,7 @@ def split(
block_size: int = 512,
batch_size=32,
pad_last_batch: bool = False,
weighting: Literal["uniform", "hat"] = "uniform",
remove_whitespace_before_inference: bool = False,
outer_batch_size=1000,
paragraph_threshold: float = 0.5,
Expand All @@ -308,6 +315,7 @@ def split(
block_size=block_size,
batch_size=batch_size,
pad_last_batch=pad_last_batch,
weighting=weighting,
remove_whitespace_before_inference=remove_whitespace_before_inference,
outer_batch_size=outer_batch_size,
paragraph_threshold=paragraph_threshold,
Expand All @@ -326,6 +334,7 @@ def split(
block_size=block_size,
batch_size=batch_size,
pad_last_batch=pad_last_batch,
weighting=weighting,
remove_whitespace_before_inference=remove_whitespace_before_inference,
outer_batch_size=outer_batch_size,
paragraph_threshold=paragraph_threshold,
Expand Down Expand Up @@ -355,6 +364,7 @@ def _split(
block_size: int,
batch_size: int,
pad_last_batch: bool,
weighting: Literal["uniform", "hat"],
remove_whitespace_before_inference: bool,
outer_batch_size: int,
paragraph_threshold: float,
Expand Down Expand Up @@ -391,6 +401,7 @@ def _split(
block_size=block_size,
batch_size=batch_size,
pad_last_batch=pad_last_batch,
weighting=weighting,
remove_whitespace_before_inference=remove_whitespace_before_inference,
outer_batch_size=outer_batch_size,
return_paragraph_probabilities=do_paragraph_segmentation,
Expand Down Expand Up @@ -573,6 +584,7 @@ def predict_proba(
block_size: int = 512,
batch_size=32,
pad_last_batch: bool = False,
weighting: Literal["uniform", "hat"] = "uniform",
remove_whitespace_before_inference: bool = False,
outer_batch_size=1000,
return_paragraph_probabilities=False,
Expand All @@ -586,6 +598,7 @@ def predict_proba(
block_size=block_size,
batch_size=batch_size,
pad_last_batch=pad_last_batch,
weighting=weighting,
remove_whitespace_before_inference=remove_whitespace_before_inference,
outer_batch_size=outer_batch_size,
return_paragraph_probabilities=return_paragraph_probabilities,
Expand All @@ -599,6 +612,7 @@ def predict_proba(
block_size=block_size,
batch_size=batch_size,
pad_last_batch=pad_last_batch,
weighting=weighting,
remove_whitespace_before_inference=remove_whitespace_before_inference,
outer_batch_size=outer_batch_size,
return_paragraph_probabilities=return_paragraph_probabilities,
Expand All @@ -612,6 +626,7 @@ def _predict_proba(
block_size: int,
batch_size: int,
pad_last_batch: bool,
weighting: Literal["uniform", "hat"],
remove_whitespace_before_inference: bool,
outer_batch_size: int,
return_paragraph_probabilities: bool,
Expand Down Expand Up @@ -657,6 +672,7 @@ def newline_probability_fn(logits):
max_block_size=block_size,
batch_size=batch_size,
pad_last_batch=pad_last_batch,
weighting=weighting,
verbose=verbose,
tokenizer=self.tokenizer,
)
Expand Down Expand Up @@ -705,6 +721,7 @@ def split(
block_size: int = 512,
batch_size=32,
pad_last_batch: bool = False,
weighting: Literal["uniform", "hat"] = "uniform",
remove_whitespace_before_inference: bool = False,
outer_batch_size=1000,
paragraph_threshold: float = 0.5,
Expand All @@ -722,6 +739,7 @@ def split(
block_size=block_size,
batch_size=batch_size,
pad_last_batch=pad_last_batch,
weighting=weighting,
remove_whitespace_before_inference=remove_whitespace_before_inference,
outer_batch_size=outer_batch_size,
paragraph_threshold=paragraph_threshold,
Expand All @@ -739,6 +757,7 @@ def split(
block_size=block_size,
batch_size=batch_size,
pad_last_batch=pad_last_batch,
weighting=weighting,
remove_whitespace_before_inference=remove_whitespace_before_inference,
outer_batch_size=outer_batch_size,
paragraph_threshold=paragraph_threshold,
Expand All @@ -756,6 +775,7 @@ def _split(
block_size: int,
batch_size: int,
pad_last_batch: bool,
weighting: Literal["uniform", "hat"],
paragraph_threshold: float,
remove_whitespace_before_inference: bool,
outer_batch_size: int,
Expand Down Expand Up @@ -784,6 +804,7 @@ def get_default_threshold(model_str: str):
block_size=block_size,
batch_size=batch_size,
pad_last_batch=pad_last_batch,
weighting=weighting,
remove_whitespace_before_inference=remove_whitespace_before_inference,
outer_batch_size=outer_batch_size,
return_paragraph_probabilities=do_paragraph_segmentation,
Expand Down
3 changes: 3 additions & 0 deletions wtpsplit/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import subprocess
import unicodedata
import os
from typing import Literal

import numpy as np
import regex as re
Expand Down Expand Up @@ -240,6 +241,7 @@ def our_sentencize(
block_size=512,
stride=64,
batch_size=32,
weighting: Literal["uniform", "hat"] = "uniform",
):
logits = extract(
[text],
Expand All @@ -249,6 +251,7 @@ def our_sentencize(
max_block_size=block_size,
batch_size=batch_size,
pad_last_batch=False,
weighting=weighting,
use_hidden_states=False,
verbose=False,
)[0]
Expand Down
16 changes: 13 additions & 3 deletions wtpsplit/extract.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
import sys
import logging
from typing import Literal

import numpy as np
from tqdm.auto import tqdm
Expand Down Expand Up @@ -93,6 +94,7 @@ def extract(
batch_size,
lang_code=None,
pad_last_batch=False,
weighting: Literal["uniform", "hat"] = "uniform",
verbose=False,
tokenizer=None,
):
Expand Down Expand Up @@ -202,7 +204,7 @@ def extract(
for length in text_lengths
]
# container for the number of chunks that any character was part of (to average chunk predictions)
all_counts = [np.zeros(length, dtype=np.int16) for length in text_lengths]
all_counts = [np.zeros(length, dtype=np.float16) for length in text_lengths]

uses_lang_adapters = getattr(model.config, "language_adapter", "off") == "on"
if uses_lang_adapters:
Expand All @@ -218,6 +220,13 @@ def extract(
)
else:
language_ids = None

# compute weights for the given weighting scheme
if weighting == "uniform":
weights = np.ones(block_size, dtype=np.float16)
elif weighting == "hat":
x = np.linspace(-(1 - 1 / block_size), 1 - 1 / block_size, block_size, dtype=np.float16)
weights = 1 - np.abs(x)

# forward passes through all chunks
for batch_idx in tqdm(range(n_batches), disable=not verbose):
Expand Down Expand Up @@ -255,8 +264,9 @@ def extract(

for i in range(start, end):
original_idx, start_char_idx, end_char_idx = locs[i]
all_logits[original_idx][start_char_idx:end_char_idx] += logits[i - start, : end_char_idx - start_char_idx]
all_counts[original_idx][start_char_idx:end_char_idx] += 1
n = end_char_idx - start_char_idx
all_logits[original_idx][start_char_idx:end_char_idx] += weights[:n, np.newaxis] * logits[i - start, :n]
all_counts[original_idx][start_char_idx:end_char_idx] += weights[:n]

# so far, logits are summed, so we average them here
all_logits = [(logits / counts[:, None]).astype(np.float16) for logits, counts in zip(all_logits, all_counts)]
Expand Down
3 changes: 3 additions & 0 deletions wtpsplit/train/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import sys
from typing import Literal

import numpy as np
import pysbd
Expand Down Expand Up @@ -74,6 +75,7 @@ def evaluate_sentence(
stride,
block_size,
batch_size,
weighting: Literal["uniform", "hat"] = "uniform",
use_pysbd=False,
positive_index=None,
do_lowercase=False,
Expand All @@ -97,6 +99,7 @@ def evaluate_sentence(
stride=stride,
max_block_size=block_size,
batch_size=batch_size,
weighting=weighting,
)
logits = logits[0]
if offsets_mapping is not None:
Expand Down

0 comments on commit 5902e7e

Please sign in to comment.