Skip to content

Commit

Permalink
Merge pull request #441 from robertknight/byte-level-pre-tokenizer
Browse files Browse the repository at this point in the history
Extract pre-tokenization out of tokenization models
  • Loading branch information
robertknight authored Dec 4, 2024
2 parents fe19446 + 154df81 commit d51994e
Show file tree
Hide file tree
Showing 9 changed files with 315 additions and 157 deletions.
5 changes: 3 additions & 2 deletions rten-generate/src/text_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ impl<G: Iterator<Item = GeneratorItem>> Iterator for TextDecoder<'_, G> {
mod tests {
use std::collections::HashMap;

use rten_text::models::patterns::GPT2;
use rten_text::models::{Bpe, WordPiece};
use rten_text::pre_tokenizers::ByteLevelPreTokenizer;
use rten_text::tokenizers::{TokenId, Tokenizer};

use crate::{GeneratorError, GeneratorUtils};
Expand All @@ -92,8 +92,9 @@ mod tests {
/// Create a BPE tokenizer with an empty vocab. This can encode and decode
/// arbitrary Unicode characters, by using one token per UTF-8 byte.
fn create_bpe_tokenizer() -> Tokenizer {
let model = Bpe::new(&[], GPT2, None, Default::default(), None).unwrap();
let model = Bpe::new(&[], None, Default::default(), None).unwrap();
Tokenizer::new(model, Default::default())
.with_pre_tokenizer(Box::new(ByteLevelPreTokenizer::gpt2()))
}

#[test]
Expand Down
1 change: 1 addition & 0 deletions rten-text/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
pub mod models;
pub mod normalizer;
pub mod pre_tokenizers;
pub mod tokenizers;

mod split;
63 changes: 16 additions & 47 deletions rten-text/src/models/bpe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ use std::error::Error;
use std::fmt;
use std::fmt::{Debug, Display};

use fancy_regex::Regex;

use crate::tokenizers::{Model, TokenId, TokenizerError};

/// Errors that can occur when building a [`Bpe`] tokenizer or encoding or
Expand All @@ -17,9 +15,6 @@ pub enum BpeError {
/// of another pair in the merge list.
InvalidMergeEntry(String),

/// The regex for splitting tokens is invalid.
InvalidPattern(Box<fancy_regex::Error>),

/// An entry in the vocab (token string to ID map) is not either a known
/// special token or an entry in the merge list.
InvalidVocabEntry(String),
Expand All @@ -29,7 +24,6 @@ impl Display for BpeError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BpeError::InvalidMergeEntry(entry) => write!(fmt, "invalid merge entry: {}", entry),
BpeError::InvalidPattern(err) => write!(fmt, "invalid regex: {}", err),
BpeError::InvalidVocabEntry(entry) => write!(fmt, "invalid vocab entry: {}", entry),
}
}
Expand Down Expand Up @@ -299,10 +293,6 @@ pub struct Bpe {
/// that each represent single byte values.
token_id_to_encoded_bytes: Option<HashMap<TokenId, EncodedBytes>>,

/// Pattern used to split the text into pieces prior to applying BPE
/// tokenization.
splitter: Regex,

/// Map from token ID to content for special tokens (eg. end-of-string).
added_tokens: HashMap<TokenId, String>,

Expand Down Expand Up @@ -341,16 +331,13 @@ impl Bpe {
/// substring that is tokenized, after initial splitting.
pub fn new(
merges: &[(EncodedByteSlice, EncodedByteSlice)],
pattern: &str,
vocab: Option<HashMap<EncodedBytes, TokenId>>,
added_tokens: HashMap<TokenId, String>,
mut end_of_word_suffix: Option<String>,
) -> Result<Bpe, BpeError> {
// Normalize empty end-of-word suffix to `None`.
end_of_word_suffix.take_if(|suffix| suffix.is_empty());

let splitter = Regex::new(pattern).map_err(|err| BpeError::InvalidPattern(err.into()))?;

let bb_opts = BpeBuilderOptions {
end_of_word_suffix: end_of_word_suffix.as_deref(),
};
Expand All @@ -376,7 +363,6 @@ impl Bpe {
merges: builder.ranks,
byte_to_rank: builder.byte_to_rank,
rank_to_token_id,
splitter,
added_tokens,
token_id_to_encoded_bytes,
end_of_word_suffix,
Expand Down Expand Up @@ -516,21 +502,15 @@ impl Model for Bpe {

fn encode_with_offsets(
&self,
text: &str,
piece: &str,
on_token: &mut dyn FnMut(usize, TokenId),
) -> Result<(), TokenizerError> {
for piece in self.splitter.find_iter(text) {
let piece = piece.map_err(|err| TokenizerError::RegexSplitFailed(err.into()))?;
if piece.range().is_empty() {
continue;
}

let piece_str = piece.as_str();
for token in self.encode_piece(piece_str, true /* end_of_word */) {
on_token(piece.start(), token)
}
if piece.is_empty() {
return Ok(());
}
for token in self.encode_piece(piece, true /* end_of_word */) {
on_token(0, token)
}

Ok(())
}

Expand Down Expand Up @@ -566,8 +546,8 @@ impl Model for Bpe {
mod tests {
use std::collections::HashMap;

use super::patterns::GPT2 as GPT2_SPLIT_PATTERN;
use super::{merge_pairs_from_lines, Bpe, EncodedBytes};
use crate::pre_tokenizers::ByteLevelPreTokenizer;
use crate::tokenizers::{TokenId, Tokenizer};

// The first ~25 lines of the merge list from GPT 2.
Expand Down Expand Up @@ -700,15 +680,9 @@ ba r",
{
let merges: Vec<&str> = merges.lines().collect();
let merge_pairs = merge_pairs_from_lines(&merges);
let model = Bpe::new(
&merge_pairs,
GPT2_SPLIT_PATTERN,
vocab,
HashMap::new(),
end_of_word_suffix,
)
.unwrap();
let tokenizer = Tokenizer::new(model, Default::default());
let model = Bpe::new(&merge_pairs, vocab, HashMap::new(), end_of_word_suffix).unwrap();
let tokenizer = Tokenizer::new(model, Default::default())
.with_pre_tokenizer(Box::new(ByteLevelPreTokenizer::gpt2()));
let encoded = tokenizer.encode(text, None).unwrap();
assert_eq!(
tokenizer.model().get_tokens(encoded.token_ids()).unwrap(),
Expand Down Expand Up @@ -745,8 +719,9 @@ ba r",

let merges: Vec<&str> = MINI_GPT2.lines().collect();
let merge_pairs = merge_pairs_from_lines(&merges);
let model = Bpe::new(&merge_pairs, GPT2_SPLIT_PATTERN, None, added_tokens(), None).unwrap();
let tokenizer = Tokenizer::new(model, Default::default());
let model = Bpe::new(&merge_pairs, None, added_tokens(), None).unwrap();
let tokenizer = Tokenizer::new(model, Default::default())
.with_pre_tokenizer(Box::new(ByteLevelPreTokenizer::gpt2()));

for Case { input, encoded_str } in cases {
let tok_id = tokenizer.model().get_token_id(input).unwrap();
Expand Down Expand Up @@ -802,15 +777,9 @@ ba r",
{
let merges: Vec<&str> = MINI_GPT2.lines().collect();
let merge_pairs = merge_pairs_from_lines(&merges);
let model = Bpe::new(
&merge_pairs,
GPT2_SPLIT_PATTERN,
vocab,
added_tokens(),
None,
)
.unwrap();
let tokenizer = Tokenizer::new(model, Default::default());
let model = Bpe::new(&merge_pairs, vocab, added_tokens(), None).unwrap();
let tokenizer = Tokenizer::new(model, Default::default())
.with_pre_tokenizer(Box::new(ByteLevelPreTokenizer::gpt2()));

let encoded = tokenizer.encode(text, None).unwrap();
let mut token_ids = encoded.token_ids().to_vec();
Expand Down
86 changes: 39 additions & 47 deletions rten-text/src/models/wordpiece.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use std::collections::HashMap;

use crate::split::SplitExt;
use crate::tokenizers::{Model, TokenId, TokenizerError};

use unicode_categories::UnicodeCategories;

/// WordPiece tokenizer [^1] used by BERT [^2] models.
///
/// [^1]: Schuster, Mike, and Kaisuke Nakajima. "Japanese and korean voice
Expand Down Expand Up @@ -55,14 +52,10 @@ impl WordPiece {
impl Model for WordPiece {
fn encode_with_offsets(
&self,
text: &str,
word: &str,
on_token: &mut dyn FnMut(usize, TokenId),
) -> Result<(), TokenizerError> {
let mut tmp_buf = String::with_capacity(self.max_word_len);

let is_punc_or_space =
|ch: char| ch.is_ascii_punctuation() || ch.is_punctuation() || ch.is_whitespace();
let words = text.split_keep_delimeters(is_punc_or_space);
let mut offset = 0;

macro_rules! add_unknown_token {
Expand All @@ -72,51 +65,48 @@ impl Model for WordPiece {
};
}

for word in words {
if word.trim().is_empty() {
offset += word.len();
continue;
}

if word.chars().count() > self.max_word_len {
add_unknown_token!();
continue;
}
if word.trim().is_empty() {
return Ok(());
}

let mut remainder = word;
let mut word_tokens = 0;
while !remainder.is_empty() {
// Find longest prefix of `remainder` that is in the vocab.
let mut len = remainder.len();
while len > 0 {
let prefix = if word_tokens > 0 {
tmp_buf.clear();
tmp_buf.push_str(&self.subword_prefix);
tmp_buf.push_str(&remainder[..len]);
&tmp_buf[..]
} else {
&remainder[..len]
};

if let Some(id) = self.token_to_id.get(prefix) {
on_token(offset, *id);
remainder = remainder.split_at(len).1;
word_tokens += 1;
break;
} else {
let last_char_bytes = prefix.chars().next_back().unwrap().len_utf8();
len -= last_char_bytes;
}
}
if word.chars().count() > self.max_word_len {
add_unknown_token!();
return Ok(());
}

if len == 0 {
add_unknown_token!();
let mut remainder = word;
let mut word_tokens = 0;
while !remainder.is_empty() {
// Find longest prefix of `remainder` that is in the vocab.
let mut len = remainder.len();
while len > 0 {
let prefix = if word_tokens > 0 {
tmp_buf.clear();
tmp_buf.push_str(&self.subword_prefix);
tmp_buf.push_str(&remainder[..len]);
&tmp_buf[..]
} else {
&remainder[..len]
};

if let Some(id) = self.token_to_id.get(prefix) {
on_token(offset, *id);
offset += prefix.len();
remainder = remainder.split_at(len).1;
word_tokens += 1;
break;
} else {
let last_char_bytes = prefix.chars().next_back().unwrap().len_utf8();
len -= last_char_bytes;
}
}

offset += word.len();
if len == 0 {
add_unknown_token!();
break;
}
}

Ok(())
}

Expand Down Expand Up @@ -146,6 +136,7 @@ mod tests {

use crate::models::{WordPiece, WordPieceOptions};
use crate::normalizer::{BertNormalizer, BertNormalizerOptions, Normalizer};
use crate::pre_tokenizers::BertPreTokenizer;
use crate::tokenizers::{Tokenizer, TokenizerOptions};

fn create_tokenizer(
Expand All @@ -165,7 +156,8 @@ mod tests {
cls_token: Some("[CLS]"),
sep_token: Some("[SEP]"),
},
);
)
.with_pre_tokenizer(Box::new(BertPreTokenizer::new()));

if let Some(normalizer) = normalizer {
tokenizer = tokenizer.with_normalizer(normalizer);
Expand Down
Loading

0 comments on commit d51994e

Please sign in to comment.