Skip to content

Commit

Permalink
Merge pull request #428 from robertknight/normalizer-trait
Browse files Browse the repository at this point in the history
Convert `Normalizer` into a trait
  • Loading branch information
robertknight authored Dec 2, 2024
2 parents 313d074 + d1992c0 commit 2af6fc5
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 47 deletions.
58 changes: 33 additions & 25 deletions rten-text/src/normalizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,22 +60,33 @@ impl CharNormalizer {
}
}

/// Normalizer applies normalization such as Unicode normalization and
/// A normalizer applies normalization such as Unicode normalization and
/// lower-casing to strings.
///
/// In addition to the normalized text, Normalizer methods also return mappings
/// from positions in the normalized string back to the original string. This
/// is useful for post-processing in NLP tasks to map machine learning model
/// outputs back to the location in the original text.
pub trait Normalizer: std::fmt::Debug {
/// Apply normalization to a string.
///
/// Returns a tuple of `(normalized_string, offset_map)` where `offset_map`
/// is a mapping from byte offsets in the normalized string to corresponding
/// offsets in the original string.
fn normalize(&self, text: &str) -> (String, Vec<usize>);
}

/// A [`Normalizer`] that implements normalization used by BERT and BERT-derived
/// models.
#[derive(Clone, Debug)]
pub struct Normalizer {
pub struct BertNormalizer {
lowercase: bool,
strip_accents: bool,
}

/// Configuration for a [`Normalizer`].
/// Configuration for a [`BertNormalizer`].
#[derive(Clone, Debug, Default)]
pub struct NormalizerOptions {
pub struct BertNormalizerOptions {
/// If true, convert all text to lowercase using [`char::to_lowercase`].
pub lowercase: bool,

Expand All @@ -84,20 +95,22 @@ pub struct NormalizerOptions {
pub strip_accents: bool,
}

impl Normalizer {
pub fn new(opts: NormalizerOptions) -> Normalizer {
Normalizer {
impl BertNormalizer {
pub fn new(opts: BertNormalizerOptions) -> BertNormalizer {
BertNormalizer {
lowercase: opts.lowercase,
strip_accents: opts.strip_accents,
}
}

/// Apply normalization to a string.
///
/// Returns a tuple of `(normalized_string, offset_map)` where `offset_map`
/// is a mapping from byte offsets in the normalized string to corresponding
/// offsets in the original string.
pub fn normalize(&self, text: &str) -> (String, Vec<usize>) {
/// Return true if this normalizer doesn't alter its input.
fn is_noop(&self) -> bool {
!self.lowercase && !self.strip_accents
}
}

impl Normalizer for BertNormalizer {
fn normalize(&self, text: &str) -> (String, Vec<usize>) {
if self.is_noop() {
let offsets = (0..text.len()).collect();
return (text.to_string(), offsets);
Expand Down Expand Up @@ -128,20 +141,15 @@ impl Normalizer {

(normalized, offsets)
}

/// Return true if this normalizer doesn't alter its input.
fn is_noop(&self) -> bool {
!self.lowercase && !self.strip_accents
}
}

#[cfg(test)]
mod tests {
use super::{Normalizer, NormalizerOptions};
use super::{BertNormalizer, BertNormalizerOptions, Normalizer};

#[test]
fn test_normalizer_noop() {
let normalizer = Normalizer::new(NormalizerOptions::default());
fn test_bert_normalizer_noop() {
let normalizer = BertNormalizer::new(BertNormalizerOptions::default());
let inputs = [
"Hello world!", // Mixed case
"Motörhead", // Accented
Expand All @@ -155,8 +163,8 @@ mod tests {
}

#[test]
fn test_normalizer_lowercase() {
let normalizer = Normalizer::new(NormalizerOptions {
fn test_bert_normalizer_lowercase() {
let normalizer = BertNormalizer::new(BertNormalizerOptions {
lowercase: true,
..Default::default()
});
Expand Down Expand Up @@ -200,7 +208,7 @@ mod tests {
}

#[test]
fn test_normalizer_strip_accepts() {
fn test_bert_normalizer_strip_accepts() {
struct Case<'a> {
input: &'a str,
lowercase: bool,
Expand Down Expand Up @@ -236,7 +244,7 @@ mod tests {
expected_offsets,
} in cases
{
let normalizer = Normalizer::new(NormalizerOptions {
let normalizer = BertNormalizer::new(BertNormalizerOptions {
lowercase,
strip_accents: true,
..Default::default()
Expand Down
30 changes: 18 additions & 12 deletions rten-text/src/tokenizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ use std::error::Error;
use std::fmt;
use std::iter::repeat;
use std::ops::Range;
use std::rc::Rc;

use crate::normalizer::{Normalizer, NormalizerOptions};
use crate::normalizer::{BertNormalizer, BertNormalizerOptions, Normalizer};
use crate::split::SliceExt;

mod bpe;
Expand Down Expand Up @@ -292,17 +293,22 @@ impl Tokenizer {
}

fn from_parsed_json(json: json::TokenizerJson) -> Result<Tokenizer, FromJsonError> {
let normalizer = json.normalizer.map(|normalizer| match normalizer {
json::Normalizer::Bert(bert_norm) => Normalizer::new(NormalizerOptions {
lowercase: bert_norm.lowercase,
strip_accents: bert_norm.strip_accents.unwrap_or(bert_norm.lowercase),
}),

// Dummy implementation of NFC normalization.
json::Normalizer::Nfc => Normalizer::new(NormalizerOptions {
lowercase: false,
strip_accents: false,
}),
let normalizer: Option<Rc<dyn Normalizer>> = json.normalizer.map(|normalizer| {
let normalizer: Rc<dyn Normalizer> = match normalizer {
json::Normalizer::Bert(bert_norm) => {
Rc::new(BertNormalizer::new(BertNormalizerOptions {
lowercase: bert_norm.lowercase,
strip_accents: bert_norm.strip_accents.unwrap_or(bert_norm.lowercase),
}))
}

// Dummy implementation of NFC normalization.
json::Normalizer::Nfc => Rc::new(BertNormalizer::new(BertNormalizerOptions {
lowercase: false,
strip_accents: false,
})),
};
normalizer
});

match json.model {
Expand Down
16 changes: 9 additions & 7 deletions rten-text/src/tokenizers/wordpiece.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::rc::Rc;

use super::{Encoder, TokenId, TokenizerError};
use crate::normalizer::Normalizer;
Expand All @@ -18,7 +19,7 @@ use unicode_categories::UnicodeCategories;
/// (2018). <https://arxiv.org/abs/1810.04805>
#[derive(Clone)]
pub struct WordPiece {
normalizer: Option<Normalizer>,
normalizer: Option<Rc<dyn Normalizer>>,
token_to_id: HashMap<String, TokenId>,
id_to_token: HashMap<TokenId, String>,
subword_prefix: String,
Expand All @@ -30,7 +31,7 @@ pub struct WordPiece {
pub struct WordPieceOptions {
/// The normalizer that handles Unicode normalization, lower-casing the
/// input etc.
pub normalizer: Option<Normalizer>,
pub normalizer: Option<Rc<dyn Normalizer>>,

/// The maximum length of words that can be tokenized. Any words longer than
/// this are tokenized as `[UNK]`.
Expand Down Expand Up @@ -172,8 +173,9 @@ impl Encoder for WordPiece {
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::rc::Rc;

use crate::normalizer::{Normalizer, NormalizerOptions};
use crate::normalizer::{BertNormalizer, BertNormalizerOptions};
use crate::tokenizers::{
EncodeOptions, Tokenizer, TokenizerOptions, WordPiece, WordPieceOptions,
};
Expand Down Expand Up @@ -298,10 +300,10 @@ mod tests {
let tokenizer = create_tokenizer(
vocab,
WordPieceOptions {
normalizer: Some(Normalizer::new(NormalizerOptions {
normalizer: Some(Rc::new(BertNormalizer::new(BertNormalizerOptions {
lowercase: true,
..Default::default()
})),
}))),
..Default::default()
},
);
Expand Down Expand Up @@ -358,10 +360,10 @@ mod tests {
let tokenizer = create_tokenizer(
vocab,
WordPieceOptions {
normalizer: Some(Normalizer::new(NormalizerOptions {
normalizer: Some(Rc::new(BertNormalizer::new(BertNormalizerOptions {
lowercase: true,
..Default::default()
})),
}))),
..Default::default()
},
);
Expand Down
7 changes: 4 additions & 3 deletions rten-text/tests/reftest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ use std::error::Error;
use std::fs::read_to_string;
use std::io;
use std::path::PathBuf;
use std::rc::Rc;

use rten_text::normalizer::{Normalizer, NormalizerOptions};
use rten_text::normalizer::{BertNormalizer, BertNormalizerOptions};
use rten_text::tokenizers::patterns::GPT2 as GPT2_SPLIT_PATTERN;
use rten_text::tokenizers::{
merge_pairs_from_lines, Bpe, TokenId, Tokenizer, TokenizerOptions, WordPiece, WordPieceOptions,
Expand Down Expand Up @@ -120,15 +121,15 @@ fn test_wordpiece_bert_uncased() -> Result<(), Box<dyn Error>> {

let vocab = read_vocab_text_file("models/bert-base-uncased/vocab.txt")?;

let normalizer = Normalizer::new(NormalizerOptions {
let normalizer = BertNormalizer::new(BertNormalizerOptions {
lowercase: true,
strip_accents: true,
..Default::default()
});
let encoder = WordPiece::from_vocab(
vocab,
WordPieceOptions {
normalizer: Some(normalizer),
normalizer: Some(Rc::new(normalizer)),
..Default::default()
},
);
Expand Down

0 comments on commit 2af6fc5

Please sign in to comment.