From 6afe7f8fa40fe3b4b78438297eb4748de75e2a5b Mon Sep 17 00:00:00 2001 From: Meet Patel Date: Tue, 18 Mar 2025 19:00:46 -0400 Subject: [PATCH 1/3] Split from other PR --- bindings/node/Cargo.toml | 1 + bindings/node/src/models.rs | 4 +- bindings/node/src/tokenizer.rs | 4 +- bindings/python/Cargo.toml | 1 + bindings/python/src/models.rs | 4 +- bindings/python/src/tokenizer.rs | 6 +- tokenizers/Cargo.toml | 1 + tokenizers/benches/unigram_benchmark.rs | 6 +- tokenizers/src/models/bpe/model.rs | 14 ++-- tokenizers/src/models/bpe/serialization.rs | 4 +- tokenizers/src/models/bpe/trainer.rs | 71 ++++++++++--------- tokenizers/src/models/bpe/word.rs | 5 +- tokenizers/src/models/mod.rs | 14 ++-- tokenizers/src/models/unigram/model.rs | 8 +-- tokenizers/src/models/unigram/trainer.rs | 31 ++++---- tokenizers/src/models/unigram/trie.rs | 6 +- tokenizers/src/models/wordlevel/mod.rs | 22 +++--- .../src/models/wordlevel/serialization.rs | 4 +- tokenizers/src/models/wordlevel/trainer.rs | 18 ++--- tokenizers/src/models/wordpiece/mod.rs | 16 ++--- .../src/models/wordpiece/serialization.rs | 4 +- tokenizers/src/models/wordpiece/trainer.rs | 8 +-- tokenizers/src/normalizers/byte_level.rs | 7 +- tokenizers/src/pre_tokenizers/byte_level.rs | 21 +++--- tokenizers/src/processors/bert.rs | 17 ++--- tokenizers/src/processors/roberta.rs | 17 ++--- tokenizers/src/processors/sequence.rs | 8 +-- tokenizers/src/processors/template.rs | 35 ++++----- tokenizers/src/tokenizer/added_vocabulary.rs | 30 ++++---- tokenizers/src/tokenizer/encoding.rs | 23 +++--- tokenizers/src/tokenizer/mod.rs | 16 ++--- tokenizers/src/tokenizer/pre_tokenizer.rs | 4 +- tokenizers/src/utils/cache.rs | 9 ++- tokenizers/src/utils/from_pretrained.rs | 6 +- tokenizers/src/utils/mod.rs | 5 +- tokenizers/src/utils/padding.rs | 6 +- tokenizers/src/utils/truncation.rs | 10 +-- tokenizers/tests/documentation.rs | 4 +- tokenizers/tests/unigram.rs | 4 +- 39 files changed, 248 insertions(+), 226 deletions(-) diff --git a/bindings/node/Cargo.toml b/bindings/node/Cargo.toml index cf1e51e99..43e86ed84 100644 --- a/bindings/node/Cargo.toml +++ b/bindings/node/Cargo.toml @@ -12,6 +12,7 @@ crate-type = ["cdylib"] [dependencies] napi = "2" napi-derive = "2" +rustc-hash = "2.1.1" serde = { version = "1.0.163", features = ["derive"] } tokenizers = { path = "../../tokenizers/" } diff --git a/bindings/node/src/models.rs b/bindings/node/src/models.rs index a4138b91f..f66962742 100644 --- a/bindings/node/src/models.rs +++ b/bindings/node/src/models.rs @@ -3,8 +3,8 @@ use crate::tasks::models::{BPEFromFilesTask, WordLevelFromFilesTask, WordPieceFr use crate::trainers::Trainer; use napi::bindgen_prelude::*; use napi_derive::napi; +use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; use std::path::{Path, PathBuf}; use std::sync::{Arc, RwLock}; use tokenizers as tk; @@ -95,7 +95,7 @@ impl tk::Model for Model { self.model.as_ref()?.read().unwrap().id_to_token(id) } - fn get_vocab(&self) -> HashMap { + fn get_vocab(&self) -> FxHashMap { self .model .as_ref() diff --git a/bindings/node/src/tokenizer.rs b/bindings/node/src/tokenizer.rs index 4acbcac83..a99ac0313 100644 --- a/bindings/node/src/tokenizer.rs +++ b/bindings/node/src/tokenizer.rs @@ -6,7 +6,7 @@ use crate::pre_tokenizers::PreTokenizer; use crate::processors::Processor; use crate::tasks::tokenizer::{DecodeBatchTask, DecodeTask, EncodeBatchTask, EncodeTask}; use crate::trainers::Trainer; -use std::collections::HashMap; +use rustc_hash::FxHashMap; use tokenizers::Model as ModelTrait; use napi::bindgen_prelude::*; @@ -433,7 +433,7 @@ impl Tokenizer { } #[napi] - pub fn get_vocab(&self, with_added_tokens: Option) -> HashMap { + pub fn get_vocab(&self, with_added_tokens: Option) -> FxHashMap { let with_added_tokens = with_added_tokens.unwrap_or(true); self.tokenizer.read().unwrap().get_vocab(with_added_tokens) } diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 6e8b0c34c..ac8041e6e 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -18,6 +18,7 @@ pyo3 = { version = "0.23", features = ["abi3", "abi3-py39", "py-clone"] } numpy = "0.23" ndarray = "0.16" itertools = "0.12" +rustc-hash = "2.1.1" [dependencies.tokenizers] path = "../../tokenizers" diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 2f4dba825..4d6e084f3 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use rustc_hash::FxHashMap; use std::path::{Path, PathBuf}; use std::sync::{Arc, RwLock}; @@ -70,7 +70,7 @@ impl Model for PyModel { self.model.read().unwrap().id_to_token(id) } - fn get_vocab(&self) -> HashMap { + fn get_vocab(&self) -> FxHashMap { self.model.read().unwrap().get_vocab() } diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 73a0dbbe8..cb4f6ad47 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1,5 +1,5 @@ +use rustc_hash::{FxHashMap, FxHasher}; use serde::Serialize; -use std::collections::{hash_map::DefaultHasher, HashMap}; use std::hash::{Hash, Hasher}; use numpy::{npyffi, PyArray1, PyArrayMethods}; @@ -255,7 +255,7 @@ impl PyAddedToken { } fn __hash__(&self) -> u64 { - let mut hasher = DefaultHasher::new(); + let mut hasher = FxHasher::default(); self.get_token().hash(&mut hasher); hasher.finish() } @@ -675,7 +675,7 @@ impl PyTokenizer { /// :obj:`Dict[str, int]`: The vocabulary #[pyo3(signature = (with_added_tokens = true))] #[pyo3(text_signature = "(self, with_added_tokens=True)")] - fn get_vocab(&self, with_added_tokens: bool) -> HashMap { + fn get_vocab(&self, with_added_tokens: bool) -> FxHashMap { self.tokenizer.get_vocab(with_added_tokens) } diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index db56865d2..154b7a698 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -67,6 +67,7 @@ fancy-regex = { version = "0.14", optional = true} getrandom = { version = "0.2.10" } esaxx-rs = { version = "0.1.10", default-features = false, features=[]} monostate = "0.1.12" +rustc-hash = "2.1.1" [features] default = ["progressbar", "onig", "esaxx_fast"] diff --git a/tokenizers/benches/unigram_benchmark.rs b/tokenizers/benches/unigram_benchmark.rs index 9121a1937..c840aef52 100644 --- a/tokenizers/benches/unigram_benchmark.rs +++ b/tokenizers/benches/unigram_benchmark.rs @@ -2,7 +2,7 @@ extern crate criterion; use criterion::Criterion; -use std::collections::HashMap; +use rustc_hash::FxHashMap; use std::fs::read_to_string; use std::time::{Duration, Instant}; use tokenizers::models::unigram::Unigram; @@ -18,7 +18,7 @@ pub fn bench_train(c: &mut Criterion) { let mut model = Unigram::default(); let content = read_to_string("data/small.txt").unwrap(); - let mut word_counts = HashMap::new(); + let mut word_counts = FxHashMap::default(); content.split_whitespace().for_each(|word| { // This is important for the test of char vs u8 let word = format!("▁{word}"); @@ -46,7 +46,7 @@ pub fn bench_train(c: &mut Criterion) { let content = read_to_string("data/big.txt").unwrap(); // creating `medium` data, which is the first 25% of `data/big.txt` let content = String::from(&content[..(content.len() as f64 * 0.25) as usize]); - let mut word_counts = HashMap::new(); + let mut word_counts = FxHashMap::default(); content.split_whitespace().for_each(|word| { // This is important for the test of char vs u8 let word = format!("▁{word}"); diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 217c37e90..2f9687e16 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -2,19 +2,19 @@ use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, Pair, Word}; use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY, MAX_LENGTH}; use crate::utils::iter::ResultShunt; +use rustc_hash::FxHashMap; use serde_json::Value; use std::borrow::Cow; use std::{ - collections::HashMap, fs::File, io::prelude::*, io::{BufRead, BufReader}, path::{Path, PathBuf}, }; -pub type Vocab = HashMap; -type VocabR = HashMap; -pub type MergeMap = HashMap; +pub type Vocab = FxHashMap; +type VocabR = FxHashMap; +pub type MergeMap = FxHashMap; pub type Merges = Vec<(String, String)>; struct Config { @@ -41,7 +41,7 @@ impl Default for BpeBuilder { Self { config: Config { files: None, - vocab: HashMap::new(), + vocab: FxHashMap::default(), merges: vec![], cache_capacity: DEFAULT_CACHE_CAPACITY, dropout: None, @@ -324,7 +324,7 @@ impl BPE { let mut buffer = String::new(); vocab_file.read_to_string(&mut buffer)?; let json: Value = serde_json::from_str(&buffer)?; - let mut vocab = HashMap::new(); + let mut vocab = FxHashMap::default(); match json { Value::Object(m) => { for (token, id) in m { @@ -493,7 +493,7 @@ impl BPE { impl Model for BPE { type Trainer = BpeTrainer; - fn get_vocab(&self) -> HashMap { + fn get_vocab(&self) -> FxHashMap { self.vocab.clone() } diff --git a/tokenizers/src/models/bpe/serialization.rs b/tokenizers/src/models/bpe/serialization.rs index 98cc15102..b443889c8 100644 --- a/tokenizers/src/models/bpe/serialization.rs +++ b/tokenizers/src/models/bpe/serialization.rs @@ -1,10 +1,10 @@ use super::{super::OrderedVocabIter, convert_merges_to_hashmap, BpeBuilder, Pair, BPE}; +use rustc_hash::FxHashMap; use serde::{ de::{Error, MapAccess, Visitor}, ser::SerializeStruct, Deserialize, Deserializer, Serialize, Serializer, }; -use std::collections::HashMap; impl Serialize for BPE { fn serialize(&self, serializer: S) -> Result @@ -80,7 +80,7 @@ impl<'de> Visitor<'de> for BPEVisitor { V: MapAccess<'de>, { let mut builder = BpeBuilder::new(); - let mut vocab: Option> = None; + let mut vocab: Option> = None; #[derive(Debug, Deserialize)] #[serde(untagged)] diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index a1a0aba76..2890ecb1f 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -4,15 +4,17 @@ use super::{Pair, WithFirstLastIterator, Word, BPE}; use crate::parallelism::*; use crate::tokenizer::{AddedToken, Result, Trainer}; use crate::utils::progress::{ProgressBar, ProgressStyle}; +use rustc_hash::FxHashMap; +use rustc_hash::FxHashSet; use serde::{Deserialize, Serialize}; use std::cmp::Ordering; -use std::collections::{BinaryHeap, HashMap, HashSet}; +use std::collections::BinaryHeap; #[derive(Debug, Eq)] struct Merge { pair: Pair, count: u64, - pos: HashSet, + pos: FxHashSet, } impl PartialEq for Merge { fn eq(&self, other: &Self) -> bool { @@ -41,7 +43,7 @@ struct Config { show_progress: bool, special_tokens: Vec, limit_alphabet: Option, - initial_alphabet: HashSet, + initial_alphabet: FxHashSet, continuing_subword_prefix: Option, end_of_word_suffix: Option, max_token_length: Option, @@ -62,7 +64,7 @@ impl Default for BpeTrainerBuilder { show_progress: true, special_tokens: vec![], limit_alphabet: None, - initial_alphabet: HashSet::new(), + initial_alphabet: FxHashSet::default(), continuing_subword_prefix: None, end_of_word_suffix: None, max_token_length: None, @@ -114,7 +116,7 @@ impl BpeTrainerBuilder { /// Set the initial alphabet #[must_use] - pub fn initial_alphabet(mut self, alphabet: HashSet) -> Self { + pub fn initial_alphabet(mut self, alphabet: FxHashSet) -> Self { self.config.initial_alphabet = alphabet; self } @@ -151,7 +153,7 @@ impl BpeTrainerBuilder { continuing_subword_prefix: self.config.continuing_subword_prefix, end_of_word_suffix: self.config.end_of_word_suffix, max_token_length: self.config.max_token_length, - words: HashMap::new(), + words: FxHashMap::default(), } } } @@ -187,7 +189,7 @@ pub struct BpeTrainer { pub limit_alphabet: Option, /// The initial alphabet we want absolutely to include. This allows to cover /// some characters that are not necessarily in the training set - pub initial_alphabet: HashSet, + pub initial_alphabet: FxHashSet, /// An optional prefix to use on any subword that exist only behind another one pub continuing_subword_prefix: Option, /// An optional suffix to caracterize and end-of-word subword @@ -195,7 +197,7 @@ pub struct BpeTrainer { /// An optional parameter to limit the max length of any single token pub max_token_length: Option, - words: HashMap, + words: FxHashMap, } impl Default for BpeTrainer { @@ -251,7 +253,7 @@ impl BpeTrainer { } /// Add the provided special tokens to the initial vocabulary - fn add_special_tokens(&self, w2id: &mut HashMap, id2w: &mut Vec) { + fn add_special_tokens(&self, w2id: &mut FxHashMap, id2w: &mut Vec) { for token in &self.special_tokens { if !w2id.contains_key(&token.content) { id2w.push(token.content.to_owned()); @@ -263,12 +265,12 @@ impl BpeTrainer { /// Compute the initial alphabet and limit it if relevant fn compute_alphabet( &self, - wc: &HashMap, - w2id: &mut HashMap, + wc: &FxHashMap, + w2id: &mut FxHashMap, id2w: &mut Vec, ) { // Compute the alphabet from seen words - let mut alphabet: HashMap = HashMap::new(); + let mut alphabet: FxHashMap = FxHashMap::default(); for (word, count) in wc { for c in word.chars() { alphabet @@ -322,8 +324,8 @@ impl BpeTrainer { /// Tokenize words and add subwords to the vocabulary when relevant fn tokenize_words( &self, - wc: &HashMap, - w2id: &mut HashMap, + wc: &FxHashMap, + w2id: &mut FxHashMap, id2w: &mut Vec, p: &Option, ) -> (Vec, Vec) { @@ -375,13 +377,13 @@ impl BpeTrainer { words: &[Word], counts: &[u64], p: &Option, - ) -> (HashMap, HashMap>) { + ) -> (FxHashMap, FxHashMap>) { words .maybe_par_iter() .enumerate() .map(|(i, word)| { - let mut pair_counts = HashMap::new(); - let mut where_to_update: HashMap> = HashMap::new(); + let mut pair_counts = FxHashMap::default(); + let mut where_to_update: FxHashMap> = FxHashMap::default(); for window in word.get_chars().windows(2) { let cur_pair: Pair = (window[0], window[1]); @@ -399,7 +401,7 @@ impl BpeTrainer { h.insert(i); }) .or_insert_with(|| { - let mut h = HashSet::new(); + let mut h = FxHashSet::default(); h.insert(i); h }); @@ -413,7 +415,7 @@ impl BpeTrainer { (pair_counts, where_to_update) }) .reduce( - || (HashMap::new(), HashMap::new()), + || (FxHashMap::default(), FxHashMap::default()), |(mut pair_counts, mut where_to_update), (pc, wtu)| { for (k, v) in pc { pair_counts.entry(k).and_modify(|c| *c += v).or_insert(v); @@ -431,10 +433,11 @@ impl BpeTrainer { pub fn do_train( &self, - word_counts: &HashMap, + word_counts: &FxHashMap, model: &mut BPE, ) -> Result> { - let mut word_to_id: HashMap = HashMap::with_capacity(self.vocab_size); + let mut word_to_id: FxHashMap = + FxHashMap::with_capacity_and_hasher(self.vocab_size, Default::default()); let mut id_to_word: Vec = Vec::with_capacity(self.vocab_size); let max_token_length: usize = self.max_token_length.unwrap_or(usize::MAX); @@ -532,7 +535,7 @@ impl BpeTrainer { // Merge the new pair in every words // Safety: This is just a type assertion, the code below may no longer be safe // if the type of `pos` changes - let pos: &HashSet = &top.pos; + let pos: &FxHashSet = &top.pos; let words_len = words.len(); struct WordPtr(*mut Word); @@ -577,7 +580,7 @@ impl BpeTrainer { h.insert(iw); }) .or_insert_with(|| { - let mut h = HashSet::new(); + let mut h = FxHashSet::default(); h.insert(iw); h }); @@ -647,18 +650,18 @@ impl Trainer for BpeTrainer { S: AsRef + Send, F: Fn(&str) -> Result> + Sync, { - let words: Result> = iterator + let words: Result> = iterator .maybe_par_bridge() .map(|sequence| { let words = process(sequence.as_ref())?; - let mut map = HashMap::new(); + let mut map = FxHashMap::default(); for word in words { map.entry(word).and_modify(|c| *c += 1).or_insert(1); } Ok(map) }) .reduce( - || Ok(HashMap::new()), + || Ok(FxHashMap::default()), |acc, ws| { let mut acc = acc?; for (k, v) in ws? { @@ -676,11 +679,11 @@ impl Trainer for BpeTrainer { #[cfg(test)] mod tests { use super::{BpeTrainer, Pair, BPE}; - use std::collections::HashMap; + use rustc_hash::FxHashMap; #[test] fn test_train() { - let word_counts: HashMap = [ + let word_counts: FxHashMap = [ ("roses".into(), 1), ("are".into(), 2), ("red".into(), 1), @@ -705,7 +708,7 @@ mod tests { // Vocab should contain all of the characters from the `word_counts` mapping // as well as three merges: 're', 'are', and 'is'. - let expected_vocab: HashMap = [ + let expected_vocab: FxHashMap = [ ("-".into(), 0), ("2".into(), 1), ("B".into(), 2), @@ -741,7 +744,7 @@ mod tests { // where 'rank' determines the order in which this merge will be applied during // tokenization, and 'id' is the vocab id of the symbol resulting from merging // the pair of symbols in the corresponding key. - let expected_merges: HashMap = [ + let expected_merges: FxHashMap = [ ((17, 11), (0, 22)), // 'r' + 'e' -> 're' ((8, 22), (1, 23)), // 'a' + 're' -> 'are' ((13, 18), (2, 24)), // 'i' + 's' -> 'is' @@ -759,7 +762,7 @@ mod tests { */ let max_token_length = 16; - let long_word_counts: HashMap = [ + let long_word_counts: FxHashMap = [ ("singlelongtokenwithoutcasechange", 2), ("singleLongTokenWithCamelCaseChange", 2), ("Longsingletokenwithpunctu@t!onwithin", 2), @@ -799,7 +802,7 @@ mod tests { // directly compares tokens with known expected values. // maybe unstable depending on specific settings or changes. */ - let long_word_counts: HashMap = [ + let long_word_counts: FxHashMap = [ ("sin", 2), ("Sin", 2), ("Lon", 2), @@ -823,8 +826,8 @@ mod tests { .build(); let mut model = BPE::default(); trainer.do_train(&long_word_counts, &mut model).unwrap(); - let trained_vocab: HashMap = model.get_vocab(); - let expected_vocab: HashMap = [ + let trained_vocab: FxHashMap = model.get_vocab(); + let expected_vocab: FxHashMap = [ ("短", 12), ("n", 6), ("i", 5), diff --git a/tokenizers/src/models/bpe/word.rs b/tokenizers/src/models/bpe/word.rs index 93b3d9c37..60bd2258d 100644 --- a/tokenizers/src/models/bpe/word.rs +++ b/tokenizers/src/models/bpe/word.rs @@ -1,7 +1,8 @@ use super::Pair; use rand::{thread_rng, Rng}; +use rustc_hash::FxHashMap; use std::cmp::Ordering; -use std::collections::{BinaryHeap, HashMap}; +use std::collections::BinaryHeap; #[derive(Debug, Eq)] struct Merge { @@ -158,7 +159,7 @@ impl Word { changes } - pub(super) fn merge_all(&mut self, merges: &HashMap, dropout: Option) { + pub(super) fn merge_all(&mut self, merges: &FxHashMap, dropout: Option) { let mut queue = BinaryHeap::with_capacity(self.symbols.len()); let mut skip = Vec::with_capacity(queue.len()); diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index 3a3a91adc..48433d480 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -5,7 +5,7 @@ pub mod unigram; pub mod wordlevel; pub mod wordpiece; -use std::collections::HashMap; +use rustc_hash::FxHashMap; use std::path::{Path, PathBuf}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; @@ -19,11 +19,11 @@ use crate::{AddedToken, Model, Result, Token, Trainer}; /// Wraps a vocab mapping (ID -> token) to a struct that will be serialized in order /// of token ID, smallest to largest. struct OrderedVocabIter<'a> { - vocab_r: &'a HashMap, + vocab_r: &'a FxHashMap, } impl<'a> OrderedVocabIter<'a> { - fn new(vocab_r: &'a HashMap) -> Self { + fn new(vocab_r: &'a FxHashMap) -> Self { Self { vocab_r } } } @@ -170,7 +170,7 @@ impl Model for ModelWrapper { } } - fn get_vocab(&self) -> HashMap { + fn get_vocab(&self) -> FxHashMap { match self { Self::WordLevel(t) => t.get_vocab(), Self::WordPiece(t) => t.get_vocab(), @@ -287,6 +287,8 @@ impl_enum_from!(WordLevelTrainer, TrainerWrapper, WordLevelTrainer); #[cfg(test)] mod tests { + use std::iter::FromIterator; + use super::*; use crate::models::bpe::{BpeBuilder, Vocab}; @@ -301,8 +303,8 @@ mod tests { #[test] fn incomplete_ordered_vocab() { - let vocab_r: HashMap = - HashMap::from([(0, "Hi".to_string()), (2, "There".to_string())]); + let vocab_r: FxHashMap = + FxHashMap::from_iter([(0, "Hi".to_string()), (2, "There".to_string())]); let ordered = OrderedVocabIter::new(&vocab_r); diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index da4d631ce..1c3f14234 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -6,12 +6,12 @@ use super::{ use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::{Cache, MAX_LENGTH}; -use std::collections::HashMap; +use rustc_hash::FxHashMap; use std::convert::TryInto; use std::fs::read_to_string; use std::path::{Path, PathBuf}; -type TokenMap = HashMap; +type TokenMap = FxHashMap; type Vocab = Vec<(String, f64)>; /// A `Unigram` model to encode sentences. @@ -98,7 +98,7 @@ impl Unigram { byte_fallback: bool, ) -> Result { let n = vocab.len(); - let mut token_to_ids: TokenMap = HashMap::new(); + let mut token_to_ids: TokenMap = FxHashMap::default(); let mut builder = TrieBuilder::default(); if let Some(unk_id) = unk_id { @@ -415,7 +415,7 @@ impl<'a> Iterator for UnigramIterator<'a> { impl Model for Unigram { type Trainer = UnigramTrainer; - fn get_vocab(&self) -> HashMap { + fn get_vocab(&self) -> FxHashMap { self.token_to_ids.clone() } diff --git a/tokenizers/src/models/unigram/trainer.rs b/tokenizers/src/models/unigram/trainer.rs index 5d178e77b..66122377f 100644 --- a/tokenizers/src/models/unigram/trainer.rs +++ b/tokenizers/src/models/unigram/trainer.rs @@ -3,9 +3,10 @@ use crate::tokenizer::{AddedToken, Result, Trainer}; use crate::utils::parallelism::*; use crate::utils::progress::{ProgressBar, ProgressStyle}; use log::debug; +use rustc_hash::FxHashMap; +use rustc_hash::FxHashSet; use serde::{Deserialize, Serialize}; use std::cmp::Reverse; -use std::collections::{HashMap, HashSet}; use std::convert::TryInto; // A token and a score @@ -57,8 +58,8 @@ pub struct UnigramTrainer { pub shrinking_factor: f64, #[builder(default = "vec![]")] pub special_tokens: Vec, - #[builder(default = "HashSet::new()")] - pub initial_alphabet: HashSet, + #[builder(default = "FxHashSet::default()")] + pub initial_alphabet: FxHashSet, #[builder(default = "None")] pub unk_token: Option, @@ -67,8 +68,8 @@ pub struct UnigramTrainer { pub max_piece_length: usize, #[builder(default = "1_000_000")] seed_size: usize, - #[builder(default = "HashMap::new()")] - words: HashMap, + #[builder(default = "FxHashMap::default()")] + words: FxHashMap, } impl Default for UnigramTrainer { @@ -110,17 +111,17 @@ impl UnigramTrainer { true } - fn finalize(&self, model: Unigram, required_chars: HashSet) -> Result { + fn finalize(&self, model: Unigram, required_chars: FxHashSet) -> Result { let mut min_score_penalty = 0.0; let min_score_penalty_delta = 0.0001; let mut pieces: Vec<(String, f64)> = vec![]; - let mut inserted: HashSet = HashSet::new(); + let mut inserted: FxHashSet = FxHashSet::default(); // We don't want to include the that was used to train inserted.insert("".into()); - let existing_pieces: HashMap = model.iter().cloned().collect(); + let existing_pieces: FxHashMap = model.iter().cloned().collect(); for c in required_chars { if let Some(t) = existing_pieces.get(&c) { inserted.insert(c.clone()); @@ -185,7 +186,7 @@ impl UnigramTrainer { ) } - fn required_chars(&self, word_counts: &[Sentence]) -> HashSet { + fn required_chars(&self, word_counts: &[Sentence]) -> FxHashSet { word_counts .iter() .flat_map(|(s, _count)| s.chars()) @@ -205,7 +206,7 @@ impl UnigramTrainer { .sum::() + sentences.len(); let mut flat_string = String::with_capacity(total); - let mut all_chars: HashMap = HashMap::new(); + let mut all_chars: FxHashMap = FxHashMap::default(); let c_sentence_boundary = '\0'; let k_sentence_boundary = '\0'.to_string(); for (string, n) in sentences { @@ -631,18 +632,18 @@ impl Trainer for UnigramTrainer { S: AsRef + Send, F: Fn(&str) -> Result> + Sync, { - let words: Result> = iterator + let words: Result> = iterator .maybe_par_bridge() .map(|sequence| { let words = process(sequence.as_ref())?; - let mut map = HashMap::new(); + let mut map = FxHashMap::default(); for word in words { map.entry(word).and_modify(|c| *c += 1).or_insert(1); } Ok(map) }) .reduce( - || Ok(HashMap::new()), + || Ok(FxHashMap::default()), |acc, ws| { let mut acc = acc?; for (k, v) in ws? { @@ -716,7 +717,7 @@ mod tests { fn test_initial_alphabet() { let trainer = UnigramTrainerBuilder::default() .show_progress(false) - .initial_alphabet(HashSet::from_iter(vec!['a', 'b', 'c', 'd', 'e', 'f'])) + .initial_alphabet(FxHashSet::from_iter(vec!['a', 'b', 'c', 'd', 'e', 'f'])) .build() .unwrap(); @@ -727,7 +728,7 @@ mod tests { vec!["こ", "ん", "に", "ち", "は", "友", "達", "a", "b", "c", "d", "e", "f"] .into_iter() .map(|s| s.to_owned()) - .collect::>() + .collect::>() ); } diff --git a/tokenizers/src/models/unigram/trie.rs b/tokenizers/src/models/unigram/trie.rs index 2f94b1766..70f5333d4 100644 --- a/tokenizers/src/models/unigram/trie.rs +++ b/tokenizers/src/models/unigram/trie.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use rustc_hash::FxHashMap; use std::hash::Hash; #[derive(Default)] @@ -78,14 +78,14 @@ impl