diff --git a/bindings/node/Cargo.toml b/bindings/node/Cargo.toml index cf1e51e99..43e86ed84 100644 --- a/bindings/node/Cargo.toml +++ b/bindings/node/Cargo.toml @@ -12,6 +12,7 @@ crate-type = ["cdylib"] [dependencies] napi = "2" napi-derive = "2" +rustc-hash = "2.1.1" serde = { version = "1.0.163", features = ["derive"] } tokenizers = { path = "../../tokenizers/" } diff --git a/bindings/node/src/models.rs b/bindings/node/src/models.rs index a4138b91f..f66962742 100644 --- a/bindings/node/src/models.rs +++ b/bindings/node/src/models.rs @@ -3,8 +3,8 @@ use crate::tasks::models::{BPEFromFilesTask, WordLevelFromFilesTask, WordPieceFr use crate::trainers::Trainer; use napi::bindgen_prelude::*; use napi_derive::napi; +use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; use std::path::{Path, PathBuf}; use std::sync::{Arc, RwLock}; use tokenizers as tk; @@ -95,7 +95,7 @@ impl tk::Model for Model { self.model.as_ref()?.read().unwrap().id_to_token(id) } - fn get_vocab(&self) -> HashMap { + fn get_vocab(&self) -> FxHashMap { self .model .as_ref() diff --git a/bindings/node/src/tokenizer.rs b/bindings/node/src/tokenizer.rs index 4acbcac83..a99ac0313 100644 --- a/bindings/node/src/tokenizer.rs +++ b/bindings/node/src/tokenizer.rs @@ -6,7 +6,7 @@ use crate::pre_tokenizers::PreTokenizer; use crate::processors::Processor; use crate::tasks::tokenizer::{DecodeBatchTask, DecodeTask, EncodeBatchTask, EncodeTask}; use crate::trainers::Trainer; -use std::collections::HashMap; +use rustc_hash::FxHashMap; use tokenizers::Model as ModelTrait; use napi::bindgen_prelude::*; @@ -433,7 +433,7 @@ impl Tokenizer { } #[napi] - pub fn get_vocab(&self, with_added_tokens: Option) -> HashMap { + pub fn get_vocab(&self, with_added_tokens: Option) -> FxHashMap { let with_added_tokens = with_added_tokens.unwrap_or(true); self.tokenizer.read().unwrap().get_vocab(with_added_tokens) } diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 6e8b0c34c..ac8041e6e 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -18,6 +18,7 @@ pyo3 = { version = "0.23", features = ["abi3", "abi3-py39", "py-clone"] } numpy = "0.23" ndarray = "0.16" itertools = "0.12" +rustc-hash = "2.1.1" [dependencies.tokenizers] path = "../../tokenizers" diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 2f4dba825..4d6e084f3 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use rustc_hash::FxHashMap; use std::path::{Path, PathBuf}; use std::sync::{Arc, RwLock}; @@ -70,7 +70,7 @@ impl Model for PyModel { self.model.read().unwrap().id_to_token(id) } - fn get_vocab(&self) -> HashMap { + fn get_vocab(&self) -> FxHashMap { self.model.read().unwrap().get_vocab() } diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 73a0dbbe8..cb4f6ad47 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1,5 +1,5 @@ +use rustc_hash::{FxHashMap, FxHasher}; use serde::Serialize; -use std::collections::{hash_map::DefaultHasher, HashMap}; use std::hash::{Hash, Hasher}; use numpy::{npyffi, PyArray1, PyArrayMethods}; @@ -255,7 +255,7 @@ impl PyAddedToken { } fn __hash__(&self) -> u64 { - let mut hasher = DefaultHasher::new(); + let mut hasher = FxHasher::default(); self.get_token().hash(&mut hasher); hasher.finish() } @@ -675,7 +675,7 @@ impl PyTokenizer { /// :obj:`Dict[str, int]`: The vocabulary #[pyo3(signature = (with_added_tokens = true))] #[pyo3(text_signature = "(self, with_added_tokens=True)")] - fn get_vocab(&self, with_added_tokens: bool) -> HashMap { + fn get_vocab(&self, with_added_tokens: bool) -> FxHashMap { self.tokenizer.get_vocab(with_added_tokens) } diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index db56865d2..154b7a698 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -67,6 +67,7 @@ fancy-regex = { version = "0.14", optional = true} getrandom = { version = "0.2.10" } esaxx-rs = { version = "0.1.10", default-features = false, features=[]} monostate = "0.1.12" +rustc-hash = "2.1.1" [features] default = ["progressbar", "onig", "esaxx_fast"] diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 217c37e90..db9cfaebd 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -2,19 +2,22 @@ use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, Pair, Word}; use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY, MAX_LENGTH}; use crate::utils::iter::ResultShunt; +use rustc_hash::FxHashMap; use serde_json::Value; use std::borrow::Cow; +use std::collections::HashMap; +use std::hash::BuildHasher; +use std::iter::FromIterator; use std::{ - collections::HashMap, fs::File, io::prelude::*, io::{BufRead, BufReader}, path::{Path, PathBuf}, }; -pub type Vocab = HashMap; -type VocabR = HashMap; -pub type MergeMap = HashMap; +pub type Vocab = FxHashMap; +type VocabR = FxHashMap; +pub type MergeMap = FxHashMap; pub type Merges = Vec<(String, String)>; struct Config { @@ -41,7 +44,7 @@ impl Default for BpeBuilder { Self { config: Config { files: None, - vocab: HashMap::new(), + vocab: FxHashMap::default(), merges: vec![], cache_capacity: DEFAULT_CACHE_CAPACITY, dropout: None, @@ -71,8 +74,12 @@ impl BpeBuilder { /// Set the vocab (token -> ID) and merges mappings. #[must_use] - pub fn vocab_and_merges(mut self, vocab: Vocab, merges: Merges) -> Self { - self.config.vocab = vocab; + pub fn vocab_and_merges( + mut self, + vocab: HashMap, + merges: Merges, + ) -> Self { + self.config.vocab = FxHashMap::from_iter(vocab); self.config.merges = merges; self } @@ -324,7 +331,7 @@ impl BPE { let mut buffer = String::new(); vocab_file.read_to_string(&mut buffer)?; let json: Value = serde_json::from_str(&buffer)?; - let mut vocab = HashMap::new(); + let mut vocab = FxHashMap::default(); match json { Value::Object(m) => { for (token, id) in m { @@ -493,7 +500,7 @@ impl BPE { impl Model for BPE { type Trainer = BpeTrainer; - fn get_vocab(&self) -> HashMap { + fn get_vocab(&self) -> FxHashMap { self.vocab.clone() } diff --git a/tokenizers/src/models/bpe/serialization.rs b/tokenizers/src/models/bpe/serialization.rs index 98cc15102..b443889c8 100644 --- a/tokenizers/src/models/bpe/serialization.rs +++ b/tokenizers/src/models/bpe/serialization.rs @@ -1,10 +1,10 @@ use super::{super::OrderedVocabIter, convert_merges_to_hashmap, BpeBuilder, Pair, BPE}; +use rustc_hash::FxHashMap; use serde::{ de::{Error, MapAccess, Visitor}, ser::SerializeStruct, Deserialize, Deserializer, Serialize, Serializer, }; -use std::collections::HashMap; impl Serialize for BPE { fn serialize(&self, serializer: S) -> Result @@ -80,7 +80,7 @@ impl<'de> Visitor<'de> for BPEVisitor { V: MapAccess<'de>, { let mut builder = BpeBuilder::new(); - let mut vocab: Option> = None; + let mut vocab: Option> = None; #[derive(Debug, Deserialize)] #[serde(untagged)] diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index a1a0aba76..7d7fe2a2f 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -4,15 +4,19 @@ use super::{Pair, WithFirstLastIterator, Word, BPE}; use crate::parallelism::*; use crate::tokenizer::{AddedToken, Result, Trainer}; use crate::utils::progress::{ProgressBar, ProgressStyle}; +use rustc_hash::FxHashMap; +use rustc_hash::FxHashSet; use serde::{Deserialize, Serialize}; use std::cmp::Ordering; use std::collections::{BinaryHeap, HashMap, HashSet}; +use std::hash::BuildHasher; +use std::iter::FromIterator; #[derive(Debug, Eq)] struct Merge { pair: Pair, count: u64, - pos: HashSet, + pos: FxHashSet, } impl PartialEq for Merge { fn eq(&self, other: &Self) -> bool { @@ -41,7 +45,7 @@ struct Config { show_progress: bool, special_tokens: Vec, limit_alphabet: Option, - initial_alphabet: HashSet, + initial_alphabet: FxHashSet, continuing_subword_prefix: Option, end_of_word_suffix: Option, max_token_length: Option, @@ -62,7 +66,7 @@ impl Default for BpeTrainerBuilder { show_progress: true, special_tokens: vec![], limit_alphabet: None, - initial_alphabet: HashSet::new(), + initial_alphabet: FxHashSet::default(), continuing_subword_prefix: None, end_of_word_suffix: None, max_token_length: None, @@ -114,8 +118,8 @@ impl BpeTrainerBuilder { /// Set the initial alphabet #[must_use] - pub fn initial_alphabet(mut self, alphabet: HashSet) -> Self { - self.config.initial_alphabet = alphabet; + pub fn initial_alphabet(mut self, alphabet: HashSet) -> Self { + self.config.initial_alphabet = FxHashSet::from_iter(alphabet); self } @@ -151,7 +155,7 @@ impl BpeTrainerBuilder { continuing_subword_prefix: self.config.continuing_subword_prefix, end_of_word_suffix: self.config.end_of_word_suffix, max_token_length: self.config.max_token_length, - words: HashMap::new(), + words: FxHashMap::default(), } } } @@ -187,7 +191,7 @@ pub struct BpeTrainer { pub limit_alphabet: Option, /// The initial alphabet we want absolutely to include. This allows to cover /// some characters that are not necessarily in the training set - pub initial_alphabet: HashSet, + pub initial_alphabet: FxHashSet, /// An optional prefix to use on any subword that exist only behind another one pub continuing_subword_prefix: Option, /// An optional suffix to caracterize and end-of-word subword @@ -195,7 +199,7 @@ pub struct BpeTrainer { /// An optional parameter to limit the max length of any single token pub max_token_length: Option, - words: HashMap, + words: FxHashMap, } impl Default for BpeTrainer { @@ -251,7 +255,11 @@ impl BpeTrainer { } /// Add the provided special tokens to the initial vocabulary - fn add_special_tokens(&self, w2id: &mut HashMap, id2w: &mut Vec) { + fn add_special_tokens( + &self, + w2id: &mut HashMap, + id2w: &mut Vec, + ) { for token in &self.special_tokens { if !w2id.contains_key(&token.content) { id2w.push(token.content.to_owned()); @@ -261,14 +269,14 @@ impl BpeTrainer { } /// Compute the initial alphabet and limit it if relevant - fn compute_alphabet( + fn compute_alphabet( &self, - wc: &HashMap, - w2id: &mut HashMap, + wc: &HashMap, + w2id: &mut HashMap, id2w: &mut Vec, ) { // Compute the alphabet from seen words - let mut alphabet: HashMap = HashMap::new(); + let mut alphabet: FxHashMap = FxHashMap::default(); for (word, count) in wc { for c in word.chars() { alphabet @@ -320,10 +328,10 @@ impl BpeTrainer { } /// Tokenize words and add subwords to the vocabulary when relevant - fn tokenize_words( + fn tokenize_words( &self, - wc: &HashMap, - w2id: &mut HashMap, + wc: &HashMap, + w2id: &mut HashMap, id2w: &mut Vec, p: &Option, ) -> (Vec, Vec) { @@ -375,13 +383,13 @@ impl BpeTrainer { words: &[Word], counts: &[u64], p: &Option, - ) -> (HashMap, HashMap>) { + ) -> (FxHashMap, FxHashMap>) { words .maybe_par_iter() .enumerate() .map(|(i, word)| { - let mut pair_counts = HashMap::new(); - let mut where_to_update: HashMap> = HashMap::new(); + let mut pair_counts = FxHashMap::default(); + let mut where_to_update: FxHashMap> = FxHashMap::default(); for window in word.get_chars().windows(2) { let cur_pair: Pair = (window[0], window[1]); @@ -399,7 +407,7 @@ impl BpeTrainer { h.insert(i); }) .or_insert_with(|| { - let mut h = HashSet::new(); + let mut h = FxHashSet::default(); h.insert(i); h }); @@ -413,7 +421,7 @@ impl BpeTrainer { (pair_counts, where_to_update) }) .reduce( - || (HashMap::new(), HashMap::new()), + || (FxHashMap::default(), FxHashMap::default()), |(mut pair_counts, mut where_to_update), (pc, wtu)| { for (k, v) in pc { pair_counts.entry(k).and_modify(|c| *c += v).or_insert(v); @@ -429,12 +437,13 @@ impl BpeTrainer { ) } - pub fn do_train( + pub fn do_train( &self, - word_counts: &HashMap, + word_counts: &HashMap, model: &mut BPE, ) -> Result> { - let mut word_to_id: HashMap = HashMap::with_capacity(self.vocab_size); + let mut word_to_id: FxHashMap = + FxHashMap::with_capacity_and_hasher(self.vocab_size, Default::default()); let mut id_to_word: Vec = Vec::with_capacity(self.vocab_size); let max_token_length: usize = self.max_token_length.unwrap_or(usize::MAX); @@ -532,7 +541,7 @@ impl BpeTrainer { // Merge the new pair in every words // Safety: This is just a type assertion, the code below may no longer be safe // if the type of `pos` changes - let pos: &HashSet = &top.pos; + let pos: &FxHashSet = &top.pos; let words_len = words.len(); struct WordPtr(*mut Word); @@ -577,7 +586,7 @@ impl BpeTrainer { h.insert(iw); }) .or_insert_with(|| { - let mut h = HashSet::new(); + let mut h = FxHashSet::default(); h.insert(iw); h }); @@ -647,18 +656,18 @@ impl Trainer for BpeTrainer { S: AsRef + Send, F: Fn(&str) -> Result> + Sync, { - let words: Result> = iterator + let words: Result> = iterator .maybe_par_bridge() .map(|sequence| { let words = process(sequence.as_ref())?; - let mut map = HashMap::new(); + let mut map = FxHashMap::default(); for word in words { map.entry(word).and_modify(|c| *c += 1).or_insert(1); } Ok(map) }) .reduce( - || Ok(HashMap::new()), + || Ok(FxHashMap::default()), |acc, ws| { let mut acc = acc?; for (k, v) in ws? { @@ -675,9 +684,10 @@ impl Trainer for BpeTrainer { #[cfg(test)] mod tests { - use super::{BpeTrainer, Pair, BPE}; use std::collections::HashMap; + use super::{BpeTrainer, Pair, BPE}; + #[test] fn test_train() { let word_counts: HashMap = [ @@ -735,7 +745,12 @@ mod tests { .iter() .cloned() .collect(); - assert_eq!(model.vocab, expected_vocab); + + let mut lhs = model.vocab.into_iter().collect::>(); + let mut rhs = expected_vocab.into_iter().collect::>(); + lhs.sort_unstable(); + rhs.sort_unstable(); + assert_eq!(lhs, rhs); // The keys in `merges` are pairs of symbols, the values are tuples of (rank, id), // where 'rank' determines the order in which this merge will be applied during @@ -749,7 +764,12 @@ mod tests { .iter() .cloned() .collect(); - assert_eq!(model.merges, expected_merges); + + let mut lhs = model.merges.into_iter().collect::>(); + let mut rhs = expected_merges.into_iter().collect::>(); + lhs.sort_unstable(); + rhs.sort_unstable(); + assert_eq!(lhs, rhs); } #[test] fn bpe_test_max_token_length_16() { @@ -823,7 +843,7 @@ mod tests { .build(); let mut model = BPE::default(); trainer.do_train(&long_word_counts, &mut model).unwrap(); - let trained_vocab: HashMap = model.get_vocab(); + let trained_vocab = model.get_vocab(); let expected_vocab: HashMap = [ ("短", 12), ("n", 6), @@ -860,6 +880,11 @@ mod tests { .cloned() .map(|(k, v)| (k.to_string(), v)) .collect(); - assert_eq!(trained_vocab, expected_vocab) + + let mut lhs = trained_vocab.into_iter().collect::>(); + let mut rhs = expected_vocab.into_iter().collect::>(); + lhs.sort_unstable(); + rhs.sort_unstable(); + assert_eq!(lhs, rhs) } } diff --git a/tokenizers/src/models/bpe/word.rs b/tokenizers/src/models/bpe/word.rs index 93b3d9c37..24eb0e730 100644 --- a/tokenizers/src/models/bpe/word.rs +++ b/tokenizers/src/models/bpe/word.rs @@ -2,6 +2,7 @@ use super::Pair; use rand::{thread_rng, Rng}; use std::cmp::Ordering; use std::collections::{BinaryHeap, HashMap}; +use std::hash::BuildHasher; #[derive(Debug, Eq)] struct Merge { @@ -158,7 +159,11 @@ impl Word { changes } - pub(super) fn merge_all(&mut self, merges: &HashMap, dropout: Option) { + pub(super) fn merge_all( + &mut self, + merges: &HashMap, + dropout: Option, + ) { let mut queue = BinaryHeap::with_capacity(self.symbols.len()); let mut skip = Vec::with_capacity(queue.len()); diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index 3a3a91adc..0fd750447 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -5,7 +5,7 @@ pub mod unigram; pub mod wordlevel; pub mod wordpiece; -use std::collections::HashMap; +use rustc_hash::FxHashMap; use std::path::{Path, PathBuf}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; @@ -19,11 +19,11 @@ use crate::{AddedToken, Model, Result, Token, Trainer}; /// Wraps a vocab mapping (ID -> token) to a struct that will be serialized in order /// of token ID, smallest to largest. struct OrderedVocabIter<'a> { - vocab_r: &'a HashMap, + vocab_r: &'a FxHashMap, } impl<'a> OrderedVocabIter<'a> { - fn new(vocab_r: &'a HashMap) -> Self { + fn new(vocab_r: &'a FxHashMap) -> Self { Self { vocab_r } } } @@ -35,7 +35,7 @@ impl Serialize for OrderedVocabIter<'_> { { // There could be holes so max + 1 is more correct than vocab_r.len() let mut holes = vec![]; - let result = if let Some(max) = self.vocab_r.iter().map(|(key, _)| key).max() { + let result = if let Some(max) = self.vocab_r.keys().max() { let iter = (0..*max + 1).filter_map(|i| { if let Some(token) = self.vocab_r.get(&i) { Some((token, i)) @@ -170,7 +170,7 @@ impl Model for ModelWrapper { } } - fn get_vocab(&self) -> HashMap { + fn get_vocab(&self) -> FxHashMap { match self { Self::WordLevel(t) => t.get_vocab(), Self::WordPiece(t) => t.get_vocab(), @@ -287,6 +287,8 @@ impl_enum_from!(WordLevelTrainer, TrainerWrapper, WordLevelTrainer); #[cfg(test)] mod tests { + use std::iter::FromIterator; + use super::*; use crate::models::bpe::{BpeBuilder, Vocab}; @@ -301,8 +303,8 @@ mod tests { #[test] fn incomplete_ordered_vocab() { - let vocab_r: HashMap = - HashMap::from([(0, "Hi".to_string()), (2, "There".to_string())]); + let vocab_r: FxHashMap = + FxHashMap::from_iter([(0, "Hi".to_string()), (2, "There".to_string())]); let ordered = OrderedVocabIter::new(&vocab_r); diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index da4d631ce..1c3f14234 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -6,12 +6,12 @@ use super::{ use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::{Cache, MAX_LENGTH}; -use std::collections::HashMap; +use rustc_hash::FxHashMap; use std::convert::TryInto; use std::fs::read_to_string; use std::path::{Path, PathBuf}; -type TokenMap = HashMap; +type TokenMap = FxHashMap; type Vocab = Vec<(String, f64)>; /// A `Unigram` model to encode sentences. @@ -98,7 +98,7 @@ impl Unigram { byte_fallback: bool, ) -> Result { let n = vocab.len(); - let mut token_to_ids: TokenMap = HashMap::new(); + let mut token_to_ids: TokenMap = FxHashMap::default(); let mut builder = TrieBuilder::default(); if let Some(unk_id) = unk_id { @@ -415,7 +415,7 @@ impl<'a> Iterator for UnigramIterator<'a> { impl Model for Unigram { type Trainer = UnigramTrainer; - fn get_vocab(&self) -> HashMap { + fn get_vocab(&self) -> FxHashMap { self.token_to_ids.clone() } diff --git a/tokenizers/src/models/unigram/trainer.rs b/tokenizers/src/models/unigram/trainer.rs index 5d178e77b..ebd4749fa 100644 --- a/tokenizers/src/models/unigram/trainer.rs +++ b/tokenizers/src/models/unigram/trainer.rs @@ -3,10 +3,13 @@ use crate::tokenizer::{AddedToken, Result, Trainer}; use crate::utils::parallelism::*; use crate::utils::progress::{ProgressBar, ProgressStyle}; use log::debug; +use rustc_hash::FxHashMap; +use rustc_hash::FxHashSet; use serde::{Deserialize, Serialize}; use std::cmp::Reverse; -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use std::convert::TryInto; +use std::hash::BuildHasher; // A token and a score type SentencePiece = (String, f64); @@ -57,8 +60,8 @@ pub struct UnigramTrainer { pub shrinking_factor: f64, #[builder(default = "vec![]")] pub special_tokens: Vec, - #[builder(default = "HashSet::new()")] - pub initial_alphabet: HashSet, + #[builder(default = "FxHashSet::default()")] + pub initial_alphabet: FxHashSet, #[builder(default = "None")] pub unk_token: Option, @@ -67,8 +70,8 @@ pub struct UnigramTrainer { pub max_piece_length: usize, #[builder(default = "1_000_000")] seed_size: usize, - #[builder(default = "HashMap::new()")] - words: HashMap, + #[builder(default = "FxHashMap::default()")] + words: FxHashMap, } impl Default for UnigramTrainer { @@ -110,17 +113,21 @@ impl UnigramTrainer { true } - fn finalize(&self, model: Unigram, required_chars: HashSet) -> Result { + fn finalize( + &self, + model: Unigram, + required_chars: HashSet, + ) -> Result { let mut min_score_penalty = 0.0; let min_score_penalty_delta = 0.0001; let mut pieces: Vec<(String, f64)> = vec![]; - let mut inserted: HashSet = HashSet::new(); + let mut inserted: FxHashSet = FxHashSet::default(); // We don't want to include the that was used to train inserted.insert("".into()); - let existing_pieces: HashMap = model.iter().cloned().collect(); + let existing_pieces: FxHashMap = model.iter().cloned().collect(); for c in required_chars { if let Some(t) = existing_pieces.get(&c) { inserted.insert(c.clone()); @@ -185,7 +192,7 @@ impl UnigramTrainer { ) } - fn required_chars(&self, word_counts: &[Sentence]) -> HashSet { + fn required_chars(&self, word_counts: &[Sentence]) -> FxHashSet { word_counts .iter() .flat_map(|(s, _count)| s.chars()) @@ -205,7 +212,7 @@ impl UnigramTrainer { .sum::() + sentences.len(); let mut flat_string = String::with_capacity(total); - let mut all_chars: HashMap = HashMap::new(); + let mut all_chars: FxHashMap = FxHashMap::default(); let c_sentence_boundary = '\0'; let k_sentence_boundary = '\0'.to_string(); for (string, n) in sentences { @@ -631,18 +638,18 @@ impl Trainer for UnigramTrainer { S: AsRef + Send, F: Fn(&str) -> Result> + Sync, { - let words: Result> = iterator + let words: Result> = iterator .maybe_par_bridge() .map(|sequence| { let words = process(sequence.as_ref())?; - let mut map = HashMap::new(); + let mut map = FxHashMap::default(); for word in words { map.entry(word).and_modify(|c| *c += 1).or_insert(1); } Ok(map) }) .reduce( - || Ok(HashMap::new()), + || Ok(FxHashMap::default()), |acc, ws| { let mut acc = acc?; for (k, v) in ws? { @@ -661,7 +668,7 @@ impl Trainer for UnigramTrainer { mod tests { use super::*; use assert_approx_eq::assert_approx_eq; - use std::iter::FromIterator; + use std::{collections::HashSet, iter::FromIterator}; #[test] fn test_unigram_chars() { @@ -722,13 +729,18 @@ mod tests { let sentences = vec![("こんにちは友達".to_string(), 1)]; let required_chars = trainer.required_chars(&sentences); - assert_eq!( - required_chars, - vec!["こ", "ん", "に", "ち", "は", "友", "達", "a", "b", "c", "d", "e", "f"] - .into_iter() - .map(|s| s.to_owned()) - .collect::>() - ); + + let mut lhs = required_chars.into_iter().collect::>(); + let mut rhs = vec![ + "こ", "ん", "に", "ち", "は", "友", "達", "a", "b", "c", "d", "e", "f", + ] + .into_iter() + .collect::>(); + + lhs.sort_unstable(); + rhs.sort_unstable(); + + assert_eq!(lhs, rhs); } #[test] diff --git a/tokenizers/src/models/unigram/trie.rs b/tokenizers/src/models/unigram/trie.rs index 2f94b1766..70f5333d4 100644 --- a/tokenizers/src/models/unigram/trie.rs +++ b/tokenizers/src/models/unigram/trie.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use rustc_hash::FxHashMap; use std::hash::Hash; #[derive(Default)] @@ -78,14 +78,14 @@ impl