Skip to content

Commit

Permalink
Merge pull request #440 from robertknight/normalize-in-tokenizer
Browse files Browse the repository at this point in the history
Move normalization from model into Tokenizer
  • Loading branch information
robertknight authored Dec 3, 2024
2 parents e538910 + 690bbca commit fe19446
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 102 deletions.
89 changes: 30 additions & 59 deletions rten-text/src/models/wordpiece.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use std::collections::HashMap;
use std::rc::Rc;

use crate::normalizer::Normalizer;
use crate::split::SplitExt;
use crate::tokenizers::{Model, TokenId, TokenizerError};

Expand All @@ -19,7 +17,6 @@ use unicode_categories::UnicodeCategories;
/// (2018). <https://arxiv.org/abs/1810.04805>
#[derive(Clone)]
pub struct WordPiece {
normalizer: Option<Rc<dyn Normalizer>>,
token_to_id: HashMap<String, TokenId>,
id_to_token: HashMap<TokenId, String>,
subword_prefix: String,
Expand All @@ -29,10 +26,6 @@ pub struct WordPiece {
/// Configuration for a [`WordPiece`] tokenizer.
#[derive(Debug, Default, Clone)]
pub struct WordPieceOptions {
/// The normalizer that handles Unicode normalization, lower-casing the
/// input etc.
pub normalizer: Option<Rc<dyn Normalizer>>,

/// The maximum length of words that can be tokenized. Any words longer than
/// this are tokenized as `[UNK]`.
///
Expand All @@ -51,7 +44,6 @@ impl WordPiece {
let subword_prefix = "##".to_string();

WordPiece {
normalizer: options.normalizer,
token_to_id: vocab,
subword_prefix,
max_word_len: options.max_word_len.unwrap_or(100),
Expand All @@ -68,28 +60,6 @@ impl Model for WordPiece {
) -> Result<(), TokenizerError> {
let mut tmp_buf = String::with_capacity(self.max_word_len);

// Apply normalization to the input text.
let (text, normalized_to_source_offsets) = match &self.normalizer {
None => (text.to_string(), None),
Some(normalizer) => {
let (normalized_text, offsets) = normalizer.normalize(text);
(normalized_text, Some(offsets))
}
};

// Map an offset into the normalized string into an offset in the source
// string.
let map_offset = |offset: usize| {
if let Some(mappings) = &normalized_to_source_offsets {
mappings
.get(offset)
.copied()
.expect("invalid normalized offset")
} else {
offset
}
};

let is_punc_or_space =
|ch: char| ch.is_ascii_punctuation() || ch.is_punctuation() || ch.is_whitespace();
let words = text.split_keep_delimeters(is_punc_or_space);
Expand All @@ -98,7 +68,7 @@ impl Model for WordPiece {
macro_rules! add_unknown_token {
() => {
let unknown_token = self.get_token_id("[UNK]")?;
on_token(map_offset(offset), unknown_token);
on_token(offset, unknown_token);
};
}

Expand Down Expand Up @@ -129,7 +99,7 @@ impl Model for WordPiece {
};

if let Some(id) = self.token_to_id.get(prefix) {
on_token(map_offset(offset), *id);
on_token(offset, *id);
remainder = remainder.split_at(len).1;
word_tokens += 1;
break;
Expand Down Expand Up @@ -173,26 +143,35 @@ impl Model for WordPiece {
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::rc::Rc;

use crate::models::{WordPiece, WordPieceOptions};
use crate::normalizer::{BertNormalizer, BertNormalizerOptions};
use crate::normalizer::{BertNormalizer, BertNormalizerOptions, Normalizer};
use crate::tokenizers::{Tokenizer, TokenizerOptions};

fn create_tokenizer(vocab: &[&str], options: WordPieceOptions) -> Tokenizer {
fn create_tokenizer(
vocab: &[&str],
normalizer: Option<Box<dyn Normalizer>>,
options: WordPieceOptions,
) -> Tokenizer {
let vocab: HashMap<_, _> = vocab
.iter()
.enumerate()
.map(|(i, token)| (token.to_string(), i as u32))
.collect();
let model = WordPiece::from_vocab(vocab, options);
Tokenizer::new(
let mut tokenizer = Tokenizer::new(
model,
TokenizerOptions {
cls_token: Some("[CLS]"),
sep_token: Some("[SEP]"),
},
)
);

if let Some(normalizer) = normalizer {
tokenizer = tokenizer.with_normalizer(normalizer);
}

tokenizer
}

#[test]
Expand All @@ -207,7 +186,7 @@ mod tests {
"Piece", "of", "pie", ".", "!", "?", "Hey", "Hello", "the", "game", "is", "set", "in",
"Faerûn",
];
let tokenizer = create_tokenizer(vocab, Default::default());
let tokenizer = create_tokenizer(vocab, None, Default::default());

let cases = [
// Single sequence, no subwords.
Expand Down Expand Up @@ -269,7 +248,7 @@ mod tests {
max_word_len: Some(6),
..Default::default()
};
let tokenizer = create_tokenizer(vocab, opts);
let tokenizer = create_tokenizer(vocab, None, opts);

// The third word should be tokenized to `[UNK]` because it exceeds
// `max_word_len`.
Expand All @@ -292,16 +271,12 @@ mod tests {
let vocab = &[
"[CLS]", "[SEP]", "[UNK]", "this", "is", "a", "test", "sequence",
];
let tokenizer = create_tokenizer(
vocab,
WordPieceOptions {
normalizer: Some(Rc::new(BertNormalizer::new(BertNormalizerOptions {
lowercase: true,
..Default::default()
}))),
..Default::default()
},
);

let normalizer = BertNormalizer::new(BertNormalizerOptions {
lowercase: true,
..Default::default()
});
let tokenizer = create_tokenizer(vocab, Some(Box::new(normalizer)), Default::default());

let cases = [
// Single sequence, no subwords.
Expand Down Expand Up @@ -350,16 +325,12 @@ mod tests {
let vocab = &[
"[CLS]", "[SEP]", "[UNK]", "this", "is", "a", "test", "sequence",
];
let tokenizer = create_tokenizer(
vocab,
WordPieceOptions {
normalizer: Some(Rc::new(BertNormalizer::new(BertNormalizerOptions {
lowercase: true,
..Default::default()
}))),
..Default::default()
},
);

let normalizer = BertNormalizer::new(BertNormalizerOptions {
lowercase: true,
..Default::default()
});
let tokenizer = create_tokenizer(vocab, Some(Box::new(normalizer)), Default::default());

for Case { input, expected } in cases {
let encoded = tokenizer.encode(input, None).unwrap();
Expand Down
Loading

0 comments on commit fe19446

Please sign in to comment.