diff --git a/bindings/node/src/models.rs b/bindings/node/src/models.rs index a4138b91f..9e35727b8 100644 --- a/bindings/node/src/models.rs +++ b/bindings/node/src/models.rs @@ -4,7 +4,7 @@ use crate::trainers::Trainer; use napi::bindgen_prelude::*; use napi_derive::napi; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; +use rustc_hash::FxHashMap; use std::path::{Path, PathBuf}; use std::sync::{Arc, RwLock}; use tokenizers as tk; @@ -95,7 +95,7 @@ impl tk::Model for Model { self.model.as_ref()?.read().unwrap().id_to_token(id) } - fn get_vocab(&self) -> HashMap { + fn get_vocab(&self) -> FxHashMap { self .model .as_ref() diff --git a/bindings/node/src/tokenizer.rs b/bindings/node/src/tokenizer.rs index 4acbcac83..a99ac0313 100644 --- a/bindings/node/src/tokenizer.rs +++ b/bindings/node/src/tokenizer.rs @@ -6,7 +6,7 @@ use crate::pre_tokenizers::PreTokenizer; use crate::processors::Processor; use crate::tasks::tokenizer::{DecodeBatchTask, DecodeTask, EncodeBatchTask, EncodeTask}; use crate::trainers::Trainer; -use std::collections::HashMap; +use rustc_hash::FxHashMap; use tokenizers::Model as ModelTrait; use napi::bindgen_prelude::*; @@ -433,7 +433,7 @@ impl Tokenizer { } #[napi] - pub fn get_vocab(&self, with_added_tokens: Option) -> HashMap { + pub fn get_vocab(&self, with_added_tokens: Option) -> FxHashMap { let with_added_tokens = with_added_tokens.unwrap_or(true); self.tokenizer.read().unwrap().get_vocab(with_added_tokens) } diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 98bf2d694..bfa063771 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -18,6 +18,8 @@ pyo3 = { version = "0.23", features = ["abi3", "abi3-py39", "py-clone"] } numpy = "0.23" ndarray = "0.16" itertools = "0.12" +rustc-hash = "2.1.1" +compact_str = { version = "0.8.1", features = ["serde"] } [dependencies.tokenizers] path = "../../tokenizers" diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index 4a408ff1d..78b0354ed 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -3,6 +3,7 @@ use std::sync::{Arc, RwLock}; use crate::pre_tokenizers::from_string; use crate::tokenizer::PyTokenizer; use crate::utils::PyPattern; +use compact_str::ToCompactString; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; @@ -91,7 +92,10 @@ impl PyDecoder { } impl Decoder for PyDecoder { - fn decode_chain(&self, tokens: Vec) -> tk::Result> { + fn decode_chain( + &self, + tokens: Vec, + ) -> tk::Result> { self.decoder.decode_chain(tokens) } } @@ -139,7 +143,12 @@ impl PyDecoder { /// :obj:`str`: The decoded string #[pyo3(text_signature = "(self, tokens)")] fn decode(&self, tokens: Vec) -> PyResult { - ToPyResult(self.decoder.decode(tokens)).into() + ToPyResult( + self.decoder + .decode(tokens) + .map(|t| t.to_compact_string().to_string()), + ) + .into() } fn __repr__(&self) -> PyResult { @@ -235,12 +244,12 @@ pub struct PyWordPieceDec {} impl PyWordPieceDec { #[getter] fn get_prefix(self_: PyRef) -> String { - getter!(self_, WordPiece, prefix.clone()) + getter!(self_, WordPiece, prefix.clone().to_string()) } #[setter] fn set_prefix(self_: PyRef, prefix: String) { - setter!(self_, WordPiece, prefix, prefix); + setter!(self_, WordPiece, prefix, prefix.to_compact_string()); } #[getter] @@ -256,7 +265,10 @@ impl PyWordPieceDec { #[new] #[pyo3(signature = (prefix = String::from("##"), cleanup = true), text_signature = "(self, prefix=\"##\", cleanup=True)")] fn new(prefix: String, cleanup: bool) -> (Self, PyDecoder) { - (PyWordPieceDec {}, WordPiece::new(prefix, cleanup).into()) + ( + PyWordPieceDec {}, + WordPiece::new(prefix.to_compact_string(), cleanup).into(), + ) } } @@ -412,12 +424,12 @@ pub struct PyBPEDecoder {} impl PyBPEDecoder { #[getter] fn get_suffix(self_: PyRef) -> String { - getter!(self_, BPE, suffix.clone()) + getter!(self_, BPE, suffix.to_string()) } #[setter] fn set_suffix(self_: PyRef, suffix: String) { - setter!(self_, BPE, suffix, suffix); + setter!(self_, BPE, suffix, suffix.into()); } #[new] @@ -443,22 +455,27 @@ pub struct PyCTCDecoder {} impl PyCTCDecoder { #[getter] fn get_pad_token(self_: PyRef) -> String { - getter!(self_, CTC, pad_token.clone()) + getter!(self_, CTC, pad_token.to_string()) } #[setter] fn set_pad_token(self_: PyRef, pad_token: String) { - setter!(self_, CTC, pad_token, pad_token); + setter!(self_, CTC, pad_token, pad_token.into()); } #[getter] fn get_word_delimiter_token(self_: PyRef) -> String { - getter!(self_, CTC, word_delimiter_token.clone()) + getter!(self_, CTC, word_delimiter_token.clone()).to_string() } #[setter] fn set_word_delimiter_token(self_: PyRef, word_delimiter_token: String) { - setter!(self_, CTC, word_delimiter_token, word_delimiter_token); + setter!( + self_, + CTC, + word_delimiter_token, + word_delimiter_token.into() + ); } #[getter] @@ -526,22 +543,33 @@ impl CustomDecoder { } impl Decoder for CustomDecoder { - fn decode(&self, tokens: Vec) -> tk::Result { + fn decode(&self, tokens: Vec) -> tk::Result { + let tokens: Vec = tokens + .into_iter() + .map(|t| t.to_compact_string().to_string()) + .collect(); Python::with_gil(|py| { let decoded = self .inner .call_method(py, "decode", (tokens,), None)? - .extract(py)?; + .extract::(py)?; Ok(decoded) }) } - fn decode_chain(&self, tokens: Vec) -> tk::Result> { + fn decode_chain( + &self, + tokens: Vec, + ) -> tk::Result> { + let tokens: Vec = tokens + .into_iter() + .map(|t| t.to_compact_string().to_string()) + .collect(); Python::with_gil(|py| { let decoded = self .inner .call_method(py, "decode_chain", (tokens,), None)? - .extract(py)?; + .extract::>(py)?; Ok(decoded) }) } @@ -595,10 +623,21 @@ where } impl Decoder for PyDecoderWrapper { - fn decode_chain(&self, tokens: Vec) -> tk::Result> { + fn decode_chain( + &self, + tokens: Vec, + ) -> tk::Result> { match self { - PyDecoderWrapper::Wrapped(inner) => inner.read().unwrap().decode_chain(tokens), - PyDecoderWrapper::Custom(inner) => inner.read().unwrap().decode_chain(tokens), + PyDecoderWrapper::Wrapped(inner) => inner + .read() + .unwrap() + .decode_chain(tokens) + .map(|v| v.into_iter().map(|t| t.to_compact_string()).collect()), + PyDecoderWrapper::Custom(inner) => inner + .read() + .unwrap() + .decode_chain(tokens) + .map(|v| v.into_iter().map(|t| t.to_compact_string()).collect()), } } } @@ -663,14 +702,17 @@ impl PyDecodeStream { #[pyo3(signature = (tokenizer, id), text_signature = "(self, tokenizer, id)")] fn step(&mut self, tokenizer: &PyTokenizer, id: u32) -> PyResult> { - ToPyResult(tk::tokenizer::step_decode_stream( - &tokenizer.tokenizer, - id, - self.skip_special_tokens, - &mut self.ids, - &mut self.prefix, - &mut self.prefix_index, - )) + ToPyResult( + tk::tokenizer::step_decode_stream( + &tokenizer.tokenizer, + id, + self.skip_special_tokens, + &mut self.ids, + &mut self.prefix.to_compact_string(), + &mut self.prefix_index, + ) + .map(|o| o.map(|s| s.to_string())), + ) .into() } } diff --git a/bindings/python/src/encoding.rs b/bindings/python/src/encoding.rs index e157b8006..a2917c086 100644 --- a/bindings/python/src/encoding.rs +++ b/bindings/python/src/encoding.rs @@ -127,7 +127,11 @@ impl PyEncoding { /// :obj:`List[str]`: The list of tokens #[getter] fn get_tokens(&self) -> Vec { - self.encoding.get_tokens().to_vec() + self.encoding + .get_tokens() + .iter() + .map(|x| x.to_string()) + .collect() } /// The generated word indices. diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 2f4dba825..e03b1fab4 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -1,4 +1,5 @@ -use std::collections::HashMap; +use compact_str::{CompactString, ToCompactString}; +use rustc_hash::FxHashMap; use std::path::{Path, PathBuf}; use std::sync::{Arc, RwLock}; @@ -8,7 +9,7 @@ use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; use serde::{Deserialize, Serialize}; -use tk::models::bpe::{BpeBuilder, Merges, Vocab, BPE}; +use tk::models::bpe::{BpeBuilder, BPE}; use tk::models::unigram::Unigram; use tk::models::wordlevel::WordLevel; use tk::models::wordpiece::{WordPiece, WordPieceBuilder}; @@ -18,6 +19,9 @@ use tokenizers as tk; use super::error::{deprecation_warning, ToPyResult}; +pub type Vocab = FxHashMap; +pub type Merges = Vec<(String, String)>; + /// Base class for all models /// /// The model represents the actual tokenization algorithm. This is the part that @@ -66,11 +70,15 @@ impl Model for PyModel { self.model.read().unwrap().token_to_id(token) } - fn id_to_token(&self, id: u32) -> Option { - self.model.read().unwrap().id_to_token(id) + fn id_to_token(&self, id: u32) -> Option { + self.model + .read() + .unwrap() + .id_to_token(id) + .map(|t| t.to_compact_string()) } - fn get_vocab(&self) -> HashMap { + fn get_vocab(&self) -> FxHashMap { self.model.read().unwrap().get_vocab() } @@ -175,7 +183,7 @@ impl PyModel { /// :obj:`str`: The token associated to the ID #[pyo3(text_signature = "(self, id)")] fn id_to_token(&self, id: u32) -> Option { - self.model.read().unwrap().id_to_token(id) + self.model.read().unwrap().id_to_token(id).map(|t| t.into()) } /// Save the current model @@ -297,14 +305,19 @@ impl PyBPE { } } "unk_token" => { - if let Some(unk) = value.extract()? { - builder = builder.unk_token(unk); + if let Some(unk) = value.extract::>()? { + builder = builder.unk_token(unk.to_compact_string()); } } "continuing_subword_prefix" => { - builder = builder.continuing_subword_prefix(value.extract()?) + builder = builder.continuing_subword_prefix( + value.extract::()?.to_compact_string(), + ) + } + "end_of_word_suffix" => { + builder = builder + .end_of_word_suffix(value.extract::()?.to_compact_string()) } - "end_of_word_suffix" => builder = builder.end_of_word_suffix(value.extract()?), "fuse_unk" => builder = builder.fuse_unk(value.extract()?), "byte_fallback" => builder = builder.byte_fallback(value.extract()?), "ignore_merges" => builder = builder.ignore_merges(value.extract()?), @@ -345,16 +358,64 @@ macro_rules! setter { }}; } -#[derive(FromPyObject)] enum PyVocab { Vocab(Vocab), Filename(String), } -#[derive(FromPyObject)] +impl<'py> FromPyObject<'py> for PyVocab { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + if let Ok(dict) = ob.downcast::() { + let mut vocab: Vocab = FxHashMap::default(); + for (key, value) in dict.iter() { + let key_str = key.extract::()?; + let value_u32 = value.extract::()?; + vocab.insert(key_str, value_u32); + } + Ok(PyVocab::Vocab(vocab)) + } else if let Ok(s) = ob.extract::() { + Ok(PyVocab::Filename(s)) + } else { + Err(PyErr::new::( + "Expected a dictionary (str -> u32) or a string for PyVocab", + )) + } + } +} enum PyMerges { Merges(Merges), Filename(String), } +impl<'py> FromPyObject<'py> for PyMerges { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + if let Ok(list) = ob.downcast::() { + let mut merges: Merges = Vec::new(); + for item in list.iter() { + if let Ok(tup) = item.downcast::() { + if tup.len() == 2 { + let first = tup.get_item(0)?.extract::()?; + let second = tup.get_item(1)?.extract::()?; + merges.push((first, second)) + } else { + return Err(PyErr::new::( + "Expected tuples of length 2 for Merges variant", + )); + } + } else { + return Err(PyErr::new::( + "Expected a list of tuples (CompactString, CompactString) for Merges variant", + )); + } + } + Ok(PyMerges::Merges(merges)) + } else if let Ok(s) = ob.downcast::() { + Ok(PyMerges::Filename(s.to_string())) + } else { + Err(PyErr::new::( + "Expected list of tuples or a string for Merges", + )) + } + } +} #[pymethods] impl PyBPE { @@ -370,17 +431,22 @@ impl PyBPE { #[getter] fn get_unk_token(self_: PyRef) -> Option { - getter!(self_, BPE, unk_token.clone()) + getter!(self_, BPE, unk_token.clone()).map(|t| t.into()) } #[setter] fn set_unk_token(self_: PyRef, unk_token: Option) { - setter!(self_, BPE, unk_token, unk_token); + setter!( + self_, + BPE, + unk_token, + unk_token.map(|t| t.to_compact_string()) + ); } #[getter] fn get_continuing_subword_prefix(self_: PyRef) -> Option { - getter!(self_, BPE, continuing_subword_prefix.clone()) + getter!(self_, BPE, continuing_subword_prefix.clone()).map(|t| t.into()) } #[setter] @@ -392,18 +458,23 @@ impl PyBPE { self_, BPE, continuing_subword_prefix, - continuing_subword_prefix + continuing_subword_prefix.map(|t| t.to_compact_string()) ); } #[getter] fn get_end_of_word_suffix(self_: PyRef) -> Option { - getter!(self_, BPE, end_of_word_suffix.clone()) + getter!(self_, BPE, end_of_word_suffix.clone()).map(|t| t.into()) } #[setter] fn set_end_of_word_suffix(self_: PyRef, end_of_word_suffix: Option) { - setter!(self_, BPE, end_of_word_suffix, end_of_word_suffix); + setter!( + self_, + BPE, + end_of_word_suffix, + end_of_word_suffix.map(|t| t.to_compact_string()) + ); } #[getter] @@ -454,6 +525,14 @@ impl PyBPE { if let (Some(vocab), Some(merges)) = (vocab, merges) { match (vocab, merges) { (PyVocab::Vocab(vocab), PyMerges::Merges(merges)) => { + let vocab = vocab + .into_iter() + .map(|(k, v)| (k.to_compact_string(), v)) + .collect(); + let merges = merges + .into_iter() + .map(|(k, v)| (k.to_compact_string(), v.to_compact_string())) + .collect(); builder = builder.vocab_and_merges(vocab, merges); } (PyVocab::Filename(vocab_filename), PyMerges::Filename(merges_filename)) => { @@ -495,12 +574,22 @@ impl PyBPE { #[staticmethod] #[pyo3(text_signature = "(self, vocab, merges)")] fn read_file(vocab: &str, merges: &str) -> PyResult<(Vocab, Merges)> { - BPE::read_file(vocab, merges).map_err(|e| { - exceptions::PyException::new_err(format!( - "Error while reading vocab & merges files: {}", - e - )) - }) + BPE::read_file(vocab, merges) + .map(|(vocab, merges)| { + ( + vocab.into_iter().map(|(k, v)| (k.to_string(), v)).collect(), + merges + .into_iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(), + ) + }) + .map_err(|e| { + exceptions::PyException::new_err(format!( + "Error while reading vocab & merges files: {}", + e + )) + }) } /// Instantiate a BPE model from the given files. @@ -540,8 +629,15 @@ impl PyBPE { py, PyBPE::new( py, - Some(PyVocab::Vocab(vocab)), - Some(PyMerges::Merges(merges)), + Some(PyVocab::Vocab( + vocab.into_iter().map(|(k, v)| (k.to_string(), v)).collect(), + )), + Some(PyMerges::Merges( + merges + .into_iter() + .map(|(s1, s2)| (s1.to_string(), s2.to_string())) + .collect(), + )), kwargs, )?, ) @@ -596,13 +692,15 @@ impl PyWordPiece { let key: String = key.extract()?; match key.as_ref() { "unk_token" => { - builder = builder.unk_token(val.extract()?); + builder = builder.unk_token(val.extract::()?.to_compact_string()); } "max_input_chars_per_word" => { builder = builder.max_input_chars_per_word(val.extract()?); } "continuing_subword_prefix" => { - builder = builder.continuing_subword_prefix(val.extract()?); + builder = builder.continuing_subword_prefix( + val.extract::()?.to_compact_string(), + ); } _ => println!("Ignored unknown kwargs option {}", key), } @@ -623,17 +721,17 @@ impl PyWordPiece { impl PyWordPiece { #[getter] fn get_unk_token(self_: PyRef) -> String { - getter!(self_, WordPiece, unk_token.clone()) + getter!(self_, WordPiece, unk_token.clone()).into() } #[setter] fn set_unk_token(self_: PyRef, unk_token: String) { - setter!(self_, WordPiece, unk_token, unk_token); + setter!(self_, WordPiece, unk_token, unk_token.to_compact_string()); } #[getter] fn get_continuing_subword_prefix(self_: PyRef) -> String { - getter!(self_, WordPiece, continuing_subword_prefix.clone()) + getter!(self_, WordPiece, continuing_subword_prefix.clone()).into() } #[setter] @@ -642,7 +740,7 @@ impl PyWordPiece { self_, WordPiece, continuing_subword_prefix, - continuing_subword_prefix + continuing_subword_prefix.to_compact_string() ); } @@ -668,6 +766,10 @@ impl PyWordPiece { if let Some(vocab) = vocab { match vocab { PyVocab::Vocab(vocab) => { + let vocab = vocab + .into_iter() + .map(|(k, v)| (k.to_compact_string(), v)) + .collect(); builder = builder.vocab(vocab); } PyVocab::Filename(vocab_filename) => { @@ -676,7 +778,7 @@ impl PyWordPiece { "0.9.0", "WordPiece.__init__ will not create from files anymore, try `WordPiece.from_file` instead", )?; - builder = builder.files(vocab_filename.to_string()); + builder = builder.files(vocab_filename.to_compact_string()); } } } @@ -700,9 +802,14 @@ impl PyWordPiece { #[staticmethod] #[pyo3(text_signature = "(vocab)")] fn read_file(vocab: &str) -> PyResult { - WordPiece::read_file(vocab).map_err(|e| { - exceptions::PyException::new_err(format!("Error while reading WordPiece file: {}", e)) - }) + WordPiece::read_file(vocab) + .map(|vocab| vocab.into_iter().map(|(k, v)| (k.to_string(), v)).collect()) + .map_err(|e| { + exceptions::PyException::new_err(format!( + "Error while reading WordPiece file: {}", + e + )) + }) } /// Instantiate a WordPiece model from the given file @@ -736,7 +843,13 @@ impl PyWordPiece { })?; Py::new( py, - PyWordPiece::new(py, Some(PyVocab::Vocab(vocab)), kwargs)?, + PyWordPiece::new( + py, + Some(PyVocab::Vocab( + vocab.into_iter().map(|(k, v)| (k.to_string(), v)).collect(), + )), + kwargs, + )?, ) } } @@ -758,12 +871,12 @@ pub struct PyWordLevel {} impl PyWordLevel { #[getter] fn get_unk_token(self_: PyRef) -> String { - getter!(self_, WordLevel, unk_token.clone()) + getter!(self_, WordLevel, unk_token.clone()).into() } #[setter] fn set_unk_token(self_: PyRef, unk_token: String) { - setter!(self_, WordLevel, unk_token, unk_token); + setter!(self_, WordLevel, unk_token, unk_token.to_compact_string()); } #[new] @@ -778,6 +891,10 @@ impl PyWordLevel { if let Some(vocab) = vocab { match vocab { PyVocab::Vocab(vocab) => { + let vocab = vocab + .into_iter() + .map(|(k, v)| (k.to_compact_string(), v)) + .collect(); builder = builder.vocab(vocab); } PyVocab::Filename(vocab_filename) => { @@ -787,12 +904,12 @@ impl PyWordLevel { "WordLevel.__init__ will not create from files anymore, \ try `WordLevel.from_file` instead", )?; - builder = builder.files(vocab_filename.to_string()); + builder = builder.files(vocab_filename.to_compact_string()); } }; } if let Some(unk_token) = unk_token { - builder = builder.unk_token(unk_token); + builder = builder.unk_token(unk_token.to_compact_string()); } Ok(( @@ -819,9 +936,14 @@ impl PyWordLevel { #[staticmethod] #[pyo3(text_signature = "(vocab)")] fn read_file(vocab: &str) -> PyResult { - WordLevel::read_file(vocab).map_err(|e| { - exceptions::PyException::new_err(format!("Error while reading WordLevel file: {}", e)) - }) + WordLevel::read_file(vocab) + .map(|vocab| vocab.into_iter().map(|(k, v)| (k.to_string(), v)).collect()) + .map_err(|e| { + exceptions::PyException::new_err(format!( + "Error while reading WordLevel file: {}", + e + )) + }) } /// Instantiate a WordLevel model from the given file @@ -855,7 +977,13 @@ impl PyWordLevel { })?; Py::new( py, - PyWordLevel::new(py, Some(PyVocab::Vocab(vocab)), unk_token)?, + PyWordLevel::new( + py, + Some(PyVocab::Vocab( + vocab.into_iter().map(|(k, v)| (k.to_string(), v)).collect(), + )), + unk_token, + )?, ) } } @@ -879,6 +1007,10 @@ impl PyUnigram { ) -> PyResult<(Self, PyModel)> { match (vocab, unk_id, byte_fallback) { (Some(vocab), unk_id, byte_fallback) => { + let vocab = vocab + .into_iter() + .map(|(t, s)| (t.to_compact_string(), s)) + .collect(); let model = Unigram::from(vocab, unk_id, byte_fallback.unwrap_or(false)).map_err(|e| { exceptions::PyException::new_err(format!( diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index 3cd59a3c7..c5c33f555 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -521,12 +521,12 @@ pub struct PyPrepend {} impl PyPrepend { #[getter] fn get_prepend(self_: PyRef) -> String { - getter!(self_, Prepend, prepend) + getter!(self_, Prepend, prepend).into() } #[setter] fn set_prepend(self_: PyRef, prepend: String) { - setter!(self_, Prepend, prepend, prepend) + setter!(self_, Prepend, prepend, prepend.into()) } #[new] @@ -624,12 +624,12 @@ impl PyReplace { #[getter] fn get_content(self_: PyRef) -> String { - getter!(self_, Replace, content) + getter!(self_, Replace, content).to_string() } #[setter] fn set_content(self_: PyRef, content: String) { - setter!(self_, Replace, content, content) + setter!(self_, Replace, content, content.into()) } } diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index 07784afaa..d396bd03b 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -4,6 +4,7 @@ use std::sync::RwLock; use crate::encoding::PyEncoding; use crate::error::ToPyResult; +use compact_str::ToCompactString; use pyo3::exceptions; use pyo3::exceptions::PyException; use pyo3::prelude::*; @@ -335,7 +336,14 @@ impl PyBertProcessing { #[new] #[pyo3(text_signature = "(self, sep, cls)")] fn new(sep: (String, u32), cls: (String, u32)) -> (Self, PyPostProcessor) { - (PyBertProcessing {}, BertProcessing::new(sep, cls).into()) + ( + PyBertProcessing {}, + BertProcessing::new( + (sep.0.to_compact_string(), sep.1), + (cls.0.to_compact_string(), cls.1), + ) + .into(), + ) } fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult> { @@ -354,8 +362,8 @@ impl PyBertProcessing { #[setter] fn set_sep(self_: PyRef, sep: Bound<'_, PyTuple>) -> PyResult<()> { - let sep = sep.extract()?; - setter!(self_, Bert, sep, sep); + let sep = sep.extract::<(String, u32)>()?; + setter!(self_, Bert, sep, (sep.0.to_compact_string(), sep.1)); Ok(()) } @@ -371,8 +379,8 @@ impl PyBertProcessing { #[setter] fn set_cls(self_: PyRef, cls: Bound<'_, PyTuple>) -> PyResult<()> { - let cls = cls.extract()?; - setter!(self_, Bert, cls, cls); + let cls = cls.extract::<(String, u32)>()?; + setter!(self_, Bert, cls, (cls.0.to_compact_string(), cls.1)); Ok(()) } } @@ -413,6 +421,8 @@ impl PyRobertaProcessing { trim_offsets: bool, add_prefix_space: bool, ) -> (Self, PyPostProcessor) { + let sep = (sep.0.to_compact_string(), sep.1); + let cls = (cls.0.to_compact_string(), cls.1); let proc = RobertaProcessing::new(sep, cls) .trim_offsets(trim_offsets) .add_prefix_space(add_prefix_space); @@ -435,8 +445,8 @@ impl PyRobertaProcessing { #[setter] fn set_sep(self_: PyRef, sep: Bound<'_, PyTuple>) -> PyResult<()> { - let sep = sep.extract()?; - setter!(self_, Roberta, sep, sep); + let sep = sep.extract::<(String, u32)>()?; + setter!(self_, Roberta, sep, (sep.0.to_compact_string(), sep.1)); Ok(()) } @@ -452,8 +462,8 @@ impl PyRobertaProcessing { #[setter] fn set_cls(self_: PyRef, cls: Bound<'_, PyTuple>) -> PyResult<()> { - let cls = cls.extract()?; - setter!(self_, Roberta, cls, cls); + let cls = cls.extract::<(String, u32)>()?; + setter!(self_, Roberta, cls, (cls.0.to_compact_string(), cls.1)); Ok(()) } @@ -558,9 +568,9 @@ impl From for SpecialToken { impl FromPyObject<'_> for PySpecialToken { fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { if let Ok(v) = ob.extract::<(String, u32)>() { - Ok(Self(v.into())) + Ok(Self((v.0.to_compact_string(), v.1).into())) } else if let Ok(v) = ob.extract::<(u32, String)>() { - Ok(Self(v.into())) + Ok(Self((v.0, v.1.to_compact_string()).into())) } else if let Ok(d) = ob.downcast::() { let id = d .get_item("id")? @@ -573,10 +583,13 @@ impl FromPyObject<'_> for PySpecialToken { let tokens = d .get_item("tokens")? .ok_or_else(|| exceptions::PyValueError::new_err("`tokens` must be specified"))? - .extract::>()?; + .extract::>()? + .into_iter() + .map(|s| s.to_compact_string()) + .collect(); Ok(Self( - ToPyResult(SpecialToken::new(id, ids, tokens)).into_py()?, + ToPyResult(SpecialToken::new(id.to_compact_string(), ids, tokens)).into_py()?, )) } else { Err(exceptions::PyTypeError::new_err( @@ -708,7 +721,7 @@ impl PyTemplateProcessing { #[getter] fn get_single(self_: PyRef) -> String { - getter!(self_, Template, get_single()) + getter!(self_, Template, get_single()).into() } #[setter] @@ -834,7 +847,7 @@ mod test { fn get_subtype() { Python::with_gil(|py| { let py_proc = PyPostProcessor::new(PyPostProcessorTypeWrapper::Single(Arc::new( - RwLock::new(BertProcessing::new(("SEP".into(), 0), ("CLS".into(), 1)).into()), + RwLock::new(BertProcessing::new(("SEP", 0), ("CLS", 1)).into()), ))); let py_bert = py_proc.get_as_subtype(py).unwrap(); assert_eq!( @@ -846,7 +859,7 @@ mod test { #[test] fn serialize() { - let rs_processing = BertProcessing::new(("SEP".into(), 0), ("CLS".into(), 1)); + let rs_processing = BertProcessing::new(("SEP", 0), ("CLS", 1)); let rs_wrapper: PostProcessorWrapper = rs_processing.clone().into(); let rs_processing_ser = serde_json::to_string(&rs_processing).unwrap(); let rs_wrapper_ser = serde_json::to_string(&rs_wrapper).unwrap(); diff --git a/bindings/python/src/token.rs b/bindings/python/src/token.rs index 86e926028..120b61016 100644 --- a/bindings/python/src/token.rs +++ b/bindings/python/src/token.rs @@ -1,3 +1,4 @@ +use compact_str::ToCompactString; use pyo3::prelude::*; use tk::Token; @@ -22,7 +23,7 @@ impl PyToken { #[new] #[pyo3(text_signature = None)] fn new(id: u32, value: String, offsets: (usize, usize)) -> PyToken { - Token::new(id, value, offsets).into() + Token::new(id, value.to_compact_string(), offsets).into() } #[getter] diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 73a0dbbe8..9b325b266 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1,5 +1,7 @@ +use compact_str::ToCompactString; +use rustc_hash::FxHashMap; use serde::Serialize; -use std::collections::{hash_map::DefaultHasher, HashMap}; +use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; use numpy::{npyffi, PyArray1, PyArrayMethods}; @@ -103,7 +105,7 @@ impl PyAddedToken { let dict = PyDict::new(py); let token = self.get_token(); - dict.set_item("content", token.content)?; + dict.set_item("content", token.content.to_string())?; dict.set_item("single_word", token.single_word)?; dict.set_item("lstrip", token.lstrip)?; dict.set_item("rstrip", token.rstrip)?; @@ -117,7 +119,7 @@ impl PyAddedToken { impl From for PyAddedToken { fn from(token: tk::AddedToken) -> Self { Self { - content: token.content, + content: token.content.to_string(), single_word: Some(token.single_word), lstrip: Some(token.lstrip), rstrip: Some(token.rstrip), @@ -675,8 +677,12 @@ impl PyTokenizer { /// :obj:`Dict[str, int]`: The vocabulary #[pyo3(signature = (with_added_tokens = true))] #[pyo3(text_signature = "(self, with_added_tokens=True)")] - fn get_vocab(&self, with_added_tokens: bool) -> HashMap { - self.tokenizer.get_vocab(with_added_tokens) + fn get_vocab(&self, with_added_tokens: bool) -> FxHashMap { + self.tokenizer + .get_vocab(with_added_tokens) + .into_iter() + .map(|(k, v)| (k.to_string(), v)) + .collect() } /// Get the underlying vocabulary @@ -865,7 +871,9 @@ impl PyTokenizer { } "pad_id" => params.pad_id = value.extract()?, "pad_type_id" => params.pad_type_id = value.extract()?, - "pad_token" => params.pad_token = value.extract()?, + "pad_token" => { + params.pad_token = value.extract::()?.to_compact_string() + } "max_length" => { println!( "enable_padding(max_length=X) is deprecated, \ @@ -921,7 +929,7 @@ impl PyTokenizer { )?; dict.set_item("pad_to_multiple_of", params.pad_to_multiple_of)?; dict.set_item("pad_id", params.pad_id)?; - dict.set_item("pad_token", ¶ms.pad_token)?; + dict.set_item("pad_token", &*params.pad_token)?; dict.set_item("pad_type_id", params.pad_type_id)?; dict.set_item("direction", params.direction.as_ref())?; @@ -1135,7 +1143,12 @@ impl PyTokenizer { #[pyo3(signature = (ids, skip_special_tokens = true))] #[pyo3(text_signature = "(self, ids, skip_special_tokens=True)")] fn decode(&self, ids: Vec, skip_special_tokens: bool) -> PyResult { - ToPyResult(self.tokenizer.decode(&ids, skip_special_tokens)).into() + ToPyResult( + self.tokenizer + .decode(&ids, skip_special_tokens) + .map(|t| t.to_compact_string().to_string()), + ) + .into() } /// Decode a batch of ids back to their corresponding string @@ -1159,7 +1172,16 @@ impl PyTokenizer { ) -> PyResult> { py.allow_threads(|| { let slices = sequences.iter().map(|v| &v[..]).collect::>(); - ToPyResult(self.tokenizer.decode_batch(&slices, skip_special_tokens)).into() + ToPyResult( + self.tokenizer + .decode_batch(&slices, skip_special_tokens) + .map(|r| { + r.into_iter() + .map(|e| e.to_compact_string().to_string()) + .collect() + }), + ) + .into() }) } @@ -1186,7 +1208,7 @@ impl PyTokenizer { /// :obj:`Optional[str]`: An optional token, :obj:`None` if out of vocabulary #[pyo3(text_signature = "(self, id)")] fn id_to_token(&self, id: u32) -> Option { - self.tokenizer.id_to_token(id) + self.tokenizer.id_to_token(id).map(|v| v.to_string()) } /// Modifies the tokenizer in order to use or not the special tokens diff --git a/bindings/python/src/trainers.rs b/bindings/python/src/trainers.rs index a3d2d556d..e65058e90 100644 --- a/bindings/python/src/trainers.rs +++ b/bindings/python/src/trainers.rs @@ -2,6 +2,8 @@ use std::sync::{Arc, RwLock}; use crate::models::PyModel; use crate::tokenizer::PyAddedToken; +use compact_str::CompactString; +use compact_str::ToCompactString; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; @@ -105,7 +107,7 @@ impl Trainer for PyTrainer { where I: Iterator + Send, S: AsRef + Send, - F: Fn(&str) -> tk::Result> + Sync, + F: Fn(&str) -> tk::Result> + Sync, { self.trainer.write().unwrap().feed(iterator, process) } @@ -295,22 +297,40 @@ impl PyBpeTrainer { #[getter] fn get_continuing_subword_prefix(self_: PyRef) -> Option { - getter!(self_, BpeTrainer, continuing_subword_prefix.clone()) + getter!( + self_, + BpeTrainer, + continuing_subword_prefix.clone().map(|s| s.to_string()) + ) } #[setter] fn set_continuing_subword_prefix(self_: PyRef, prefix: Option) { - setter!(self_, BpeTrainer, continuing_subword_prefix, prefix); + setter!( + self_, + BpeTrainer, + continuing_subword_prefix, + prefix.map(|s| s.to_compact_string()) + ); } #[getter] fn get_end_of_word_suffix(self_: PyRef) -> Option { - getter!(self_, BpeTrainer, end_of_word_suffix.clone()) + getter!( + self_, + BpeTrainer, + end_of_word_suffix.clone().map(|s| s.to_string()) + ) } #[setter] fn set_end_of_word_suffix(self_: PyRef, suffix: Option) { - setter!(self_, BpeTrainer, end_of_word_suffix, suffix); + setter!( + self_, + BpeTrainer, + end_of_word_suffix, + suffix.map(|s| s.to_compact_string()) + ); } #[new] @@ -357,9 +377,13 @@ impl PyBpeTrainer { ); } "continuing_subword_prefix" => { - builder = builder.continuing_subword_prefix(val.extract()?) + builder = builder + .continuing_subword_prefix(val.extract::()?.to_compact_string()) + } + "end_of_word_suffix" => { + builder = + builder.end_of_word_suffix(val.extract::()?.to_compact_string()) } - "end_of_word_suffix" => builder = builder.end_of_word_suffix(val.extract()?), _ => println!("Ignored unknown kwargs option {}", key), }; } @@ -499,22 +523,30 @@ impl PyWordPieceTrainer { #[getter] fn get_continuing_subword_prefix(self_: PyRef) -> Option { - getter!(self_, WordPieceTrainer, continuing_subword_prefix().clone()) + getter!( + self_, + WordPieceTrainer, + continuing_subword_prefix().clone().map(|s| s.to_string()) + ) } #[setter] fn set_continuing_subword_prefix(self_: PyRef, prefix: Option) { - setter!(self_, WordPieceTrainer, @set_continuing_subword_prefix, prefix); + setter!(self_, WordPieceTrainer, @set_continuing_subword_prefix, prefix.map(|s| s.to_compact_string())); } #[getter] fn get_end_of_word_suffix(self_: PyRef) -> Option { - getter!(self_, WordPieceTrainer, end_of_word_suffix().clone()) + getter!( + self_, + WordPieceTrainer, + end_of_word_suffix().clone().map(|s| s.to_string()) + ) } #[setter] fn set_end_of_word_suffix(self_: PyRef, suffix: Option) { - setter!(self_, WordPieceTrainer, @set_end_of_word_suffix, suffix); + setter!(self_, WordPieceTrainer, @set_end_of_word_suffix, suffix.map(|s| s.to_compact_string())); } #[new] @@ -563,9 +595,13 @@ impl PyWordPieceTrainer { ); } "continuing_subword_prefix" => { - builder = builder.continuing_subword_prefix(val.extract()?) + builder = builder + .continuing_subword_prefix(val.extract::()?.to_compact_string()) + } + "end_of_word_suffix" => { + builder = + builder.end_of_word_suffix(val.extract::()?.to_compact_string()) } - "end_of_word_suffix" => builder = builder.end_of_word_suffix(val.extract()?), _ => println!("Ignored unknown kwargs option {}", key), }; } @@ -840,7 +876,10 @@ impl PyUnigramTrainer { "show_progress" => builder.show_progress(val.extract()?), "n_sub_iterations" => builder.n_sub_iterations(val.extract()?), "shrinking_factor" => builder.shrinking_factor(val.extract()?), - "unk_token" => builder.unk_token(val.extract()?), + "unk_token" => builder.unk_token( + val.extract::>()? + .map(|s| s.to_compact_string()), + ), "max_piece_length" => builder.max_piece_length(val.extract()?), "seed_size" => builder.seed_size(val.extract()?), "initial_alphabet" => { diff --git a/bindings/python/src/utils/normalization.rs b/bindings/python/src/utils/normalization.rs index 21a9ae966..f0cd1e855 100644 --- a/bindings/python/src/utils/normalization.rs +++ b/bindings/python/src/utils/normalization.rs @@ -1,6 +1,7 @@ use super::regex::PyRegex; use super::{DestroyPtr, RefMutContainer, RefMutGuard}; use crate::error::ToPyResult; +use compact_str::ToCompactString; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; @@ -38,8 +39,10 @@ impl Pattern for PyPattern { impl From for tk::normalizers::replace::ReplacePattern { fn from(pattern: PyPattern) -> Self { match pattern { - PyPattern::Str(s) => Self::String(s.to_owned()), - PyPattern::Regex(r) => Python::with_gil(|py| Self::Regex(r.borrow(py).pattern.clone())), + PyPattern::Str(s) => Self::String(s.into()), + PyPattern::Regex(r) => { + Python::with_gil(|py| Self::Regex(r.borrow(py).pattern.to_compact_string())) + } } } } @@ -47,8 +50,10 @@ impl From for tk::normalizers::replace::ReplacePattern { impl From for tk::pre_tokenizers::split::SplitPattern { fn from(pattern: PyPattern) -> Self { match pattern { - PyPattern::Str(s) => Self::String(s.to_owned()), - PyPattern::Regex(r) => Python::with_gil(|py| Self::Regex(r.borrow(py).pattern.clone())), + PyPattern::Str(s) => Self::String(s.into()), + PyPattern::Regex(r) => { + Python::with_gil(|py| Self::Regex(r.borrow(py).pattern.to_compact_string())) + } } } } diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 0633b8ef6..825438a11 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -1,5 +1,8 @@ [package] -authors = ["Anthony MOI ", "Nicolas Patry "] +authors = [ + "Anthony MOI ", + "Nicolas Patry ", +] edition = "2018" name = "tokenizers" version = "0.21.0-dev.0" @@ -13,7 +16,14 @@ description = """ Provides an implementation of today's most used tokenizers, with a focus on performances and versatility. """ -exclude = [ "rust-toolchain", "target/*", "Cargo.lock", "benches/*.txt", "benches/*.json", "data/*" ] +exclude = [ + "rust-toolchain", + "target/*", + "Cargo.lock", + "benches/*.txt", + "benches/*.json", + "data/*", +] [lib] name = "tokenizers" @@ -49,12 +59,12 @@ regex = "1.10" regex-syntax = "0.8" rayon = "1.10" rayon-cond = "0.3" -serde = { version = "1.0", features = [ "derive" ] } +serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" unicode-normalization-alignments = "0.1" unicode_categories = "0.1" unicode-segmentation = "1.11" -indicatif = {version = "0.17", optional = true} +indicatif = { version = "0.17", optional = true } itertools = "0.13" log = "0.4" derive_builder = "0.20" @@ -64,10 +74,12 @@ aho-corasick = "1.1" paste = "1.0.14" macro_rules_attribute = "0.2.0" thiserror = "2" -fancy-regex = { version = "0.14", optional = true} +fancy-regex = { version = "0.14", optional = true } getrandom = { version = "0.2.10" } -esaxx-rs = { version = "0.1.10", default-features = false, features=[]} +esaxx-rs = { version = "0.1.10", default-features = false, features = [] } monostate = "0.1.12" +rustc-hash = "2.1.1" +compact_str = { version = "0.8.1", features = ["serde"] } [features] default = ["progressbar", "onig", "esaxx_fast"] @@ -90,4 +102,3 @@ lto = "fat" [[example]] name = "encode_batch" required-features = ["http"] - diff --git a/tokenizers/README.md b/tokenizers/README.md index be7f5eb02..ca74476a2 100644 --- a/tokenizers/README.md +++ b/tokenizers/README.md @@ -95,11 +95,11 @@ fn main() -> Result<()> { .vocab_size(vocab_size) .min_frequency(0) .special_tokens(vec![ - AddedToken::from(String::from(""), true), - AddedToken::from(String::from(""), true), - AddedToken::from(String::from(""), true), - AddedToken::from(String::from(""), true), - AddedToken::from(String::from(""), true), + AddedToken::from("", true), + AddedToken::from("", true), + AddedToken::from("", true), + AddedToken::from("", true), + AddedToken::from("", true), ]) .build(); diff --git a/tokenizers/benches/bert_benchmark.rs b/tokenizers/benches/bert_benchmark.rs index cfdab9070..8d6e3690a 100644 --- a/tokenizers/benches/bert_benchmark.rs +++ b/tokenizers/benches/bert_benchmark.rs @@ -38,8 +38,8 @@ fn create_bert_tokenizer(wp: WordPiece) -> BertTokenizer { tokenizer.with_normalizer(Some(BertNormalizer::default())); tokenizer.with_decoder(Some(decoders::wordpiece::WordPiece::default())); tokenizer.with_post_processor(Some(BertProcessing::new( - ("[SEP]".to_string(), sep_id), - ("[CLS]".to_string(), cls_id), + ("[SEP]", sep_id), + ("[CLS]", cls_id), ))); tokenizer } diff --git a/tokenizers/benches/unigram_benchmark.rs b/tokenizers/benches/unigram_benchmark.rs index 9121a1937..2de8feadd 100644 --- a/tokenizers/benches/unigram_benchmark.rs +++ b/tokenizers/benches/unigram_benchmark.rs @@ -25,10 +25,7 @@ pub fn bench_train(c: &mut Criterion) { *word_counts.entry(word).or_insert(0) += 1; }); - let sentences: Vec<_> = word_counts - .iter() - .map(|(s, i)| (s.to_owned(), *i)) - .collect(); + let sentences: Vec<_> = word_counts.iter().map(|(s, i)| (s.into(), *i)).collect(); c.bench_function("Unigram Train vocabulary (small)", |b| { b.iter_custom(|iters| { @@ -53,10 +50,7 @@ pub fn bench_train(c: &mut Criterion) { *word_counts.entry(word).or_insert(0) += 1; }); - let sentences: Vec<_> = word_counts - .iter() - .map(|(s, i)| (s.to_owned(), *i)) - .collect(); + let sentences: Vec<_> = word_counts.iter().map(|(s, i)| (s.into(), *i)).collect(); c.bench_function("Unigram Train vocabulary (medium)", |b| { b.iter_custom(|iters| { diff --git a/tokenizers/src/decoders/bpe.rs b/tokenizers/src/decoders/bpe.rs index 813dc7083..440322c35 100644 --- a/tokenizers/src/decoders/bpe.rs +++ b/tokenizers/src/decoders/bpe.rs @@ -1,5 +1,6 @@ use crate::tokenizer::{Decoder, Result}; +use compact_str::{CompactString, ToCompactString}; use serde::{Deserialize, Serialize}; #[derive(Deserialize, Clone, Debug, Serialize)] @@ -8,31 +9,39 @@ use serde::{Deserialize, Serialize}; #[serde(tag = "type")] #[non_exhaustive] pub struct BPEDecoder { - pub suffix: String, + pub suffix: CompactString, } impl BPEDecoder { - pub fn new(suffix: String) -> Self { - Self { suffix } + pub fn new(suffix: impl Into) -> Self { + Self { + suffix: suffix.into(), + } } } impl Default for BPEDecoder { fn default() -> Self { - Self::new("".into()) + Self::new("") } } impl Decoder for BPEDecoder { - fn decode_chain(&self, tokens: Vec) -> Result> { + fn decode_chain( + &self, + tokens: Vec, + ) -> Result> { let n = tokens.len() - 1; Ok(tokens .into_iter() .enumerate() .map(|(i, token)| { let replacement = if i == n { "" } else { " " }; - token.replace(&self.suffix, replacement) + token + .to_compact_string() + .replace(&*self.suffix, replacement) + .to_compact_string() }) - .collect()) + .collect::>()) } } diff --git a/tokenizers/src/decoders/byte_fallback.rs b/tokenizers/src/decoders/byte_fallback.rs index b04b3db60..eef920b74 100644 --- a/tokenizers/src/decoders/byte_fallback.rs +++ b/tokenizers/src/decoders/byte_fallback.rs @@ -1,4 +1,5 @@ use crate::tokenizer::{Decoder, Result}; +use compact_str::{CompactString, ToCompactString}; use monostate::MustBe; use serde::{Deserialize, Serialize}; @@ -22,11 +23,15 @@ impl ByteFallback { } impl Decoder for ByteFallback { - fn decode_chain(&self, tokens: Vec) -> Result> { - let mut new_tokens: Vec = vec![]; + fn decode_chain( + &self, + tokens: Vec, + ) -> Result> { + let mut new_tokens: Vec = vec![]; let mut previous_byte_tokens: Vec = vec![]; for token in tokens { + let token: CompactString = token.to_compact_string(); let bytes = if token.len() == 6 && token.starts_with("<0x") && token.ends_with('>') { if let Ok(byte) = u8::from_str_radix(&token[3..5], 16) { Some(byte) @@ -40,7 +45,7 @@ impl Decoder for ByteFallback { previous_byte_tokens.push(bytes); } else { if !previous_byte_tokens.is_empty() { - if let Ok(string) = String::from_utf8(previous_byte_tokens.clone()) { + if let Ok(string) = CompactString::from_utf8(previous_byte_tokens.clone()) { new_tokens.push(string); } else { for _ in 0..previous_byte_tokens.len() { @@ -53,7 +58,7 @@ impl Decoder for ByteFallback { } } if !previous_byte_tokens.is_empty() { - if let Ok(string) = String::from_utf8(previous_byte_tokens.clone()) { + if let Ok(string) = CompactString::from_utf8(previous_byte_tokens.clone()) { new_tokens.push(string); } else { for _ in 0..previous_byte_tokens.len() { @@ -73,41 +78,65 @@ mod tests { #[test] fn decode() { let decoder = ByteFallback::new(); - let res = decoder - .decode_chain(vec!["Hey".into(), "friend!".into()]) - .unwrap(); - assert_eq!(res, vec!["Hey", "friend!"]); + let res = decoder.decode_chain(vec!["Hey", "friend!"]).unwrap(); + assert_eq!( + res.into_iter() + .map(|t| t.to_compact_string()) + .collect::>(), + vec!["Hey", "friend!"] + ); - let res = decoder.decode_chain(vec!["<0x61>".into()]).unwrap(); - assert_eq!(res, vec!["a"]); + let res = decoder.decode_chain(vec!["<0x61>"]).unwrap(); + assert_eq!( + res.into_iter() + .map(|t| t.to_compact_string()) + .collect::>(), + vec!["a"] + ); - let res = decoder.decode_chain(vec!["<0xE5>".into()]).unwrap(); - assert_eq!(res, vec!["�"]); + let res = decoder.decode_chain(vec!["<0xE5>"]).unwrap(); + assert_eq!( + res.into_iter() + .map(|t| t.to_compact_string()) + .collect::>(), + vec!["�"] + ); - let res = decoder - .decode_chain(vec!["<0xE5>".into(), "<0x8f>".into()]) - .unwrap(); - assert_eq!(res, vec!["�", "�"]); + let res = decoder.decode_chain(vec!["<0xE5>", "<0x8f>"]).unwrap(); + assert_eq!( + res.into_iter() + .map(|t| t.to_compact_string()) + .collect::>(), + vec!["�", "�"] + ); // 叫 let res = decoder - .decode_chain(vec!["<0xE5>".into(), "<0x8f>".into(), "<0xab>".into()]) + .decode_chain(vec!["<0xE5>", "<0x8f>", "<0xab>"]) .unwrap(); - assert_eq!(res, vec!["叫"]); + assert_eq!( + res.into_iter() + .map(|t| t.to_compact_string()) + .collect::>(), + vec!["叫"] + ); let res = decoder - .decode_chain(vec![ - "<0xE5>".into(), - "<0x8f>".into(), - "<0xab>".into(), - "a".into(), - ]) + .decode_chain(vec!["<0xE5>", "<0x8f>", "<0xab>", "a"]) .unwrap(); - assert_eq!(res, vec!["叫", "a"]); + assert_eq!( + res.into_iter() + .map(|t| t.to_compact_string()) + .collect::>(), + vec!["叫", "a"] + ); - let res = decoder - .decode_chain(vec!["<0xE5>".into(), "<0x8f>".into(), "a".into()]) - .unwrap(); - assert_eq!(res, vec!["�", "�", "a"]); + let res = decoder.decode_chain(vec!["<0xE5>", "<0x8f>", "a"]).unwrap(); + assert_eq!( + res.into_iter() + .map(|t| t.to_compact_string()) + .collect::>(), + vec!["�", "�", "a"] + ); } } diff --git a/tokenizers/src/decoders/ctc.rs b/tokenizers/src/decoders/ctc.rs index 9d5a57188..18f822b7b 100644 --- a/tokenizers/src/decoders/ctc.rs +++ b/tokenizers/src/decoders/ctc.rs @@ -1,6 +1,7 @@ use crate::decoders::wordpiece; use crate::tokenizer::{Decoder, Result}; +use compact_str::{CompactString, ToCompactString}; use itertools::Itertools; use serde::{Deserialize, Serialize}; @@ -13,19 +14,23 @@ use serde::{Deserialize, Serialize}; #[non_exhaustive] pub struct CTC { /// The pad token used by CTC to delimit a new token. - pub pad_token: String, + pub pad_token: CompactString, /// The word delimiter token. It will be replaced by a ``. - pub word_delimiter_token: String, + pub word_delimiter_token: CompactString, /// Whether to cleanup some tokenization artifacts. /// Mainly spaces before punctuation, and some abbreviated english forms. pub cleanup: bool, } impl CTC { - pub fn new(pad_token: String, word_delimiter_token: String, cleanup: bool) -> Self { + pub fn new( + pad_token: impl Into, + word_delimiter_token: impl Into, + cleanup: bool, + ) -> Self { Self { - pad_token, - word_delimiter_token, + pad_token: pad_token.into(), + word_delimiter_token: word_delimiter_token.into(), cleanup, } } @@ -33,24 +38,25 @@ impl CTC { impl Default for CTC { fn default() -> Self { - Self { - pad_token: "".to_string(), - word_delimiter_token: "|".to_string(), - cleanup: true, - } + Self::new("", "|", true) } } impl Decoder for CTC { - fn decode_chain(&self, tokens: Vec) -> Result> { + fn decode_chain( + &self, + tokens: Vec, + ) -> Result> { Ok(tokens .into_iter() + .map(|token| token.to_compact_string()) .dedup() .filter_map(|token| { - let mut replaced = token.replace(&self.pad_token, ""); + let mut replaced: CompactString = token.replace(&*self.pad_token, "").into(); if self.cleanup { - replaced = - wordpiece::cleanup(&replaced).replace(&self.word_delimiter_token, " "); + replaced = wordpiece::cleanup(&replaced) + .replace(&*self.word_delimiter_token, " ") + .into(); } if replaced.is_empty() { None @@ -65,15 +71,22 @@ impl Decoder for CTC { #[cfg(test)] mod tests { use super::*; + use compact_str::ToCompactString; + #[test] fn handmade_sample() { let ctc_decoder = CTC::default(); let id_to_string_result = " h e e l l l o o o " .split(' ') - .map(|s| s.to_string()) + .map(|s| s.to_compact_string()) .collect(); assert_eq!( - ctc_decoder.decode_chain(id_to_string_result).unwrap(), + ctc_decoder + .decode_chain(id_to_string_result) + .unwrap() + .into_iter() + .map(|t| t.to_compact_string()) + .collect::>(), vec!["h", "e", "l", "l", "o"] ); } @@ -82,19 +95,29 @@ mod tests { let ctc_decoder = CTC::default(); let id_to_string_result = " h e e l l l o o o | w o o o r l l d " .split(' ') - .map(|s| s.to_string()) + .map(|s| s.to_compact_string()) .collect(); assert_eq!( - ctc_decoder.decode_chain(id_to_string_result).unwrap(), + ctc_decoder + .decode_chain(id_to_string_result) + .unwrap() + .into_iter() + .map(|t| t.to_compact_string()) + .collect::>(), vec!["h", "e", "l", "l", "o", " ", "w", "o", "r", "l", "d"] ); } #[test] fn librispeech_sample() { let ctc_decoder = CTC::default(); - let id_to_string_result = " A | | M A N | | | S A I D D | | T T O | | T H E E | | | U U N N I V E R R S E E | | S S I R R | | | I | E X I S T | | ".split(' ').map(|s| s.to_string()).collect(); + let id_to_string_result = " A | | M A N | | | S A I D D | | T T O | | T H E E | | | U U N N I V E R R S E E | | S S I R R | | | I | E X I S T | | ".split(' ').map(|s| s.to_compact_string()).collect(); assert_eq!( - ctc_decoder.decode_chain(id_to_string_result).unwrap(), + ctc_decoder + .decode_chain(id_to_string_result) + .unwrap() + .into_iter() + .map(|t| t.to_compact_string()) + .collect::>(), vec![ "A", " ", "M", "A", "N", " ", "S", "A", "I", "D", " ", "T", "O", " ", "T", "H", "E", " ", "U", "N", "I", "V", "E", "R", "S", "E", " ", "S", "I", "R", " ", "I", @@ -105,9 +128,14 @@ mod tests { #[test] fn another_librispeech_sample() { let ctc_decoder = CTC::default(); - let id_to_string_result = " H I S S | | I N S T T A N C C T | | | | | P A N N N I C | | W A S | | F O L L L O O W E E D | | B Y | | | A | | S S S M M A L L L | | | S H H A R R P | B L L O W W | | | H I G H H | | O N | | H I S S | | C H H E S S T T | | | ".split(' ').map(|s| s.to_string()).collect(); + let id_to_string_result = " H I S S | | I N S T T A N C C T | | | | | P A N N N I C | | W A S | | F O L L L O O W E E D | | B Y | | | A | | S S S M M A L L L | | | S H H A R R P | B L L O W W | | | H I G H H | | O N | | H I S S | | C H H E S S T T | | | ".split(' ').map(|s| s.to_compact_string()).collect(); assert_eq!( - ctc_decoder.decode_chain(id_to_string_result).unwrap(), + ctc_decoder + .decode_chain(id_to_string_result) + .unwrap() + .into_iter() + .map(|t| t.to_compact_string()) + .collect::>(), vec![ "H", "I", "S", " ", "I", "N", "S", "T", "A", "N", "C", "T", " ", "P", "A", "N", "I", "C", " ", "W", "A", "S", " ", "F", "O", "L", "L", "O", "W", "E", "D", " ", diff --git a/tokenizers/src/decoders/fuse.rs b/tokenizers/src/decoders/fuse.rs index 5e4a1c119..15b2de64b 100644 --- a/tokenizers/src/decoders/fuse.rs +++ b/tokenizers/src/decoders/fuse.rs @@ -1,4 +1,5 @@ use crate::tokenizer::{Decoder, Result}; +use compact_str::{CompactString, ToCompactString}; use monostate::MustBe; use serde::{Deserialize, Serialize}; @@ -22,8 +23,16 @@ impl Fuse { } impl Decoder for Fuse { - fn decode_chain(&self, tokens: Vec) -> Result> { - let new_string = tokens.join(""); + fn decode_chain( + &self, + tokens: Vec, + ) -> Result> { + let new_string: CompactString = tokens + .into_iter() + .map(|token| token.to_compact_string()) + .collect::>() + .join("") + .into(); Ok(vec![new_string]) } } @@ -35,9 +44,12 @@ mod tests { #[test] fn decode() { let decoder = Fuse::new(); - let res = decoder - .decode_chain(vec!["Hey".into(), " friend!".into()]) - .unwrap(); - assert_eq!(res, vec!["Hey friend!"]); + let res = decoder.decode_chain(vec!["Hey", " friend!"]).unwrap(); + assert_eq!( + res.into_iter() + .map(|t| t.to_compact_string()) + .collect::>(), + vec!["Hey friend!"] + ); } } diff --git a/tokenizers/src/decoders/mod.rs b/tokenizers/src/decoders/mod.rs index 6e79e7029..9372a3579 100644 --- a/tokenizers/src/decoders/mod.rs +++ b/tokenizers/src/decoders/mod.rs @@ -10,6 +10,7 @@ pub mod wordpiece; pub use super::pre_tokenizers::byte_level; pub use super::pre_tokenizers::metaspace; +use compact_str::ToCompactString; use serde::{Deserialize, Deserializer, Serialize}; use crate::decoders::bpe::BPEDecoder; @@ -150,18 +151,41 @@ impl<'de> Deserialize<'de> for DecoderWrapper { } impl Decoder for DecoderWrapper { - fn decode_chain(&self, tokens: Vec) -> Result> { + fn decode_chain( + &self, + tokens: Vec, + ) -> Result> { match self { - Self::BPE(bpe) => bpe.decode_chain(tokens), - Self::ByteLevel(bl) => bl.decode_chain(tokens), - Self::Metaspace(ms) => ms.decode_chain(tokens), - Self::WordPiece(wp) => wp.decode_chain(tokens), - Self::CTC(ctc) => ctc.decode_chain(tokens), - Self::Sequence(seq) => seq.decode_chain(tokens), - Self::Replace(seq) => seq.decode_chain(tokens), - Self::ByteFallback(bf) => bf.decode_chain(tokens), - Self::Strip(bf) => bf.decode_chain(tokens), - Self::Fuse(bf) => bf.decode_chain(tokens), + Self::BPE(bpe) => bpe + .decode_chain(tokens) + .map(|v| v.into_iter().map(|t| t.to_compact_string()).collect()), + Self::ByteLevel(bl) => bl + .decode_chain(tokens) + .map(|v| v.into_iter().map(|t| t.to_compact_string()).collect()), + Self::Metaspace(ms) => ms + .decode_chain(tokens) + .map(|v| v.into_iter().map(|t| t.to_compact_string()).collect()), + Self::WordPiece(wp) => wp + .decode_chain(tokens) + .map(|v| v.into_iter().map(|t| t.to_compact_string()).collect()), + Self::CTC(ctc) => ctc + .decode_chain(tokens) + .map(|v| v.into_iter().map(|t| t.to_compact_string()).collect()), + Self::Sequence(seq) => seq + .decode_chain(tokens) + .map(|v| v.into_iter().map(|t| t.to_compact_string()).collect()), + Self::Replace(seq) => seq + .decode_chain(tokens) + .map(|v| v.into_iter().map(|t| t.to_compact_string()).collect()), + Self::ByteFallback(bf) => bf + .decode_chain(tokens) + .map(|v| v.into_iter().map(|t| t.to_compact_string()).collect()), + Self::Strip(bf) => bf + .decode_chain(tokens) + .map(|v| v.into_iter().map(|t| t.to_compact_string()).collect()), + Self::Fuse(bf) => bf + .decode_chain(tokens) + .map(|v| v.into_iter().map(|t| t.to_compact_string()).collect()), } } } diff --git a/tokenizers/src/decoders/sequence.rs b/tokenizers/src/decoders/sequence.rs index 73169b695..a7d83b3b3 100644 --- a/tokenizers/src/decoders/sequence.rs +++ b/tokenizers/src/decoders/sequence.rs @@ -1,6 +1,7 @@ use crate::decoders::DecoderWrapper; use crate::tokenizer::{Decoder, Result}; use crate::utils::macro_rules_attribute; +use compact_str::ToCompactString; use serde::{Deserialize, Serialize}; #[derive(Clone, Debug)] @@ -24,11 +25,19 @@ impl Sequence { } impl Decoder for Sequence { - fn decode_chain(&self, mut tokens: Vec) -> Result> { + fn decode_chain( + &self, + tokens: Vec, + ) -> Result> { + let mut current_tokens = tokens.into_iter().map(|t| t.to_compact_string()).collect(); for decoder in &self.decoders { - tokens = decoder.decode_chain(tokens)?; + current_tokens = decoder + .decode_chain(current_tokens)? + .into_iter() + .map(|t| t.to_compact_string()) + .collect(); } - Ok(tokens) + Ok(current_tokens) } } @@ -37,6 +46,7 @@ mod tests { use super::*; use crate::decoders::ctc::CTC; use crate::pre_tokenizers::metaspace::Metaspace; + use compact_str::CompactString; #[test] fn sequence_basic() { @@ -45,11 +55,11 @@ mod tests { DecoderWrapper::Metaspace(Metaspace::default()), ]; let decoder = Sequence::new(decoders); - let tokens: Vec = vec!["▁", "▁", "H", "H", "i", "i", "▁", "y", "o", "u"] + let tokens: Vec = vec!["▁", "▁", "H", "H", "i", "i", "▁", "y", "o", "u"] .into_iter() - .map(|s| s.to_string()) + .map(|s| s.into()) .collect(); - let out_tokens = decoder.decode(tokens).unwrap(); + let out_tokens = decoder.decode(tokens).unwrap().to_compact_string(); assert_eq!(out_tokens, "Hi you"); } } diff --git a/tokenizers/src/decoders/strip.rs b/tokenizers/src/decoders/strip.rs index 9aeffec64..ae40301a3 100644 --- a/tokenizers/src/decoders/strip.rs +++ b/tokenizers/src/decoders/strip.rs @@ -1,5 +1,6 @@ use crate::tokenizer::{Decoder, Result}; +use compact_str::{CompactString, ToCompactString}; use serde::{Deserialize, Serialize}; #[derive(Deserialize, Clone, Debug, Serialize, Default)] @@ -25,11 +26,14 @@ impl Strip { } impl Decoder for Strip { - fn decode_chain(&self, tokens: Vec) -> Result> { + fn decode_chain( + &self, + tokens: Vec, + ) -> Result> { Ok(tokens .into_iter() .map(|token| { - let chars: Vec = token.chars().collect(); + let chars: Vec = token.to_compact_string().chars().collect(); let mut start_cut = 0; for (i, &c) in chars.iter().enumerate().take(self.start) { @@ -52,7 +56,7 @@ impl Decoder for Strip { } } - let new_token: String = chars[start_cut..stop_cut].iter().collect(); + let new_token: CompactString = chars[start_cut..stop_cut].iter().collect(); new_token }) .collect()) @@ -67,14 +71,22 @@ mod tests { fn decode() { let decoder = Strip::new('H', 1, 0); let res = decoder - .decode_chain(vec!["Hey".into(), " friend!".into(), "HHH".into()]) + .decode_chain(vec!["Hey", " friend!", "HHH"]) .unwrap(); - assert_eq!(res, vec!["ey", " friend!", "HH"]); + assert_eq!( + res.into_iter() + .map(|t| t.to_compact_string()) + .collect::>(), + vec!["ey", " friend!", "HH"] + ); let decoder = Strip::new('y', 0, 1); - let res = decoder - .decode_chain(vec!["Hey".into(), " friend!".into()]) - .unwrap(); - assert_eq!(res, vec!["He", " friend!"]); + let res = decoder.decode_chain(vec!["Hey", " friend!"]).unwrap(); + assert_eq!( + res.into_iter() + .map(|t| t.to_compact_string()) + .collect::>(), + vec!["He", " friend!"] + ); } } diff --git a/tokenizers/src/decoders/wordpiece.rs b/tokenizers/src/decoders/wordpiece.rs index 1a78586e2..d50d293cd 100644 --- a/tokenizers/src/decoders/wordpiece.rs +++ b/tokenizers/src/decoders/wordpiece.rs @@ -1,5 +1,6 @@ use crate::tokenizer::{Decoder, Result}; +use compact_str::{format_compact, CompactString, ToCompactString}; use serde::{Deserialize, Serialize}; #[derive(Deserialize, Clone, Debug, Serialize)] @@ -9,13 +10,13 @@ use serde::{Deserialize, Serialize}; #[non_exhaustive] pub struct WordPiece { /// The prefix to be used for continuing subwords - pub prefix: String, + pub prefix: CompactString, /// Whether to cleanup some tokenization artifacts (spaces before punctuation, ...) pub cleanup: bool, } impl WordPiece { - pub fn new(prefix: String, cleanup: bool) -> Self { + pub fn new(prefix: CompactString, cleanup: bool) -> Self { Self { prefix, cleanup } } } @@ -23,13 +24,14 @@ impl WordPiece { impl Default for WordPiece { fn default() -> Self { Self { - prefix: "##".to_owned(), + prefix: "##".into(), cleanup: true, } } } -pub fn cleanup(dirty_input: &str) -> String { +pub fn cleanup(dirty_input: impl ToCompactString) -> CompactString { dirty_input + .to_compact_string() .replace(" .", ".") .replace(" ?", "?") .replace(" !", "!") @@ -41,27 +43,32 @@ pub fn cleanup(dirty_input: &str) -> String { .replace(" 's", "'s") .replace(" 've", "'ve") .replace(" 're", "'re") + .into() } impl Decoder for WordPiece { - fn decode_chain(&self, mut tokens: Vec) -> Result> { + fn decode_chain( + &self, + tokens: Vec, + ) -> Result> { tokens - .iter_mut() + .into_iter() + .map(|t| t.to_compact_string()) .enumerate() - .map(|(i, token)| { + .map(|(i, mut token)| { if i != 0 { - if token.starts_with(&self.prefix) { - *token = token.replacen(&self.prefix, "", 1); + if token.starts_with(&*self.prefix) { + token = token.replacen(&*self.prefix, "", 1).to_compact_string(); } else { - *token = format!(" {token}"); + token = format_compact!(" {}", token); } } if self.cleanup { - *token = cleanup(token); + token = cleanup(token); } - Ok(token.to_string()) + Ok(token) }) - .collect::>() + .collect::>>() } } @@ -71,19 +78,13 @@ mod tests { #[test] fn wordpiece_decoder() { - let decoder = WordPiece::new("##".to_string(), false); + let decoder = WordPiece::new("##".into(), false); assert_eq!( decoder - .decode(vec![ - "##uelo".to_string(), - "Ara".to_string(), - "##új".to_string(), - "##o".to_string(), - "No".to_string(), - "##guera".to_string() - ]) - .unwrap(), + .decode(vec!["##uelo", "Ara", "##új", "##o", "No", "##guera"]) + .unwrap() + .to_compact_string(), "##uelo Araújo Noguera" ); } diff --git a/tokenizers/src/lib.rs b/tokenizers/src/lib.rs index 441612717..1bd02cc03 100644 --- a/tokenizers/src/lib.rs +++ b/tokenizers/src/lib.rs @@ -83,11 +83,11 @@ //! .vocab_size(vocab_size) //! .min_frequency(0) //! .special_tokens(vec![ -//! AddedToken::from(String::from(""), true), -//! AddedToken::from(String::from(""), true), -//! AddedToken::from(String::from(""), true), -//! AddedToken::from(String::from(""), true), -//! AddedToken::from(String::from(""), true), +//! AddedToken::from("", true), +//! AddedToken::from("", true), +//! AddedToken::from("", true), +//! AddedToken::from("", true), +//! AddedToken::from("", true), //! ]) //! .build(); //! diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 217c37e90..24ecbf78a 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -2,20 +2,21 @@ use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, Pair, Word}; use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY, MAX_LENGTH}; use crate::utils::iter::ResultShunt; +use compact_str::{format_compact, CompactString}; +use rustc_hash::FxHashMap; use serde_json::Value; use std::borrow::Cow; use std::{ - collections::HashMap, fs::File, io::prelude::*, io::{BufRead, BufReader}, path::{Path, PathBuf}, }; -pub type Vocab = HashMap; -type VocabR = HashMap; -pub type MergeMap = HashMap; -pub type Merges = Vec<(String, String)>; +pub type Vocab = FxHashMap; +type VocabR = FxHashMap; +pub type MergeMap = FxHashMap; +pub type Merges = Vec<(CompactString, CompactString)>; struct Config { files: Option<(String, String)>, @@ -23,9 +24,9 @@ struct Config { merges: Merges, cache_capacity: usize, dropout: Option, - unk_token: Option, - continuing_subword_prefix: Option, - end_of_word_suffix: Option, + unk_token: Option, + continuing_subword_prefix: Option, + end_of_word_suffix: Option, fuse_unk: bool, byte_fallback: bool, ignore_merges: bool, @@ -41,7 +42,7 @@ impl Default for BpeBuilder { Self { config: Config { files: None, - vocab: HashMap::new(), + vocab: FxHashMap::default(), merges: vec![], cache_capacity: DEFAULT_CACHE_CAPACITY, dropout: None, @@ -93,22 +94,22 @@ impl BpeBuilder { /// Set the `UNK` token for the vocab. #[must_use] - pub fn unk_token(mut self, unk_token: String) -> Self { - self.config.unk_token = Some(unk_token); + pub fn unk_token(mut self, unk_token: impl Into) -> Self { + self.config.unk_token = Some(unk_token.into()); self } /// Set the `continuing_subword_prefix` option. #[must_use] - pub fn continuing_subword_prefix(mut self, prefix: String) -> Self { - self.config.continuing_subword_prefix = Some(prefix); + pub fn continuing_subword_prefix(mut self, prefix: impl Into) -> Self { + self.config.continuing_subword_prefix = Some(prefix.into()); self } /// Set the `end_of_word_suffix` option. #[must_use] - pub fn end_of_word_suffix(mut self, prefix: String) -> Self { - self.config.end_of_word_suffix = Some(prefix); + pub fn end_of_word_suffix(mut self, prefix: impl Into) -> Self { + self.config.end_of_word_suffix = Some(prefix.into()); self } @@ -173,14 +174,14 @@ impl BpeBuilder { .map(|(i, (a, b))| -> Result<(Pair, (u32, u32))> { let a_id = vocab .get(&a) - .ok_or_else(|| Error::MergeTokenOutOfVocabulary(a.to_owned()))?; + .ok_or_else(|| Error::MergeTokenOutOfVocabulary(a.to_string()))?; let b_id = vocab .get(&b) - .ok_or_else(|| Error::MergeTokenOutOfVocabulary(b.to_owned()))?; - let new_token = format!("{}{}", a, &b[prefix_len..]); + .ok_or_else(|| Error::MergeTokenOutOfVocabulary(b.to_string()))?; + let new_token = format_compact!("{}{}", a, &b[prefix_len..]); let new_id = vocab .get(&new_token) - .ok_or(Error::MergeTokenOutOfVocabulary(new_token))?; + .ok_or(Error::MergeTokenOutOfVocabulary(new_token.to_string()))?; Ok(((*a_id, *b_id), (i as u32, *new_id))) }) .collect::>()?; @@ -213,16 +214,16 @@ pub struct BPE { /// Contains the mapping between Pairs and their (rank, new_id). pub(crate) merges: MergeMap, /// Contains the cache for optimizing the encoding step. - cache: Option>, + cache: Option>, /// Dropout probability for merges. 0.0 = no dropout is the default. At 1.0, tokenization will /// perform no merges, so the result will just be characters. pub dropout: Option, /// The unknown token to be used when we encounter an unknown char - pub unk_token: Option, + pub unk_token: Option, /// An optional prefix to use on any subword that exist only behind another one - pub continuing_subword_prefix: Option, + pub continuing_subword_prefix: Option, /// An optional suffix to caracterize and end-of-word subword - pub end_of_word_suffix: Option, + pub end_of_word_suffix: Option, /// Do multiple unk tokens get fused pub fuse_unk: bool, /// Byte fallback from sentence pieces, instead of UNK, uses `"<0x00>"` @@ -277,7 +278,7 @@ impl Clone for BPE { /// Converts the merges strings (for example from `merges.txt` file) with the format /// "{pair_a} {pair_b}" into the format expected by the BPE struct -pub(crate) fn convert_merges_to_hashmap>( +pub(crate) fn convert_merges_to_hashmap>( iter: I, _vocab: &Vocab, ) -> Result { @@ -290,7 +291,7 @@ pub(crate) fn convert_merges_to_hashmap>( return Err(Error::BadMerges(rank + 1).into()); } - merges.push((parts[0].to_string(), parts[1].to_string())); + merges.push((parts[0].into(), parts[1].into())); } Ok(merges) @@ -324,13 +325,13 @@ impl BPE { let mut buffer = String::new(); vocab_file.read_to_string(&mut buffer)?; let json: Value = serde_json::from_str(&buffer)?; - let mut vocab = HashMap::new(); + let mut vocab = FxHashMap::default(); match json { Value::Object(m) => { for (token, id) in m { if let Value::Number(id) = id { let id = id.as_u64().ok_or(Error::BadVocabulary)? as u32; - vocab.insert(token, id); + vocab.insert(token.into(), id); } } } @@ -340,9 +341,10 @@ impl BPE { // Read merges file let merge_file = File::open(merges)?; let merge_file = BufReader::new(merge_file); - let merges = ResultShunt::process(merge_file.lines(), |iter| { - convert_merges_to_hashmap(iter, &vocab) - })??; + let merges = ResultShunt::process( + merge_file.lines().map(|line| line.map(CompactString::from)), + |iter| convert_merges_to_hashmap(iter, &vocab), + )??; Ok((vocab, merges)) } @@ -365,11 +367,11 @@ impl BPE { self.vocab.clone() } - pub fn get_unk_token(&self) -> &Option { + pub fn get_unk_token(&self) -> &Option { &self.unk_token } - pub fn get_continuing_subword_prefix(&self) -> &Option { + pub fn get_continuing_subword_prefix(&self) -> &Option { &self.continuing_subword_prefix } @@ -413,7 +415,7 @@ impl BPE { let tokens: Option> = s .bytes() .map(|b| -> Option<&u32> { - let code = format!("<{b:#04X}>"); + let code = format_compact!("<{b:#04X}>"); self.vocab.get(&code) }) @@ -436,14 +438,14 @@ impl BPE { word.add(unk_id, unk_len); Some(( *self.vocab.get(unk_token).ok_or_else(|| { - Error::UnkTokenOutOfVocabulary(unk_token.to_owned()) + Error::UnkTokenOutOfVocabulary(unk_token.to_string()) })?, byte_len, )) } _ => Some(( *self.vocab.get(unk_token).ok_or_else(|| { - Error::UnkTokenOutOfVocabulary(unk_token.to_owned()) + Error::UnkTokenOutOfVocabulary(unk_token.to_string()) })?, byte_len, )), @@ -469,11 +471,7 @@ impl BPE { fn tokenize_with_cache(&self, sequence: &str) -> Result> { if self.ignore_merges { if let Some(id) = self.vocab.get(sequence) { - return Ok(vec![Token::new( - *id, - sequence.to_string().clone(), - (0, sequence.len()), - )]); + return Ok(vec![Token::new(*id, sequence, (0, sequence.len()))]); } } if let Some(ref hit) = self.cache.as_ref().and_then(|c| c.get(sequence)) { @@ -483,7 +481,7 @@ impl BPE { let ret = self.word_to_tokens(&word).collect(); if let Some(ref cache) = self.cache { if sequence.len() < MAX_LENGTH { - cache.set(sequence.to_owned(), word); + cache.set(sequence.into(), word); } } Ok(ret) @@ -493,7 +491,7 @@ impl BPE { impl Model for BPE { type Trainer = BpeTrainer; - fn get_vocab(&self) -> HashMap { + fn get_vocab(&self) -> FxHashMap { self.vocab.clone() } @@ -518,7 +516,7 @@ impl Model for BPE { self.vocab.get(token).copied() } - fn id_to_token(&self, id: u32) -> Option { + fn id_to_token(&self, id: u32) -> Option { self.vocab_r.get(&id).cloned() } @@ -600,18 +598,18 @@ mod tests { .collect(); let bpe = BpeBuilder::default() .vocab_and_merges(vocab, vec![]) - .unk_token("".to_string()) + .unk_token("") .build() .unwrap(); let tokens = bpe.tokenize("c").unwrap(); - assert_eq!(tokens, vec![Token::new(0u32, "".into(), (0, 1)),]); + assert_eq!(tokens, vec![Token::new(0u32, "", (0, 1)),]); let tokens = bpe.tokenize("cc").unwrap(); assert_eq!( tokens, vec![ - Token::new(0u32, "".into(), (0, 1)), - Token::new(0u32, "".into(), (1, 2)), + Token::new(0u32, "", (0, 1)), + Token::new(0u32, "", (1, 2)), ] ); @@ -619,10 +617,10 @@ mod tests { assert_eq!( tokens, vec![ - Token::new(1u32, "a".into(), (0, 1)), - Token::new(0u32, "".into(), (1, 2)), - Token::new(0u32, "".into(), (2, 3)), - Token::new(2u32, "b".into(), (3, 4)), + Token::new(1u32, "a", (0, 1)), + Token::new(0u32, "", (1, 2)), + Token::new(0u32, "", (2, 3)), + Token::new(2u32, "b", (3, 4)), ] ); } @@ -634,23 +632,23 @@ mod tests { .collect(); let bpe = BpeBuilder::default() .vocab_and_merges(vocab, vec![]) - .unk_token("".to_string()) + .unk_token("") .fuse_unk(true) .build() .unwrap(); let tokens = bpe.tokenize("c").unwrap(); - assert_eq!(tokens, vec![Token::new(0u32, "".into(), (0, 1)),]); + assert_eq!(tokens, vec![Token::new(0u32, "", (0, 1)),]); let tokens = bpe.tokenize("cc").unwrap(); - assert_eq!(tokens, vec![Token::new(0u32, "".into(), (0, 2)),]); + assert_eq!(tokens, vec![Token::new(0u32, "", (0, 2)),]); let tokens = bpe.tokenize("accb").unwrap(); assert_eq!( tokens, vec![ - Token::new(1u32, "a".into(), (0, 1)), - Token::new(0u32, "".into(), (1, 3)), - Token::new(2u32, "b".into(), (3, 4)), + Token::new(1u32, "a", (0, 1)), + Token::new(0u32, "", (1, 3)), + Token::new(2u32, "b", (3, 4)), ] ); } @@ -683,25 +681,25 @@ mod tests { .cloned() .collect(); let merges: Merges = vec![ - ("r".to_string(), "e".to_string()), - ("a".to_string(), "t".to_string()), - ("e".to_string(), "d".to_string()), - ("u".to_string(), "n".to_string()), - ("at".to_string(), "ed".to_string()), - ("re".to_string(), "l".to_string()), - ("rel".to_string(), "ated".to_string()), - ("un".to_string(), "related".to_string()), + ("r".into(), "e".into()), + ("a".into(), "t".into()), + ("e".into(), "d".into()), + ("u".into(), "n".into()), + ("at".into(), "ed".into()), + ("re".into(), "l".into()), + ("rel".into(), "ated".into()), + ("un".into(), "related".into()), ]; let mut bpe = BPE::new(vocab, merges); // With no dropout: let tokens = bpe.tokenize("unrelated").unwrap(); - assert_eq!(tokens, vec![Token::new(15u32, "unrelated".into(), (0, 9))]); + assert_eq!(tokens, vec![Token::new(15u32, "unrelated", (0, 9))]); // With dropout = 0.0 (equivalent to dropout == none) bpe.dropout = Some(0.0); let tokens = bpe.tokenize("unrelated").unwrap(); - assert_eq!(tokens, vec![Token::new(15u32, "unrelated".into(), (0, 9))]); + assert_eq!(tokens, vec![Token::new(15u32, "unrelated", (0, 9))]); // Now set dropout to 1.0. Result should be no merges performed. bpe.dropout = Some(1.0); @@ -709,15 +707,15 @@ mod tests { assert_eq!( tokens, vec![ - Token::new(0u32, "u".into(), (0, 1)), - Token::new(1u32, "n".into(), (1, 2)), - Token::new(2u32, "r".into(), (2, 3)), - Token::new(3u32, "e".into(), (3, 4)), - Token::new(4u32, "l".into(), (4, 5)), - Token::new(5u32, "a".into(), (5, 6)), - Token::new(6u32, "t".into(), (6, 7)), - Token::new(3u32, "e".into(), (7, 8)), - Token::new(7u32, "d".into(), (8, 9)), + Token::new(0u32, "u", (0, 1)), + Token::new(1u32, "n", (1, 2)), + Token::new(2u32, "r", (2, 3)), + Token::new(3u32, "e", (3, 4)), + Token::new(4u32, "l", (4, 5)), + Token::new(5u32, "a", (5, 6)), + Token::new(6u32, "t", (6, 7)), + Token::new(3u32, "e", (7, 8)), + Token::new(7u32, "d", (8, 9)), ] ); @@ -768,45 +766,28 @@ mod tests { // Ensure `BPE::from_file` works as expected. fn test_bpe_with_continuing_subword_prefix() { let vocab: Vocab = vec![ - ("a".to_string(), 0), - ("##b".to_string(), 1), - ("##c".to_string(), 2), - ("ab".to_string(), 3), - ("abc".to_string(), 4), + ("a".into(), 0), + ("##b".into(), 1), + ("##c".into(), 2), + ("ab".into(), 3), + ("abc".into(), 4), ] .into_iter() .collect(); - let merges = vec![ - ("a".to_string(), "##b".to_string()), - ("ab".to_string(), "##c".to_string()), - ]; + let merges = vec![("a".into(), "##b".into()), ("ab".into(), "##c".into())]; let bpe = BPE::builder() .vocab_and_merges(vocab, merges) - .unk_token("[UNK]".to_string()) - .continuing_subword_prefix("##".to_string()) + .unk_token("[UNK]") + .continuing_subword_prefix("##") .build() .unwrap(); let res = bpe.tokenize("ab"); - assert_eq!( - res.unwrap(), - vec![Token { - id: 3, - value: "ab".to_string(), - offsets: (0, 2) - }] - ); + assert_eq!(res.unwrap(), vec![Token::new(3, "ab", (0, 2))]); let res = bpe.tokenize("abc"); - assert_eq!( - res.unwrap(), - vec![Token { - id: 4, - value: "abc".to_string(), - offsets: (0, 3) - }] - ); + assert_eq!(res.unwrap(), vec![Token::new(4, "abc", (0, 3))]); } #[test] @@ -877,15 +858,15 @@ mod tests { .collect(); let bpe = BpeBuilder::default() .vocab_and_merges(vocab, vec![]) - .unk_token("".to_string()) + .unk_token("") .byte_fallback(true) .build() .unwrap(); let tokens = bpe.tokenize("c").unwrap(); - assert_eq!(tokens, vec![Token::new(0u32, "".into(), (0, 1)),]); + assert_eq!(tokens, vec![Token::new(0u32, "", (0, 1)),]); let tokens = bpe.tokenize("a").unwrap(); - assert_eq!(tokens, vec![Token::new(1u32, "<0x61>".into(), (0, 1)),]); + assert_eq!(tokens, vec![Token::new(1u32, "<0x61>", (0, 1)),]); } #[test] @@ -897,12 +878,12 @@ mod tests { .collect(); let bpe = BpeBuilder::default() .vocab_and_merges(vocab, vec![]) - .unk_token("".to_string()) + .unk_token("") .byte_fallback(true) .build() .unwrap(); let tokens = bpe.tokenize("\n").unwrap(); - assert_eq!(tokens, vec![Token::new(1u32, "<0x0A>".into(), (0, 1)),]); + assert_eq!(tokens, vec![Token::new(1u32, "<0x0A>", (0, 1)),]); } #[test] @@ -954,13 +935,10 @@ mod tests { .build() .unwrap(); let tokens = bpe.tokenize(".:.:").unwrap(); - assert_eq!(tokens, vec![Token::new(0u32, ".:.:".into(), (0, 4))]); + assert_eq!(tokens, vec![Token::new(0u32, ".:.:", (0, 4))]); let tokens = bpe.tokenize("Ġbelirtilen").unwrap(); - assert_eq!( - tokens, - vec![Token::new(1u32, "Ġbelirtilen".into(), (0, 12))] - ); + assert_eq!(tokens, vec![Token::new(1u32, "Ġbelirtilen", (0, 12))]); bpe.ignore_merges = false; @@ -968,8 +946,8 @@ mod tests { assert_eq!( tokens, vec![ - Token::new(7u32, ".:".into(), (0, 2)), - Token::new(7u32, ".:".into(), (2, 4)) + Token::new(7u32, ".:", (0, 2)), + Token::new(7u32, ".:", (2, 4)) ] ); @@ -977,26 +955,10 @@ mod tests { assert_eq!( tokens, vec![ - Token { - id: 6, - value: "Ġ".into(), - offsets: (0, 2) - }, - Token { - id: 4, - value: "bel".into(), - offsets: (2, 5) - }, - Token { - id: 15, - value: "irtil".into(), - offsets: (5, 10) - }, - Token { - id: 14, - value: "en".into(), - offsets: (10, 12) - } + Token::new(6, "Ġ", (0, 2)), + Token::new(4, "bel", (2, 5)), + Token::new(15, "irtil", (5, 10)), + Token::new(14, "en", (10, 12)) ] ) } diff --git a/tokenizers/src/models/bpe/serialization.rs b/tokenizers/src/models/bpe/serialization.rs index 98cc15102..fb3fc682a 100644 --- a/tokenizers/src/models/bpe/serialization.rs +++ b/tokenizers/src/models/bpe/serialization.rs @@ -1,10 +1,11 @@ use super::{super::OrderedVocabIter, convert_merges_to_hashmap, BpeBuilder, Pair, BPE}; +use compact_str::CompactString; +use rustc_hash::FxHashMap; use serde::{ de::{Error, MapAccess, Visitor}, ser::SerializeStruct, Deserialize, Deserializer, Serialize, Serializer, }; -use std::collections::HashMap; impl Serialize for BPE { fn serialize(&self, serializer: S) -> Result @@ -80,13 +81,13 @@ impl<'de> Visitor<'de> for BPEVisitor { V: MapAccess<'de>, { let mut builder = BpeBuilder::new(); - let mut vocab: Option> = None; + let mut vocab: Option> = None; #[derive(Debug, Deserialize)] #[serde(untagged)] enum MergeType { - Tuple(Vec<(String, String)>), - Legacy(Vec), + Tuple(Vec<(CompactString, CompactString)>), + Legacy(Vec), } let mut merges: Option = None; while let Some(key) = map.next_key::()? { @@ -97,17 +98,17 @@ impl<'de> Visitor<'de> for BPEVisitor { } } "unk_token" => { - if let Some(unk) = map.next_value()? { + if let Some(unk) = map.next_value::>()? { builder = builder.unk_token(unk); } } "continuing_subword_prefix" => { - if let Some(prefix) = map.next_value()? { + if let Some(prefix) = map.next_value::>()? { builder = builder.continuing_subword_prefix(prefix); } } "end_of_word_suffix" => { - if let Some(suffix) = map.next_value()? { + if let Some(suffix) = map.next_value::>()? { builder = builder.end_of_word_suffix(suffix); } } @@ -172,8 +173,8 @@ mod test { .cloned() .collect(); let bpe = BpeBuilder::default() - .vocab_and_merges(vocab, vec![("a".to_string(), "b".to_string())]) - .unk_token("".to_string()) + .vocab_and_merges(vocab, vec![("a".into(), "b".into())]) + .unk_token("") .ignore_merges(true) .build() .unwrap(); @@ -201,8 +202,8 @@ mod test { .cloned() .collect(); let bpe = BpeBuilder::default() - .vocab_and_merges(vocab, vec![("a".to_string(), "b c d".to_string())]) - .unk_token("".to_string()) + .vocab_and_merges(vocab, vec![("a".into(), "b c d".into())]) + .unk_token("") .ignore_merges(true) .build() .unwrap(); @@ -223,7 +224,7 @@ mod test { .collect(); let mut bpe = BpeBuilder::default() .vocab_and_merges(vocab, vec![]) - .unk_token("".to_string()) + .unk_token("") .ignore_merges(true) .build() .unwrap(); diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index a1a0aba76..0a05b43f4 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -4,15 +4,17 @@ use super::{Pair, WithFirstLastIterator, Word, BPE}; use crate::parallelism::*; use crate::tokenizer::{AddedToken, Result, Trainer}; use crate::utils::progress::{ProgressBar, ProgressStyle}; +use compact_str::{format_compact, CompactString, ToCompactString}; +use rustc_hash::{FxHashMap, FxHashSet}; use serde::{Deserialize, Serialize}; use std::cmp::Ordering; -use std::collections::{BinaryHeap, HashMap, HashSet}; +use std::collections::BinaryHeap; #[derive(Debug, Eq)] struct Merge { pair: Pair, count: u64, - pos: HashSet, + pos: FxHashSet, } impl PartialEq for Merge { fn eq(&self, other: &Self) -> bool { @@ -41,9 +43,9 @@ struct Config { show_progress: bool, special_tokens: Vec, limit_alphabet: Option, - initial_alphabet: HashSet, - continuing_subword_prefix: Option, - end_of_word_suffix: Option, + initial_alphabet: FxHashSet, + continuing_subword_prefix: Option, + end_of_word_suffix: Option, max_token_length: Option, } @@ -62,7 +64,7 @@ impl Default for BpeTrainerBuilder { show_progress: true, special_tokens: vec![], limit_alphabet: None, - initial_alphabet: HashSet::new(), + initial_alphabet: FxHashSet::default(), continuing_subword_prefix: None, end_of_word_suffix: None, max_token_length: None, @@ -114,21 +116,21 @@ impl BpeTrainerBuilder { /// Set the initial alphabet #[must_use] - pub fn initial_alphabet(mut self, alphabet: HashSet) -> Self { + pub fn initial_alphabet(mut self, alphabet: FxHashSet) -> Self { self.config.initial_alphabet = alphabet; self } /// Set the continuing_subword_prefix #[must_use] - pub fn continuing_subword_prefix(mut self, prefix: String) -> Self { + pub fn continuing_subword_prefix(mut self, prefix: CompactString) -> Self { self.config.continuing_subword_prefix = Some(prefix); self } /// Set the end_of_word_suffix #[must_use] - pub fn end_of_word_suffix(mut self, suffix: String) -> Self { + pub fn end_of_word_suffix(mut self, suffix: CompactString) -> Self { self.config.end_of_word_suffix = Some(suffix); self } @@ -151,7 +153,7 @@ impl BpeTrainerBuilder { continuing_subword_prefix: self.config.continuing_subword_prefix, end_of_word_suffix: self.config.end_of_word_suffix, max_token_length: self.config.max_token_length, - words: HashMap::new(), + words: FxHashMap::default(), } } } @@ -161,13 +163,14 @@ impl BpeTrainerBuilder { /// # Examples /// /// ``` +/// use compact_str::ToCompactString; /// use tokenizers::tokenizer::Trainer; /// use tokenizers::models::bpe::{BPE, BpeTrainer}; /// /// let sequences = vec![ "Hello", "World" ]; /// /// let mut trainer = BpeTrainer::default(); -/// trainer.feed(sequences.iter(), |s| Ok(vec![s.to_owned()])); +/// let _ = trainer.feed(sequences.iter(), |s| Ok(vec![s.to_compact_string()])); /// /// let mut model = BPE::default(); /// let special_tokens = trainer.train(&mut model).unwrap(); @@ -187,15 +190,15 @@ pub struct BpeTrainer { pub limit_alphabet: Option, /// The initial alphabet we want absolutely to include. This allows to cover /// some characters that are not necessarily in the training set - pub initial_alphabet: HashSet, + pub initial_alphabet: FxHashSet, /// An optional prefix to use on any subword that exist only behind another one - pub continuing_subword_prefix: Option, + pub continuing_subword_prefix: Option, /// An optional suffix to caracterize and end-of-word subword - pub end_of_word_suffix: Option, + pub end_of_word_suffix: Option, /// An optional parameter to limit the max length of any single token pub max_token_length: Option, - words: HashMap, + words: FxHashMap, } impl Default for BpeTrainer { @@ -251,7 +254,11 @@ impl BpeTrainer { } /// Add the provided special tokens to the initial vocabulary - fn add_special_tokens(&self, w2id: &mut HashMap, id2w: &mut Vec) { + fn add_special_tokens( + &self, + w2id: &mut FxHashMap, + id2w: &mut Vec, + ) { for token in &self.special_tokens { if !w2id.contains_key(&token.content) { id2w.push(token.content.to_owned()); @@ -263,12 +270,12 @@ impl BpeTrainer { /// Compute the initial alphabet and limit it if relevant fn compute_alphabet( &self, - wc: &HashMap, - w2id: &mut HashMap, - id2w: &mut Vec, + wc: &FxHashMap, + w2id: &mut FxHashMap, + id2w: &mut Vec, ) { // Compute the alphabet from seen words - let mut alphabet: HashMap = HashMap::new(); + let mut alphabet: FxHashMap = FxHashMap::default(); for (word, count) in wc { for c in word.chars() { alphabet @@ -311,7 +318,7 @@ impl BpeTrainer { // Keep the initial alphabet (sorted for determinism) kept.sort_unstable_by_key(|k| (*k.0) as u32); kept.into_iter().for_each(|(c, _)| { - let s = c.to_string(); + let s = c.to_compact_string(); if !w2id.contains_key(&s) { id2w.push(s.clone()); w2id.insert(s, (id2w.len() - 1) as u32); @@ -322,9 +329,9 @@ impl BpeTrainer { /// Tokenize words and add subwords to the vocabulary when relevant fn tokenize_words( &self, - wc: &HashMap, - w2id: &mut HashMap, - id2w: &mut Vec, + wc: &FxHashMap, + w2id: &mut FxHashMap, + id2w: &mut Vec, p: &Option, ) -> (Vec, Vec) { let mut words: Vec = Vec::with_capacity(wc.len()); @@ -335,20 +342,20 @@ impl BpeTrainer { counts.push(*count); for (is_first, is_last, c) in word.chars().with_first_and_last() { - let mut s = c.to_string(); + let mut s = c.to_compact_string(); if w2id.contains_key(&s) { // Found the initial char in the authorized alphabet // Add the `continuing_subword_prefix` if relevant if !is_first { if let Some(prefix) = &self.continuing_subword_prefix { - s = format!("{prefix}{s}"); + s = format_compact!("{prefix}{s}"); } } // Add the `end_of_word_suffix` if relevant if is_last { if let Some(suffix) = &self.end_of_word_suffix { - s = format!("{s}{suffix}"); + s = format_compact!("{s}{suffix}"); } } @@ -375,13 +382,13 @@ impl BpeTrainer { words: &[Word], counts: &[u64], p: &Option, - ) -> (HashMap, HashMap>) { + ) -> (FxHashMap, FxHashMap>) { words .maybe_par_iter() .enumerate() .map(|(i, word)| { - let mut pair_counts = HashMap::new(); - let mut where_to_update: HashMap> = HashMap::new(); + let mut pair_counts = FxHashMap::default(); + let mut where_to_update: FxHashMap> = FxHashMap::default(); for window in word.get_chars().windows(2) { let cur_pair: Pair = (window[0], window[1]); @@ -399,7 +406,7 @@ impl BpeTrainer { h.insert(i); }) .or_insert_with(|| { - let mut h = HashSet::new(); + let mut h = FxHashSet::default(); h.insert(i); h }); @@ -413,7 +420,7 @@ impl BpeTrainer { (pair_counts, where_to_update) }) .reduce( - || (HashMap::new(), HashMap::new()), + || (FxHashMap::default(), FxHashMap::default()), |(mut pair_counts, mut where_to_update), (pc, wtu)| { for (k, v) in pc { pair_counts.entry(k).and_modify(|c| *c += v).or_insert(v); @@ -431,11 +438,12 @@ impl BpeTrainer { pub fn do_train( &self, - word_counts: &HashMap, + word_counts: &FxHashMap, model: &mut BPE, ) -> Result> { - let mut word_to_id: HashMap = HashMap::with_capacity(self.vocab_size); - let mut id_to_word: Vec = Vec::with_capacity(self.vocab_size); + let mut word_to_id: FxHashMap = + FxHashMap::with_capacity_and_hasher(self.vocab_size, Default::default()); + let mut id_to_word: Vec = Vec::with_capacity(self.vocab_size); let max_token_length: usize = self.max_token_length.unwrap_or(usize::MAX); let progress = self.setup_progress(); @@ -504,16 +512,16 @@ impl BpeTrainer { } let part_a = &id_to_word[top.pair.0 as usize]; - let mut part_b = id_to_word[top.pair.1 as usize].to_owned(); + let mut part_b = id_to_word[top.pair.1 as usize].clone(); // Build new token if let Some(prefix) = &self.continuing_subword_prefix { - if part_b.starts_with(prefix) { + if part_b.starts_with(&**prefix) { let prefix_byte_len = prefix.chars().map(|c| c.len_utf8()).sum(); - part_b = part_b[prefix_byte_len..].to_string(); + part_b = part_b[prefix_byte_len..].into(); } } - let new_token = format!("{part_a}{part_b}"); + let new_token = format_compact!("{part_a}{part_b}"); // implement sentencepiece-like merge. // if this code were to be merged, integrate a way in the python bindings to communicate this variable // default should be 0/None to maintain previous behavior. 16 is the spm default. @@ -532,7 +540,7 @@ impl BpeTrainer { // Merge the new pair in every words // Safety: This is just a type assertion, the code below may no longer be safe // if the type of `pos` changes - let pos: &HashSet = &top.pos; + let pos: &FxHashSet = &top.pos; let words_len = words.len(); struct WordPtr(*mut Word); @@ -577,7 +585,7 @@ impl BpeTrainer { h.insert(iw); }) .or_insert_with(|| { - let mut h = HashSet::new(); + let mut h = FxHashSet::default(); h.insert(iw); h }); @@ -645,20 +653,20 @@ impl Trainer for BpeTrainer { where I: Iterator + Send, S: AsRef + Send, - F: Fn(&str) -> Result> + Sync, + F: Fn(&str) -> Result> + Sync, { - let words: Result> = iterator + let words: Result> = iterator .maybe_par_bridge() .map(|sequence| { let words = process(sequence.as_ref())?; - let mut map = HashMap::new(); + let mut map = FxHashMap::default(); for word in words { map.entry(word).and_modify(|c| *c += 1).or_insert(1); } Ok(map) }) .reduce( - || Ok(HashMap::new()), + || Ok(FxHashMap::default()), |acc, ws| { let mut acc = acc?; for (k, v) in ws? { @@ -676,11 +684,12 @@ impl Trainer for BpeTrainer { #[cfg(test)] mod tests { use super::{BpeTrainer, Pair, BPE}; - use std::collections::HashMap; + use compact_str::{CompactString, ToCompactString}; + use rustc_hash::FxHashMap; #[test] fn test_train() { - let word_counts: HashMap = [ + let word_counts: FxHashMap = [ ("roses".into(), 1), ("are".into(), 2), ("red".into(), 1), @@ -705,7 +714,7 @@ mod tests { // Vocab should contain all of the characters from the `word_counts` mapping // as well as three merges: 're', 'are', and 'is'. - let expected_vocab: HashMap = [ + let expected_vocab: FxHashMap = [ ("-".into(), 0), ("2".into(), 1), ("B".into(), 2), @@ -741,7 +750,7 @@ mod tests { // where 'rank' determines the order in which this merge will be applied during // tokenization, and 'id' is the vocab id of the symbol resulting from merging // the pair of symbols in the corresponding key. - let expected_merges: HashMap = [ + let expected_merges: FxHashMap = [ ((17, 11), (0, 22)), // 'r' + 'e' -> 're' ((8, 22), (1, 23)), // 'a' + 're' -> 'are' ((13, 18), (2, 24)), // 'i' + 's' -> 'is' @@ -759,7 +768,7 @@ mod tests { */ let max_token_length = 16; - let long_word_counts: HashMap = [ + let long_word_counts: FxHashMap = [ ("singlelongtokenwithoutcasechange", 2), ("singleLongTokenWithCamelCaseChange", 2), ("Longsingletokenwithpunctu@t!onwithin", 2), @@ -774,7 +783,7 @@ mod tests { ("GPT-2", 2), ] .iter() - .map(|(key, value)| (key.to_string(), *value)) + .map(|(key, value)| (key.to_compact_string(), *value)) .collect(); let trainer = BpeTrainer::builder() .max_token_length(Some(max_token_length)) @@ -799,7 +808,7 @@ mod tests { // directly compares tokens with known expected values. // maybe unstable depending on specific settings or changes. */ - let long_word_counts: HashMap = [ + let long_word_counts: FxHashMap = [ ("sin", 2), ("Sin", 2), ("Lon", 2), @@ -814,7 +823,7 @@ mod tests { ("GP", 2), ] .iter() - .map(|(key, value)| (key.to_string(), *value)) + .map(|(key, value)| (key.to_compact_string(), *value)) .collect(); let trainer = BpeTrainer::builder() .max_token_length(Some(2)) @@ -823,8 +832,8 @@ mod tests { .build(); let mut model = BPE::default(); trainer.do_train(&long_word_counts, &mut model).unwrap(); - let trained_vocab: HashMap = model.get_vocab(); - let expected_vocab: HashMap = [ + let trained_vocab: FxHashMap = model.get_vocab(); + let expected_vocab: FxHashMap = [ ("短", 12), ("n", 6), ("i", 5), @@ -858,7 +867,7 @@ mod tests { ] .iter() .cloned() - .map(|(k, v)| (k.to_string(), v)) + .map(|(k, v)| (k.to_compact_string(), v)) .collect(); assert_eq!(trained_vocab, expected_vocab) } diff --git a/tokenizers/src/models/bpe/word.rs b/tokenizers/src/models/bpe/word.rs index 93b3d9c37..94faf63bc 100644 --- a/tokenizers/src/models/bpe/word.rs +++ b/tokenizers/src/models/bpe/word.rs @@ -1,7 +1,8 @@ use super::Pair; use rand::{thread_rng, Rng}; use std::cmp::Ordering; -use std::collections::{BinaryHeap, HashMap}; +use std::collections::BinaryHeap; +use rustc_hash::FxHashMap; #[derive(Debug, Eq)] struct Merge { @@ -158,7 +159,7 @@ impl Word { changes } - pub(super) fn merge_all(&mut self, merges: &HashMap, dropout: Option) { + pub(super) fn merge_all(&mut self, merges: &FxHashMap, dropout: Option) { let mut queue = BinaryHeap::with_capacity(self.symbols.len()); let mut skip = Vec::with_capacity(queue.len()); diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index 3a3a91adc..ea576c9c6 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -5,7 +5,8 @@ pub mod unigram; pub mod wordlevel; pub mod wordpiece; -use std::collections::HashMap; +use compact_str::CompactString; +use rustc_hash::FxHashMap; use std::path::{Path, PathBuf}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; @@ -19,11 +20,11 @@ use crate::{AddedToken, Model, Result, Token, Trainer}; /// Wraps a vocab mapping (ID -> token) to a struct that will be serialized in order /// of token ID, smallest to largest. struct OrderedVocabIter<'a> { - vocab_r: &'a HashMap, + vocab_r: &'a FxHashMap, } impl<'a> OrderedVocabIter<'a> { - fn new(vocab_r: &'a HashMap) -> Self { + fn new(vocab_r: &'a FxHashMap) -> Self { Self { vocab_r } } } @@ -161,7 +162,7 @@ impl Model for ModelWrapper { } } - fn id_to_token(&self, id: u32) -> Option { + fn id_to_token(&self, id: u32) -> Option { match self { Self::WordLevel(t) => t.id_to_token(id), Self::WordPiece(t) => t.id_to_token(id), @@ -170,7 +171,7 @@ impl Model for ModelWrapper { } } - fn get_vocab(&self) -> HashMap { + fn get_vocab(&self) -> FxHashMap { match self { Self::WordLevel(t) => t.get_vocab(), Self::WordPiece(t) => t.get_vocab(), @@ -269,7 +270,7 @@ impl Trainer for TrainerWrapper { where I: Iterator + Send, S: AsRef + Send, - F: Fn(&str) -> Result> + Sync, + F: Fn(&str) -> Result> + Sync, { match self { Self::BpeTrainer(bpe) => bpe.feed(iterator, process), @@ -301,8 +302,12 @@ mod tests { #[test] fn incomplete_ordered_vocab() { - let vocab_r: HashMap = - HashMap::from([(0, "Hi".to_string()), (2, "There".to_string())]); + let vocab_r: FxHashMap = { + let mut tmp = FxHashMap::default(); + tmp.insert(0, "Hi".into()); + tmp.insert(2, "There".into()); + tmp + }; let ordered = OrderedVocabIter::new(&vocab_r); @@ -322,8 +327,8 @@ mod tests { .cloned() .collect(); let bpe = BpeBuilder::default() - .vocab_and_merges(vocab, vec![("a".to_string(), "b".to_string())]) - .unk_token("".to_string()) + .vocab_and_merges(vocab, vec![("a".into(), "b".into())]) + .unk_token("") .ignore_merges(true) .build() .unwrap(); diff --git a/tokenizers/src/models/unigram/lattice.rs b/tokenizers/src/models/unigram/lattice.rs index 30b82245d..04f5df0ae 100644 --- a/tokenizers/src/models/unigram/lattice.rs +++ b/tokenizers/src/models/unigram/lattice.rs @@ -1,3 +1,4 @@ +use compact_str::CompactString; use rand::distributions::WeightedIndex; use rand::prelude::*; use std::cell::RefCell; @@ -223,11 +224,11 @@ impl<'a> Lattice<'a> { results } - pub fn piece(&self, node: &Node) -> String { - self.sentence[node.pos..node.pos + node.length].to_owned() + pub fn piece(&self, node: &Node) -> CompactString { + self.sentence[node.pos..node.pos + node.length].into() } - pub fn tokens(&mut self) -> Vec { + pub fn tokens(&mut self) -> Vec { self.viterbi() .iter() .map(|node| self.piece(&node.borrow())) @@ -296,7 +297,7 @@ impl<'a> Lattice<'a> { } } - pub fn nbest_tokens(&mut self, n: usize) -> Vec> { + pub fn nbest_tokens(&mut self, n: usize) -> Vec> { self.nbest(n) .iter() .map(|v| v.iter().map(|node| self.piece(&node.borrow())).collect()) @@ -422,7 +423,7 @@ impl<'a> Lattice<'a> { results } - pub fn sample_token(&self, theta: f64) -> Vec { + pub fn sample_token(&self, theta: f64) -> Vec { self.sample(theta) .iter() .map(|node| self.piece(&node.borrow())) diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index da4d631ce..83f408b0d 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -6,19 +6,20 @@ use super::{ use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::{Cache, MAX_LENGTH}; -use std::collections::HashMap; +use compact_str::{format_compact, CompactString}; +use rustc_hash::FxHashMap; use std::convert::TryInto; use std::fs::read_to_string; use std::path::{Path, PathBuf}; -type TokenMap = HashMap; -type Vocab = Vec<(String, f64)>; +type TokenMap = FxHashMap; +type Vocab = Vec<(CompactString, f64)>; /// A `Unigram` model to encode sentences. pub struct Unigram { token_to_ids: TokenMap, pub(crate) vocab: Vocab, - cache: Cache>, + cache: Cache>, trie: Trie, pub min_score: f64, pub(super) unk_id: Option, @@ -80,7 +81,7 @@ pub enum UnigramError { impl Default for Unigram { fn default() -> Self { - let vocab = vec![("".to_string(), 0.0)]; + let vocab = vec![("", 0.0)]; Self::from(vocab, Some(0), false).unwrap() } } @@ -93,12 +94,12 @@ impl Unigram { /// For now `Unigram` *requires* at least `unk` because we might find a never seen char. /// Further versions might allow that part to be hidden. pub fn from( - vocab: Vec<(String, f64)>, + vocab: Vec<(impl Into, f64)>, unk_id: Option, byte_fallback: bool, ) -> Result { let n = vocab.len(); - let mut token_to_ids: TokenMap = HashMap::new(); + let mut token_to_ids: TokenMap = FxHashMap::default(); let mut builder = TrieBuilder::default(); if let Some(unk_id) = unk_id { @@ -112,9 +113,12 @@ impl Unigram { let bos_id = n + 1; let eos_id = n + 2; + let vocab: Vec<(CompactString, f64)> = + vocab.into_iter().map(|(s, f)| (s.into(), f)).collect(); + let mut min_score = f64::INFINITY; for (id, (token, score)) in vocab.iter().enumerate() { - token_to_ids.insert(token.to_string(), id as u32); + token_to_ids.insert(token.clone(), id as u32); let bytes: Vec = token.bytes().collect(); builder.push(&bytes); if score < &min_score { @@ -177,7 +181,7 @@ impl Unigram { .common_prefix_search(lattice.sentence.bytes().skip(begin_pos)) { let n = bytes.len(); - let tok = String::from_utf8(bytes).unwrap(); + let tok = CompactString::from_utf8(bytes).unwrap(); let id = *self.token_to_ids.get(&tok).unwrap(); let item = &self.vocab[id as usize]; @@ -204,21 +208,21 @@ impl Unigram { /// use tokenizers::models::unigram::Unigram; /// /// let pieces = vec![ - /// ("".to_string(), 0.0), - /// ("a".to_string(), 0.0), - /// ("b".to_string(), 0.0), - /// ("c".to_string(), 0.0), - /// ("d".to_string(), 0.0), - /// ("cd".to_string(), 1.0), - /// ("ab".to_string(), 2.0), - /// ("abc".to_string(), 5.0), - /// ("abcd".to_string(), 10.0), + /// ("".into(), 0.0), + /// ("a".into(), 0.0), + /// ("b".into(), 0.0), + /// ("c".into(), 0.0), + /// ("d".into(), 0.0), + /// ("cd".into(), 1.0), + /// ("ab".into(), 2.0), + /// ("abc".into(), 5.0), + /// ("abcd".into(), 10.0), /// ]; /// let model = Unigram::from(pieces, Some(0), false).unwrap(); /// let result = model.encode("abcdacdxx").unwrap(); /// assert_eq!(result, vec!["abcd", "a", "cd", "xx"]); /// ``` - pub fn encode(&self, sentence: &str) -> Result> { + pub fn encode(&self, sentence: &str) -> Result> { if sentence.is_empty() { return Ok(vec![]); } @@ -231,13 +235,13 @@ impl Unigram { self.encode_unoptimized(sentence)? }; if sentence.len() < MAX_LENGTH { - self.cache.set(sentence.to_owned(), result.clone()); + self.cache.set(sentence.into(), result.clone()); } Ok(result) } } - fn encode_optimized(&self, sentence: &str) -> Result> { + fn encode_optimized(&self, sentence: &str) -> Result> { // https://github.com/google/sentencepiece/blob/d48247191a6d50e469ed1a4a36e877befffd1851/src/unigram_model.cc#L600 #[derive(Debug, Clone)] struct BestPathNode { @@ -272,7 +276,7 @@ impl Unigram { .common_prefix_search(sentence.bytes().skip(starts_at)) { let key_pos = starts_at + tok_bytes.len(); - let token: String = String::from_utf8(tok_bytes).unwrap(); + let token: CompactString = CompactString::from_utf8(tok_bytes).unwrap(); let target_node = &mut best_path_ends_at[key_pos]; let length = key_pos - starts_at; let id = self.token_to_ids.get(&token).unwrap(); @@ -303,7 +307,7 @@ impl Unigram { starts_at += mblen } let mut ends_at = size; - let mut results: Vec = vec![]; + let mut results: Vec = vec![]; let mut token = vec![]; while ends_at > 0 { let node = &best_path_ends_at[ends_at]; @@ -313,34 +317,34 @@ impl Unigram { && node.id == self.unk_id.ok_or(UnigramError::MissingUnkId)? { token.push( - String::from_utf8(sentence[starts_at..ends_at].as_bytes().to_vec()).unwrap(), + CompactString::from_utf8(sentence[starts_at..ends_at].as_bytes()).unwrap(), ); } else { if !token.is_empty() { token.reverse(); - results.push(token.concat()); + results.push(token.concat().into()); token = vec![]; } results.push( - String::from_utf8(sentence[starts_at..ends_at].as_bytes().to_vec()).unwrap(), + CompactString::from_utf8(sentence[starts_at..ends_at].as_bytes()).unwrap(), ); } ends_at = starts_at; } if !token.is_empty() { token.reverse(); - results.push(token.concat()); + results.push(token.concat().into()); } results.reverse(); Ok(results) } - fn encode_unoptimized(&self, sentence: &str) -> Result> { + fn encode_unoptimized(&self, sentence: &str) -> Result> { let mut lattice = Lattice::from(sentence, self.bos_id, self.eos_id); self.populate_nodes(&mut lattice); if self.fuse_unk { let mut results = vec![]; - let mut token = String::new(); + let mut token = CompactString::default(); for node in lattice.viterbi().iter() { let item = lattice.piece(&node.borrow()); if node.borrow().id == self.unk_id.ok_or(UnigramError::MissingUnkId)? { @@ -348,9 +352,9 @@ impl Unigram { } else { if !token.is_empty() { results.push(token); - token = String::new(); + token = CompactString::default(); } - results.push(item.to_string()); + results.push(item); } } if !token.is_empty() { @@ -398,7 +402,7 @@ pub struct UnigramIterator<'a> { } impl<'a> Iterator for UnigramIterator<'a> { - type Item = &'a (String, f64); + type Item = &'a (CompactString, f64); fn next(&mut self) -> Option { let i = self.i; @@ -415,7 +419,7 @@ impl<'a> Iterator for UnigramIterator<'a> { impl Model for Unigram { type Trainer = UnigramTrainer; - fn get_vocab(&self) -> HashMap { + fn get_vocab(&self) -> FxHashMap { self.token_to_ids.clone() } @@ -437,7 +441,7 @@ impl Model for Unigram { let byte_tokens: Option> = string .bytes() .map(|byte| -> Option { - let byte_string = format!("<0x{byte:02X}>"); + let byte_string = format_compact!("<0x{byte:02X}>"); let id = self.token_to_ids.get(&byte_string); id.map(|id| Token::new(*id, byte_string, (offset, offset + len))) }) @@ -463,7 +467,7 @@ impl Model for Unigram { self.token_to_ids.get(token).copied() } - fn id_to_token(&self, id: u32) -> Option { + fn id_to_token(&self, id: u32) -> Option { self.vocab.get(id as usize).map(|item| item.0.clone()) } @@ -491,7 +495,7 @@ mod tests { #[test] fn test_populate_nodes_unk() { - let pieces = vec![("".to_string(), 0.0)]; + let pieces = vec![("", 0.0)]; let model = Unigram::from(pieces, Some(0), false).unwrap(); let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id); @@ -511,11 +515,11 @@ mod tests { #[test] fn test_populate_nodes() { let pieces = vec![ - ("".to_string(), 0.0), - ("a".to_string(), 0.1), - ("b".to_string(), 0.2), - ("ab".to_string(), 0.3), - ("bc".to_string(), 0.4), + ("", 0.0), + ("a", 0.1), + ("b", 0.2), + ("ab", 0.3), + ("bc", 0.4), ]; let model = Unigram::from(pieces, Some(0), false).unwrap(); @@ -543,15 +547,15 @@ mod tests { #[test] fn test_encode() { let sentencepieces = vec![ - ("".to_string(), 0.0), - ("a".to_string(), 0.0), - ("b".to_string(), 0.0), - ("c".to_string(), 0.0), - ("d".to_string(), 0.0), - ("cd".to_string(), 1.0), - ("ab".to_string(), 2.0), - ("abc".to_string(), 5.0), - ("abcd".to_string(), 10.0), + ("", 0.0), + ("a", 0.0), + ("b", 0.0), + ("c", 0.0), + ("d", 0.0), + ("cd", 1.0), + ("ab", 2.0), + ("abc", 5.0), + ("abcd", 10.0), ]; let model = Unigram::from(sentencepieces, Some(0), false).unwrap(); @@ -562,18 +566,18 @@ mod tests { #[test] fn test_encode2() { let sentencepieces = vec![ - ("".to_string(), 0.0), - ("ab".to_string(), 0.0), - ("cd".to_string(), -0.1), - ("abc".to_string(), -0.2), - ("a".to_string(), -0.3), - ("b".to_string(), -0.4), - ("c".to_string(), -0.5), - ("ABC".to_string(), -0.5), - ("abcdabcd".to_string(), 20.0), // User defined just max the scores. - ("q".to_string(), 20.5), - ("r".to_string(), 20.5), - ("qr".to_string(), -0.5), + ("", 0.0), + ("ab", 0.0), + ("cd", -0.1), + ("abc", -0.2), + ("a", -0.3), + ("b", -0.4), + ("c", -0.5), + ("ABC", -0.5), + ("abcdabcd", 20.0), // User defined just max the scores. + ("q", 20.5), + ("r", 20.5), + ("qr", -0.5), ]; let mut model = Unigram::from(sentencepieces, Some(0), false).unwrap(); @@ -618,26 +622,14 @@ mod tests { fn test_unigram_bytefallback() { // In [97]: processor.encode_as_pieces("⅐⅛⅑ ") // Out[97]: ['▁', '<0xE2>', '<0x85>', '<0x90>', '⅛', '<0xE2>', '<0x85>', '<0x91>', '▁'] - let sentencepieces = vec![ - ("".to_string(), 0.0), - ("<0xC3>".to_string(), -0.01), - ("<0xA9>".to_string(), -0.03), - ]; + let sentencepieces = vec![("", 0.0), ("<0xC3>", -0.01), ("<0xA9>", -0.03)]; let unigram = Unigram::from(sentencepieces, Some(0), true).unwrap(); let tokens: Vec = unigram.tokenize("é").unwrap(); assert_eq!( tokens, [ - Token { - id: 1, - value: "<0xC3>".to_string(), - offsets: (0, 2) - }, - Token { - id: 2, - value: "<0xA9>".to_string(), - offsets: (0, 2) - } + Token::new(1, "<0xC3>", (0, 2)), + Token::new(2, "<0xA9>", (0, 2)) ] ); diff --git a/tokenizers/src/models/unigram/serialization.rs b/tokenizers/src/models/unigram/serialization.rs index a6e56b735..26df86ed1 100644 --- a/tokenizers/src/models/unigram/serialization.rs +++ b/tokenizers/src/models/unigram/serialization.rs @@ -1,4 +1,5 @@ use super::model::Unigram; +use compact_str::CompactString; use serde::{ de::{Error, MapAccess, Visitor}, ser::SerializeStruct, @@ -46,10 +47,10 @@ impl<'de> Visitor<'de> for UnigramVisitor { where V: MapAccess<'de>, { - let mut vocab: Option> = None; + let mut vocab: Option> = None; let mut unk_id: Option = None; let mut byte_fallback: bool = false; - while let Some(key) = map.next_key::()? { + while let Some(key) = map.next_key::()? { match key.as_ref() { "unk_id" => { unk_id = map.next_value()?; @@ -82,7 +83,7 @@ mod test { #[test] fn test_serialization() { - let vocab = vec![("".to_string(), 0.0), ("a".to_string(), -0.5)]; + let vocab = vec![("", 0.0), ("a", -0.5)]; let model = Unigram::from(vocab, Some(0), false).unwrap(); let data = serde_json::to_string(&model).unwrap(); @@ -93,7 +94,7 @@ mod test { #[test] fn test_serialization_unk_id_not_zero() { - let vocab = vec![("a".to_string(), -0.5), ("".to_string(), 0.0)]; + let vocab = vec![("a", -0.5), ("", 0.0)]; let model = Unigram::from(vocab, Some(1), false).unwrap(); let data = serde_json::to_string(&model).unwrap(); @@ -104,7 +105,7 @@ mod test { #[test] fn test_serialization_no_unk_id() { - let vocab = vec![("a".to_string(), -0.5)]; + let vocab = vec![("a", -0.5)]; let model = Unigram::from(vocab, None, false).unwrap(); let data = serde_json::to_string(&model).unwrap(); diff --git a/tokenizers/src/models/unigram/trainer.rs b/tokenizers/src/models/unigram/trainer.rs index 5d178e77b..9e83de113 100644 --- a/tokenizers/src/models/unigram/trainer.rs +++ b/tokenizers/src/models/unigram/trainer.rs @@ -2,17 +2,18 @@ use crate::models::unigram::{lattice::Lattice, model::Unigram}; use crate::tokenizer::{AddedToken, Result, Trainer}; use crate::utils::parallelism::*; use crate::utils::progress::{ProgressBar, ProgressStyle}; +use compact_str::{CompactString, ToCompactString}; use log::debug; +use rustc_hash::{FxHashMap, FxHashSet}; use serde::{Deserialize, Serialize}; use std::cmp::Reverse; -use std::collections::{HashMap, HashSet}; use std::convert::TryInto; // A token and a score -type SentencePiece = (String, f64); +type SentencePiece = (CompactString, f64); // A full sentence or word + it's count within the dataset -type Sentence = (String, u32); +type Sentence = (CompactString, u32); fn digamma(mut x: f64) -> f64 { let mut result = 0.0; @@ -57,18 +58,18 @@ pub struct UnigramTrainer { pub shrinking_factor: f64, #[builder(default = "vec![]")] pub special_tokens: Vec, - #[builder(default = "HashSet::new()")] - pub initial_alphabet: HashSet, + #[builder(default = "FxHashSet::default()")] + pub initial_alphabet: FxHashSet, #[builder(default = "None")] - pub unk_token: Option, + pub unk_token: Option, #[builder(default = "16")] pub max_piece_length: usize, #[builder(default = "1_000_000")] seed_size: usize, - #[builder(default = "HashMap::new()")] - words: HashMap, + #[builder(default = "FxHashMap::default()")] + words: FxHashMap, } impl Default for UnigramTrainer { @@ -110,17 +111,21 @@ impl UnigramTrainer { true } - fn finalize(&self, model: Unigram, required_chars: HashSet) -> Result { + fn finalize( + &self, + model: Unigram, + required_chars: FxHashSet, + ) -> Result { let mut min_score_penalty = 0.0; let min_score_penalty_delta = 0.0001; - let mut pieces: Vec<(String, f64)> = vec![]; - let mut inserted: HashSet = HashSet::new(); + let mut pieces: Vec<(CompactString, f64)> = vec![]; + let mut inserted: FxHashSet = FxHashSet::default(); // We don't want to include the that was used to train inserted.insert("".into()); - let existing_pieces: HashMap = model.iter().cloned().collect(); + let existing_pieces: FxHashMap = model.iter().cloned().collect(); for c in required_chars { if let Some(t) = existing_pieces.get(&c) { inserted.insert(c.clone()); @@ -159,8 +164,8 @@ impl UnigramTrainer { if inserted.contains::(token) { continue; } - inserted.insert(token.to_string()); - pieces.push((token.to_string(), if score.is_nan() { 0.0 } else { *score })); + inserted.insert(token.clone()); + pieces.push((token.clone(), if score.is_nan() { 0.0 } else { *score })); if pieces.len() == vocab_size_without_special_tokens { break; @@ -185,12 +190,12 @@ impl UnigramTrainer { ) } - fn required_chars(&self, word_counts: &[Sentence]) -> HashSet { + fn required_chars(&self, word_counts: &[Sentence]) -> FxHashSet { word_counts .iter() .flat_map(|(s, _count)| s.chars()) .chain(self.initial_alphabet.iter().copied()) - .map(|c| c.to_string()) + .map(|c| c.to_compact_string()) .collect() } fn make_seed_sentence_pieces( @@ -204,8 +209,8 @@ impl UnigramTrainer { .map(|(s, _)| s.chars().count()) .sum::() + sentences.len(); - let mut flat_string = String::with_capacity(total); - let mut all_chars: HashMap = HashMap::new(); + let mut flat_string = CompactString::with_capacity(total); + let mut all_chars: FxHashMap = FxHashMap::default(); let c_sentence_boundary = '\0'; let k_sentence_boundary = '\0'.to_string(); for (string, n) in sentences { @@ -257,7 +262,7 @@ impl UnigramTrainer { // Fill seed_sentencepieces for (count, character) in sall_chars { - seed_sentencepieces.push((character.to_string(), count.into())); + seed_sentencepieces.push((character.to_compact_string(), count.into())); } // sort by decreasing score @@ -265,7 +270,7 @@ impl UnigramTrainer { for (score, char_string) in substr_index { // Just in case assert!(self.is_valid_sentencepiece(char_string)); - let string: String = char_string.iter().collect(); + let string: CompactString = char_string.iter().collect(); seed_sentencepieces.push((string, score.into())); if seed_sentencepieces.len() >= self.seed_size { break; @@ -378,7 +383,7 @@ impl UnigramTrainer { continue; } else if alternatives[id].is_empty() { // no alternatives. Keeps this entry. - new_pieces.push((token.to_string(), *score)); + new_pieces.push((token.to_compact_string(), *score)); } else { let mut f = 0.0; // the frequency of pieces[i]; @@ -616,7 +621,11 @@ impl Trainer for UnigramTrainer { /// Train a Unigram model fn train(&self, model: &mut Unigram) -> Result> { - let sentences: Vec<_> = self.words.iter().map(|(s, i)| (s.to_owned(), *i)).collect(); + let sentences: Vec<_> = self + .words + .iter() + .map(|(s, i)| (s.to_compact_string(), *i)) + .collect(); self.do_train(sentences, model) } @@ -629,20 +638,20 @@ impl Trainer for UnigramTrainer { where I: Iterator + Send, S: AsRef + Send, - F: Fn(&str) -> Result> + Sync, + F: Fn(&str) -> Result> + Sync, { - let words: Result> = iterator + let words: Result> = iterator .maybe_par_bridge() .map(|sequence| { let words = process(sequence.as_ref())?; - let mut map = HashMap::new(); + let mut map = FxHashMap::default(); for word in words { map.entry(word).and_modify(|c| *c += 1).or_insert(1); } Ok(map) }) .reduce( - || Ok(HashMap::new()), + || Ok(FxHashMap::default()), |acc, ws| { let mut acc = acc?; for (k, v) in ws? { @@ -670,10 +679,7 @@ mod tests { .build() .unwrap(); - let sentences = vec![ - ("This is a".to_string(), 1), - ("こんにちは友達".to_string(), 1), - ]; + let sentences = vec![("This is a".into(), 1), ("こんにちは友達".into(), 1)]; let required_chars = trainer.required_chars(&sentences); assert_eq!(required_chars.len(), 13); @@ -716,18 +722,18 @@ mod tests { fn test_initial_alphabet() { let trainer = UnigramTrainerBuilder::default() .show_progress(false) - .initial_alphabet(HashSet::from_iter(vec!['a', 'b', 'c', 'd', 'e', 'f'])) + .initial_alphabet(FxHashSet::from_iter(vec!['a', 'b', 'c', 'd', 'e', 'f'])) .build() .unwrap(); - let sentences = vec![("こんにちは友達".to_string(), 1)]; + let sentences = vec![("こんにちは友達".into(), 1)]; let required_chars = trainer.required_chars(&sentences); assert_eq!( required_chars, vec!["こ", "ん", "に", "ち", "は", "友", "達", "a", "b", "c", "d", "e", "f"] .into_iter() - .map(|s| s.to_owned()) - .collect::>() + .map(|s| s.to_compact_string()) + .collect::>() ); } @@ -814,7 +820,7 @@ mod tests { #[test] fn test_to_log_prob() { - let mut a = vec![("".to_string(), 1.0), ("".to_string(), 2.0)]; + let mut a = vec![("".into(), 1.0), ("".into(), 2.0)]; to_log_prob(&mut a); let scores = a.iter().map(|(_, score)| *score).collect::>(); // ln(1) - ln(3) diff --git a/tokenizers/src/models/wordlevel/mod.rs b/tokenizers/src/models/wordlevel/mod.rs index 545db13a7..f04b33853 100644 --- a/tokenizers/src/models/wordlevel/mod.rs +++ b/tokenizers/src/models/wordlevel/mod.rs @@ -1,7 +1,8 @@ use super::OrderedVocabIter; use crate::tokenizer::{Model, Result, Token}; +use compact_str::CompactString; +use rustc_hash::FxHashMap; use serde_json::Value; -use std::collections::HashMap; use std::fs::File; use std::io::{BufReader, Read, Write}; use std::path::{Path, PathBuf}; @@ -12,7 +13,7 @@ mod trainer; // Re-export pub use trainer::*; -type Vocab = HashMap; +type Vocab = FxHashMap; #[derive(thiserror::Error, Debug)] pub enum Error { @@ -23,9 +24,9 @@ pub enum Error { } struct Config { - files: Option, - vocab: HashMap, - unk_token: String, + files: Option, + vocab: FxHashMap, + unk_token: CompactString, } /// A `WordLevelBuilder` can be used to create a `WordLevel` @@ -39,8 +40,8 @@ impl Default for WordLevelBuilder { Self { config: Config { files: None, - vocab: HashMap::new(), - unk_token: String::from(""), + vocab: FxHashMap::default(), + unk_token: CompactString::from(""), }, } } @@ -54,21 +55,21 @@ impl WordLevelBuilder { /// Set the input files. #[must_use] - pub fn files(mut self, vocab: String) -> Self { + pub fn files(mut self, vocab: CompactString) -> Self { self.config.files = Some(vocab); self } /// Set the vocab (token -> ID) mapping. #[must_use] - pub fn vocab(mut self, vocab: HashMap) -> Self { + pub fn vocab(mut self, vocab: FxHashMap) -> Self { self.config.vocab = vocab; self } /// The the `UNK` token for the vocab. #[must_use] - pub fn unk_token(mut self, unk_token: String) -> Self { + pub fn unk_token(mut self, unk_token: CompactString) -> Self { self.config.unk_token = unk_token; self } @@ -96,9 +97,9 @@ impl WordLevelBuilder { #[derive(PartialEq, Clone, Eq)] pub struct WordLevel { - vocab: HashMap, - vocab_r: HashMap, - pub unk_token: String, + vocab: FxHashMap, + vocab_r: FxHashMap, + pub unk_token: CompactString, } impl std::fmt::Debug for WordLevel { @@ -119,7 +120,7 @@ impl WordLevel { let vocab_file = File::open(vocab_path)?; let mut vocab_file = BufReader::new(vocab_file); let mut buffer = String::new(); - let mut vocab = HashMap::new(); + let mut vocab = FxHashMap::default(); vocab_file.read_to_string(&mut buffer)?; let json: Value = serde_json::from_str(&buffer)?; @@ -129,7 +130,7 @@ impl WordLevel { for (token, id) in m { if let Value::Number(id) = id { let id = id.as_u64().ok_or(Error::BadVocabulary)? as u32; - vocab.insert(token, id); + vocab.insert(token.into(), id); } } } @@ -139,7 +140,7 @@ impl WordLevel { } /// Initialize a WordLevel model from vocab and merges file. - pub fn from_file(vocab_path: &str, unk_token: String) -> Result { + pub fn from_file(vocab_path: &str, unk_token: CompactString) -> Result { let vocab = WordLevel::read_file(vocab_path)?; Self::builder().vocab(vocab).unk_token(unk_token).build() } @@ -148,9 +149,9 @@ impl WordLevel { impl Default for WordLevel { fn default() -> Self { Self { - vocab: HashMap::new(), - vocab_r: HashMap::new(), - unk_token: String::from(""), + vocab: FxHashMap::default(), + vocab_r: FxHashMap::default(), + unk_token: CompactString::from(""), } } } @@ -162,7 +163,7 @@ impl Model for WordLevel { if let Some(&id) = self.vocab.get(token) { Ok(vec![Token { id, - value: token.to_owned(), + value: token.into(), offsets: (0, token.len()), }]) } else if let Some(&unk_id) = self.vocab.get(&self.unk_token) { @@ -180,11 +181,11 @@ impl Model for WordLevel { self.vocab.get(token).copied() } - fn id_to_token(&self, id: u32) -> Option { + fn id_to_token(&self, id: u32) -> Option { self.vocab_r.get(&id).cloned() } - fn get_vocab(&self) -> HashMap { + fn get_vocab(&self) -> FxHashMap { self.vocab.clone() } @@ -217,6 +218,8 @@ impl Model for WordLevel { #[cfg(test)] mod tests { + use compact_str::ToCompactString; + use super::*; #[test] @@ -227,14 +230,14 @@ mod tests { .collect(); let wordlevel = WordLevelBuilder::default() .vocab(vocab) - .unk_token("".to_string()) + .unk_token("".to_compact_string()) .build() .unwrap(); let tokens = wordlevel.tokenize("c").unwrap(); - assert_eq!(tokens, vec![Token::new(0u32, "".into(), (0, 1)),]); + assert_eq!(tokens, vec![Token::new(0u32, "", (0, 1)),]); let tokens = wordlevel.tokenize("a").unwrap(); - assert_eq!(tokens, vec![Token::new(1u32, "a".into(), (0, 1)),]); + assert_eq!(tokens, vec![Token::new(1u32, "a", (0, 1)),]); } #[test] @@ -242,7 +245,7 @@ mod tests { let vocab: Vocab = [("a".into(), 0), ("b".into(), 1)].iter().cloned().collect(); let wordlevel = WordLevelBuilder::default().vocab(vocab).build().unwrap(); let tokens = wordlevel.tokenize("a").unwrap(); - assert_eq!(tokens, vec![Token::new(0u32, "a".into(), (0, 1)),]); + assert_eq!(tokens, vec![Token::new(0u32, "a", (0, 1)),]); let error = wordlevel.tokenize("c").err().unwrap(); assert!(error.is::()); diff --git a/tokenizers/src/models/wordlevel/serialization.rs b/tokenizers/src/models/wordlevel/serialization.rs index a077a4999..d7a1f8b67 100644 --- a/tokenizers/src/models/wordlevel/serialization.rs +++ b/tokenizers/src/models/wordlevel/serialization.rs @@ -83,6 +83,8 @@ impl<'de> Visitor<'de> for WordLevelVisitor { #[cfg(test)] mod tests { + use compact_str::ToCompactString; + use crate::models::wordlevel::{Vocab, WordLevel, WordLevelBuilder}; #[test] @@ -102,7 +104,7 @@ mod tests { .collect(); let wordlevel = WordLevelBuilder::default() .vocab(vocab) - .unk_token("".to_string()) + .unk_token("".to_compact_string()) .build() .unwrap(); let wl_s = r#"{"type":"WordLevel","vocab":{"":0,"b":2},"unk_token":""}"#; diff --git a/tokenizers/src/models/wordlevel/trainer.rs b/tokenizers/src/models/wordlevel/trainer.rs index c52ad08d7..f02b05e13 100644 --- a/tokenizers/src/models/wordlevel/trainer.rs +++ b/tokenizers/src/models/wordlevel/trainer.rs @@ -1,9 +1,10 @@ use super::WordLevel; use crate::utils::parallelism::*; use crate::{AddedToken, Result, Trainer}; +use compact_str::CompactString; +use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use std::cmp::Ordering; -use std::collections::HashMap; #[non_exhaustive] #[derive(Debug, Clone, Builder, Serialize, Deserialize)] @@ -22,7 +23,7 @@ pub struct WordLevelTrainer { pub special_tokens: Vec, #[builder(default, private)] - words: HashMap, + words: FxHashMap, } impl Default for WordLevelTrainer { @@ -38,14 +39,14 @@ impl WordLevelTrainer { fn do_train( &self, - word_counts: &HashMap, + word_counts: &FxHashMap, model: &mut WordLevel, ) -> Result> { let mut ordered_counts = word_counts.iter().collect::>(); //sort the word counts first by inverse counts and then by word, in order //to keep the sorting deterministic in case of equal counts - let cmp = |l: &(&String, &u64), r: &(&String, &u64)| -> Ordering { + let cmp = |l: &(&CompactString, &u64), r: &(&CompactString, &u64)| -> Ordering { let count_comp: Ordering = l.1.cmp(r.1); if count_comp != Ordering::Equal { return count_comp.reverse(); @@ -98,20 +99,20 @@ impl Trainer for WordLevelTrainer { where I: Iterator + Send, S: AsRef + Send, - F: Fn(&str) -> Result> + Sync, + F: Fn(&str) -> Result> + Sync, { - let words: Result> = iterator + let words: Result> = iterator .maybe_par_bridge() .map(|sequence| { let words = process(sequence.as_ref())?; - let mut map = HashMap::new(); + let mut map = FxHashMap::default(); for word in words { map.entry(word).and_modify(|c| *c += 1).or_insert(1); } Ok(map) }) .reduce( - || Ok(HashMap::new()), + || Ok(FxHashMap::default()), |acc, ws| { let mut acc = acc?; for (k, v) in ws? { @@ -132,7 +133,7 @@ mod tests { #[test] fn test_train() { - let word_counts: HashMap = [ + let word_counts: FxHashMap = [ ("the".into(), 25), ("roses".into(), 22), ("are".into(), 24), @@ -151,7 +152,7 @@ mod tests { let mut model = WordLevel::default(); trainer.do_train(&word_counts, &mut model).unwrap(); - let expected_vocab: HashMap = [ + let expected_vocab: FxHashMap = [ ("the".into(), 0), ("are".into(), 1), ("roses".into(), 2), @@ -167,7 +168,7 @@ mod tests { trainer.min_frequency = 15; let mut model = WordLevel::default(); trainer.do_train(&word_counts, &mut model).unwrap(); - let expected_vocab: HashMap = [ + let expected_vocab: FxHashMap = [ ("the".into(), 0), ("are".into(), 1), ("roses".into(), 2), diff --git a/tokenizers/src/models/wordpiece/mod.rs b/tokenizers/src/models/wordpiece/mod.rs index 0c63405c1..aaf32ef0e 100644 --- a/tokenizers/src/models/wordpiece/mod.rs +++ b/tokenizers/src/models/wordpiece/mod.rs @@ -3,9 +3,10 @@ use crate::models::bpe::BPE; use crate::tokenizer::{Model, Result, Token}; +use compact_str::CompactString; +use rustc_hash::FxHashMap; use std::{ borrow::Cow, - collections::HashMap, fs::File, io::prelude::*, io::{BufRead, BufReader}, @@ -22,14 +23,14 @@ pub enum Error { MissingUnkToken, } -type Vocab = HashMap; -type VocabR = HashMap; +type Vocab = FxHashMap; +type VocabR = FxHashMap; struct Config { - files: Option, + files: Option, vocab: Vocab, - unk_token: String, - continuing_subword_prefix: String, + unk_token: CompactString, + continuing_subword_prefix: CompactString, max_input_chars_per_word: usize, } @@ -43,9 +44,9 @@ impl Default for WordPieceBuilder { Self { config: Config { files: None, - vocab: HashMap::new(), - unk_token: String::from("[UNK]"), - continuing_subword_prefix: String::from("##"), + vocab: FxHashMap::default(), + unk_token: CompactString::from("[UNK]"), + continuing_subword_prefix: CompactString::from("##"), max_input_chars_per_word: 100, }, } @@ -60,7 +61,7 @@ impl WordPieceBuilder { /// Set the input files. #[must_use] - pub fn files(mut self, vocab: String) -> Self { + pub fn files(mut self, vocab: CompactString) -> Self { self.config.files = Some(vocab); self } @@ -74,14 +75,14 @@ impl WordPieceBuilder { /// The the `UNK` token for the vocab. #[must_use] - pub fn unk_token(mut self, unk_token: String) -> Self { + pub fn unk_token(mut self, unk_token: CompactString) -> Self { self.config.unk_token = unk_token; self } /// Set the prefix for continuing subwords. #[must_use] - pub fn continuing_subword_prefix(mut self, continuing_subword_prefix: String) -> Self { + pub fn continuing_subword_prefix(mut self, continuing_subword_prefix: CompactString) -> Self { self.config.continuing_subword_prefix = continuing_subword_prefix; self } @@ -123,8 +124,8 @@ impl WordPieceBuilder { pub struct WordPiece { vocab: Vocab, vocab_r: VocabR, - pub unk_token: String, - pub continuing_subword_prefix: String, + pub unk_token: CompactString, + pub continuing_subword_prefix: CompactString, pub max_input_chars_per_word: usize, } @@ -142,10 +143,10 @@ impl std::fmt::Debug for WordPiece { impl Default for WordPiece { fn default() -> Self { Self { - vocab: HashMap::new(), - vocab_r: HashMap::new(), - unk_token: String::from("[UNK]"), - continuing_subword_prefix: String::from("##"), + vocab: FxHashMap::default(), + vocab_r: FxHashMap::default(), + unk_token: CompactString::from("[UNK]"), + continuing_subword_prefix: CompactString::from("##"), max_input_chars_per_word: 100, } } @@ -162,10 +163,10 @@ impl WordPiece { let file = File::open(vocab)?; let file = BufReader::new(file); - let mut vocab = HashMap::new(); + let mut vocab = FxHashMap::default(); for (index, line) in file.lines().enumerate() { let line = line?; - vocab.insert(line.trim_end().to_owned(), index as u32); + vocab.insert(line.trim_end().into(), index as u32); } Ok(vocab) @@ -173,7 +174,7 @@ impl WordPiece { /// Initialize a `WordPiece` model from a vocab mapping file. pub fn from_file(vocab: &str) -> WordPieceBuilder { - WordPiece::builder().files(vocab.to_owned()) + WordPiece::builder().files(vocab.into()) } /// Create a `WordPiece` model from a `BPE` model. @@ -192,7 +193,7 @@ impl WordPiece { impl Model for WordPiece { type Trainer = WordPieceTrainer; - fn get_vocab(&self) -> HashMap { + fn get_vocab(&self) -> FxHashMap { self.vocab.clone() } @@ -230,7 +231,7 @@ impl Model for WordPiece { if self.vocab.contains_key(substr.as_ref()) { cur_str = Some(Token { id: self.vocab[substr.as_ref()], - value: substr.to_string(), + value: substr.into(), offsets: (start, end), }); break; @@ -265,7 +266,7 @@ impl Model for WordPiece { self.vocab.get(token).copied() } - fn id_to_token(&self, id: u32) -> Option { + fn id_to_token(&self, id: u32) -> Option { self.vocab_r.get(&id).cloned() } @@ -280,7 +281,7 @@ impl Model for WordPiece { .iter() .collect(); let mut vocab_file = File::create(&vocab_path)?; - let mut vocab: Vec<(&String, &u32)> = self.vocab.iter().collect(); + let mut vocab: Vec<(&CompactString, &u32)> = self.vocab.iter().collect(); vocab.sort_unstable_by_key(|k| *k.1); vocab_file.write_all( &vocab diff --git a/tokenizers/src/models/wordpiece/trainer.rs b/tokenizers/src/models/wordpiece/trainer.rs index 58a5abc8f..2084a247e 100644 --- a/tokenizers/src/models/wordpiece/trainer.rs +++ b/tokenizers/src/models/wordpiece/trainer.rs @@ -1,8 +1,9 @@ use super::WordPiece; use crate::models::bpe::{BpeTrainer, BpeTrainerBuilder, BPE}; use crate::tokenizer::{AddedToken, Result, Trainer}; +use compact_str::CompactString; +use rustc_hash::FxHashSet; use serde::{Deserialize, Serialize}; -use std::collections::HashSet; /// A `WordPieceTrainerBuilder` can be used to create a `WordPieceTrainer` with a custom /// configuration. @@ -61,21 +62,21 @@ impl WordPieceTrainerBuilder { /// Set the initial alphabet #[must_use] - pub fn initial_alphabet(mut self, alphabet: HashSet) -> Self { + pub fn initial_alphabet(mut self, alphabet: FxHashSet) -> Self { self.bpe_trainer_builder = self.bpe_trainer_builder.initial_alphabet(alphabet); self } /// Set the continuing_subword_prefix #[must_use] - pub fn continuing_subword_prefix(mut self, prefix: String) -> Self { + pub fn continuing_subword_prefix(mut self, prefix: CompactString) -> Self { self.bpe_trainer_builder = self.bpe_trainer_builder.continuing_subword_prefix(prefix); self } /// Set the end_of_word_suffix #[must_use] - pub fn end_of_word_suffix(mut self, suffix: String) -> Self { + pub fn end_of_word_suffix(mut self, suffix: CompactString) -> Self { self.bpe_trainer_builder = self.bpe_trainer_builder.end_of_word_suffix(suffix); self } @@ -134,27 +135,27 @@ impl WordPieceTrainer { self.bpe_trainer.limit_alphabet = limit; } - pub fn initial_alphabet(&self) -> &HashSet { + pub fn initial_alphabet(&self) -> &FxHashSet { &self.bpe_trainer.initial_alphabet } - pub fn set_initial_alphabet(&mut self, alphabet: HashSet) { + pub fn set_initial_alphabet(&mut self, alphabet: FxHashSet) { self.bpe_trainer.initial_alphabet = alphabet; } - pub fn continuing_subword_prefix(&self) -> &Option { + pub fn continuing_subword_prefix(&self) -> &Option { &self.bpe_trainer.continuing_subword_prefix } - pub fn set_continuing_subword_prefix(&mut self, prefix: Option) { + pub fn set_continuing_subword_prefix(&mut self, prefix: Option) { self.bpe_trainer.continuing_subword_prefix = prefix; } - pub fn end_of_word_suffix(&self) -> &Option { + pub fn end_of_word_suffix(&self) -> &Option { &self.bpe_trainer.end_of_word_suffix } - pub fn set_end_of_word_suffix(&mut self, suffix: Option) { + pub fn set_end_of_word_suffix(&mut self, suffix: Option) { self.bpe_trainer.end_of_word_suffix = suffix; } @@ -192,7 +193,7 @@ impl Trainer for WordPieceTrainer { where I: Iterator + Send, S: AsRef + Send, - F: Fn(&str) -> Result> + Sync, + F: Fn(&str) -> Result> + Sync, { self.bpe_trainer.feed(iterator, process) } diff --git a/tokenizers/src/normalizers/prepend.rs b/tokenizers/src/normalizers/prepend.rs index 4e318c259..d6a7ae20f 100644 --- a/tokenizers/src/normalizers/prepend.rs +++ b/tokenizers/src/normalizers/prepend.rs @@ -1,15 +1,18 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; +use compact_str::CompactString; use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Deserialize, Serialize)] #[serde(tag = "type")] pub struct Prepend { - pub prepend: String, + pub prepend: CompactString, } impl Prepend { - pub fn new(prepend: String) -> Self { - Self { prepend } + pub fn new(prepend: impl Into) -> Self { + Self { + prepend: prepend.into(), + } } } diff --git a/tokenizers/src/normalizers/replace.rs b/tokenizers/src/normalizers/replace.rs index 565757483..1b8176b4a 100644 --- a/tokenizers/src/normalizers/replace.rs +++ b/tokenizers/src/normalizers/replace.rs @@ -2,24 +2,19 @@ use crate::tokenizer::pattern::Pattern; use crate::tokenizer::Decoder; use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::SysRegex; +use compact_str::{CompactString, ToCompactString}; use serde::{Deserialize, Serialize}; /// Represents the different patterns that `Replace` can use #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] pub enum ReplacePattern { - String(String), - Regex(String), + String(CompactString), + Regex(CompactString), } -impl From for ReplacePattern { - fn from(v: String) -> Self { - Self::String(v) - } -} - -impl From<&str> for ReplacePattern { - fn from(v: &str) -> Self { - Self::String(v.to_owned()) +impl> From for ReplacePattern { + fn from(v: T) -> Self { + Self::String(v.into()) } } @@ -29,7 +24,7 @@ impl From<&str> for ReplacePattern { #[serde(tag = "type")] struct ReplaceDeserializer { pattern: ReplacePattern, - content: String, + content: CompactString, } impl std::convert::TryFrom for Replace { @@ -46,14 +41,14 @@ impl std::convert::TryFrom for Replace { #[serde(tag = "type", try_from = "ReplaceDeserializer")] pub struct Replace { pattern: ReplacePattern, - pub content: String, + pub content: CompactString, #[serde(skip)] regex: SysRegex, } impl Clone for Replace { fn clone(&self) -> Self { - Self::new(self.pattern.clone(), &self.content).unwrap() + Self::new(self.pattern.clone(), &*self.content).unwrap() } } @@ -64,7 +59,10 @@ impl PartialEq for Replace { } impl Replace { - pub fn new, C: Into>(pattern: I, content: C) -> Result { + pub fn new, C: Into>( + pattern: I, + content: C, + ) -> Result { let pattern: ReplacePattern = pattern.into(); let regex = match &pattern { ReplacePattern::String(s) => SysRegex::new(®ex::escape(s))?, @@ -86,13 +84,17 @@ impl Normalizer for Replace { } impl Decoder for Replace { - fn decode_chain(&self, tokens: Vec) -> Result> { + fn decode_chain( + &self, + tokens: Vec, + ) -> Result> { tokens .into_iter() - .map(|token| -> Result { - let mut new_token = "".to_string(); + .map(|token| -> Result { + let token = token.to_compact_string(); + let mut new_token = CompactString::from(""); - for ((start, stop), is_match) in (&self.regex).find_matches(&token)? { + for ((start, stop), is_match) in (&self.regex).find_matches(token.as_str())? { if is_match { new_token.push_str(&self.content); } else { @@ -126,7 +128,7 @@ mod tests { let normalized = "This is a test"; let mut n = NormalizedString::from(original); - Replace::new(ReplacePattern::Regex(r"\s+".into()), ' ') + Replace::new(ReplacePattern::Regex(r"\s+".into()), " ") .unwrap() .normalize(&mut n) .unwrap(); @@ -141,7 +143,7 @@ mod tests { assert_eq!(serde_json::to_string(&replace).unwrap(), replace_s); assert_eq!(serde_json::from_str::(replace_s).unwrap(), replace); - let replace = Replace::new(ReplacePattern::Regex(r"\s+".into()), ' ').unwrap(); + let replace = Replace::new(ReplacePattern::Regex(r"\s+".into()), " ").unwrap(); let replace_s = r#"{"type":"Replace","pattern":{"Regex":"\\s+"},"content":" "}"#; assert_eq!(serde_json::to_string(&replace).unwrap(), replace_s); assert_eq!(serde_json::from_str::(replace_s).unwrap(), replace); @@ -149,10 +151,15 @@ mod tests { #[test] fn test_replace_decode() { - let original = vec!["hello".to_string(), "_hello".to_string()]; + let original = vec!["hello", "_hello"]; let replace = Replace::new("_", " ").unwrap(); assert_eq!( - replace.decode_chain(original).unwrap(), + replace + .decode_chain(original) + .unwrap() + .into_iter() + .map(|t| t.to_compact_string()) + .collect::>(), vec!["hello", " hello"] ); } diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 8396f1a7b..e9078b380 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -1,6 +1,7 @@ use std::collections::{HashMap, HashSet}; use crate::utils::SysRegex; +use compact_str::{CompactString, ToCompactString}; use serde::{Deserialize, Serialize}; use crate::tokenizer::{ @@ -157,10 +158,14 @@ impl PreTokenizer for ByteLevel { /// the fact that single token decoded might be a byte not representable as /// as String. impl Decoder for ByteLevel { - fn decode_chain(&self, tokens: Vec) -> Result> { + fn decode_chain( + &self, + tokens: Vec, + ) -> Result> { let toks = tokens .into_iter() .flat_map(|t| { + let t: CompactString = t.to_compact_string(); t.chars() .try_fold(vec![], |mut acc, c| { CHAR_BYTES.get(&c).map(|b| { @@ -171,7 +176,7 @@ impl Decoder for ByteLevel { .unwrap_or_else(|| t.as_bytes().to_vec()) }) .collect::>(); - Ok(vec![String::from_utf8_lossy(&toks).to_string()]) + Ok(Vec::from([CompactString::from_utf8_lossy(&toks)])) } } @@ -292,16 +297,13 @@ mod tests { let bytelevel = ByteLevel::default().add_prefix_space(false); assert_eq!( bytelevel - .decode_chain( - vec![ - "Hello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", - "?" - ] - .into_iter() - .map(|s| s.into()) - .collect::>() - ) - .unwrap(), + .decode_chain(vec![ + "Hello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", "?" + ]) + .unwrap() + .into_iter() + .map(|t| t.to_compact_string()) + .collect::>(), vec!["Hello my friend, how is your day going?"] ); } @@ -353,10 +355,16 @@ mod tests { .get_splits(OffsetReferential::Original, OffsetType::Byte) .iter() .flat_map(|(s, _, _)| s.split("").map(|t| t.into())) - .collect::>(); + .collect::>(); assert_eq!( sample, - bytelevel.decode_chain(separated_tokens).unwrap().join("") + bytelevel + .decode_chain(separated_tokens) + .unwrap() + .into_iter() + .map(|t| t.to_compact_string()) + .collect::>() + .join("") ); } } @@ -445,7 +453,7 @@ mod tests { let mut encoding = Encoding::new( vec![0; 5], vec![], - vec!["Ġl".into(), "ove".into(), "Ġl".into(), "ove".into()], + vec!["Ġl", "ove", "Ġl", "ove"], vec![], vec![(0, 1), (1, 4), (0, 1), (1, 4)], vec![], @@ -459,7 +467,7 @@ mod tests { Encoding::new( vec![0; 5], vec![], - vec!["Ġl".into(), "ove".into(), "Ġl".into(), "ove".into()], + vec!["Ġl", "ove", "Ġl", "ove"], vec![], vec![(0, 1), (1, 4), (0, 1), (1, 4)], vec![], @@ -475,13 +483,7 @@ mod tests { let start = Encoding::new( vec![0; 5], vec![], - vec![ - "Ġ".into(), - "ĠĠĠĠHelloĠĠ".into(), - "ĠĠHello".into(), - "HelloĠĠ".into(), - "ĠĠĠĠ".into(), - ], + vec!["Ġ", "ĠĠĠĠHelloĠĠ", "ĠĠHello", "HelloĠĠ", "ĠĠĠĠ"], vec![], vec![(0, 1), (0, 11), (11, 18), (18, 25), (25, 29)], vec![], @@ -492,13 +494,7 @@ mod tests { let expected = Encoding::new( vec![0; 5], vec![0; 5], - vec![ - "Ġ".into(), - "ĠĠĠĠHelloĠĠ".into(), - "ĠĠHello".into(), - "HelloĠĠ".into(), - "ĠĠĠĠ".into(), - ], + vec!["Ġ", "ĠĠĠĠHelloĠĠ", "ĠĠHello", "HelloĠĠ", "ĠĠĠĠ"], vec![], vec![(0, 0), (4, 9), (13, 18), (18, 23), (29, 29)], vec![], @@ -517,16 +513,16 @@ mod tests { vec![0; 10], vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1], vec![ - "Ġ".into(), - "ĠĠĠĠHelloĠĠ".into(), - "ĠĠHello".into(), - "HelloĠĠ".into(), - "ĠĠĠĠ".into(), - "Ġ".into(), - "ĠĠĠĠHelloĠĠ".into(), - "ĠĠHello".into(), - "HelloĠĠ".into(), - "ĠĠĠĠ".into(), + "Ġ", + "ĠĠĠĠHelloĠĠ", + "ĠĠHello", + "HelloĠĠ", + "ĠĠĠĠ", + "Ġ", + "ĠĠĠĠHelloĠĠ", + "ĠĠHello", + "HelloĠĠ", + "ĠĠĠĠ", ], vec![], vec![ @@ -559,15 +555,11 @@ mod tests { let byte_level = ByteLevel::default(); assert_eq!( byte_level - .decode_chain(vec![ - "Hello".into(), - "Ġthere".into(), - "Ġdear".into(), - "Ġfriend!".into(), - "Ġ".into(), - "[PA D]".into() - ]) - .unwrap(), + .decode_chain(vec!["Hello", "Ġthere", "Ġdear", "Ġfriend!", "Ġ", "[PA D]"]) + .unwrap() + .into_iter() + .map(|t| t.to_compact_string()) + .collect::>(), vec!["Hello there dear friend! [PA D]"] ); } @@ -585,13 +577,13 @@ mod tests { let byte_level: ByteLevel = serde_json::from_str( r#"{"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": false, "use_regex": true}"#, ) - .unwrap(); + .unwrap(); assert!(byte_level.use_regex); let byte_level: ByteLevel = serde_json::from_str( r#"{"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": false, "use_regex": false}"#, ) - .unwrap(); + .unwrap(); assert!(!byte_level.use_regex); } } diff --git a/tokenizers/src/pre_tokenizers/metaspace.rs b/tokenizers/src/pre_tokenizers/metaspace.rs index d821f1184..bcafb6875 100644 --- a/tokenizers/src/pre_tokenizers/metaspace.rs +++ b/tokenizers/src/pre_tokenizers/metaspace.rs @@ -1,4 +1,5 @@ use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; +use compact_str::{CompactString, ToCompactString}; use serde::{de, Deserialize, Deserializer, Serialize}; /// Enum representing options for the metaspace prepending scheme. @@ -28,7 +29,7 @@ pub struct Metaspace { pub prepend_scheme: PrependScheme, pub split: bool, #[serde(skip)] - str_rep: String, + str_rep: CompactString, } impl<'de> Deserialize<'de> for Metaspace { @@ -56,7 +57,7 @@ impl<'de> Deserialize<'de> for Metaspace { pub prepend_scheme: PrependScheme, pub split: Option, #[serde(rename = "str_rep")] - _str_rep: Option, + _str_rep: Option, } let mut helper = MetaspaceHelper::deserialize(deserializer)?; @@ -81,7 +82,7 @@ impl Metaspace { pub fn new(replacement: char, prepend_scheme: PrependScheme, split: bool) -> Self { Self { replacement, - str_rep: replacement.to_string(), + str_rep: replacement.to_compact_string(), prepend_scheme, split, } @@ -93,7 +94,7 @@ impl Metaspace { pub fn set_replacement(&mut self, replacement: char) { self.replacement = replacement; - self.str_rep = replacement.to_string(); + self.str_rep = replacement.to_compact_string(); } pub fn get_split(&self) -> bool { @@ -148,12 +149,16 @@ impl PreTokenizer for Metaspace { } impl Decoder for Metaspace { - fn decode_chain(&self, tokens: Vec) -> Result> { + fn decode_chain( + &self, + tokens: Vec, + ) -> Result> { Ok(tokens - .iter() + .into_iter() .enumerate() .map(|(i, token)| { token + .to_compact_string() .chars() .flat_map(|c| { if c == self.replacement { @@ -166,7 +171,7 @@ impl Decoder for Metaspace { Some(c) } }) - .collect::() + .collect::() }) .collect()) } @@ -357,14 +362,20 @@ mod tests { fn decode() { let decoder = Metaspace::new('▁', PrependScheme::Always, true); let res = decoder - .decode_chain(vec!["▁Hey".into(), "▁friend!".into()]) - .unwrap(); + .decode_chain(vec!["▁Hey", "▁friend!"]) + .unwrap() + .into_iter() + .map(|t| t.to_compact_string()) + .collect::>(); assert_eq!(res, vec!["Hey", " friend!"]); let decoder = Metaspace::new('▁', PrependScheme::Never, true); let res = decoder - .decode_chain(vec!["▁Hey".into(), "▁friend!".into()]) - .unwrap(); + .decode_chain(vec!["▁Hey", "▁friend!"]) + .unwrap() + .into_iter() + .map(|t| t.to_compact_string()) + .collect::>(); assert_eq!(res, vec![" Hey", " friend!"]); } } diff --git a/tokenizers/src/pre_tokenizers/split.rs b/tokenizers/src/pre_tokenizers/split.rs index 5f7362f71..88c71f013 100644 --- a/tokenizers/src/pre_tokenizers/split.rs +++ b/tokenizers/src/pre_tokenizers/split.rs @@ -1,4 +1,5 @@ use crate::utils::SysRegex; +use compact_str::CompactString; use serde::{Deserialize, Deserializer, Serialize}; use crate::tokenizer::{ @@ -8,19 +9,13 @@ use crate::tokenizer::{ /// Represents the different patterns that `Split` can use #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] pub enum SplitPattern { - String(String), - Regex(String), + String(CompactString), + Regex(CompactString), } -impl From for SplitPattern { - fn from(v: String) -> Self { - Self::String(v) - } -} - -impl From<&str> for SplitPattern { - fn from(v: &str) -> Self { - Self::String(v.to_owned()) +impl> From for SplitPattern { + fn from(v: T) -> Self { + Self::String(v.into()) } } diff --git a/tokenizers/src/processors/bert.rs b/tokenizers/src/processors/bert.rs index 179391122..9a2820635 100644 --- a/tokenizers/src/processors/bert.rs +++ b/tokenizers/src/processors/bert.rs @@ -1,4 +1,5 @@ use crate::tokenizer::{Encoding, PostProcessor, Result}; +use compact_str::CompactString; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::iter::FromIterator; @@ -6,29 +7,29 @@ use std::iter::FromIterator; #[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)] #[serde(tag = "type")] pub struct BertProcessing { - pub sep: (String, u32), - pub cls: (String, u32), + pub sep: (CompactString, u32), + pub cls: (CompactString, u32), } impl Default for BertProcessing { fn default() -> Self { - Self { - sep: ("[SEP]".into(), 102), - cls: ("[CLS]".into(), 101), - } + Self::new(("[SEP]", 102), ("[CLS]", 101)) } } impl BertProcessing { - pub fn new(sep: (String, u32), cls: (String, u32)) -> Self { - Self { sep, cls } + pub fn new(sep: (impl Into, u32), cls: (impl Into, u32)) -> Self { + Self { + sep: (sep.0.into(), sep.1), + cls: (cls.0.into(), cls.1), + } } - pub fn get_sep_copy(&self) -> (String, u32) { + pub fn get_sep_copy(&self) -> (CompactString, u32) { (self.sep.0.clone(), self.sep.1) } - pub fn get_cls_copy(&self) -> (String, u32) { + pub fn get_cls_copy(&self) -> (CompactString, u32) { (self.cls.0.clone(), self.cls.1) } } @@ -213,24 +214,19 @@ mod tests { use crate::Token; let encoding = Encoding::from_tokens( vec![ - Token::new(12, "Hello".into(), (0, 5)), - Token::new(14, "there".into(), (6, 11)), + Token::new(12, "Hello", (0, 5)), + Token::new(14, "there", (6, 11)), ], 0, ); - let pair = Encoding::from_tokens(vec![Token::new(15, "pair".into(), (0, 4))], 0); + let pair = Encoding::from_tokens(vec![Token::new(15, "pair", (0, 4))], 0); let single_encoding = processor.process(encoding.clone(), None, true).unwrap(); assert_eq!( single_encoding, Encoding::new( vec![101, 12, 14, 102], vec![0, 0, 0, 0], - vec![ - "[CLS]".into(), - "Hello".into(), - "there".into(), - "[SEP]".into() - ], + vec!["[CLS]", "Hello", "there", "[SEP]"], vec![None, None, None, None], vec![(0, 0), (0, 5), (6, 11), (0, 0)], vec![1, 0, 0, 1], @@ -249,14 +245,7 @@ mod tests { Encoding::new( vec![101, 12, 14, 102, 15, 102], vec![0, 0, 0, 0, 1, 1], - vec![ - "[CLS]".into(), - "Hello".into(), - "there".into(), - "[SEP]".into(), - "pair".into(), - "[SEP]".into() - ], + vec!["[CLS]", "Hello", "there", "[SEP]", "pair", "[SEP]"], vec![None, None, None, None, None, None], vec![(0, 0), (0, 5), (6, 11), (0, 0), (0, 4), (0, 0)], vec![1, 0, 0, 1, 0, 1], @@ -277,7 +266,7 @@ mod tests { Encoding::new( vec![12, 14, 15], vec![0, 0, 1], - vec!["Hello".into(), "there".into(), "pair".into(),], + vec!["Hello", "there", "pair",], vec![None, None, None], vec![(0, 5), (6, 11), (0, 4)], vec![0, 0, 0], diff --git a/tokenizers/src/processors/roberta.rs b/tokenizers/src/processors/roberta.rs index 5bbc4ea63..88725c6fd 100644 --- a/tokenizers/src/processors/roberta.rs +++ b/tokenizers/src/processors/roberta.rs @@ -1,5 +1,6 @@ use crate::processors::byte_level::process_offsets; use crate::tokenizer::{Encoding, PostProcessor, Result}; +use compact_str::CompactString; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::iter::FromIterator; @@ -7,8 +8,8 @@ use std::iter::FromIterator; #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] #[serde(tag = "type")] pub struct RobertaProcessing { - pub sep: (String, u32), - pub cls: (String, u32), + pub sep: (CompactString, u32), + pub cls: (CompactString, u32), pub trim_offsets: bool, pub add_prefix_space: bool, } @@ -25,10 +26,10 @@ impl Default for RobertaProcessing { } impl RobertaProcessing { - pub fn new(sep: (String, u32), cls: (String, u32)) -> Self { + pub fn new(sep: (impl Into, u32), cls: (impl Into, u32)) -> Self { Self { - sep, - cls, + sep: (sep.0.into(), sep.1), + cls: (cls.0.into(), cls.1), ..Default::default() } } @@ -45,11 +46,11 @@ impl RobertaProcessing { self } - pub fn get_sep_copy(&self) -> (String, u32) { + pub fn get_sep_copy(&self) -> (CompactString, u32) { (self.sep.0.clone(), self.sep.1) } - pub fn get_cls_copy(&self) -> (String, u32) { + pub fn get_cls_copy(&self) -> (CompactString, u32) { (self.cls.0.clone(), self.cls.1) } } @@ -263,19 +264,19 @@ mod tests { use crate::Token; let encoding = Encoding::from_tokens( vec![ - Token::new(12, "Hello".into(), (0, 5)), - Token::new(14, "there".into(), (6, 11)), + Token::new(12, "Hello", (0, 5)), + Token::new(14, "there", (6, 11)), ], 0, ); - let pair = Encoding::from_tokens(vec![Token::new(15, "pair".into(), (0, 4))], 0); + let pair = Encoding::from_tokens(vec![Token::new(15, "pair", (0, 4))], 0); let single_encoding = processor.process(encoding.clone(), None, true).unwrap(); assert_eq!( single_encoding, Encoding::new( vec![0, 12, 14, 2], vec![0, 0, 0, 0], - vec!["".into(), "Hello".into(), "there".into(), "".into()], + vec!["", "Hello", "there", ""], vec![None, None, None, None], vec![(0, 0), (0, 5), (6, 11), (0, 0)], vec![1, 0, 0, 1], @@ -294,15 +295,7 @@ mod tests { Encoding::new( vec![0, 12, 14, 2, 2, 15, 2], vec![0, 0, 0, 0, 0, 0, 0], - vec![ - "".into(), - "Hello".into(), - "there".into(), - "".into(), - "".into(), - "pair".into(), - "".into() - ], + vec!["", "Hello", "there", "", "", "pair", ""], vec![None, None, None, None, None, None, None], vec![(0, 0), (0, 5), (6, 11), (0, 0), (0, 0), (0, 4), (0, 0)], vec![1, 0, 0, 1, 1, 0, 1], @@ -324,7 +317,7 @@ mod tests { Encoding::new( vec![12, 14, 15], vec![0, 0, 0], - vec!["Hello".into(), "there".into(), "pair".into(),], + vec!["Hello", "there", "pair",], vec![None, None, None], vec![(0, 5), (6, 11), (0, 4)], vec![0, 0, 0], diff --git a/tokenizers/src/processors/sequence.rs b/tokenizers/src/processors/sequence.rs index 5cfb3eb51..3a50b5417 100644 --- a/tokenizers/src/processors/sequence.rs +++ b/tokenizers/src/processors/sequence.rs @@ -81,13 +81,7 @@ mod tests { let start = Encoding::new( vec![0; 5], vec![0; 5], - vec![ - "Ġ".into(), - "ĠĠĠĠHelloĠĠ".into(), - "ĠĠHello".into(), - "HelloĠĠ".into(), - "ĠĠĠĠ".into(), - ], + vec!["Ġ", "ĠĠĠĠHelloĠĠ", "ĠĠHello", "HelloĠĠ", "ĠĠĠĠ"], vec![], vec![(0, 1), (0, 11), (11, 18), (18, 25), (25, 29)], vec![], @@ -101,13 +95,7 @@ mod tests { let expected = Encoding::new( vec![0; 5], vec![0; 5], - vec![ - "Ġ".into(), - "ĠĠĠĠHelloĠĠ".into(), - "ĠĠHello".into(), - "HelloĠĠ".into(), - "ĠĠĠĠ".into(), - ], + vec!["Ġ", "ĠĠĠĠHelloĠĠ", "ĠĠHello", "HelloĠĠ", "ĠĠĠĠ"], vec![], vec![(0, 0), (4, 9), (13, 18), (18, 23), (29, 29)], vec![], @@ -129,16 +117,16 @@ mod tests { vec![0; 10], vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1], vec![ - "Ġ".into(), - "ĠĠĠĠHelloĠĠ".into(), - "ĠĠHello".into(), - "HelloĠĠ".into(), - "ĠĠĠĠ".into(), - "Ġ".into(), - "ĠĠĠĠHelloĠĠ".into(), - "ĠĠHello".into(), - "HelloĠĠ".into(), - "ĠĠĠĠ".into(), + "Ġ", + "ĠĠĠĠHelloĠĠ", + "ĠĠHello", + "HelloĠĠ", + "ĠĠĠĠ", + "Ġ", + "ĠĠĠĠHelloĠĠ", + "ĠĠHello", + "HelloĠĠ", + "ĠĠĠĠ", ], vec![], vec![ diff --git a/tokenizers/src/processors/template.rs b/tokenizers/src/processors/template.rs index 74b4fe1c5..7a9091e7a 100644 --- a/tokenizers/src/processors/template.rs +++ b/tokenizers/src/processors/template.rs @@ -57,6 +57,7 @@ //! [`TemplateProcessing`]: struct.TemplateProcessing.html //! use crate::{Encoding, PostProcessor, Result}; +use compact_str::{format_compact, CompactString, ToCompactString}; use itertools::Itertools; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; @@ -95,7 +96,7 @@ pub enum Sequence { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] pub enum Piece { Sequence { id: Sequence, type_id: u32 }, - SpecialToken { id: String, type_id: u32 }, + SpecialToken { id: CompactString, type_id: u32 }, } impl Piece { @@ -130,7 +131,7 @@ impl Piece { } } else { Some(Self::SpecialToken { - id: s.to_owned(), + id: s.into(), type_id: 0, }) } @@ -144,6 +145,26 @@ impl Piece { } } +impl TryFrom for Piece { + type Error = String; + + fn try_from(s: CompactString) -> StdResult { + let s = s.to_string(); + let parts = s.split(':').collect::>(); + + let err = || format!("Cannot build Piece from compact string \"{s}\""); + match parts.as_slice() { + [id, type_id] => { + let type_id: u32 = type_id.parse().map_err(|_| err())?; + let piece = Self::extract_id(id).ok_or_else(err)?; + Ok(piece.with_type_id(type_id)) + } + [id] => Self::extract_id(id).ok_or_else(err), + _ => Err(err()), + } + } +} + impl TryFrom for Piece { type Error = String; @@ -192,15 +213,15 @@ impl TryFrom<&str> for Piece { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] pub struct SpecialToken { /// A unique id used to identify this SpecialToken in the template - id: String, + id: CompactString, /// The list of associated ids ids: Vec, /// The list of associated tokens - tokens: Vec, + tokens: Vec, } -impl From<(String, u32)> for SpecialToken { - fn from(v: (String, u32)) -> Self { +impl From<(CompactString, u32)> for SpecialToken { + fn from(v: (CompactString, u32)) -> Self { Self { id: v.0.clone(), ids: vec![v.1], @@ -210,22 +231,22 @@ impl From<(String, u32)> for SpecialToken { } impl From<(&str, u32)> for SpecialToken { fn from(v: (&str, u32)) -> Self { - Self::from((v.0.to_owned(), v.1)) + Self::from((v.0.to_compact_string(), v.1)) } } -impl From<(u32, String)> for SpecialToken { - fn from(v: (u32, String)) -> Self { +impl From<(u32, CompactString)> for SpecialToken { + fn from(v: (u32, CompactString)) -> Self { Self::from((v.1, v.0)) } } impl From<(u32, &str)> for SpecialToken { fn from(v: (u32, &str)) -> Self { - Self::from((v.1.to_owned(), v.0)) + Self::from((v.1.to_compact_string(), v.0)) } } impl SpecialToken { - pub fn new(id: String, ids: Vec, tokens: Vec) -> Result { + pub fn new(id: CompactString, ids: Vec, tokens: Vec) -> Result { if ids.len() != tokens.len() { Err("SpecialToken: ids and tokens must be of the same length".into()) } else { @@ -269,6 +290,14 @@ where } } +impl TryFrom for Template { + type Error = String; + + fn try_from(s: CompactString) -> StdResult { + Self::try_from(s.to_string().as_ref()) + } +} + impl TryFrom for Template { type Error = String; @@ -293,7 +322,7 @@ impl TryFrom<&str> for Template { #[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize, Eq)] #[serde(transparent)] pub struct Tokens( - #[serde(serialize_with = "crate::utils::ordered_map")] pub HashMap, + #[serde(serialize_with = "crate::utils::ordered_map")] pub HashMap, ); impl> From> for Tokens { @@ -309,8 +338,8 @@ impl> From> for Tokens { } } -impl From> for Tokens { - fn from(v: HashMap) -> Self { +impl From> for Tokens { + fn from(v: HashMap) -> Self { Self(v) } } @@ -353,8 +382,8 @@ pub struct TemplateProcessing { impl TemplateProcessing { // Getter for `single` - pub fn get_single(&self) -> String { - format!("{:?}", self.single) + pub fn get_single(&self) -> CompactString { + format_compact!("{:?}", self.single) } // Setter for `single` @@ -894,24 +923,19 @@ mod tests { use crate::Token; let encoding = Encoding::from_tokens( vec![ - Token::new(12, "Hello".into(), (0, 5)), - Token::new(14, "there".into(), (6, 11)), + Token::new(12, "Hello", (0, 5)), + Token::new(14, "there", (6, 11)), ], 0, ); - let pair = Encoding::from_tokens(vec![Token::new(15, "pair".into(), (0, 4))], 0); + let pair = Encoding::from_tokens(vec![Token::new(15, "pair", (0, 4))], 0); let single_encoding = processor.process(encoding.clone(), None, true).unwrap(); assert_eq!( single_encoding, Encoding::new( vec![1, 12, 14, 0], vec![0, 0, 0, 0], - vec![ - "[CLS]".into(), - "Hello".into(), - "there".into(), - "[SEP]".into() - ], + vec!["[CLS]", "Hello", "there", "[SEP]"], vec![None, None, None, None], vec![(0, 0), (0, 5), (6, 11), (0, 0)], vec![1, 0, 0, 1], @@ -928,14 +952,7 @@ mod tests { Encoding::new( vec![1, 12, 14, 0, 15, 0], vec![0, 0, 0, 0, 1, 1], - vec![ - "[CLS]".into(), - "Hello".into(), - "there".into(), - "[SEP]".into(), - "pair".into(), - "[SEP]".into() - ], + vec!["[CLS]", "Hello", "there", "[SEP]", "pair", "[SEP]"], vec![None, None, None, None, None, None], vec![(0, 0), (0, 5), (6, 11), (0, 0), (0, 4), (0, 0)], vec![1, 0, 0, 1, 0, 1], @@ -959,23 +976,22 @@ mod tests { use crate::Token; let mut encoding = Encoding::from_tokens( vec![ - Token::new(12, "Hello".into(), (0, 5)), - Token::new(14, "there".into(), (6, 11)), + Token::new(12, "Hello", (0, 5)), + Token::new(14, "there", (6, 11)), ], 0, ); - let overflowing = Encoding::from_tokens(vec![Token::new(13, "you".into(), (12, 15))], 0); + let overflowing = Encoding::from_tokens(vec![Token::new(13, "you", (12, 15))], 0); encoding.set_overflowing(vec![overflowing]); let mut pair = Encoding::from_tokens( vec![ - Token::new(15, "pair".into(), (0, 4)), - Token::new(16, "with".into(), (5, 9)), + Token::new(15, "pair", (0, 4)), + Token::new(16, "with", (5, 9)), ], 0, ); - let pair_overflowing = - Encoding::from_tokens(vec![Token::new(17, "info".into(), (10, 14))], 0); + let pair_overflowing = Encoding::from_tokens(vec![Token::new(17, "info", (10, 14))], 0); pair.set_overflowing(vec![pair_overflowing]); let single_encoding = processor.process(encoding.clone(), None, true).unwrap(); @@ -984,12 +1000,7 @@ mod tests { Encoding::new( vec![1, 12, 14, 0], vec![0, 0, 0, 0], - vec![ - "[CLS]".into(), - "Hello".into(), - "there".into(), - "[SEP]".into() - ], + vec!["[CLS]", "Hello", "there", "[SEP]"], vec![None, None, None, None], vec![(0, 0), (0, 5), (6, 11), (0, 0)], vec![1, 0, 0, 1], @@ -997,7 +1008,7 @@ mod tests { vec![Encoding::new( vec![1, 13, 0], vec![0, 0, 0], - vec!["[CLS]".into(), "you".into(), "[SEP]".into()], + vec!["[CLS]", "you".into(), "[SEP]".into()], vec![None, None, None], vec![(0, 0), (12, 15), (0, 0)], vec![1, 0, 1], @@ -1017,15 +1028,7 @@ mod tests { Encoding::new( vec![1, 12, 14, 0, 15, 16, 0], vec![0, 0, 0, 0, 1, 1, 1], - vec![ - "[CLS]".into(), - "Hello".into(), - "there".into(), - "[SEP]".into(), - "pair".into(), - "with".into(), - "[SEP]".into() - ], + vec!["[CLS]", "Hello", "there", "[SEP]", "pair", "with", "[SEP]"], vec![None, None, None, None, None, None, None], vec![(0, 0), (0, 5), (6, 11), (0, 0), (0, 4), (5, 9), (0, 0)], vec![1, 0, 0, 1, 0, 0, 1], @@ -1034,14 +1037,7 @@ mod tests { Encoding::new( vec![1, 13, 0, 15, 16, 0], vec![0, 0, 0, 1, 1, 1], - vec![ - "[CLS]".into(), - "you".into(), - "[SEP]".into(), - "pair".into(), - "with".into(), - "[SEP]".into() - ], + vec!["[CLS]", "you", "[SEP]", "pair", "with", "[SEP]"], vec![None, None, None, None, None, None], vec![(0, 0), (12, 15), (0, 0), (0, 4), (5, 9), (0, 0)], vec![1, 0, 1, 0, 0, 1], @@ -1049,13 +1045,7 @@ mod tests { vec![Encoding::new( vec![1, 13, 0, 17, 0], vec![0, 0, 0, 0, 1], - vec![ - "[CLS]".into(), - "you".into(), - "[SEP]".into(), - "info".into(), - "[SEP]".into() - ], + vec!["[CLS]", "you", "[SEP]", "info", "[SEP]"], vec![None, None, None, None, None,], vec![(0, 0), (12, 15), (0, 0), (10, 14), (0, 0)], vec![1, 0, 1, 0, 1], @@ -1068,13 +1058,7 @@ mod tests { Encoding::new( vec![1, 13, 0, 17, 0], vec![0, 0, 0, 0, 1], - vec![ - "[CLS]".into(), - "you".into(), - "[SEP]".into(), - "info".into(), - "[SEP]".into() - ], + vec!["[CLS]", "you", "[SEP]", "info", "[SEP]"], vec![None, None, None, None, None,], vec![(0, 0), (12, 15), (0, 0), (10, 14), (0, 0)], vec![1, 0, 1, 0, 1], @@ -1085,14 +1069,7 @@ mod tests { Encoding::new( vec![1, 12, 14, 0, 17, 0], vec![0, 0, 0, 0, 0, 1], - vec![ - "[CLS]".into(), - "Hello".into(), - "there".into(), - "[SEP]".into(), - "info".into(), - "[SEP]".into() - ], + vec!["[CLS]", "Hello", "there", "[SEP]", "info", "[SEP]"], vec![None, None, None, None, None, None], vec![(0, 0), (0, 5), (6, 11), (0, 0), (10, 14), (0, 0)], vec![1, 0, 0, 1, 0, 1], @@ -1100,13 +1077,7 @@ mod tests { vec![Encoding::new( vec![1, 13, 0, 17, 0], vec![0, 0, 0, 0, 1], - vec![ - "[CLS]".into(), - "you".into(), - "[SEP]".into(), - "info".into(), - "[SEP]".into() - ], + vec!["[CLS]", "you", "[SEP]", "info", "[SEP]"], vec![None, None, None, None, None,], vec![(0, 0), (12, 15), (0, 0), (10, 14), (0, 0)], vec![1, 0, 1, 0, 1], diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index a0c2f4542..21a9fbc21 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -2,9 +2,10 @@ use super::{ normalizer::Range, Model, NormalizedString, Normalizer, Offsets, PreTokenizedString, Token, }; use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind}; +use compact_str::CompactString; use regex::Regex; +use rustc_hash::{FxHashMap, FxHashSet}; use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer}; -use std::collections::{HashMap, HashSet}; /// Represent a token added by the user on top of the existing Model vocabulary. /// AddedToken can be configured to specify the behavior they should have in various situations @@ -14,7 +15,7 @@ use std::collections::{HashMap, HashSet}; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct AddedToken { /// The content of the added token - pub content: String, + pub content: CompactString, /// Whether this token must be a single word or can break words pub single_word: bool, /// Whether this token should strip whitespaces on its left @@ -30,7 +31,7 @@ pub struct AddedToken { impl AddedToken { /// Build this token from the given content, specifying if it is intented to be a /// special token. Special tokens are not normalized by default. - pub fn from>(content: S, special: bool) -> Self { + pub fn from>(content: S, special: bool) -> Self { Self { content: content.into(), normalized: !special, @@ -76,7 +77,7 @@ impl AddedToken { impl Default for AddedToken { fn default() -> Self { Self { - content: String::new(), + content: CompactString::default(), single_word: false, lstrip: false, rstrip: false, @@ -142,10 +143,10 @@ fn space_rightmost_at_start(sentence: &str) -> usize { pub struct AddedVocabulary { /// Contains the mapping from String (token content) to ID. This map contains both special /// tokens and classic added tokens that were added to the this vocabulary. - added_tokens_map: HashMap, + added_tokens_map: FxHashMap, /// Contains the mapping from ID to AddedToken for all the added tokens, both special /// and classic. - added_tokens_map_r: HashMap, + added_tokens_map_r: FxHashMap, /// Contains only the classic AddedToken, in the specific order the user gave them. added_tokens: Vec, @@ -154,7 +155,7 @@ pub struct AddedVocabulary { /// A Set, containing all the special token for easy access while decoding. This let's /// us remove them easily with an O(1) complexity. - special_tokens_set: HashSet, + special_tokens_set: FxHashSet, /// A RegexSet containing all the non-normalized patterns used to split on AddedTokens split_trie: MatchingSet, @@ -176,11 +177,11 @@ impl AddedVocabulary { .build::<_, &&[u8]>([]) .expect("The normalized trie should build correctly"); Self { - added_tokens_map: HashMap::new(), - added_tokens_map_r: HashMap::new(), + added_tokens_map: FxHashMap::default(), + added_tokens_map_r: FxHashMap::default(), added_tokens: vec![], special_tokens: vec![], - special_tokens_set: HashSet::new(), + special_tokens_set: FxHashSet::default(), split_trie: (trie, vec![]), split_normalized_trie: (normalized_trie, vec![]), encode_special_tokens: false, @@ -198,12 +199,12 @@ impl AddedVocabulary { } /// Get the additional vocabulary - pub fn get_vocab(&self) -> &HashMap { + pub fn get_vocab(&self) -> &FxHashMap { &self.added_tokens_map } /// Get the additional vocabulary with the AddedTokens - pub fn get_added_tokens_decoder(&self) -> &HashMap { + pub fn get_added_tokens_decoder(&self) -> &FxHashMap { &self.added_tokens_map_r } @@ -220,14 +221,14 @@ impl AddedVocabulary { since = "0.19.0", note = "please use `added_vocabulary.simple_id_to_token(id).or_else(|| model.id_to_token(id)` instead" )] - pub fn id_to_token(&self, id: u32, model: &impl Model) -> Option { + pub fn id_to_token(&self, id: u32, model: &impl Model) -> Option { self.added_tokens_map_r .get(&id) .map(|t| t.content.clone()) .or_else(|| model.id_to_token(id)) } - pub fn simple_id_to_token(&self, id: u32) -> Option { + pub fn simple_id_to_token(&self, id: u32) -> Option { self.added_tokens_map_r.get(&id).map(|t| t.content.clone()) } @@ -350,7 +351,7 @@ impl AddedVocabulary { let patterns: Vec<_> = ntokens .iter() .map(|token| { - let mut content = NormalizedString::from(token.content.as_ref()); + let mut content = NormalizedString::from(token.content.clone()); if let Some(n) = normalizer { n.normalize(&mut content).unwrap(); } @@ -440,7 +441,7 @@ impl AddedVocabulary { .slice(Range::Normalized(byte_offsets.0..byte_offsets.1)) .expect("AddedVocabulary bad split"); if let Some(id) = id { - let value = slice.get().to_owned(); + let value: CompactString = slice.get().into(); let len = value.len(); (slice, Some(vec![Token::new(id, value, (0, len))])) } else { @@ -542,6 +543,8 @@ impl Serialize for AddedVocabulary { #[cfg(test)] mod tests { + use compact_str::ToCompactString; + use super::*; use crate::normalizers::byte_level::ByteLevel as ByteLevelNormalizer; use crate::normalizers::utils::Lowercase; @@ -551,17 +554,17 @@ mod tests { #[derive(Serialize, Deserialize)] struct ModelMock { - vocab: HashMap, - vocab_r: HashMap, + vocab: FxHashMap, + vocab_r: FxHashMap, } impl ModelMock { pub fn new(iter: I) -> Self where I: IntoIterator, { - let vocab: HashMap = iter + let vocab: FxHashMap = iter .into_iter() - .map(|&(tok, id)| (tok.to_string(), id)) + .map(|&(tok, id)| (tok.to_compact_string(), id)) .collect(); Self { vocab_r: vocab @@ -601,7 +604,7 @@ mod tests { where I: Iterator + Send, S: AsRef + Send, - F: Fn(&str) -> Result> + Sync, + F: Fn(&str) -> Result> + Sync, { unimplemented!() } @@ -616,10 +619,10 @@ mod tests { fn token_to_id(&self, token: &str) -> Option { self.vocab.get(token).copied() } - fn id_to_token(&self, id: u32) -> Option { + fn id_to_token(&self, id: u32) -> Option { self.vocab_r.get(&id).cloned() } - fn get_vocab(&self) -> HashMap { + fn get_vocab(&self) -> FxHashMap { self.vocab.clone() } fn get_vocab_size(&self) -> usize { @@ -714,14 +717,13 @@ mod tests { ); assert_eq!(vocab.len(), 3); // New token was added assert!(vocab.is_special_token("test")); - assert_eq!( - *vocab.get_added_tokens_decoder(), - HashMap::from([ - (0, AddedToken::from("test", true)), - (2, AddedToken::from("added_token_1", true)), - (3, AddedToken::from("added_token_2", true)), - ]) - ); + assert_eq!(*vocab.get_added_tokens_decoder(), { + let mut map = FxHashMap::default(); + map.insert(0, AddedToken::from("test", true)); + map.insert(2, AddedToken::from("added_token_1", true)); + map.insert(3, AddedToken::from("added_token_2", true)); + map + }); assert!(vocab.added_tokens_map.contains_key("test")); assert!(vocab.added_tokens_map_r.contains_key(&0)); @@ -746,7 +748,7 @@ mod tests { // Just checking that we can set the content of the string in rust let mut token: AddedToken = AddedToken::from("Hey", false); - token.content = "hey".to_string(); + token.content = "hey".into(); assert_eq!(token.content, "hey"); // Token was already there token.special = true; diff --git a/tokenizers/src/tokenizer/encoding.rs b/tokenizers/src/tokenizer/encoding.rs index 1732686e4..e52f297a2 100644 --- a/tokenizers/src/tokenizer/encoding.rs +++ b/tokenizers/src/tokenizer/encoding.rs @@ -2,6 +2,7 @@ use crate::parallelism::*; use crate::tokenizer::{Offsets, Token}; use crate::utils::padding::PaddingDirection; use crate::utils::truncation::TruncationDirection; +use compact_str::CompactString; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::ops::Range; @@ -14,7 +15,7 @@ pub struct Encoding { /// Type of the IDs type_ids: Vec, /// Tokens associated to each ID - tokens: Vec, + tokens: Vec, /// Indice of the word associated to each token/ID words: Vec>, /// Offsets of the token/ID from the NormalizedString @@ -34,7 +35,7 @@ impl Encoding { pub fn new( ids: Vec, type_ids: Vec, - tokens: Vec, + tokens: Vec>, words: Vec>, offsets: Vec, special_tokens_mask: Vec, @@ -45,7 +46,7 @@ impl Encoding { Self { ids, type_ids, - tokens, + tokens: tokens.into_iter().map(|t| t.into()).collect(), words, offsets, special_tokens_mask, @@ -122,7 +123,7 @@ impl Encoding { self.sequence_ranges.insert(sequence_id, 0..self.len()); } - pub fn get_tokens(&self) -> &[String] { + pub fn get_tokens(&self) -> &[CompactString] { &self.tokens[..] } @@ -190,7 +191,7 @@ impl Encoding { pub(crate) fn process_tokens_with_offsets_mut(&mut self, func: F) where - F: FnMut((usize, (&String, &mut Offsets))), + F: FnMut((usize, (&CompactString, &mut Offsets))), { self.tokens .iter() @@ -493,7 +494,7 @@ impl Encoding { .chain(self.type_ids.drain(..)) .collect(); self.tokens = (0..pad_length) - .map(|_| pad_token.to_owned()) + .map(|_| pad_token.into()) .chain(self.tokens.drain(..)) .collect(); self.words = (0..pad_length) @@ -522,7 +523,7 @@ impl Encoding { self.ids.extend((0..pad_length).map(|_| pad_id)); self.type_ids.extend((0..pad_length).map(|_| pad_type_id)); self.tokens - .extend((0..pad_length).map(|_| pad_token.to_owned())); + .extend((0..pad_length).map(|_| pad_token.into())); self.words.extend((0..pad_length).map(|_| None)); self.attention_mask.extend((0..pad_length).map(|_| 0)); self.special_tokens_mask.extend((0..pad_length).map(|_| 1)); @@ -538,8 +539,8 @@ impl std::iter::FromIterator for Encoding { } } -impl std::iter::FromIterator<(u32, String, (usize, usize), Option, u32)> for Encoding { - fn from_iter, u32)>>( +impl std::iter::FromIterator<(u32, CompactString, (usize, usize), Option, u32)> for Encoding { + fn from_iter, u32)>>( iter: I, ) -> Self { let items = iter.into_iter(); @@ -568,26 +569,28 @@ mod tests { #[test] fn merge_encodings() { - let mut a = Encoding { - ids: vec![1], - type_ids: vec![0], - tokens: vec![String::from("Hello ")], - words: vec![Some(0)], - offsets: vec![(0, 6)], - special_tokens_mask: vec![0], - attention_mask: vec![1], - ..Default::default() - }; - let b = Encoding { - ids: vec![2], - type_ids: vec![1], - tokens: vec![String::from("World!")], - words: vec![Some(0)], - offsets: vec![(0, 6)], - special_tokens_mask: vec![0], - attention_mask: vec![1], - ..Default::default() - }; + let mut a = Encoding::new( + vec![1], + vec![0], + vec!["Hello "], + vec![Some(0)], + vec![(0, 6)], + vec![0], + vec![1], + Default::default(), + Default::default(), + ); + let b = Encoding::new( + vec![2], + vec![1], + vec!["World!"], + vec![Some(0)], + vec![(0, 6)], + vec![0], + vec![1], + Default::default(), + Default::default(), + ); a.merge_with(b, true); assert_eq!( @@ -595,7 +598,7 @@ mod tests { Encoding { ids: vec![1, 2], type_ids: vec![0, 1], - tokens: vec![String::from("Hello "), String::from("World!")], + tokens: vec![CompactString::from("Hello "), CompactString::from("World!")], words: vec![Some(0), Some(0)], offsets: vec![(0, 6), (6, 12)], special_tokens_mask: vec![0, 0], @@ -607,184 +610,165 @@ mod tests { #[test] fn truncate() { - let mut a = Encoding { - ids: vec![1, 2, 3], - type_ids: vec![0, 0, 0], - tokens: vec![ - String::from("Hello"), - String::from("World"), - String::from("!"), - ], - words: vec![Some(0), Some(1), Some(2)], - offsets: vec![(0, 5), (6, 11), (11, 12)], - special_tokens_mask: vec![0, 0, 0], - attention_mask: vec![1, 1, 1], - ..Default::default() - }; + let mut a = Encoding::new( + vec![1, 2, 3], + vec![0, 0, 0], + vec!["Hello", "World", "!"], + vec![Some(0), Some(1), Some(2)], + vec![(0, 5), (6, 11), (11, 12)], + vec![0, 0, 0], + vec![1, 1, 1], + Default::default(), + Default::default(), + ); a.truncate(2, 0, TruncationDirection::Right); assert_eq!( a, - Encoding { - ids: vec![1, 2], - type_ids: vec![0, 0], - tokens: vec![String::from("Hello"), String::from("World")], - words: vec![Some(0), Some(1)], - offsets: vec![(0, 5), (6, 11)], - special_tokens_mask: vec![0, 0], - attention_mask: vec![1, 1], - overflowing: vec![Encoding { - ids: vec![3], - type_ids: vec![0], - tokens: vec![String::from("!")], - words: vec![Some(2)], - offsets: vec![(11, 12)], - special_tokens_mask: vec![0], - attention_mask: vec![1], - ..Default::default() - }], - ..Default::default() - } + Encoding::new( + vec![1, 2], + vec![0, 0], + vec!["Hello", "World"], + vec![Some(0), Some(1)], + vec![(0, 5), (6, 11)], + vec![0, 0], + vec![1, 1], + vec![Encoding::new( + vec![3], + vec![0], + vec!["!"], + vec![Some(2)], + vec![(11, 12)], + vec![0], + vec![1], + Default::default(), + Default::default(), + )], + Default::default() + ) ); } #[test] fn truncate_to_empty() { - let mut a = Encoding { - ids: vec![1, 2, 3], - type_ids: vec![0, 0, 0], - tokens: vec![ - String::from("Hello"), - String::from("World"), - String::from("!"), - ], - words: vec![Some(0), Some(1), Some(2)], - offsets: vec![(0, 5), (6, 11), (11, 12)], - special_tokens_mask: vec![0, 0, 0], - attention_mask: vec![1, 1, 1], - ..Default::default() - }; + let mut a = Encoding::new( + vec![1, 2, 3], + vec![0, 0, 0], + vec!["Hello", "World", "!"], + vec![Some(0), Some(1), Some(2)], + vec![(0, 5), (6, 11), (11, 12)], + vec![0, 0, 0], + vec![1, 1, 1], + Default::default(), + Default::default(), + ); a.truncate(0, 0, TruncationDirection::Right); assert_eq!( a, - Encoding { - overflowing: vec![Encoding { - ids: vec![1, 2, 3], - type_ids: vec![0, 0, 0], - tokens: vec![ - String::from("Hello"), - String::from("World"), - String::from("!"), - ], - words: vec![Some(0), Some(1), Some(2)], - offsets: vec![(0, 5), (6, 11), (11, 12)], - special_tokens_mask: vec![0, 0, 0], - attention_mask: vec![1, 1, 1], - overflowing: vec![], - ..Default::default() - }], - ..Default::default() - } + Encoding::new( + Default::default(), + Default::default(), + Vec::::new(), // Cannot use Default::default, since the argument is an impl trait. + Default::default(), + Default::default(), + Default::default(), + Default::default(), + vec![Encoding::new( + vec![1, 2, 3], + vec![0, 0, 0], + vec!["Hello", "World", "!",], + vec![Some(0), Some(1), Some(2)], + vec![(0, 5), (6, 11), (11, 12)], + vec![0, 0, 0], + vec![1, 1, 1], + vec![], + Default::default() + )], + Default::default(), + ) ); } #[test] fn truncate_overflow_with_stride() { - let mut enc = Encoding { - ids: vec![1, 2, 3, 4, 5], - type_ids: vec![0, 0, 0, 0, 0], - tokens: vec![ - String::from("42"), - String::from("is"), - String::from("the"), - String::from("answer"), - String::from("!"), - ], - words: vec![Some(0), Some(1), Some(2), Some(3), Some(4)], - offsets: vec![(0, 2), (2, 4), (4, 7), (7, 13), (13, 14)], - special_tokens_mask: vec![0, 0, 0, 0, 0], - attention_mask: vec![1, 1, 1, 1, 1], - overflowing: vec![], - ..Default::default() - }; + let mut enc = Encoding::new( + vec![1, 2, 3, 4, 5], + vec![0, 0, 0, 0, 0], + vec!["42", "is", "the", "answer", "!"], + vec![Some(0), Some(1), Some(2), Some(3), Some(4)], + vec![(0, 2), (2, 4), (4, 7), (7, 13), (13, 14)], + vec![0, 0, 0, 0, 0], + vec![1, 1, 1, 1, 1], + vec![], + Default::default(), + ); enc.truncate(4, 2, TruncationDirection::Right); assert_eq!( enc, - Encoding { - ids: vec![1, 2, 3, 4], - type_ids: vec![0, 0, 0, 0], - tokens: vec![ - String::from("42"), - String::from("is"), - String::from("the"), - String::from("answer"), - ], - words: vec![Some(0), Some(1), Some(2), Some(3)], - offsets: vec![(0, 2), (2, 4), (4, 7), (7, 13)], - special_tokens_mask: vec![0, 0, 0, 0], - attention_mask: vec![1, 1, 1, 1], - overflowing: vec![Encoding { - ids: vec![3, 4, 5], - type_ids: vec![0, 0, 0], - tokens: vec![ - String::from("the"), - String::from("answer"), - String::from("!"), - ], - words: vec![Some(2), Some(3), Some(4)], - offsets: vec![(4, 7), (7, 13), (13, 14)], - special_tokens_mask: vec![0, 0, 0], - attention_mask: vec![1, 1, 1], - overflowing: vec![], - ..Default::default() - }], - ..Default::default() - } + Encoding::new( + vec![1, 2, 3, 4], + vec![0, 0, 0, 0], + vec!["42", "is", "the", "answer",], + vec![Some(0), Some(1), Some(2), Some(3)], + vec![(0, 2), (2, 4), (4, 7), (7, 13)], + vec![0, 0, 0, 0], + vec![1, 1, 1, 1], + vec![Encoding::new( + vec![3, 4, 5], + vec![0, 0, 0], + vec!["the", "answer", "!",], + vec![Some(2), Some(3), Some(4)], + vec![(4, 7), (7, 13), (13, 14)], + vec![0, 0, 0], + vec![1, 1, 1], + vec![], + Default::default() + )], + Default::default() + ) ); } #[test] fn truncate_left() { - let mut a = Encoding { - ids: vec![1, 2, 3], - type_ids: vec![0, 0, 0], - tokens: vec![ - String::from("Hello"), - String::from("World"), - String::from("!"), - ], - words: vec![Some(0), Some(1), Some(2)], - offsets: vec![(0, 5), (6, 11), (11, 12)], - special_tokens_mask: vec![0, 0, 0], - attention_mask: vec![1, 1, 1], - ..Default::default() - }; + let mut a = Encoding::new( + vec![1, 2, 3], + vec![0, 0, 0], + vec!["Hello", "World", "!"], + vec![Some(0), Some(1), Some(2)], + vec![(0, 5), (6, 11), (11, 12)], + vec![0, 0, 0], + vec![1, 1, 1], + Default::default(), + Default::default(), + ); a.truncate(2, 0, TruncationDirection::Left); assert_eq!( a, - Encoding { - ids: vec![2, 3], - type_ids: vec![0, 0], - tokens: vec![String::from("World"), String::from("!")], - words: vec![Some(1), Some(2)], - offsets: vec![(6, 11), (11, 12)], - special_tokens_mask: vec![0, 0], - attention_mask: vec![1, 1], - overflowing: vec![Encoding { - ids: vec![1], - type_ids: vec![0], - tokens: vec![String::from("Hello")], - words: vec![Some(0)], - offsets: vec![(0, 5)], - special_tokens_mask: vec![0], - attention_mask: vec![1], - ..Default::default() - }], - ..Default::default() - } + Encoding::new( + vec![2, 3], + vec![0, 0], + vec!["World", "!"], + vec![Some(1), Some(2)], + vec![(6, 11), (11, 12)], + vec![0, 0], + vec![1, 1], + vec![Encoding::new( + vec![1], + vec![0], + vec!["Hello"], + vec![Some(0)], + vec![(0, 5)], + vec![0], + vec![1], + Default::default(), + Default::default() + )], + Default::default() + ) ); } @@ -882,17 +866,17 @@ mod tests { #[test] fn padding() { - let mut a = Encoding { - ids: vec![1], - type_ids: vec![0], - tokens: vec![String::from("Hello ")], - words: vec![Some(0)], - offsets: vec![(0, 6)], - special_tokens_mask: vec![0], - attention_mask: vec![1], - sequence_ranges: HashMap::from([(0, 0..1)]), - ..Default::default() - }; + let mut a = Encoding::new( + vec![1], + vec![0], + vec!["Hello "], + vec![Some(0)], + vec![(0, 6)], + vec![0], + vec![1], + Default::default(), + HashMap::from([(0, 0..1)]), + ); let target_length = 2; let pad_id = 99; let pad_type_id = 0; diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 808d120d5..45d5a53c9 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -9,8 +9,9 @@ //! - [`PostProcessor`](trait.PostProcessor.html): Takes care of the processing after tokenization (like truncating, padding, //! ...). +use compact_str::{CompactString, ToCompactString}; +use rustc_hash::FxHashMap; use std::{ - collections::HashMap, fs::{read_to_string, File}, io::{prelude::*, BufReader}, ops::{Deref, DerefMut}, @@ -75,9 +76,9 @@ pub trait Model { /// Find the ID associated to a string token fn token_to_id(&self, token: &str) -> Option; /// Find the string token associated to an ID - fn id_to_token(&self, id: u32) -> Option; + fn id_to_token(&self, id: u32) -> Option; /// Retrieve the entire vocabulary mapping (token -> ID) - fn get_vocab(&self) -> HashMap; + fn get_vocab(&self) -> FxHashMap; /// Retrieve the size of the vocabulary fn get_vocab_size(&self) -> usize; /// Save the current `Model` in the given folder, using the given `prefix` for the various @@ -151,11 +152,16 @@ pub enum ProcessorError { /// A `Decoder` changes the raw tokens into its more readable form. pub trait Decoder { - fn decode(&self, tokens: Vec) -> Result { - let results = self.decode_chain(tokens)?; + fn decode(&self, tokens: Vec) -> Result { + let results: Vec = self + .decode_chain(tokens)? + .iter() + .map(|r| r.to_compact_string()) + .collect(); Ok(results.join("")) } - fn decode_chain(&self, tokens: Vec) -> Result>; + fn decode_chain(&self, tokens: Vec) + -> Result>; } /// A `Trainer` has the responsibility to train a model. We feed it with lines/sentences @@ -173,18 +179,22 @@ pub trait Trainer { where I: Iterator + Send, S: AsRef + Send, - F: Fn(&str) -> Result> + Sync; + F: Fn(&str) -> Result> + Sync; } #[derive(Debug, Clone, PartialEq, Eq)] pub struct Token { pub id: u32, - pub value: String, + pub value: CompactString, pub offsets: (usize, usize), } impl Token { - pub fn new(id: u32, value: String, offsets: (usize, usize)) -> Self { - Self { id, value, offsets } + pub fn new(id: u32, value: impl Into, offsets: (usize, usize)) -> Self { + Self { + id, + value: value.into(), + offsets, + } } } @@ -193,7 +203,7 @@ use std::borrow::Cow; pub enum InputSequence<'s> { Raw(Cow<'s, str>), PreTokenized(Cow<'s, [&'s str]>), - PreTokenizedOwned(Cow<'s, [String]>), + PreTokenizedOwned(Cow<'s, [CompactString]>), PreTokenizedCow(Cow<'s, [Cow<'s, str>]>), } @@ -229,12 +239,26 @@ impl<'s> From> for InputSequence<'s> { impl<'s> From<&'s [String]> for InputSequence<'s> { fn from(input: &'s [String]) -> Self { - Self::PreTokenizedOwned(Cow::Borrowed(input)) + Self::PreTokenizedOwned(Cow::Owned( + input.iter().map(|s| s.to_compact_string()).collect(), + )) } } impl From> for InputSequence<'_> { fn from(input: Vec) -> Self { + Self::PreTokenizedOwned(Cow::Owned(input.into_iter().map(|s| s.into()).collect())) + } +} + +impl<'s> From<&'s [CompactString]> for InputSequence<'s> { + fn from(input: &'s [CompactString]) -> Self { + Self::PreTokenizedOwned(Cow::Borrowed(input)) + } +} + +impl From> for InputSequence<'_> { + fn from(input: Vec) -> Self { Self::PreTokenizedOwned(Cow::Owned(input)) } } @@ -658,7 +682,7 @@ where } /// Get the vocabulary - pub fn get_vocab(&self, with_added_tokens: bool) -> HashMap { + pub fn get_vocab(&self, with_added_tokens: bool) -> FxHashMap { let mut final_vocab = self.model.get_vocab(); if with_added_tokens { @@ -675,7 +699,7 @@ where } /// Get the added tokens decoder - pub fn get_added_tokens_decoder(&self) -> HashMap { + pub fn get_added_tokens_decoder(&self) -> FxHashMap { self.added_vocabulary.get_added_tokens_decoder().clone() } @@ -696,7 +720,7 @@ where } /// Converts an id to the corresponding token. - pub fn id_to_token(&self, id: u32) -> Option { + pub fn id_to_token(&self, id: u32) -> Option { self.added_vocabulary .simple_id_to_token(id) .or_else(|| self.model.id_to_token(id)) @@ -886,7 +910,7 @@ where } /// Decode the given ids, back to a String - pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result { + pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result { let tokens = ids .iter() .filter_map(|id| { @@ -900,9 +924,9 @@ where .collect::>(); if let Some(decoder) = &self.decoder { - decoder.decode(tokens) + decoder.decode(tokens).map(|t| t.to_compact_string()) } else { - Ok(tokens.join(" ")) + Ok(tokens.join(" ").to_compact_string()) } } @@ -922,18 +946,19 @@ where /// Example: /// /// ``` -/// # #[cfg(not(target_os = "windows"))] +/// # use compact_str::ToCompactString; +/// #[cfg(not(target_os = "windows"))] /// # { /// use tokenizers::Tokenizer; /// let tokenizer = Tokenizer::from_file("data/roberta.json").unwrap(); /// /// let mut decode_stream = tokenizer.decode_stream(false); -/// assert_eq!(decode_stream.step(713).unwrap(), Some("This".to_string())); -/// assert_eq!(decode_stream.step(16).unwrap(), Some(" is".to_string())); -/// assert_eq!(decode_stream.step(41).unwrap(), Some(" an".to_string())); +/// assert_eq!(decode_stream.step(713).unwrap(), Some("This".to_compact_string())); +/// assert_eq!(decode_stream.step(16).unwrap(), Some(" is".to_compact_string())); +/// assert_eq!(decode_stream.step(41).unwrap(), Some(" an".to_compact_string())); /// assert_eq!( /// decode_stream.step(1246).unwrap(), -/// Some(" example".to_string()) +/// Some(" example".to_compact_string()) /// ); /// # } /// ``` @@ -944,14 +969,15 @@ where /// a valid chunk. /// ``` /// use tokenizers::{Tokenizer, TokenizerBuilder, models::bpe::BPE, decoders::byte_fallback::ByteFallback, pre_tokenizers::byte_level::ByteLevel, normalizers::unicode::NFC}; -/// use std::collections::HashMap; /// use std::iter::FromIterator; +/// use compact_str::{CompactString, ToCompactString}; +/// use rustc_hash::FxHashMap; /// -/// let vocab = HashMap::from_iter([ -/// ("<0x20>".to_string(), 0), -/// ("<0xC3>".to_string(), 1), -/// ("<0xA9>".to_string(), 2), -/// (" This".to_string(), 3), +/// let vocab: FxHashMap = FxHashMap::from_iter([ +/// ("<0x20>".into(), 0), +/// ("<0xC3>".into(), 1), +/// ("<0xA9>".into(), 2), +/// (" This".into(), 3), /// ]); /// let merges = vec![]; /// let bpe = BPE::builder() @@ -969,11 +995,11 @@ where /// /// let mut decode_stream = tokenizer.decode_stream(false); /// // Single byte_fallback is valid utf-8 -/// assert_eq!(decode_stream.step(0).unwrap(), Some(" ".to_string())); +/// assert_eq!(decode_stream.step(0).unwrap(), Some(" ".to_compact_string())); /// // Invalid utf-8 /// assert_eq!(decode_stream.step(1).unwrap(), None); /// // Valid utf-8 again, this corresponds to both tokens: [1, 2] -/// assert_eq!(decode_stream.step(2).unwrap(), Some("é".to_string())); +/// assert_eq!(decode_stream.step(2).unwrap(), Some("é".to_compact_string())); /// ``` /// /// To see how [`DecodeStream`] is necessary, let's show how using raw [`TokenizerImpl::decode`] would @@ -981,11 +1007,12 @@ where /// /// ``` /// use tokenizers::{Tokenizer, TokenizerBuilder, models::bpe::BPE, pre_tokenizers::{byte_level::ByteLevel, metaspace::Metaspace}, normalizers::unicode::NFC}; -/// use std::collections::HashMap; /// use std::iter::FromIterator; +/// use compact_str::{CompactString, ToCompactString}; +/// use rustc_hash::FxHashMap; /// -/// let vocab = HashMap::from_iter([ -/// ("▁This".to_string(), 0), +/// let vocab: FxHashMap = FxHashMap::from_iter([ +/// ("▁This".into(), 0), /// ]); /// let merges = vec![]; /// let bpe = BPE::builder() @@ -1009,8 +1036,8 @@ where /// /// // Using a stream fixes it by keeping the necessary state. /// let mut decode_stream = tokenizer.decode_stream(false); -/// assert_eq!(decode_stream.step(0).unwrap(), Some("This".to_string())); -/// assert_eq!(decode_stream.step(0).unwrap(), Some(" This".to_string())); +/// assert_eq!(decode_stream.step(0).unwrap(), Some("This".to_compact_string())); +/// assert_eq!(decode_stream.step(0).unwrap(), Some(" This".to_compact_string())); /// ``` pub struct DecodeStream<'tok, M, N, PT, PP, D> { /// A reference to the tokenizer @@ -1031,7 +1058,7 @@ pub struct DecodeStream<'tok, M, N, PT, PP, D> { ids: Vec, /// The previously returned chunk that needs to be discarded from the /// decoding of the current ids to produce the next chunk - prefix: String, + prefix: CompactString, /// The index within the ids corresponding to the prefix so we can drain /// correctly prefix_index: usize, @@ -1056,13 +1083,13 @@ where tokenizer, ids: vec![], skip_special_tokens, - prefix: "".to_string(), + prefix: "".into(), prefix_index: 0, } } /// See [`DecodeStream`] - pub fn step(&mut self, id: u32) -> Result> { + pub fn step(&mut self, id: u32) -> Result> { step_decode_stream( self.tokenizer, id, @@ -1080,9 +1107,9 @@ pub fn step_decode_stream( id: u32, skip_special_tokens: bool, ids: &mut Vec, - prefix: &mut String, + prefix: &mut CompactString, prefix_index: &mut usize, -) -> Result> +) -> Result> where M: Model, N: Normalizer, @@ -1091,17 +1118,21 @@ where D: Decoder, { ids.push(id); - let string = tokenizer.decode(ids.as_slice(), skip_special_tokens)?; + let string = tokenizer + .decode(ids.as_slice(), skip_special_tokens)? + .to_compact_string(); if string.len() > prefix.len() && !string.ends_with('�') { - if !(string.starts_with(&*prefix)) { + if !(string.starts_with(&**prefix)) { return Err(Box::new(DecodeStreamError::InvalidPrefix)); } let new_text = &string[prefix.len()..].to_string(); let new_prefix_index = ids.len() - *prefix_index; *ids = ids.drain(*prefix_index..).collect(); - *prefix = tokenizer.decode(ids, skip_special_tokens)?; + *prefix = tokenizer + .decode(ids, skip_special_tokens)? + .to_compact_string(); *prefix_index = new_prefix_index; - Ok(Some(new_text.to_string())) + Ok(Some(new_text.into())) } else { Ok(None) } @@ -1327,13 +1358,16 @@ where &self, sentences: &[&[u32]], skip_special_tokens: bool, - ) -> Result> + ) -> Result> where M: Send + Sync, { sentences .into_maybe_par_iter() - .map(|sentence| self.decode(sentence, skip_special_tokens)) + .map(|sentence| { + self.decode(sentence, skip_special_tokens) + .map(|t| t.to_compact_string()) + }) .collect() } @@ -1386,12 +1420,12 @@ where } }), |seq| { - let normalized = self.do_normalize(seq.as_ref())?; + let normalized = self.do_normalize(seq)?; let pre_tokenized = self.do_pre_tokenize(normalized)?; Ok(pre_tokenized .get_splits(OffsetReferential::Original, OffsetType::Byte) .into_iter() - .map(|(s, _, _)| s.to_owned()) + .map(|(s, _, _)| s.into()) .collect()) }, )?; @@ -1437,12 +1471,12 @@ where } }), |seq| { - let normalized = self.do_normalize(seq.as_ref())?; + let normalized = self.do_normalize(seq)?; let pre_tokenized = self.do_pre_tokenize(normalized)?; Ok(pre_tokenized .get_splits(OffsetReferential::Original, OffsetType::Byte) .into_iter() - .map(|(s, _, _)| s.to_owned()) + .map(|(s, _, _)| s.into()) .collect()) }, )?; diff --git a/tokenizers/src/tokenizer/normalizer.rs b/tokenizers/src/tokenizer/normalizer.rs index 432c6cc69..c25b6198f 100644 --- a/tokenizers/src/tokenizer/normalizer.rs +++ b/tokenizers/src/tokenizer/normalizer.rs @@ -1,5 +1,6 @@ use crate::pattern::Pattern; use crate::{Offsets, Result}; +use compact_str::CompactString; use std::ops::{Bound, RangeBounds}; use unicode_normalization_alignments::UnicodeNormalization; @@ -104,9 +105,9 @@ impl std::fmt::Display for SplitDelimiterBehavior { #[derive(Default, Debug, Clone, PartialEq, Eq)] pub struct NormalizedString { /// The original version of the string, before any modification - original: String, + original: CompactString, /// The normalized version of the string, after all modifications - normalized: String, + normalized: CompactString, /// Mapping from normalized string to original one: (start, end) for each /// byte of the normalized string alignments: Vec<(usize, usize)>, @@ -119,14 +120,14 @@ pub struct NormalizedString { impl NormalizedString { #[cfg(test)] pub(crate) fn new( - original: String, - normalized: String, + original: impl Into, + normalized: impl Into, alignments: Vec<(usize, usize)>, original_shift: usize, ) -> Self { Self { - original, - normalized, + original: original.into(), + normalized: normalized.into(), alignments, original_shift, } @@ -422,14 +423,16 @@ impl NormalizedString { // code could change to mutate `self` or `self.normalized` in the interim. // Perform it again and hope the optimizer collapses it. assert!(self.normalized.get(n_range.clone()).is_some()); + let mut tmp = self.normalized.to_string(); unsafe { - self.normalized + tmp // Safety: This is safe as long as we do not splice across a // UTF-8 character, and we only add UTF-8 text. `normalized` is a String // so the latter is trivially true, and we assert for the former above. .as_mut_vec() .splice(n_range, normalized.bytes()); } + self.normalized = tmp.into(); } /// Applies transformations to the current normalized version of the string, @@ -573,7 +576,7 @@ impl NormalizedString { /// Replace anything that matches the pattern with the given content. pub fn replace(&mut self, pattern: P, content: &str) -> Result<()> { - let mut new_normalized = String::with_capacity(self.normalized.len()); // Initially allocate for the input size + let mut new_normalized = CompactString::with_capacity(self.normalized.len()); // Initially allocate for the input size let mut new_alignments: Vec<(usize, usize)> = Vec::with_capacity(self.alignments.len()); let mut last_end = 0; // Keep track of the last end position @@ -1000,8 +1003,9 @@ pub fn char_to_bytes(s: &str, range: std::ops::Range) -> Option for NormalizedString { - fn from(s: String) -> Self { +impl> From for NormalizedString { + fn from(s: T) -> Self { + let s = s.into(); let alignments = s .char_indices() .flat_map(|(b, c)| { @@ -1018,12 +1022,6 @@ impl From for NormalizedString { } } -impl From<&str> for NormalizedString { - fn from(s: &str) -> Self { - Self::from(s.to_owned()) - } -} - #[cfg(test)] mod tests { use super::*; @@ -1252,10 +1250,10 @@ mod tests { assert_eq!( n, - NormalizedString { - original: "野口 No".into(), - normalized: " 野 口 No".into(), - alignments: vec![ + NormalizedString::new( + "野口 No", + " 野 口 No", + vec![ (0, 3), (0, 3), (0, 3), @@ -1270,8 +1268,8 @@ mod tests { (7, 8), (8, 9) ], - original_shift: 0 - } + 0 + ) ); assert_eq!( n.alignments_original(), @@ -1520,10 +1518,10 @@ mod tests { current.transform_range(Range::Original(0..4), vec![('Y', 0)], 3); assert_eq!( current, - NormalizedString { - original: "Hello friend".into(), - normalized: "Yo friend".into(), - alignments: vec![ + NormalizedString::new( + "Hello friend", + "Yo friend", + vec![ (3, 4), (4, 5), (5, 6), @@ -1534,8 +1532,8 @@ mod tests { (10, 11), (11, 12) ], - original_shift: 0, - } + 0, + ) ); assert_eq!( @@ -1565,10 +1563,10 @@ mod tests { ); assert_eq!( current, - NormalizedString { - original: "Hello friend".into(), - normalized: "Hel_FRnd".into(), - alignments: vec![ + NormalizedString::new( + "Hello friend", + "Hel_FRnd", + vec![ (0, 1), (1, 2), (2, 3), @@ -1578,8 +1576,8 @@ mod tests { (10, 11), (11, 12) ], - original_shift: 0, - } + 0, + ) ); assert_eq!( @@ -1605,12 +1603,12 @@ mod tests { current.transform_range(Range::Original(5..), vec![('_', 0), ('F', -5)], 0); assert_eq!( current, - NormalizedString { - original: "Hello friend".into(), - normalized: "Hello_F".into(), - alignments: vec![(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)], - original_shift: 0, - } + NormalizedString::new( + "Hello friend", + "Hello_F", + vec![(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)], + 0, + ) ); assert_eq!( current.alignments_original(), @@ -1635,10 +1633,10 @@ mod tests { current.transform_range(Range::Original(0..1), vec![('H', 1), ('H', 0)], 0); assert_eq!( current, - NormalizedString { - original: "Hello friend".into(), - normalized: "HHello friend".into(), - alignments: vec![ + NormalizedString::new( + "Hello friend", + "HHello friend", + vec![ (0, 0), (0, 1), (1, 2), @@ -1653,8 +1651,8 @@ mod tests { (10, 11), (11, 12) ], - original_shift: 0, - } + 0, + ) ); assert_eq!( current.alignments_original(), @@ -1678,10 +1676,10 @@ mod tests { current.transform_range(Range::Original(0..0), vec![('H', 1)], 0); assert_eq!( current, - NormalizedString { - original: "Hello friend".into(), - normalized: "HHello friend".into(), - alignments: vec![ + NormalizedString::new( + "Hello friend", + "HHello friend", + vec![ (0, 0), (0, 1), (1, 2), @@ -1696,8 +1694,8 @@ mod tests { (10, 11), (11, 12) ], - original_shift: 0, - } + 0, + ) ); assert_eq!( current.alignments_original(), @@ -1721,10 +1719,10 @@ mod tests { current.transform_range(Range::Original(0..1), vec![('H', 0), ('H', 1)], 0); assert_eq!( current, - NormalizedString { - original: "Hello friend".into(), - normalized: "HHello friend".into(), - alignments: vec![ + NormalizedString::new( + "Hello friend", + "HHello friend", + vec![ (0, 1), (0, 1), (1, 2), @@ -1739,8 +1737,8 @@ mod tests { (10, 11), (11, 12) ], - original_shift: 0, - } + 0, + ) ); assert_eq!( @@ -1770,10 +1768,10 @@ mod tests { ); assert_eq!( current, - NormalizedString { - original: "Hello friend".into(), - normalized: "Hello_my_friend".into(), - alignments: vec![ + NormalizedString::new( + "Hello friend", + "Hello_my_friend", + vec![ (0, 1), (1, 2), (2, 3), @@ -1790,8 +1788,8 @@ mod tests { (10, 11), (11, 12) ], - original_shift: 0, - } + 0, + ) ); assert_eq!( current.alignments_original(), @@ -1816,10 +1814,10 @@ mod tests { current.transform_range(Range::Original(11..), vec![('d', 0), ('_', 1), ('!', 1)], 0); assert_eq!( current, - NormalizedString { - original: "Hello friend".into(), - normalized: "Hello friend_!".into(), - alignments: vec![ + NormalizedString::new( + "Hello friend", + "Hello friend_!", + vec![ (0, 1), (1, 2), (2, 3), @@ -1835,8 +1833,8 @@ mod tests { (11, 12), (11, 12) ], - original_shift: 0, - } + 0, + ) ); assert_eq!( current.alignments_original(), @@ -1866,10 +1864,10 @@ mod tests { current.transform_range(Range::Original(0..8), vec![('G', -1)], 0); assert_eq!( current, - NormalizedString { - original: "𝔾𝕠𝕠𝕕".into(), - normalized: "G𝕠𝕕".into(), - alignments: vec![ + NormalizedString::new( + "𝔾𝕠𝕠𝕕", + "G𝕠𝕕", + vec![ (0, 4), (8, 12), (8, 12), @@ -1880,8 +1878,8 @@ mod tests { (12, 16), (12, 16) ], - original_shift: 0, - } + 0, + ) ); assert_eq!( current.alignments_original(), @@ -1920,10 +1918,10 @@ mod tests { current.transform_range(Range::Original(4..12), vec![('o', -1)], 0); assert_eq!( current, - NormalizedString { - original: "𝔾𝕠𝕠𝕕".into(), - normalized: "𝔾o𝕕".into(), - alignments: vec![ + NormalizedString::new( + "𝔾𝕠𝕠𝕕", + "𝔾o𝕕", + vec![ (0, 4), (0, 4), (0, 4), @@ -1934,8 +1932,8 @@ mod tests { (12, 16), (12, 16) ], - original_shift: 0, - } + 0, + ) ); assert_eq!( current.alignments_original(), @@ -1964,10 +1962,10 @@ mod tests { current.transform_range(Range::Original(12..), vec![('d', 0), ('!', 1)], 0); assert_eq!( current, - NormalizedString { - original: "𝔾𝕠𝕠𝕕".into(), - normalized: "𝔾𝕠𝕠d!".into(), - alignments: vec![ + NormalizedString::new( + "𝔾𝕠𝕠𝕕", + "𝔾𝕠𝕠d!", + vec![ (0, 4), (0, 4), (0, 4), @@ -1983,8 +1981,8 @@ mod tests { (12, 16), (12, 16) ], - original_shift: 0, - } + 0, + ) ); // Adding at the beginning @@ -1992,10 +1990,10 @@ mod tests { current.transform_range(Range::Original(0..4), vec![('_', 1), ('𝔾', 0)], 0); assert_eq!( current, - NormalizedString { - original: "𝔾𝕠𝕠𝕕".into(), - normalized: "_𝔾𝕠𝕠𝕕".into(), - alignments: vec![ + NormalizedString::new( + "𝔾𝕠𝕠𝕕", + "_𝔾𝕠𝕠𝕕", + vec![ (0, 0), (0, 4), (0, 4), @@ -2014,8 +2012,8 @@ mod tests { (12, 16), (12, 16) ], - original_shift: 0, - } + 0, + ) ); assert_eq!( current.alignments_original(), @@ -2054,10 +2052,10 @@ mod tests { current.transform_range(Range::Original(0..0), vec![('_', 1)], 0); assert_eq!( current, - NormalizedString { - original: "𝔾𝕠𝕠𝕕".into(), - normalized: "_𝔾𝕠𝕠𝕕".into(), - alignments: vec![ + NormalizedString::new( + "𝔾𝕠𝕠𝕕", + "_𝔾𝕠𝕠𝕕", + vec![ (0, 0), (0, 4), (0, 4), @@ -2076,8 +2074,8 @@ mod tests { (12, 16), (12, 16) ], - original_shift: 0, - } + 0, + ) ); assert_eq!( current.alignments_original(), @@ -2116,10 +2114,10 @@ mod tests { current.transform_range(Range::Original(0..4), vec![('𝔾', 0), ('o', 1)], 0); assert_eq!( current, - NormalizedString { - original: "𝔾𝕠𝕠𝕕".into(), - normalized: "𝔾o𝕠𝕠𝕕".into(), - alignments: vec![ + NormalizedString::new( + "𝔾𝕠𝕠𝕕", + "𝔾o𝕠𝕠𝕕", + vec![ (0, 4), (0, 4), (0, 4), @@ -2138,8 +2136,8 @@ mod tests { (12, 16), (12, 16) ], - original_shift: 0, - } + 0, + ) ); assert_eq!( current.alignments_original(), @@ -2182,10 +2180,10 @@ mod tests { ); assert_eq!( current, - NormalizedString { - original: "𝔾𝕠𝕠𝕕".into(), - normalized: "𝔾𝕠ooo𝕠𝕕".into(), - alignments: vec![ + NormalizedString::new( + "𝔾𝕠𝕠𝕕", + "𝔾𝕠ooo𝕠𝕕", + vec![ (0, 4), (0, 4), (0, 4), @@ -2206,8 +2204,8 @@ mod tests { (12, 16), (12, 16) ], - original_shift: 0, - } + 0, + ) ); assert_eq!( current.alignments_original(), @@ -2236,10 +2234,10 @@ mod tests { current.transform_range(Range::Original(16..), vec![('!', 1)], 0); assert_eq!( current, - NormalizedString { - original: "𝔾𝕠𝕠𝕕".into(), - normalized: "𝔾𝕠𝕠𝕕!".into(), - alignments: vec![ + NormalizedString::new( + "𝔾𝕠𝕠𝕕", + "𝔾𝕠𝕠𝕕!", + vec![ (0, 4), (0, 4), (0, 4), @@ -2258,8 +2256,8 @@ mod tests { (12, 16), (12, 16) ], - original_shift: 0, - } + 0, + ) ); assert_eq!( current.alignments_original(), diff --git a/tokenizers/src/tokenizer/pattern.rs b/tokenizers/src/tokenizer/pattern.rs index 9fa22dd9b..634bb71d9 100644 --- a/tokenizers/src/tokenizer/pattern.rs +++ b/tokenizers/src/tokenizer/pattern.rs @@ -1,5 +1,6 @@ use crate::utils::SysRegex; use crate::{Offsets, Result}; +use compact_str::CompactString; use regex::Regex; /// Pattern used to split a NormalizedString @@ -38,6 +39,13 @@ impl Pattern for &String { } } +impl Pattern for &CompactString { + fn find_matches(&self, inside: &str) -> Result> { + let s: &str = self; + s.find_matches(inside) + } +} + impl Pattern for &Regex { fn find_matches(&self, inside: &str) -> Result> { if inside.is_empty() { diff --git a/tokenizers/src/tokenizer/pre_tokenizer.rs b/tokenizers/src/tokenizer/pre_tokenizer.rs index 0d54cd62b..6fe4224b3 100644 --- a/tokenizers/src/tokenizer/pre_tokenizer.rs +++ b/tokenizers/src/tokenizer/pre_tokenizer.rs @@ -1,3 +1,5 @@ +use compact_str::CompactString; + use crate::{ normalizer::Range, Encoding, NormalizedString, OffsetReferential, Offsets, Result, Token, }; @@ -52,7 +54,7 @@ impl From<(NormalizedString, Option>)> for Split { /// original string. #[derive(Debug, Clone, PartialEq, Eq)] pub struct PreTokenizedString { - original: String, + original: CompactString, splits: Vec, } @@ -154,7 +156,7 @@ impl PreTokenizedString { .flat_map(|split| { split.tokens.unwrap().into_iter().map(|token| { // Replace this with the actual fields you need for the Encoding type - (token.id, String::with_capacity(0), (0, 0), None, 0) + (token.id, CompactString::with_capacity(0), (0, 0), None, 0) }) }) .collect(); @@ -241,7 +243,7 @@ impl PreTokenizedString { impl From for PreTokenizedString { fn from(s: NormalizedString) -> Self { Self { - original: s.get_original().to_owned(), + original: s.get_original().into(), splits: vec![Split { normalized: s, tokens: None, @@ -264,6 +266,13 @@ impl From for PreTokenizedString { } } +impl From for PreTokenizedString { + fn from(s: CompactString) -> Self { + let normalized: NormalizedString = s.into(); + normalized.into() + } +} + struct BytesToCharOffsetConverter { map: HashMap, } diff --git a/tokenizers/src/utils/padding.rs b/tokenizers/src/utils/padding.rs index 39585a304..9a601a6d2 100644 --- a/tokenizers/src/utils/padding.rs +++ b/tokenizers/src/utils/padding.rs @@ -1,5 +1,6 @@ use crate::parallelism::*; use crate::tokenizer::{Encoding, Result}; +use compact_str::CompactString; use serde::{Deserialize, Serialize}; /// The various possible padding directions. @@ -25,7 +26,7 @@ pub struct PaddingParams { pub pad_to_multiple_of: Option, pub pad_id: u32, pub pad_type_id: u32, - pub pad_token: String, + pub pad_token: CompactString, } impl Default for PaddingParams { @@ -36,7 +37,7 @@ impl Default for PaddingParams { pad_to_multiple_of: None, pad_id: 0, pad_type_id: 0, - pad_token: String::from("[PAD]"), + pad_token: CompactString::from("[PAD]"), } } } @@ -93,7 +94,7 @@ mod tests { Encoding::new( vec![0, 1, 2, 3, 4], vec![], - vec![], + Vec::::new(), vec![], vec![], vec![], @@ -104,7 +105,7 @@ mod tests { Encoding::new( vec![0, 1, 2], vec![], - vec![], + Vec::::new(), vec![], vec![], vec![], @@ -123,7 +124,7 @@ mod tests { pad_to_multiple_of: Some(8), pad_id: 0, pad_type_id: 0, - pad_token: String::from("[PAD]"), + pad_token: CompactString::from("[PAD]"), }; pad_encodings(&mut encodings, ¶ms).unwrap(); assert!(encodings.iter().all(|e| e.get_ids().len() == 8)); diff --git a/tokenizers/src/utils/truncation.rs b/tokenizers/src/utils/truncation.rs index 9acc297bf..dc5e170c4 100644 --- a/tokenizers/src/utils/truncation.rs +++ b/tokenizers/src/utils/truncation.rs @@ -168,6 +168,8 @@ pub fn truncate_encodings( #[cfg(test)] mod tests { + use compact_str::CompactString; + use super::*; use crate::tokenizer::Encoding; use std::collections::HashMap; @@ -176,7 +178,7 @@ mod tests { Encoding::new( vec![], vec![], - vec![], + Vec::::new(), vec![], vec![], vec![], @@ -190,7 +192,7 @@ mod tests { Encoding::new( vec![1, 2], vec![0, 0], - vec![String::from("a"), String::from("b")], + vec!["a", "b"], vec![Some(0), Some(1)], vec![(0, 1), (1, 2)], vec![0, 0], @@ -204,12 +206,7 @@ mod tests { Encoding::new( vec![3, 4, 5, 6], vec![0, 0, 0, 0], - vec![ - String::from("d"), - String::from("e"), - String::from("f"), - String::from("g"), - ], + vec!["d", "e", "f", "g"], vec![Some(0), Some(1), Some(2), Some(3)], vec![(0, 1), (1, 2), (2, 3), (3, 4)], vec![0, 0, 0, 0], @@ -223,16 +220,7 @@ mod tests { Encoding::new( vec![7, 8, 9, 10, 11, 12, 13, 14], vec![0, 0, 0, 0, 0, 0, 0, 0], - vec![ - String::from("h"), - String::from("i"), - String::from("j"), - String::from("k"), - String::from("l"), - String::from("m"), - String::from("n"), - String::from("o"), - ], + vec!["h", "i", "j", "k", "l", "m", "n", "o"], vec![ Some(0), Some(1), diff --git a/tokenizers/tests/common/mod.rs b/tokenizers/tests/common/mod.rs index 26129699b..544cd923a 100644 --- a/tokenizers/tests/common/mod.rs +++ b/tokenizers/tests/common/mod.rs @@ -1,3 +1,4 @@ +use compact_str::CompactString; use tokenizers::decoders::wordpiece::WordPiece as WordPieceDecoder; use tokenizers::models::bpe::BPE; use tokenizers::models::wordpiece::WordPiece; @@ -48,10 +49,7 @@ pub fn get_bert() -> Tokenizer { .with_normalizer(Some(BertNormalizer::default())) .with_pre_tokenizer(Some(BertPreTokenizer)) .with_decoder(Some(WordPieceDecoder::default())) - .with_post_processor(Some(BertProcessing::new( - (String::from("[SEP]"), sep), - (String::from("[CLS]"), cls), - ))); + .with_post_processor(Some(BertProcessing::new(("[SEP]", sep), ("[CLS]", cls)))); tokenizer } diff --git a/tokenizers/tests/documentation.rs b/tokenizers/tests/documentation.rs index 304211e77..eeb2cc39b 100644 --- a/tokenizers/tests/documentation.rs +++ b/tokenizers/tests/documentation.rs @@ -1,6 +1,8 @@ use std::collections::HashMap; use std::iter::FromIterator; +use compact_str::{CompactString, ToCompactString}; + use tokenizers::decoders::byte_fallback::ByteFallback; use tokenizers::models::bpe::{BpeTrainerBuilder, BPE}; use tokenizers::normalizers::{Sequence, Strip, NFC}; @@ -29,11 +31,11 @@ fn train_tokenizer() { .vocab_size(vocab_size) .min_frequency(0) .special_tokens(vec![ - AddedToken::from(String::from(""), true), - AddedToken::from(String::from(""), true), - AddedToken::from(String::from(""), true), - AddedToken::from(String::from(""), true), - AddedToken::from(String::from(""), true), + AddedToken::from("", true), + AddedToken::from("", true), + AddedToken::from("", true), + AddedToken::from("", true), + AddedToken::from("", true), ]) .build(); @@ -58,7 +60,7 @@ fn load_tokenizer() { assert_eq!(encodings.get_ids(), ids); assert_eq!(encodings.get_tokens(), tokens); - let decoded = tokenizer.decode(&ids, false).unwrap(); + let decoded = tokenizer.decode(&ids, false).unwrap().to_compact_string(); assert_eq!(decoded, example); } @@ -67,35 +69,29 @@ fn streaming_tokenizer() { let tokenizer = Tokenizer::from_file("data/roberta.json").unwrap(); let mut decode_stream = tokenizer.decode_stream(false); - assert_eq!(decode_stream.step(713).unwrap(), Some("This".to_string())); - assert_eq!(decode_stream.step(16).unwrap(), Some(" is".to_string())); - assert_eq!(decode_stream.step(41).unwrap(), Some(" an".to_string())); - assert_eq!( - decode_stream.step(1246).unwrap(), - Some(" example".to_string()) - ); + assert_eq!(decode_stream.step(713).unwrap(), Some("This".into())); + assert_eq!(decode_stream.step(16).unwrap(), Some(" is".into())); + assert_eq!(decode_stream.step(41).unwrap(), Some(" an".into())); + assert_eq!(decode_stream.step(1246).unwrap(), Some(" example".into())); let tokenizer = Tokenizer::from_file("data/albert-base-v1-tokenizer.json").unwrap(); let encoded = tokenizer.encode("This is an example", false).unwrap(); assert_eq!(encoded.get_ids(), &[48, 25, 40, 823]); let mut decode_stream = tokenizer.decode_stream(false); // No space anymore - assert_eq!(decode_stream.step(25).unwrap(), Some("is".to_string())); + assert_eq!(decode_stream.step(25).unwrap(), Some("is".into())); let mut decode_stream = tokenizer.decode_stream(false); - assert_eq!(decode_stream.step(48).unwrap(), Some("this".to_string())); - assert_eq!(decode_stream.step(25).unwrap(), Some(" is".to_string())); - assert_eq!(decode_stream.step(40).unwrap(), Some(" an".to_string())); - assert_eq!( - decode_stream.step(823).unwrap(), - Some(" example".to_string()) - ); + assert_eq!(decode_stream.step(48).unwrap(), Some("this".into())); + assert_eq!(decode_stream.step(25).unwrap(), Some(" is".into())); + assert_eq!(decode_stream.step(40).unwrap(), Some(" an".into())); + assert_eq!(decode_stream.step(823).unwrap(), Some(" example".into())); // None example let vocab = HashMap::from_iter([ - ("<0x20>".to_string(), 0), - ("<0xC3>".to_string(), 1), - ("<0xA9>".to_string(), 2), - (" This".to_string(), 3), + ("<0x20>".into(), 0), + ("<0xC3>".into(), 1), + ("<0xA9>".into(), 2), + (" This".into(), 3), ]); let merges = vec![]; let bpe = BPE::builder() @@ -115,9 +111,9 @@ fn streaming_tokenizer() { .build() .unwrap(); let mut decode_stream = tokenizer.decode_stream(false); - assert_eq!(decode_stream.step(0).unwrap(), Some(" ".to_string())); + assert_eq!(decode_stream.step(0).unwrap(), Some(" ".into())); assert_eq!(decode_stream.step(1).unwrap(), None); - assert_eq!(decode_stream.step(2).unwrap(), Some("é".to_string())); + assert_eq!(decode_stream.step(2).unwrap(), Some("é".into())); assert_eq!(decode_stream.step(2).unwrap(), None); } @@ -133,12 +129,7 @@ fn quicktour_slow_train() -> tokenizers::Result<()> { PreTokenizerWrapper, PostProcessorWrapper, DecoderWrapper, - > = TokenizerImpl::new( - BPE::builder() - .unk_token("[UNK]".to_string()) - .build() - .unwrap(), - ); + > = TokenizerImpl::new(BPE::builder().unk_token("[UNK]").build()?); // END quicktour_init_tokenizer // START quicktour_init_trainer use tokenizers::models::bpe::BpeTrainer; @@ -278,7 +269,7 @@ fn quicktour() -> tokenizers::Result<()> { tokenizer.with_padding(Some(PaddingParams { pad_id: 3, - pad_token: "[PAD]".to_string(), + pad_token: CompactString::from("[PAD]"), ..PaddingParams::default() })); // END quicktour_enable_padding @@ -395,10 +386,8 @@ fn pipeline() -> tokenizers::Result<()> { tokenizer.with_post_processor(Some( TemplateProcessing::builder() - .try_single("[CLS] $A [SEP]") - .unwrap() - .try_pair("[CLS] $A [SEP] $B:1 [SEP]:1") - .unwrap() + .try_single("[CLS] $A [SEP]")? + .try_pair("[CLS] $A [SEP] $B:1 [SEP]:1")? .special_tokens(vec![("[CLS]", 1), ("[SEP]", 2)]) .build() .unwrap(), @@ -409,10 +398,12 @@ fn pipeline() -> tokenizers::Result<()> { println!("{:?}", output.get_ids()); // [1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2] - let decoded = tokenizer.decode( - &[1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2], - true, - )?; + let decoded = tokenizer + .decode( + &[1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2], + true, + )? + .to_compact_string(); println!("{decoded}"); // "Hello , y ' all ! How are you ?" // END pipeline_test_decoding @@ -427,12 +418,8 @@ fn train_pipeline_bert() -> tokenizers::Result<()> { use tokenizers::models::wordpiece::WordPiece; use tokenizers::Tokenizer; - let mut bert_tokenizer = Tokenizer::new( - WordPiece::builder() - .unk_token("[UNK]".to_string()) - .build() - .unwrap(), - ); + let mut bert_tokenizer = + Tokenizer::new(WordPiece::builder().unk_token("[UNK]".into()).build()?); // END bert_setup_tokenizer // START bert_setup_normalizer use tokenizers::normalizers::utils::Sequence as NormalizerSequence; @@ -454,10 +441,8 @@ fn train_pipeline_bert() -> tokenizers::Result<()> { bert_tokenizer.with_post_processor(Some( TemplateProcessing::builder() - .try_single("[CLS] $A [SEP]") - .unwrap() - .try_pair("[CLS] $A [SEP] $B:1 [SEP]:1") - .unwrap() + .try_single("[CLS] $A [SEP]")? + .try_pair("[CLS] $A [SEP] $B:1 [SEP]:1")? .special_tokens(vec![("[CLS]", 1), ("[SEP]", 2)]) .build() .unwrap(), @@ -498,7 +483,9 @@ fn pipeline_bert() -> tokenizers::Result<()> { println!("{:?}", output.get_tokens()); // ["[CLS]", "welcome", "to", "the", "[UNK]", "tok", "##eni", "##zer", "##s", "library", ".", "[SEP]"] - let decoded = bert_tokenizer.decode(output.get_ids(), true)?; + let decoded = bert_tokenizer + .decode(output.get_ids(), true)? + .to_compact_string(); println!("{decoded}"); // "welcome to the tok ##eni ##zer ##s library ." // END bert_test_decoding @@ -514,7 +501,9 @@ fn pipeline_bert() -> tokenizers::Result<()> { use tokenizers::decoders::wordpiece::WordPiece as WordPieceDecoder; bert_tokenizer.with_decoder(Some(WordPieceDecoder::default())); - let decoded = bert_tokenizer.decode(output.get_ids(), true)?; + let decoded = bert_tokenizer + .decode(output.get_ids(), true)? + .to_compact_string(); // "welcome to the tokenizers library." // END bert_proper_decoding assert_eq!(decoded, "welcome to the tokenizers library."); diff --git a/tokenizers/tests/serialization.rs b/tokenizers/tests/serialization.rs index dc0c95a57..0e4f944b9 100644 --- a/tokenizers/tests/serialization.rs +++ b/tokenizers/tests/serialization.rs @@ -84,7 +84,7 @@ fn normalizers() { #[test] fn processors() { - let bert = BertProcessing::new(("SEP".into(), 0), ("CLS".into(), 0)); + let bert = BertProcessing::new(("SEP", 0), ("CLS", 0)); let bert_ser = serde_json::to_string(&bert).unwrap(); assert_eq!( bert_ser, @@ -157,7 +157,7 @@ fn pretoks() { ); assert_eq!(serde_json::from_str::(&pretok_str).unwrap(), pretok); - let pattern = SplitPattern::Regex("[SEP]".to_string()); + let pattern = SplitPattern::Regex("[SEP]".into()); let pretok = Split::new(pattern, SplitDelimiterBehavior::Isolated, false).unwrap(); let pretok_str = serde_json::to_string(&pretok).unwrap(); assert_eq!( diff --git a/tokenizers/tests/stream.rs b/tokenizers/tests/stream.rs index c4cfee3dd..939027d9d 100644 --- a/tokenizers/tests/stream.rs +++ b/tokenizers/tests/stream.rs @@ -1,3 +1,4 @@ +use compact_str::ToCompactString; use tokenizers::{ normalizers, pre_tokenizers::split::{Split, SplitPattern}, @@ -29,8 +30,11 @@ fn test_decoding_with_added_bpe() { ["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", "嗎"] ); - let decoded = tokenizer.decode(encoded.get_ids(), false); - assert_eq!(decoded.unwrap(), "Hey! how is this token: 嗎"); + let decoded = tokenizer + .decode(encoded.get_ids(), false) + .unwrap() + .to_compact_string(); + assert_eq!(decoded, "Hey! how is this token: 嗎"); tokenizer.add_tokens(&[AddedToken::from("д", false).normalized(true)]); let encoded = tokenizer @@ -44,8 +48,11 @@ fn test_decoding_with_added_bpe() { encoded.get_tokens(), ["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", "д"] ); - let decoded = tokenizer.decode(encoded.get_ids(), false); - assert_eq!(decoded.unwrap(), "Hey! how is this token: д") + let decoded = tokenizer + .decode(encoded.get_ids(), false) + .unwrap() + .to_compact_string(); + assert_eq!(decoded, "Hey! how is this token: д") } #[test] @@ -54,25 +61,25 @@ fn test_decode_stream_step_no_panic() { // "A B C D E F G H I J" let mut decode_stream = tokenizer.decode_stream(false); - assert_eq!(decode_stream.step(32).unwrap(), Some("A".to_string())); - assert_eq!(decode_stream.step(426).unwrap(), Some(" B".to_string())); - assert_eq!(decode_stream.step(356).unwrap(), Some(" C".to_string())); - assert_eq!(decode_stream.step(423).unwrap(), Some(" D".to_string())); - assert_eq!(decode_stream.step(469).unwrap(), Some(" E".to_string())); - assert_eq!(decode_stream.step(435).unwrap(), Some(" F".to_string())); - assert_eq!(decode_stream.step(480).unwrap(), Some(" G".to_string())); - assert_eq!(decode_stream.step(473).unwrap(), Some(" H".to_string())); - assert_eq!(decode_stream.step(358).unwrap(), Some(" I".to_string())); - assert_eq!(decode_stream.step(622).unwrap(), Some(" J".to_string())); + assert_eq!(decode_stream.step(32).unwrap(), Some("A".into())); + assert_eq!(decode_stream.step(426).unwrap(), Some(" B".into())); + assert_eq!(decode_stream.step(356).unwrap(), Some(" C".into())); + assert_eq!(decode_stream.step(423).unwrap(), Some(" D".into())); + assert_eq!(decode_stream.step(469).unwrap(), Some(" E".into())); + assert_eq!(decode_stream.step(435).unwrap(), Some(" F".into())); + assert_eq!(decode_stream.step(480).unwrap(), Some(" G".into())); + assert_eq!(decode_stream.step(473).unwrap(), Some(" H".into())); + assert_eq!(decode_stream.step(358).unwrap(), Some(" I".into())); + assert_eq!(decode_stream.step(622).unwrap(), Some(" J".into())); // for (i, &token) in output_tokens.iter().enumerate() {} // "삥뽕빵" (Korean words composed of 2-3 tokens: [80690, 98], [167, 121, 243], and [102457, 113]) let mut decode_stream = tokenizer.decode_stream(false); assert_eq!(decode_stream.step(80690).unwrap(), None); - assert_eq!(decode_stream.step(98).unwrap(), Some("삥".to_string())); + assert_eq!(decode_stream.step(98).unwrap(), Some("삥".into())); assert_eq!(decode_stream.step(167).unwrap(), None); assert_eq!(decode_stream.step(121).unwrap(), None); - assert_eq!(decode_stream.step(243).unwrap(), Some("뽕".to_string())); + assert_eq!(decode_stream.step(243).unwrap(), Some("뽕".into())); assert_eq!(decode_stream.step(102457).unwrap(), None); - assert_eq!(decode_stream.step(113).unwrap(), Some("빵".to_string())); + assert_eq!(decode_stream.step(113).unwrap(), Some("빵".into())); } diff --git a/tokenizers/tests/training.rs b/tokenizers/tests/training.rs index 37127d5ac..fa4caa41f 100644 --- a/tokenizers/tests/training.rs +++ b/tokenizers/tests/training.rs @@ -14,7 +14,7 @@ fn bpe_values_after_training() { >::default() .with_model( BPE::builder() - .unk_token("[UNK]".to_string()) + .unk_token("[UNK]") .dropout(0.1) .build() .unwrap(), @@ -23,10 +23,10 @@ fn bpe_values_after_training() { .unwrap(); let mut trainer = tokenizer.get_model().get_trainer(); tokenizer - .train_from_files(&mut trainer, vec!["./data/small.txt".to_string()]) + .train_from_files(&mut trainer, vec!["./data/small.txt".into()]) .unwrap(); assert_eq!(tokenizer.get_model().dropout, Some(0.1)); - assert_eq!(tokenizer.get_model().unk_token, Some("[UNK]".to_string())); + assert_eq!(tokenizer.get_model().unk_token, Some("[UNK]".into())); } #[test] @@ -40,8 +40,8 @@ fn bpe_continuing_subword_prefix_error() { >::default() .with_model( BPE::builder() - .unk_token("[UNK]".to_string()) - .continuing_subword_prefix("##".to_string()) + .unk_token("[UNK]") + .continuing_subword_prefix("##") .build() .unwrap(), ) diff --git a/tokenizers/tests/unigram.rs b/tokenizers/tests/unigram.rs index dc0dfdc07..b6c906c0f 100644 --- a/tokenizers/tests/unigram.rs +++ b/tokenizers/tests/unigram.rs @@ -57,10 +57,7 @@ fn test_train_unigram_from_file() { .unwrap(); let mut model = Unigram::default(); - let sentences: Vec<_> = word_counts - .iter() - .map(|(s, i)| (s.to_owned(), *i)) - .collect(); + let sentences: Vec<_> = word_counts.iter().map(|(s, i)| (s.into(), *i)).collect(); trainer.do_train(sentences, &mut model).unwrap(); assert_eq!(model.get_vocab_size(), 719); }