Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up outdated comments in BPE tokenizer, pass configuration as a struct #455

Merged
merged 2 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions rten-generate/src/text_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ impl<G: Iterator<Item = GeneratorItem>> Iterator for TextDecoder<'_, G> {
mod tests {
use std::collections::HashMap;

use rten_text::models::{Bpe, WordPiece};
use rten_text::models::{Bpe, BpeOptions, WordPiece};
use rten_text::pre_tokenizers::Split;
use rten_text::{TokenId, Tokenizer};

Expand All @@ -93,7 +93,7 @@ mod tests {
/// Create a BPE tokenizer with an empty vocab. This can encode and decode
/// arbitrary Unicode characters, by using one token per UTF-8 byte.
fn create_bpe_tokenizer() -> Tokenizer {
let model = Bpe::new(&[], None, Default::default(), None).unwrap();
let model = Bpe::new(BpeOptions::default()).unwrap();
Tokenizer::new(model, Default::default()).with_pre_tokenizer(Box::new(Split::gpt2()))
}

Expand Down
4 changes: 3 additions & 1 deletion rten-text/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ use std::fmt;
mod bpe;
mod wordpiece;

pub use bpe::{merge_pairs_from_lines, patterns, Bpe, BpeError};
pub use bpe::{
char_to_byte, merge_pairs_from_lines, Bpe, BpeError, BpeOptions, EncodedByteSlice, EncodedBytes,
};
pub use wordpiece::{WordPiece, WordPieceOptions};

use crate::tokenizer::TokenId;
Expand Down
123 changes: 66 additions & 57 deletions rten-text/src/models/bpe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,16 @@ impl Error for BpeError {}
/// smaller tokens, or a single byte.
type Rank = u32;

/// A sequence of UTF-8 bytes, encoded as a string of characters.
/// A sequence of UTF-8 bytes, encoded as a string of printable characters.
/// [`char_to_byte`] provides the mapping between characters and bytes.
///
/// Unlike a Rust `str`, the sequence of bytes do not necessarily form a
/// complete sequence of Unicode characters. The bytes may end in the middle of
/// a character.
type EncodedByteSlice<'a> = &'a str;
pub type EncodedByteSlice<'a> = &'a str;

/// Like [`EncodedByteSlice`], but owned.
type EncodedBytes = String;
pub type EncodedBytes = String;

/// Return true if `c` is considered a printable character.
///
Expand Down Expand Up @@ -77,12 +77,12 @@ fn byte_to_rank() -> [Rank; 256] {
ranks
}

/// Return a mapping between the characters used in the GPT 2 merge list
/// and vocabulary, and the byte values they represent.
/// Return a mapping between the printable characters used in the GPT 2 merge
/// list and vocabulary, and the byte values they represent.
///
/// Based on the `bytes_to_unicode` function in the original GPT-2 encoder -
/// https://github.com/openai/gpt-2/blob/master/src/encoder.py.
fn char_to_byte() -> HashMap<char, u8> {
/// <https://github.com/openai/gpt-2/blob/master/src/encoder.py>.
pub fn char_to_byte() -> HashMap<char, u8> {
let mut n = 0;
(0..=255u8)
.map(|b| {
Expand Down Expand Up @@ -189,9 +189,9 @@ impl BpeBuilder {

/// Build the BPE merge map that assigns a rank to pairs of tokens.
///
/// `merges` contains entries of the BPE merge table. Each entry is a
/// space-separated pair of tokens. Each token is a sequence of byte values
/// encoded using the scheme described in [`char_to_byte`].
/// `merges` contains entries of the BPE merge table. Each entry is a pair
/// of tokens. Each token is a sequence of byte values encoded using the
/// scheme described in [`char_to_byte`].
fn add_merges(
&mut self,
merges: &[(EncodedByteSlice, EncodedByteSlice)],
Expand Down Expand Up @@ -222,19 +222,6 @@ impl BpeBuilder {
}
}

/// Regex patterns used by popular tokenizer models.
///
/// Some models (eg. GPT-2) use a regex to split input text into pieces prior
/// to applying the trained tokenizer model. This module contains some widely
/// used patterns.
pub mod patterns {
/// Tokenization regex used by GPT-2.
///
/// See <https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py>.
pub const GPT2: &str =
r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+";
}

/// Parse a list of space-separated BPE merge entries into pairs of tokens.
///
/// Lines that are empty or contain only a `#version` marker are ignored.
Expand All @@ -254,6 +241,32 @@ pub fn merge_pairs_from_lines(
.collect()
}

/// Configuration for a [`Bpe`] tokenization model.
#[derive(Default)]
pub struct BpeOptions<'a> {
/// Ordered entries of the merge list. Each entry is a pair of strings
/// representing byte sequences. See also [`merge_pairs_from_lines`] which
/// can be used to extract pairs from the space-separated format used in eg.
/// `merges.txt` files.
pub merges: &'a [(EncodedByteSlice<'a>, EncodedByteSlice<'a>)],

/// Mapping between token strings and IDs. If not provided, the
/// ID of a token is 256 + the index of the pair in the merge list which
/// form the token string when concatenated. For example, if index 10 in the
/// merge list is "foo bar", then the token ID of "foobar" would be 266.
/// Token IDs below 256 are reserved for individual bytes.
pub vocab: Option<HashMap<EncodedBytes, TokenId>>,

/// Set of tokens which don't appear in `merges` but do have a mapping in
/// `vocab`. These are used for special purposes such as representing the
/// end of output.
pub added_tokens: HashMap<TokenId, String>,

/// A string which is implicitly appended to each substring that is
/// tokenized, after initial splitting.
pub end_of_word_suffix: Option<String>,
}

/// Byte Pair Encoding tokenizer used by GPT-2 [^1] and subsequently used by
/// many other models.
///
Expand Down Expand Up @@ -306,36 +319,15 @@ pub struct Bpe {
}

impl Bpe {
/// Create a new Byte Pair Encoding tokenizer.
///
/// `merges` are the ordered entries of the merge list. Each entry is a
/// pair of strings representing byte sequences. See also
/// [`merge_pairs_from_lines`] which can be used to extract pairs from
/// the space-separated format used in eg. `merges.txt` files.
///
/// `pattern` is a regex used to split input text into pieces before BPE
/// encoding is applied. The supported syntax is that supported by the
/// [fancy_regex](https://crates.io/crates/fancy-regex) crate. The
/// [patterns] module contains patterns used by popular models.
///
/// `vocab` is a mapping between token strings and IDs. If not provided, the
/// ID of a token is 256 + the index of the pair in the merge list which
/// form the token string when concatenated. For example, if index 10 in the
/// merge list is "foo bar", then the token ID of "foobar" would be 266.
/// Token IDs below 256 are reserved for individual bytes.
///
/// `added_tokens` is a set of tokens which don't appear in `merges` but
/// do have a mapping in `vocab`. These are used for special purposes such
/// as representing the end of output.
///
/// `end_of_word_suffix` is a string which is implicitly appended to each
/// substring that is tokenized, after initial splitting.
pub fn new(
merges: &[(EncodedByteSlice, EncodedByteSlice)],
vocab: Option<HashMap<EncodedBytes, TokenId>>,
added_tokens: HashMap<TokenId, String>,
mut end_of_word_suffix: Option<String>,
) -> Result<Bpe, BpeError> {
/// Create a new Byte Pair Encoding tokenizer using the given configuration.
pub fn new(config: BpeOptions) -> Result<Bpe, BpeError> {
let BpeOptions {
merges,
vocab,
added_tokens,
mut end_of_word_suffix,
} = config;

// Normalize empty end-of-word suffix to `None`.
end_of_word_suffix.take_if(|suffix| suffix.is_empty());

Expand Down Expand Up @@ -545,7 +537,7 @@ impl Model for Bpe {
mod tests {
use std::collections::HashMap;

use super::{merge_pairs_from_lines, Bpe, EncodedBytes};
use super::{merge_pairs_from_lines, Bpe, BpeOptions, EncodedBytes};
use crate::pre_tokenizers::Split;
use crate::tokenizer::{TokenId, Tokenizer};

Expand Down Expand Up @@ -679,7 +671,13 @@ ba r",
{
let merges: Vec<&str> = merges.lines().collect();
let merge_pairs = merge_pairs_from_lines(&merges);
let model = Bpe::new(&merge_pairs, vocab, HashMap::new(), end_of_word_suffix).unwrap();
let bpe_opts = BpeOptions {
merges: &merge_pairs,
vocab,
end_of_word_suffix,
..Default::default()
};
let model = Bpe::new(bpe_opts).unwrap();
let tokenizer = Tokenizer::new(model, Default::default())
.with_pre_tokenizer(Box::new(Split::gpt2()));
let encoded = tokenizer.encode(text, None).unwrap();
Expand Down Expand Up @@ -718,7 +716,12 @@ ba r",

let merges: Vec<&str> = MINI_GPT2.lines().collect();
let merge_pairs = merge_pairs_from_lines(&merges);
let model = Bpe::new(&merge_pairs, None, added_tokens(), None).unwrap();
let bpe_opts = BpeOptions {
merges: &merge_pairs,
added_tokens: added_tokens(),
..Default::default()
};
let model = Bpe::new(bpe_opts).unwrap();
let tokenizer =
Tokenizer::new(model, Default::default()).with_pre_tokenizer(Box::new(Split::gpt2()));

Expand Down Expand Up @@ -776,7 +779,13 @@ ba r",
{
let merges: Vec<&str> = MINI_GPT2.lines().collect();
let merge_pairs = merge_pairs_from_lines(&merges);
let model = Bpe::new(&merge_pairs, vocab, added_tokens(), None).unwrap();
let bpe_opts = BpeOptions {
merges: &merge_pairs,
vocab,
added_tokens: added_tokens(),
..Default::default()
};
let model = Bpe::new(bpe_opts).unwrap();
let tokenizer = Tokenizer::new(model, Default::default())
.with_pre_tokenizer(Box::new(Split::gpt2()));

Expand Down
14 changes: 7 additions & 7 deletions rten-text/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use std::ops::Range;
use std::path::Path;

use crate::models::{
merge_pairs_from_lines, Bpe, BpeError, DecodeError, EncodeError, Model, WordPiece,
merge_pairs_from_lines, Bpe, BpeError, BpeOptions, DecodeError, EncodeError, Model, WordPiece,
};
use crate::normalizers::{NormalizeError, Normalizer};
use crate::pre_tokenizers::{PreTokenizeError, PreTokenizer};
Expand Down Expand Up @@ -404,13 +404,13 @@ impl Tokenizer {
.map(|(a, b)| (a.as_str(), b.as_str()))
.collect(),
};
let model = Bpe::new(
&merges,
Some(model.vocab),
let bpe_opts = BpeOptions {
merges: &merges,
vocab: Some(model.vocab),
added_tokens,
model.end_of_word_suffix,
)
.map_err(FromJsonError::BpeError)?;
end_of_word_suffix: model.end_of_word_suffix,
};
let model = Bpe::new(bpe_opts).map_err(FromJsonError::BpeError)?;

let tokenizer = Tokenizer::new(
model,
Expand Down
8 changes: 6 additions & 2 deletions rten-text/tests/reftest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::fs::read_to_string;
use std::io;
use std::path::PathBuf;

use rten_text::models::{merge_pairs_from_lines, Bpe, WordPiece};
use rten_text::models::{merge_pairs_from_lines, Bpe, BpeOptions, WordPiece};
use rten_text::tokenizer::{TokenId, Tokenizer, TokenizerOptions};
use rten_text::{normalizers, pre_tokenizers};
use serde::Deserialize;
Expand Down Expand Up @@ -160,7 +160,11 @@ fn test_bpe_gpt2() -> Result<(), Box<dyn Error>> {
let merges = read_test_file("models/gpt2/merges.txt")?;
let merge_lines: Vec<_> = merges.lines().collect();
let merge_pairs = merge_pairs_from_lines(&merge_lines);
let model = Bpe::new(&merge_pairs, None, Default::default(), None)?;
let bpe_opts = BpeOptions {
merges: &merge_pairs,
..Default::default()
};
let model = Bpe::new(bpe_opts)?;
let tokenizer = Tokenizer::new(model, Default::default())
.with_pre_tokenizer(Box::new(pre_tokenizers::Split::gpt2()));

Expand Down
Loading