diff --git a/bindings/python/examples/train_random_chunk_bpe.py b/bindings/python/examples/train_random_chunk_bpe.py new file mode 100644 index 000000000..64c3c23b3 --- /dev/null +++ b/bindings/python/examples/train_random_chunk_bpe.py @@ -0,0 +1,93 @@ +import argparse +import glob +import json +import os +from os.path import join + +from tokenizers import Tokenizer, normalizers, trainers +from tokenizers.models import BPE +from tokenizers.pre_tokenizers import RandomChunkSplit + + +parser = argparse.ArgumentParser() +parser.add_argument( + "--files", + default=None, + metavar="path", + type=str, + required=True, + help="The files to use as training; accept '**/*.txt' type of patterns \ + if enclosed in quotes", +) +parser.add_argument( + "--out", + default="./", + type=str, + help="Path to the output directory, where the files will be saved", +) +parser.add_argument("--name", default="random-chunk-bpe", type=str, help="The name of the output vocab files") +parser.add_argument("--min-length", default=2, type=int, help="Minimum length of chunks") +parser.add_argument("--max-length", default=5, type=int, help="Maximum length of chunks") +parser.add_argument("--vocab-size", default=10000, type=int, help="Size of vocabulary") +parser.add_argument("--min-frequency", default=2, type=int, help="Minimum frequency for a token to be included") +args = parser.parse_args() + +files = glob.glob(args.files) +if not files: + print(f"File does not exist: {args.files}") + exit(1) + + +# Initialize a tokenizer with BPE model +tokenizer = Tokenizer(BPE()) + +# Use RandomChunkSplit as pre-tokenizer +tokenizer.pre_tokenizer = RandomChunkSplit(min_length=args.min_length, max_length=args.max_length) + +# Optional: Add NFKC normalization like SentencePieceBPE +tokenizer.normalizer = normalizers.NFKC() + +# Configure the BPE trainer +trainer = trainers.BpeTrainer( + vocab_size=args.vocab_size, + min_frequency=args.min_frequency, + special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"], + show_progress=True +) + +# Train the model +print(f"Training BPE with RandomChunkSplit (min_length={args.min_length}, max_length={args.max_length})") +tokenizer.train(files, trainer) + +# Save the trained tokenizer +output_path = join(args.out, f"{args.name}.json") +tokenizer.save(output_path) +print(f"Trained tokenizer saved to: {output_path}") + +# Create an inference version without pre-tokenizer +# First save to a temporary file +temp_tokenizer_path = join(args.out, "temp_tokenizer.json") +tokenizer.save(temp_tokenizer_path) + +# Read the JSON +with open(temp_tokenizer_path, "r") as f: + tokenizer_data = json.load(f) + +# Remove pre-tokenizer field if present +if "pre_tokenizer" in tokenizer_data: + del tokenizer_data["pre_tokenizer"] + +# Write modified tokenizer to inference file +inference_path = join(args.out, f"{args.name}_inference.json") +with open(inference_path, "w") as f: + json.dump(tokenizer_data, f, indent=2) + +# Clean up temp file +os.remove(temp_tokenizer_path) + +print(f"Inference-ready tokenizer (no pre-tokenizer) saved to: {inference_path}") + +# Test encoding with inference tokenizer +tokenizer = Tokenizer.from_file(inference_path) +example = "Training BPE with multi-word tokens is very easy" +print(f"\nTest encoding: {tokenizer.encode(example).tokens}") \ No newline at end of file diff --git a/bindings/python/examples/train_random_whitespace_bpe.py b/bindings/python/examples/train_random_whitespace_bpe.py new file mode 100644 index 000000000..4d8feaddf --- /dev/null +++ b/bindings/python/examples/train_random_whitespace_bpe.py @@ -0,0 +1,93 @@ +import argparse +import glob +import json +import os +from os.path import join + +from tokenizers import Tokenizer, normalizers, trainers +from tokenizers.models import BPE +from tokenizers.pre_tokenizers import RandomWhitespaceSplit + + +parser = argparse.ArgumentParser() +parser.add_argument( + "--files", + default=None, + metavar="path", + type=str, + required=True, + help="The files to use as training; accept '**/*.txt' type of patterns \ + if enclosed in quotes", +) +parser.add_argument( + "--out", + default="./", + type=str, + help="Path to the output directory, where the files will be saved", +) +parser.add_argument("--name", default="random-whitespace-bpe", type=str, help="The name of the output vocab files") +parser.add_argument("--split-prob", default=0.3, type=float, help="Probability of splitting at whitespace (0.0-1.0)") +parser.add_argument("--vocab-size", default=10000, type=int, help="Size of vocabulary") +parser.add_argument("--min-frequency", default=2, type=int, help="Minimum frequency for a token to be included") +args = parser.parse_args() + +files = glob.glob(args.files) +if not files: + print(f"File does not exist: {args.files}") + exit(1) + + +# Initialize a tokenizer with BPE model +tokenizer = Tokenizer(BPE()) + +# Use RandomWhitespaceSplit as pre-tokenizer +tokenizer.pre_tokenizer = RandomWhitespaceSplit(split_probability=args.split_prob) + +# Optional: Add NFKC normalization like SentencePieceBPE +tokenizer.normalizer = normalizers.NFKC() + +# Configure the BPE trainer +trainer = trainers.BpeTrainer( + vocab_size=args.vocab_size, + min_frequency=args.min_frequency, + special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"], + show_progress=True +) + +# Train the model +print(f"Training BPE with RandomWhitespaceSplit (split_probability={args.split_prob})") +tokenizer.train(files, trainer) + +# Save the trained tokenizer +output_path = join(args.out, f"{args.name}.json") +tokenizer.save(output_path) +print(f"Trained tokenizer saved to: {output_path}") + +# Create an inference version without pre-tokenizer +# First save to a temporary file +temp_tokenizer_path = join(args.out, "temp_tokenizer.json") +tokenizer.save(temp_tokenizer_path) + +# Read the JSON +with open(temp_tokenizer_path, "r") as f: + tokenizer_data = json.load(f) + +# Remove pre-tokenizer field if present +if "pre_tokenizer" in tokenizer_data: + del tokenizer_data["pre_tokenizer"] + +# Write modified tokenizer to inference file +inference_path = join(args.out, f"{args.name}_inference.json") +with open(inference_path, "w") as f: + json.dump(tokenizer_data, f, indent=2) + +# Clean up temp file +os.remove(temp_tokenizer_path) + +print(f"Inference-ready tokenizer (no pre-tokenizer) saved to: {inference_path}") + +# Test encoding with inference tokenizer +tokenizer = Tokenizer.from_file(inference_path) +example = "Training BPE with multi-word tokens is very easy" +print(f"\nTest encoding: {tokenizer.encode(example).tokens}") + diff --git a/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.py b/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.py index 48277f0d2..440b1e768 100644 --- a/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.py +++ b/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.py @@ -8,6 +8,8 @@ Digits = pre_tokenizers.Digits Metaspace = pre_tokenizers.Metaspace Punctuation = pre_tokenizers.Punctuation +RandomChunkSplit = pre_tokenizers.RandomChunkSplit +RandomWhitespaceSplit = pre_tokenizers.RandomWhitespaceSplit Sequence = pre_tokenizers.Sequence Split = pre_tokenizers.Split UnicodeScripts = pre_tokenizers.UnicodeScripts diff --git a/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.pyi b/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.pyi index 6f31ff3a2..4919e254f 100644 --- a/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.pyi +++ b/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.pyi @@ -367,6 +367,59 @@ class Punctuation(PreTokenizer): """ pass +class RandomChunkSplit(PreTokenizer): + """ + RandomChunkSplit PreTokenizer + + This pre-tokenizer splits text into random-length chunks regardless of whitespace + boundaries. It's useful for enabling BPE to learn tokens that span across whitespace. + + Args: + min_length (:obj:`int`, `optional`, defaults to :obj:`1`): + The minimum length (in characters) for each chunk. + max_length (:obj:`int`, `optional`, defaults to :obj:`5`): + The maximum length (in characters) for each chunk. + """ + def __init__(self, min_length=1, max_length=5): + pass + + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass + class Sequence(PreTokenizer): """ This pre-tokenizer composes other pre_tokenizers and applies them in sequence @@ -607,4 +660,4 @@ class WhitespaceSplit(PreTokenizer): :obj:`List[Tuple[str, Offsets]]`: A list of tuple with the pre-tokenized parts and their offsets """ - pass + pass \ No newline at end of file diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index 8140ade1d..b73d3c2ea 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -14,6 +14,8 @@ use tk::pre_tokenizers::delimiter::CharDelimiterSplit; use tk::pre_tokenizers::digits::Digits; use tk::pre_tokenizers::metaspace::{Metaspace, PrependScheme}; use tk::pre_tokenizers::punctuation::Punctuation; +use tk::pre_tokenizers::random_chunk::RandomChunkSplit; +use tk::pre_tokenizers::random_whitespace::RandomWhitespaceSplit; use tk::pre_tokenizers::split::Split; use tk::pre_tokenizers::unicode_scripts::UnicodeScripts; use tk::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit}; @@ -118,6 +120,18 @@ impl PyPreTokenizer { .into_any() .into() } + PreTokenizerWrapper::RandomChunkSplit(_) => { + Py::new(py, (PyRandomChunkSplit {}, base))? + .into_pyobject(py)? + .into_any() + .into() + } + PreTokenizerWrapper::RandomWhitespaceSplit(_) => { + Py::new(py, (PyRandomWhitespaceSplit {}, base))? + .into_pyobject(py)? + .into_any() + .into() + } }, } } @@ -750,6 +764,121 @@ impl PyUnicodeScripts { } } +/// RandomChunkSplit PreTokenizer +/// +/// This pre-tokenizer splits text into random-length chunks regardless of whitespace +/// boundaries. It's useful for enabling BPE to learn tokens that span across whitespace. +/// +/// Args: +/// min_length (:obj:`int`, `optional`, defaults to :obj:`1`): +/// The minimum length (in characters) for each chunk. +/// max_length (:obj:`int`, `optional`, defaults to :obj:`5`): +/// The maximum length (in characters) for each chunk. +/// deterministic (:obj:`bool`, `optional`, defaults to :obj:`False`): +/// Whether to use deterministic mode for inference. In deterministic mode, +/// instead of random-length chunks, fixed-length chunks of average size +/// between min_length and max_length will be used, ensuring consistent +/// tokenization for the same input. +#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "RandomChunkSplit")] +pub struct PyRandomChunkSplit {} +#[pymethods] +impl PyRandomChunkSplit { + #[getter] + fn get_min_length(self_: PyRef) -> usize { + getter!(self_, RandomChunkSplit, min_length) + } + + #[setter] + fn set_min_length(self_: PyRef, min_length: usize) { + setter!(self_, RandomChunkSplit, min_length, min_length); + } + + #[getter] + fn get_max_length(self_: PyRef) -> usize { + getter!(self_, RandomChunkSplit, max_length) + } + + #[setter] + fn set_max_length(self_: PyRef, max_length: usize) { + setter!(self_, RandomChunkSplit, max_length, max_length); + } + + #[getter] + fn get_deterministic(self_: PyRef) -> bool { + getter!(self_, RandomChunkSplit, deterministic) + } + + #[setter] + fn set_deterministic(self_: PyRef, deterministic: bool) { + setter!(self_, RandomChunkSplit, deterministic, deterministic); + } + + #[new] + #[pyo3(signature = (min_length = 1, max_length = 5, deterministic = false), text_signature = "(self, min_length=1, max_length=5, deterministic=False)")] + fn new(min_length: usize, max_length: usize, deterministic: bool) -> (Self, PyPreTokenizer) { + ( + PyRandomChunkSplit {}, + RandomChunkSplit::new(min_length, max_length) + .with_deterministic(deterministic) + .into(), + ) + } +} + +/// RandomWhitespaceSplit PreTokenizer +/// +/// Split the text by randomly deciding at each whitespace character whether to +/// split there or continue. This enables the tokenizer to learn common multi-word +/// expressions as a single token. +/// +/// Args: +/// split_probability (:obj:`float`, `optional`, defaults to :obj:`0.5`): +/// The probability (0.0-1.0) of splitting at each whitespace character. +/// Higher values (closer to 1.0) make the tokenizer behave more like traditional +/// whitespace splitting, while lower values promote more multi-word tokens. +/// deterministic (:obj:`bool`, `optional`, defaults to :obj:`False`): +/// Whether to use deterministic mode for inference. In deterministic mode: +/// - If split_probability > 0.5, all whitespace is split (like WhitespaceSplit) +/// - If split_probability <= 0.5, no whitespace is split (preserving multi-word tokens) +/// This ensures consistent tokenization for the same input during inference. +#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "RandomWhitespaceSplit")] +pub struct PyRandomWhitespaceSplit {} +#[pymethods] +impl PyRandomWhitespaceSplit { + #[getter] + fn get_split_probability(self_: PyRef) -> f32 { + getter!(self_, RandomWhitespaceSplit, split_probability) + } + + #[setter] + fn set_split_probability(self_: PyRef, split_probability: f32) { + // Clamp the value between 0.0 and 1.0 + let clamped_probability = split_probability.min(1.0).max(0.0); + setter!(self_, RandomWhitespaceSplit, split_probability, clamped_probability); + } + + #[getter] + fn get_deterministic(self_: PyRef) -> bool { + getter!(self_, RandomWhitespaceSplit, deterministic) + } + + #[setter] + fn set_deterministic(self_: PyRef, deterministic: bool) { + setter!(self_, RandomWhitespaceSplit, deterministic, deterministic); + } + + #[new] + #[pyo3(signature = (split_probability = 0.5, deterministic = false), text_signature = "(self, split_probability=0.5, deterministic=False)")] + fn new(split_probability: f32, deterministic: bool) -> (Self, PyPreTokenizer) { + ( + PyRandomWhitespaceSplit {}, + RandomWhitespaceSplit::new(split_probability) + .with_deterministic(deterministic) + .into(), + ) + } +} + #[derive(Clone)] pub(crate) struct CustomPreTokenizer { inner: PyObject, @@ -926,6 +1055,8 @@ pub fn pre_tokenizers(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/bindings/python/tests/bindings/test_pre_tokenizers.py b/bindings/python/tests/bindings/test_pre_tokenizers.py index 3611930ae..893e7f07b 100644 --- a/bindings/python/tests/bindings/test_pre_tokenizers.py +++ b/bindings/python/tests/bindings/test_pre_tokenizers.py @@ -11,6 +11,8 @@ Metaspace, PreTokenizer, Punctuation, + RandomChunkSplit, + RandomWhitespaceSplit, Sequence, Split, UnicodeScripts, @@ -260,6 +262,135 @@ def test_instantiate(self): assert isinstance(pickle.loads(pickle.dumps(UnicodeScripts())), UnicodeScripts) +class TestRandomChunkSplit: + def test_instantiate(self): + assert RandomChunkSplit() is not None + assert RandomChunkSplit(min_length=2, max_length=5) is not None + assert isinstance(RandomChunkSplit(), PreTokenizer) + assert isinstance(RandomChunkSplit(), RandomChunkSplit) + assert isinstance(pickle.loads(pickle.dumps(RandomChunkSplit())), RandomChunkSplit) + + def test_can_modify(self): + pretok = RandomChunkSplit(min_length=2, max_length=10) + + assert pretok.min_length == 2 + assert pretok.max_length == 10 + + # Modify these + pretok.min_length = 3 + assert pretok.min_length == 3 + pretok.max_length = 8 + assert pretok.max_length == 8 + + def test_pre_tokenize_str(self): + pretok = RandomChunkSplit(min_length=2, max_length=2) # Fixed length for deterministic testing + + # Test with a simple string - chunks of exactly 2 chars, except possibly the last one + text = "Hello world" + result = pretok.pre_tokenize_str(text) + + # Make sure all characters are accounted for + joined = "".join(token for token, _ in result) + assert joined == text + + # Most tokens should have length 2 (except possibly the last one if odd length) + for i, (token, _) in enumerate(result[:-1]): + assert len(token) == 2, f"Token at position {i} has length {len(token)}, expected 2" + + # If total length is odd, last token might be length 1 + if len(text) % 2 == 1: + assert 1 <= len(result[-1][0]) <= 2 + else: + assert len(result[-1][0]) == 2 + + # Test with unicode characters + text = "こんにちは" # Japanese for "Hello" + result = pretok.pre_tokenize_str(text) + + # Make sure all characters are accounted for + joined = "".join(token for token, _ in result) + assert joined == text + + +class TestRandomWhitespaceSplit: + def test_instantiate(self): + assert RandomWhitespaceSplit() is not None + assert RandomWhitespaceSplit(split_probability=0.5) is not None + assert RandomWhitespaceSplit(split_probability=0.0) is not None + assert RandomWhitespaceSplit(split_probability=1.0) is not None + assert isinstance(RandomWhitespaceSplit(), PreTokenizer) + assert isinstance(RandomWhitespaceSplit(), RandomWhitespaceSplit) + assert isinstance(pickle.loads(pickle.dumps(RandomWhitespaceSplit())), RandomWhitespaceSplit) + + def test_can_modify(self): + pretok = RandomWhitespaceSplit(split_probability=0.5) + + assert pretok.split_probability == 0.5 + + # Modify the split probability + pretok.split_probability = 0.75 + assert pretok.split_probability == 0.75 + + # Test value limiting (should clamp between 0 and 1) + pretok.split_probability = 1.5 # Should be clamped to 1.0 + assert pretok.split_probability == 1.0 + + pretok.split_probability = -0.5 # Should be clamped to 0.0 + assert pretok.split_probability == 0.0 + + def test_pre_tokenize_str_full_probability(self): + # With split_probability = 1.0, should behave like WhitespaceSplit + pretok = RandomWhitespaceSplit(split_probability=1.0) + + text = "Hello world!" + result = pretok.pre_tokenize_str(text) + + # Make sure all characters are accounted for + joined = "".join(token for token, _ in result) + assert joined == text + + # Should split on the whitespace + assert len(result) == 3 + assert result[0][0] == "Hello" + assert result[1][0] == " " + assert result[2][0] == "world!" + + def test_pre_tokenize_str_zero_probability(self): + # With split_probability = 0.0, should not split + pretok = RandomWhitespaceSplit(split_probability=0.0) + + text = "Hello world!" + result = pretok.pre_tokenize_str(text) + + # Make sure all characters are accounted for + joined = "".join(token for token, _ in result) + assert joined == text + + # Should not split at all + assert len(result) == 1 + assert result[0][0] == "Hello world!" + + def test_pre_tokenize_multiple_whitespaces(self): + # Test with multiple whitespaces and probability = 1.0 + pretok = RandomWhitespaceSplit(split_probability=1.0) + + text = "Hello world!\nTest" + result = pretok.pre_tokenize_str(text) + + # Make sure all characters are accounted for + joined = "".join(token for token, _ in result) + assert joined == text + + # Should split on all whitespaces + assert len(result) == 6 + assert result[0][0] == "Hello" + assert result[1][0] == " " + assert result[2][0] == " " + assert result[3][0] == "world!" + assert result[4][0] == "\n" + assert result[5][0] == "Test" + + class TestCustomPreTokenizer: class BadCustomPretok: def pre_tokenize(self, pretok, wrong): diff --git a/docs/source-doc-builder/components/random-chunk-split.mdx b/docs/source-doc-builder/components/random-chunk-split.mdx new file mode 100644 index 000000000..4a153c63e --- /dev/null +++ b/docs/source-doc-builder/components/random-chunk-split.mdx @@ -0,0 +1,77 @@ +# RandomChunkSplit + +PreTokenizer + +The `RandomChunkSplit` pre-tokenizer splits text into random-length chunks regardless of whitespace boundaries. This enables [BPE](../components/bpe.mdx) models to learn tokens that span across whitespace, which can be particularly useful for recognizing common multi-word expressions as single tokens. + +## How it works + +Unlike traditional pre-tokenizers like [WhitespaceSplit](../components/whitespace-split.mdx) that split text at whitespace before tokenization, `RandomChunkSplit` randomly segments text into chunks of configurable lengths. The algorithm: + +1. Determines a random length between `min_length` and `max_length` for each chunk +2. Splits the input text into chunks of this random length +3. Ensures proper handling of Unicode characters by splitting at character boundaries, not byte positions + +This approach allows the BPE algorithm to discover and learn multi-word tokens that occur frequently in the training corpus. + +## Example + +```python +from tokenizers import Tokenizer +from tokenizers.models import BPE +from tokenizers.pre_tokenizers import RandomChunkSplit +from tokenizers.trainers import BpeTrainer + +# Initialize a tokenizer with BPE model +tokenizer = Tokenizer(BPE()) + +# Add the RandomChunkSplit pre-tokenizer +tokenizer.pre_tokenizer = RandomChunkSplit(min_length=2, max_length=5) + +# Train on your data +trainer = BpeTrainer( + vocab_size=25000, + min_frequency=2, + special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"] +) +tokenizer.train_from_iterator(your_training_data, trainer) + +# Now the tokenizer can recognize multi-word expressions +encoded = tokenizer.encode("New York is a city in the United States") +print(encoded.tokens) +``` + +## Benefits + +1. **Multi-word expressions**: Enables learning tokens that span whitespace boundaries +2. **Domain adaptation**: Better handles domain-specific expressions like "New York" or "machine learning" +3. **Tokenization efficiency**: Can reduce sequence lengths by representing common phrases as single tokens +4. **Semantic coherence**: Helps maintain semantic meaning of expressions that should be treated as units + +## Parameters + + + The minimum length (in characters) for each chunk. + + + + The maximum length (in characters) for each chunk. + + +## Usage Recommendations + +1. **Chunk Size Tuning**: The optimal chunk size depends on your language and domain: + - Smaller chunks (1-3): Good for character-rich languages or short words + - Medium chunks (2-5): Balanced approach for most languages + - Larger chunks (5-10): Better for capturing longer expressions + +2. **Training Impact**: + - RandomChunkSplit typically requires more training data + - It may produce larger vocabularies with more diverse tokens + - Training might be slightly slower due to more complex pattern discovery + +3. **Use with other components**: + - Combines well with normalization steps like lowercasing + - Works with various BPE trainers and parameters + +4. **Evaluation**: Best compared against traditional methods using perplexity or downstream task performance \ No newline at end of file diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index db56865d2..f3d223be4 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -82,6 +82,8 @@ tempfile = "3.10" assert_approx_eq = "1.1" tracing = "0.1" tracing-subscriber = "0.3.18" +half = "=2.4.0" +anyhow = "1.0" [profile.release] lto = "fat" diff --git a/tokenizers/benches/bpe_benchmark.rs b/tokenizers/benches/bpe_benchmark.rs index f0097bf82..ed33cd0c9 100644 --- a/tokenizers/benches/bpe_benchmark.rs +++ b/tokenizers/benches/bpe_benchmark.rs @@ -11,6 +11,7 @@ use criterion::Criterion; use tokenizers::models::bpe::{BpeTrainerBuilder, BPE}; use tokenizers::models::TrainerWrapper; use tokenizers::pre_tokenizers::byte_level::ByteLevel; +use tokenizers::pre_tokenizers::random_chunk::RandomChunkSplit; use tokenizers::pre_tokenizers::whitespace::Whitespace; use tokenizers::tokenizer::{AddedToken, EncodeInput}; use tokenizers::Tokenizer; @@ -69,13 +70,14 @@ fn bench_gpt2(c: &mut Criterion) { } fn bench_train(c: &mut Criterion) { + // Standard Whitespace pre-tokenizer let mut trainer: TrainerWrapper = BpeTrainerBuilder::default() .show_progress(false) .build() .into(); let mut tokenizer = Tokenizer::new(BPE::default()).into_inner(); tokenizer.with_pre_tokenizer(Some(Whitespace {})); - c.bench_function("BPE Train vocabulary (small)", |b| { + c.bench_function("BPE Train vocabulary (small, whitespace)", |b| { b.iter_custom(|iters| { iter_bench_train( iters, @@ -86,9 +88,52 @@ fn bench_train(c: &mut Criterion) { }) }); + // RandomChunkSplit pre-tokenizer with small chunks + let mut tokenizer = Tokenizer::new(BPE::default()).into_inner(); + tokenizer.with_pre_tokenizer(Some(RandomChunkSplit::new(1, 3))); + c.bench_function("BPE Train vocabulary (small, random-1-3)", |b| { + b.iter_custom(|iters| { + iter_bench_train( + iters, + &mut tokenizer, + &mut trainer, + vec!["data/small.txt".to_string()], + ) + }) + }); + + // RandomChunkSplit pre-tokenizer with medium chunks + let mut tokenizer = Tokenizer::new(BPE::default()).into_inner(); + tokenizer.with_pre_tokenizer(Some(RandomChunkSplit::new(2, 5))); + c.bench_function("BPE Train vocabulary (small, random-2-5)", |b| { + b.iter_custom(|iters| { + iter_bench_train( + iters, + &mut tokenizer, + &mut trainer, + vec!["data/small.txt".to_string()], + ) + }) + }); + + // Big file benchmarks with whitespace pre-tokenizer let mut tokenizer = Tokenizer::new(BPE::default()).into_inner(); tokenizer.with_pre_tokenizer(Some(Whitespace {})); - c.bench_function("BPE Train vocabulary (big)", |b| { + c.bench_function("BPE Train vocabulary (big, whitespace)", |b| { + b.iter_custom(|iters| { + iter_bench_train( + iters, + &mut tokenizer, + &mut trainer, + vec!["data/big.txt".to_string()], + ) + }) + }); + + // RandomChunkSplit on big file + let mut tokenizer = Tokenizer::new(BPE::default()).into_inner(); + tokenizer.with_pre_tokenizer(Some(RandomChunkSplit::new(2, 5))); + c.bench_function("BPE Train vocabulary (big, random-2-5)", |b| { b.iter_custom(|iters| { iter_bench_train( iters, diff --git a/tokenizers/benches/random_chunk_benchmark.rs b/tokenizers/benches/random_chunk_benchmark.rs new file mode 100644 index 000000000..3ceacb267 --- /dev/null +++ b/tokenizers/benches/random_chunk_benchmark.rs @@ -0,0 +1,252 @@ +#[macro_use] +extern crate criterion; + +mod common; + +use std::fs::File; +use std::io::{BufRead, BufReader}; +use std::path::Path; +use std::time::{Duration, Instant}; + +use criterion::{black_box, Criterion}; +use tokenizers::models::bpe::{BpeTrainerBuilder, BPE}; +use tokenizers::models::TrainerWrapper; +use tokenizers::pre_tokenizers::random_chunk::RandomChunkSplit; +use tokenizers::pre_tokenizers::whitespace::WhitespaceSplit; +use tokenizers::tokenizer::{PreTokenizedString, PreTokenizer, TokenizerImpl}; +use tokenizers::{OffsetReferential, OffsetType, Tokenizer}; +use tokenizers::models::ModelWrapper; +use tokenizers::normalizers::NormalizerWrapper; +use tokenizers::pre_tokenizers::PreTokenizerWrapper; +use tokenizers::processors::PostProcessorWrapper; +use tokenizers::decoders::DecoderWrapper; + +use common::iter_bench_train; + +// Benchmark different RandomChunkSplit configurations for pre-tokenization +fn bench_pre_tokenize(c: &mut Criterion) { + // Load sample text for benchmarking + let mut text = String::new(); + for line in BufReader::new(File::open(Path::new("data/small.txt")).unwrap()) + .lines() + .take(100) + { + text.push_str(&line.unwrap()); + text.push(' '); + } + + // Create a benchmark group for pre-tokenization + let mut group = c.benchmark_group("PreTokenization"); + + // Benchmark WhitespaceSplit as baseline + let whitespace_split = WhitespaceSplit; + group.bench_function("WhitespaceSplit", |b| { + b.iter_custom(|iters| { + let mut duration = Duration::new(0, 0); + for _ in 0..iters { + let mut pretokenized = PreTokenizedString::from(text.clone()); + let start = Instant::now(); + let _ = black_box(whitespace_split.pre_tokenize(&mut pretokenized)); + duration = duration.checked_add(start.elapsed()).unwrap(); + } + duration + }); + }); + + // Benchmark different RandomChunkSplit configurations + let configs = vec![ + (1, 3, "RandomChunkSplit(1,3)"), + (1, 5, "RandomChunkSplit(1,5)"), + (1, 10, "RandomChunkSplit(1,10)"), + (2, 5, "RandomChunkSplit(2,5)"), + (3, 7, "RandomChunkSplit(3,7)"), + (5, 10, "RandomChunkSplit(5,10)"), + ]; + + for (min_len, max_len, name) in configs { + let random_chunk_split = RandomChunkSplit::new(min_len, max_len); + group.bench_function(name, |b| { + b.iter_custom(|iters| { + let mut duration = Duration::new(0, 0); + for _ in 0..iters { + let mut pretokenized = PreTokenizedString::from(text.clone()); + let start = Instant::now(); + let _ = black_box(random_chunk_split.pre_tokenize(&mut pretokenized)); + duration = duration.checked_add(start.elapsed()).unwrap(); + } + duration + }); + }); + } + + group.finish(); +} + +// Benchmark token statistics (average length, whitespace percentage, etc.) +fn bench_token_statistics(c: &mut Criterion) { + // Load sample text for analysis + let mut text = String::new(); + for line in BufReader::new(File::open(Path::new("data/small.txt")).unwrap()) + .lines() + .take(100) + { + text.push_str(&line.unwrap()); + text.push(' '); + } + + // Create a benchmark group for token statistics + let mut group = c.benchmark_group("TokenStatistics"); + group.sample_size(50); // Reduce number of samples for this analysis + + // Analyze WhitespaceSplit (baseline) + let whitespace_split = WhitespaceSplit; + group.bench_function("WhitespaceSplit_Stats", |b| { + b.iter_custom(|iters| { + let mut duration = Duration::new(0, 0); + for _ in 0..iters { + let mut pretokenized = PreTokenizedString::from(text.clone()); + whitespace_split.pre_tokenize(&mut pretokenized).unwrap(); + + let start = Instant::now(); + let splits = pretokenized.get_splits(OffsetReferential::Original, OffsetType::Char); + + // Calculate token statistics + let total_tokens = splits.len(); + let total_chars: usize = splits.iter().map(|(s, _, _)| s.chars().count()).sum(); + let _avg_token_length = if total_tokens > 0 { + total_chars as f64 / total_tokens as f64 + } else { + 0.0 + }; + + // Count tokens with whitespace + let tokens_with_whitespace = splits + .iter() + .filter(|(s, _, _)| s.contains(char::is_whitespace)) + .count(); + let _whitespace_percentage = if total_tokens > 0 { + tokens_with_whitespace as f64 / total_tokens as f64 * 100.0 + } else { + 0.0 + }; + + black_box((_avg_token_length, _whitespace_percentage)); + duration = duration.checked_add(start.elapsed()).unwrap(); + } + duration + }); + }); + + // Analyze different RandomChunkSplit configurations + let configs = vec![ + (1, 3, "RandomChunkSplit(1,3)_Stats"), + (1, 5, "RandomChunkSplit(1,5)_Stats"), + (2, 5, "RandomChunkSplit(2,5)_Stats"), + (5, 10, "RandomChunkSplit(5,10)_Stats"), + ]; + + for (min_len, max_len, name) in configs { + let random_chunk_split = RandomChunkSplit::new(min_len, max_len); + group.bench_function(name, |b| { + b.iter_custom(|iters| { + let mut duration = Duration::new(0, 0); + for _ in 0..iters { + let mut pretokenized = PreTokenizedString::from(text.clone()); + random_chunk_split.pre_tokenize(&mut pretokenized).unwrap(); + + let start = Instant::now(); + let splits = pretokenized.get_splits(OffsetReferential::Original, OffsetType::Char); + + // Calculate token statistics + let total_tokens = splits.len(); + let total_chars: usize = splits.iter().map(|(s, _, _)| s.chars().count()).sum(); + let _avg_token_length = if total_tokens > 0 { + total_chars as f64 / total_tokens as f64 + } else { + 0.0 + }; + + // Count tokens with whitespace + let tokens_with_whitespace = splits + .iter() + .filter(|(s, _, _)| s.contains(char::is_whitespace)) + .count(); + let _whitespace_percentage = if total_tokens > 0 { + tokens_with_whitespace as f64 / total_tokens as f64 * 100.0 + } else { + 0.0 + }; + + black_box((_avg_token_length, _whitespace_percentage)); + duration = duration.checked_add(start.elapsed()).unwrap(); + } + duration + }); + }); + } + + group.finish(); +} + +// Benchmark model training with different pre-tokenizers +fn bench_train_with_different_pretok(c: &mut Criterion) { + let mut group = c.benchmark_group("TrainingWithDifferentPreTokenizers"); + group.sample_size(10); // Training is expensive, use fewer samples + + type TokenizerType = TokenizerImpl; + + // Using enum to store different types + let configs = vec![ + ("WhitespaceSplit_Train", Box::new(|t: &mut TokenizerType| { + t.with_pre_tokenizer(Some(WhitespaceSplit {})); + }) as Box), + ("RandomChunkSplit(1,3)_Train", Box::new(|t: &mut TokenizerType| { + t.with_pre_tokenizer(Some(RandomChunkSplit::new(1, 3))); + })), + ("RandomChunkSplit(2,5)_Train", Box::new(|t: &mut TokenizerType| { + t.with_pre_tokenizer(Some(RandomChunkSplit::new(2, 5))); + })), + ("RandomChunkSplit(5,10)_Train", Box::new(|t: &mut TokenizerType| { + t.with_pre_tokenizer(Some(RandomChunkSplit::new(5, 10))); + })), + ]; + + for (name, setup_fn) in configs { + let mut trainer: TrainerWrapper = BpeTrainerBuilder::default() + .show_progress(false) + .vocab_size(1000) // Smaller vocab for benchmark + .min_frequency(2) + .build() + .into(); + + let mut tokenizer = Tokenizer::new(BPE::default()).into_inner(); + setup_fn(&mut tokenizer); + + group.bench_function(name, |b| { + b.iter_custom(|iters| { + iter_bench_train( + iters, + &mut tokenizer, + &mut trainer, + vec!["data/small.txt".to_string()], + ) + }); + }); + } + + group.finish(); +} + +criterion_group! { + name = random_chunk_benches; + config = Criterion::default().sample_size(30); + targets = bench_pre_tokenize, bench_token_statistics +} + +criterion_group! { + name = random_chunk_train_benches; + config = Criterion::default().sample_size(10); + targets = bench_train_with_different_pretok +} + +criterion_main!(random_chunk_benches, random_chunk_train_benches); \ No newline at end of file diff --git a/tokenizers/examples/random_chunk_split.rs b/tokenizers/examples/random_chunk_split.rs new file mode 100644 index 000000000..a2c606925 --- /dev/null +++ b/tokenizers/examples/random_chunk_split.rs @@ -0,0 +1,58 @@ +use tokenizers::models::bpe::BPE; +use tokenizers::pre_tokenizers::random_chunk::RandomChunkSplit; +use tokenizers::pre_tokenizers::whitespace::WhitespaceSplit; +use tokenizers::{OffsetReferential, OffsetType, PreTokenizedString, PreTokenizer, Tokenizer}; + +fn main() { + // The example text that contains multi-word expressions + let text = "We want to learn multi-word expressions like 'New York' or 'machine learning'"; + println!("Original text: {}", text); + + // Demonstrate how pre-tokenization works + println!("\n=== Pre-tokenization Demonstration ==="); + + // WhitespaceSplit pre-tokenization + let mut pretokenized_whitespace = PreTokenizedString::from(text); + let whitespace_split = WhitespaceSplit; + whitespace_split.pre_tokenize(&mut pretokenized_whitespace).unwrap(); + let splits_whitespace = pretokenized_whitespace + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s.to_owned(), o)) + .collect::>(); + println!("WhitespaceSplit splits:"); + for (token, offsets) in &splits_whitespace { + println!(" '{}' at position {:?}", token, offsets); + } + println!("Number of splits: {}", splits_whitespace.len()); + + // RandomChunkSplit pre-tokenization + let mut pretokenized_random = PreTokenizedString::from(text); + let random_chunk_split = RandomChunkSplit::new(2, 5); + random_chunk_split.pre_tokenize(&mut pretokenized_random).unwrap(); + let splits_random = pretokenized_random + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s.to_owned(), o)) + .collect::>(); + println!("\nRandomChunkSplit splits (min_length=2, max_length=5):"); + for (token, offsets) in &splits_random { + println!(" '{}' at position {:?}", token, offsets); + } + println!("Number of splits: {}", splits_random.len()); + + // Show a more extreme example with large chunks + let mut pretokenized_large = PreTokenizedString::from(text); + let large_chunk_split = RandomChunkSplit::new(10, 15); + large_chunk_split.pre_tokenize(&mut pretokenized_large).unwrap(); + let splits_large = pretokenized_large + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s.to_owned(), o)) + .collect::>(); + println!("\nRandomChunkSplit splits (min_length=10, max_length=15):"); + for (token, offsets) in &splits_large { + println!(" '{}' at position {:?}", token, offsets); + } + println!("Number of splits: {}", splits_large.len()); +} \ No newline at end of file diff --git a/tokenizers/src/models/bpe/word.rs b/tokenizers/src/models/bpe/word.rs index 93b3d9c37..781dba8ff 100644 --- a/tokenizers/src/models/bpe/word.rs +++ b/tokenizers/src/models/bpe/word.rs @@ -199,6 +199,7 @@ impl Word { // Make sure we are not processing an expired queue entry let target_new_pair = (self.symbols[top.pos].c, right.c); + use crate::utils::OptionExt; if merges .get(&target_new_pair) .is_none_or(|(_, new_id)| *new_id != top.new_id) diff --git a/tokenizers/src/pre_tokenizers/mod.rs b/tokenizers/src/pre_tokenizers/mod.rs index 6195d170b..50ef17b43 100644 --- a/tokenizers/src/pre_tokenizers/mod.rs +++ b/tokenizers/src/pre_tokenizers/mod.rs @@ -4,6 +4,8 @@ pub mod delimiter; pub mod digits; pub mod metaspace; pub mod punctuation; +pub mod random_chunk; +pub mod random_whitespace; pub mod sequence; pub mod split; pub mod unicode_scripts; @@ -17,6 +19,8 @@ use crate::pre_tokenizers::delimiter::CharDelimiterSplit; use crate::pre_tokenizers::digits::Digits; use crate::pre_tokenizers::metaspace::Metaspace; use crate::pre_tokenizers::punctuation::Punctuation; +use crate::pre_tokenizers::random_chunk::RandomChunkSplit; +use crate::pre_tokenizers::random_whitespace::RandomWhitespaceSplit; use crate::pre_tokenizers::sequence::Sequence; use crate::pre_tokenizers::split::Split; use crate::pre_tokenizers::unicode_scripts::UnicodeScripts; @@ -37,6 +41,8 @@ pub enum PreTokenizerWrapper { WhitespaceSplit(WhitespaceSplit), Digits(Digits), UnicodeScripts(UnicodeScripts), + RandomChunkSplit(RandomChunkSplit), + RandomWhitespaceSplit(RandomWhitespaceSplit), } impl PreTokenizer for PreTokenizerWrapper { @@ -53,6 +59,8 @@ impl PreTokenizer for PreTokenizerWrapper { Self::WhitespaceSplit(wspt) => wspt.pre_tokenize(normalized), Self::Digits(wspt) => wspt.pre_tokenize(normalized), Self::UnicodeScripts(us) => us.pre_tokenize(normalized), + Self::RandomChunkSplit(rcs) => rcs.pre_tokenize(normalized), + Self::RandomWhitespaceSplit(rws) => rws.pre_tokenize(normalized), } } } @@ -82,6 +90,8 @@ impl<'de> Deserialize<'de> for PreTokenizerWrapper { WhitespaceSplit, Digits, UnicodeScripts, + RandomChunkSplit, + RandomWhitespaceSplit, } #[derive(Deserialize)] @@ -105,6 +115,8 @@ impl<'de> Deserialize<'de> for PreTokenizerWrapper { WhitespaceSplit(WhitespaceSplit), Digits(Digits), UnicodeScripts(UnicodeScripts), + RandomChunkSplit(RandomChunkSplit), + RandomWhitespaceSplit(RandomWhitespaceSplit), } let helper = PreTokenizerHelper::deserialize(deserializer)?; @@ -152,6 +164,12 @@ impl<'de> Deserialize<'de> for PreTokenizerWrapper { EnumType::UnicodeScripts => PreTokenizerWrapper::UnicodeScripts( serde_json::from_value(values).map_err(serde::de::Error::custom)?, ), + EnumType::RandomChunkSplit => PreTokenizerWrapper::RandomChunkSplit( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), + EnumType::RandomWhitespaceSplit => PreTokenizerWrapper::RandomWhitespaceSplit( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), } } @@ -187,6 +205,12 @@ impl<'de> Deserialize<'de> for PreTokenizerWrapper { PreTokenizerUntagged::UnicodeScripts(unicode_scripts) => { PreTokenizerWrapper::UnicodeScripts(unicode_scripts) } + PreTokenizerUntagged::RandomChunkSplit(random_chunk_split) => { + PreTokenizerWrapper::RandomChunkSplit(random_chunk_split) + } + PreTokenizerUntagged::RandomWhitespaceSplit(random_whitespace_split) => { + PreTokenizerWrapper::RandomWhitespaceSplit(random_whitespace_split) + } } } }) @@ -204,6 +228,8 @@ impl_enum_from!(Metaspace, PreTokenizerWrapper, Metaspace); impl_enum_from!(WhitespaceSplit, PreTokenizerWrapper, WhitespaceSplit); impl_enum_from!(Digits, PreTokenizerWrapper, Digits); impl_enum_from!(UnicodeScripts, PreTokenizerWrapper, UnicodeScripts); +impl_enum_from!(RandomChunkSplit, PreTokenizerWrapper, RandomChunkSplit); +impl_enum_from!(RandomWhitespaceSplit, PreTokenizerWrapper, RandomWhitespaceSplit); #[cfg(test)] mod tests { @@ -280,6 +306,26 @@ mod tests { PreTokenizerWrapper::WhitespaceSplit(WhitespaceSplit {}) ); } + + #[test] + fn test_deserialize_random_chunk_split() { + let pre_tokenizer: PreTokenizerWrapper = + serde_json::from_str(r#"{"type":"RandomChunkSplit","min_length":2,"max_length":5}"#).unwrap(); + assert_eq!( + pre_tokenizer, + PreTokenizerWrapper::RandomChunkSplit(RandomChunkSplit::new(2, 5)) + ); + } + + #[test] + fn test_deserialize_random_whitespace_split() { + let pre_tokenizer: PreTokenizerWrapper = + serde_json::from_str(r#"{"type":"RandomWhitespaceSplit","split_probability":0.7}"#).unwrap(); + assert_eq!( + pre_tokenizer, + PreTokenizerWrapper::RandomWhitespaceSplit(RandomWhitespaceSplit::new(0.7)) + ); + } #[test] fn pre_tokenizer_deserialization_no_type() { diff --git a/tokenizers/src/pre_tokenizers/random_chunk.rs b/tokenizers/src/pre_tokenizers/random_chunk.rs new file mode 100644 index 000000000..b151d1808 --- /dev/null +++ b/tokenizers/src/pre_tokenizers/random_chunk.rs @@ -0,0 +1,281 @@ +use rand::Rng; +use serde::{Deserialize, Serialize}; + +use crate::tokenizer::{pattern::Pattern, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; + +/// Pre-tokenizes text by splitting it into random-length chunks. +/// +/// This allows tokenization across traditional whitespace boundaries, enabling BPE +/// to learn multi-word expressions as a single token. +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(tag = "type")] +pub struct RandomChunkSplit { + /// Minimum chunk length (in characters) + pub min_length: usize, + /// Maximum chunk length (in characters) + pub max_length: usize, + /// When true, uses deterministic behavior instead of random chunks (for inference) + #[serde(default)] + pub deterministic: bool, +} + +impl RandomChunkSplit { + /// Create a new `RandomChunkSplit` with the given min and max chunk lengths. + pub fn new(min_length: usize, max_length: usize) -> Self { + // Ensure min_length and max_length are valid + let min_length = min_length.max(1); + let max_length = max_length.max(min_length); + + Self { + min_length, + max_length, + deterministic: false, + } + } + + /// Sets the deterministic mode for inference + /// + /// When deterministic is true, chunks of fixed length will be created + /// instead of random-length chunks. This provides consistent tokenization + /// at inference time. + pub fn with_deterministic(mut self, deterministic: bool) -> Self { + self.deterministic = deterministic; + self + } +} + +/// Split pattern that creates chunks of random or fixed lengths +struct RandomChunkPattern<'a> { + min_length: usize, + max_length: usize, + deterministic: bool, + chars: &'a [char], + current_pos: usize, +} + +impl<'a> RandomChunkPattern<'a> { + fn new(chars: &'a [char], min_length: usize, max_length: usize, deterministic: bool) -> Self { + Self { + min_length, + max_length, + deterministic, + chars, + current_pos: 0, + } + } +} + +impl<'a> Pattern for RandomChunkPattern<'a> { + fn find_matches(&self, _text: &str) -> Result> { + let mut result = Vec::new(); + let mut current_pos = self.current_pos; + let chars = self.chars; + let mut char_start_byte = 0; + + // Get byte offset of current_pos + for i in 0..current_pos { + char_start_byte += chars[i].len_utf8(); + } + + while current_pos < chars.len() { + // Calculate remaining characters + let remaining = chars.len() - current_pos; + + // Calculate effective max length (limited by remaining chars) + let effective_max = self.max_length.min(remaining); + + // If we can't satisfy minimum length, just take all remaining chars + if effective_max < self.min_length { + let chunk_len = remaining; + let mut chunk_bytes = 0; + for i in 0..chunk_len { + chunk_bytes += chars[current_pos + i].len_utf8(); + } + + result.push(((char_start_byte, char_start_byte + chunk_bytes), false)); + break; + } + + // Choose chunk length based on deterministic mode + let chunk_len = if self.deterministic { + // In deterministic mode, use a consistent length + // For inference, we use the average of min and max length for consistency + // This gives a predictable behavior while still allowing multi-word tokens + let avg_length = (self.min_length + self.max_length) / 2; + avg_length.min(effective_max) + } else if self.min_length == effective_max { + self.min_length + } else { + // In random mode, generate a random length + let mut rng = rand::thread_rng(); + rng.gen_range(self.min_length..=effective_max) + }; + + // Calculate byte length of this chunk + let mut chunk_bytes = 0; + for i in 0..chunk_len { + chunk_bytes += chars[current_pos + i].len_utf8(); + } + + // Add segment + result.push(((char_start_byte, char_start_byte + chunk_bytes), false)); + + // Update positions + current_pos += chunk_len; + char_start_byte += chunk_bytes; + } + + Ok(result) + } +} + +impl PreTokenizer for RandomChunkSplit { + fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> { + let deterministic = self.deterministic; + pretokenized.split(|_, normalized| { + let chars: Vec = normalized.get().chars().collect(); + let pattern = RandomChunkPattern::new(&chars, self.min_length, self.max_length, deterministic); + normalized.split(pattern, SplitDelimiterBehavior::Isolated) + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{OffsetReferential, OffsetType}; + + #[test] + fn test_empty_string() { + let pretok = RandomChunkSplit::new(1, 5); + let mut pretokenized = PreTokenizedString::from(""); + pretok.pre_tokenize(&mut pretokenized).unwrap(); + + let splits = pretokenized.get_splits(OffsetReferential::Original, OffsetType::Byte); + assert_eq!(splits.len(), 0); + } + + #[test] + fn test_deterministic_chunks() { + // With min_length = max_length, the chunking should be deterministic + let pretok = RandomChunkSplit::new(3, 3); + let s = "Hello world!"; + let mut pretokenized = PreTokenizedString::from(s); + pretok.pre_tokenize(&mut pretokenized).unwrap(); + + let splits = pretokenized + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s, o)) + .collect::>(); + + assert_eq!( + splits, + vec![ + ("Hel", (0, 3)), + ("lo ", (3, 6)), + ("wor", (6, 9)), + ("ld!", (9, 12)), + ] + ); + } + + #[test] + fn test_unicode_handling() { + // Ensure proper handling of multi-byte Unicode characters + let pretok = RandomChunkSplit::new(1, 1); + let s = "こんにちは"; // "Hello" in Japanese + let mut pretokenized = PreTokenizedString::from(s); + pretok.pre_tokenize(&mut pretokenized).unwrap(); + + let splits = pretokenized + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s, o)) + .collect::>(); + + assert_eq!(splits.len(), 5); // 5 characters + assert_eq!(splits[0].0, "こ"); + assert_eq!(splits[1].0, "ん"); + assert_eq!(splits[2].0, "に"); + assert_eq!(splits[3].0, "ち"); + assert_eq!(splits[4].0, "は"); + } + + #[test] + fn test_min_max_validation() { + // If min > max, it should be corrected + let pretok = RandomChunkSplit::new(5, 3); + assert_eq!(pretok.min_length, 5); + assert_eq!(pretok.max_length, 5); + + // Min can't be 0 + let pretok = RandomChunkSplit::new(0, 5); + assert_eq!(pretok.min_length, 1); + assert_eq!(pretok.max_length, 5); + } + + #[test] + fn test_random_chunks() { + // Test with a range of chunk sizes + let pretok = RandomChunkSplit::new(1, 5); + let s = "The quick brown fox jumps over the lazy dog."; + let mut pretokenized = PreTokenizedString::from(s); + pretok.pre_tokenize(&mut pretokenized).unwrap(); + + let splits = pretokenized + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s, o)) + .collect::>(); + + // Ensure all characters are accounted for + let joined: String = splits.iter().map(|(s, _)| s.to_string()).collect(); + assert_eq!(joined, s); + + // Verify that each chunk has a length within the specified range + for (chunk, _) in &splits { + let chunk_chars = chunk.chars().count(); + assert!(chunk_chars >= 1 && chunk_chars <= 5); + } + } + + #[test] + fn test_deterministic_mode() { + // Test with deterministic mode enabled + let s = "The quick brown fox jumps over the lazy dog."; + + // Create a pre-tokenizer with deterministic mode enabled + let deterministic_pretok = RandomChunkSplit::new(2, 6).with_deterministic(true); + let mut deterministic_pretokenized1 = PreTokenizedString::from(s); + deterministic_pretok.pre_tokenize(&mut deterministic_pretokenized1).unwrap(); + + // Run it again to verify consistency + let mut deterministic_pretokenized2 = PreTokenizedString::from(s); + deterministic_pretok.pre_tokenize(&mut deterministic_pretokenized2).unwrap(); + + // Get the splits from both runs + let splits1 = deterministic_pretokenized1 + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s, o)) + .collect::>(); + + let splits2 = deterministic_pretokenized2 + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s, o)) + .collect::>(); + + // The two runs should produce identical results in deterministic mode + assert_eq!(splits1, splits2); + + // Each chunk should have the average length of min and max (4 in this case) + // unless we hit the end of the string + let avg_length = (deterministic_pretok.min_length + deterministic_pretok.max_length) / 2; + for (chunk, _) in splits1.iter().take(splits1.len() - 1) { // Skip the last chunk + let chunk_chars = chunk.chars().count(); + assert_eq!(chunk_chars, avg_length); + } + } +} \ No newline at end of file diff --git a/tokenizers/src/pre_tokenizers/random_whitespace.rs b/tokenizers/src/pre_tokenizers/random_whitespace.rs new file mode 100644 index 000000000..3d866b65b --- /dev/null +++ b/tokenizers/src/pre_tokenizers/random_whitespace.rs @@ -0,0 +1,226 @@ +use rand::Rng; +use serde::{Deserialize, Serialize}; + +use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; + +/// Pre-tokenizes text by deciding at each whitespace character whether to split +/// or continue based on a given probability. +/// +/// This pre-tokenizer is similar to `WhitespaceSplit` but randomly decides for +/// each whitespace character whether to split at that position. This allows +/// tokenization to occasionally span across whitespace boundaries, enabling BPE +/// to learn multi-word expressions as single tokens. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[serde(tag = "type")] +pub struct RandomWhitespaceSplit { + /// Probability of splitting at each whitespace character (0.0-1.0) + pub split_probability: f32, + /// When true, uses deterministic behavior instead of random decisions (for inference) + #[serde(default)] + pub deterministic: bool, +} + +impl RandomWhitespaceSplit { + /// Create a new `RandomWhitespaceSplit` with the given probability. + /// + /// The `split_probability` determines how likely the tokenizer is to split + /// at each whitespace character. Higher values (closer to 1.0) make the behavior + /// more similar to traditional `WhitespaceSplit`, while lower values encourage + /// more multi-word tokens. + pub fn new(split_probability: f32) -> Self { + // Ensure probability is within valid range + let split_probability = split_probability.min(1.0).max(0.0); + + Self { + split_probability, + deterministic: false, + } + } + + /// Sets the deterministic mode for inference + /// + /// When deterministic is true, a consistent behavior will be used for + /// whitespace splitting. In deterministic mode: + /// - If split_probability > 0.5, all whitespace is split (like WhitespaceSplit) + /// - If split_probability <= 0.5, no whitespace is split (preserving multi-word tokens) + /// + /// This provides consistent tokenization at inference time. + pub fn with_deterministic(mut self, deterministic: bool) -> Self { + self.deterministic = deterministic; + self + } +} + +impl PreTokenizer for RandomWhitespaceSplit { + fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> { + // This function captures the split probability and deterministic mode + let split_prob = self.split_probability; + let deterministic = self.deterministic; + + pretokenized.split(|_, normalized| { + // Create a pattern closure that decides whether to split at whitespace + let whitespace_pattern = |c: char| { + if c.is_whitespace() { + if deterministic { + // In deterministic mode, make a consistent decision based on the probability + // If split_prob > 0.5, always split (like WhitespaceSplit) + // If split_prob <= 0.5, never split (preserve multi-word tokens) + split_prob > 0.5 + } else { + // In random mode, use randomness to decide + let mut rng = rand::thread_rng(); + rng.gen::() < split_prob + } + } else { + false + } + }; + + // Use the pattern with the normalized string + normalized.split(whitespace_pattern, SplitDelimiterBehavior::Isolated) + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{OffsetReferential, OffsetType}; + + #[test] + fn test_empty_string() { + let pretok = RandomWhitespaceSplit::new(0.5); + let mut pretokenized = PreTokenizedString::from(""); + pretok.pre_tokenize(&mut pretokenized).unwrap(); + + let splits = pretokenized.get_splits(OffsetReferential::Original, OffsetType::Byte); + assert_eq!(splits.len(), 0); + } + + #[test] + fn test_full_split_probability() { + // With split_probability = 1.0, should behave like WhitespaceSplit + let pretok = RandomWhitespaceSplit::new(1.0); + let s = "Hello world!"; + let mut pretokenized = PreTokenizedString::from(s); + pretok.pre_tokenize(&mut pretokenized).unwrap(); + + let splits = pretokenized + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s, o)) + .collect::>(); + + assert_eq!( + splits, + vec![ + ("Hello", (0, 5)), + (" ", (5, 6)), + ("world!", (6, 12)), + ] + ); + } + + #[test] + fn test_zero_split_probability() { + // With split_probability = 0.0, should not split at all + let pretok = RandomWhitespaceSplit::new(0.0); + let s = "Hello world!"; + let mut pretokenized = PreTokenizedString::from(s); + pretok.pre_tokenize(&mut pretokenized).unwrap(); + + let splits = pretokenized + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s, o)) + .collect::>(); + + assert_eq!( + splits, + vec![ + ("Hello world!", (0, 12)), + ] + ); + } + + #[test] + fn test_multiple_whitespaces() { + // Test with multiple whitespaces and full split probability + let pretok = RandomWhitespaceSplit::new(1.0); + let s = "Hello world!\nTest"; + let mut pretokenized = PreTokenizedString::from(s); + pretok.pre_tokenize(&mut pretokenized).unwrap(); + + let splits = pretokenized + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s, o)) + .collect::>(); + + assert_eq!( + splits, + vec![ + ("Hello", (0, 5)), + (" ", (5, 6)), + (" ", (6, 7)), + ("world!", (7, 13)), + ("\n", (13, 14)), + ("Test", (14, 18)), + ] + ); + } + + #[test] + fn test_deterministic_mode() { + let s = "Hello world! How are you?"; + + // Test with high probability (>0.5) in deterministic mode + // Should behave like WhitespaceSplit (all whitespace split) + let high_prob_pretok = RandomWhitespaceSplit::new(0.7).with_deterministic(true); + let mut high_prob_pretokenized1 = PreTokenizedString::from(s); + high_prob_pretok.pre_tokenize(&mut high_prob_pretokenized1).unwrap(); + + // Run it again to verify consistency + let mut high_prob_pretokenized2 = PreTokenizedString::from(s); + high_prob_pretok.pre_tokenize(&mut high_prob_pretokenized2).unwrap(); + + // Get the splits from both runs + let high_prob_splits1 = high_prob_pretokenized1 + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s, o)) + .collect::>(); + + let high_prob_splits2 = high_prob_pretokenized2 + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s, o)) + .collect::>(); + + // The two runs should produce identical results in deterministic mode + assert_eq!(high_prob_splits1, high_prob_splits2); + + // With high probability, should behave like WhitespaceSplit (everything split) + // Expected: ["Hello", " ", "world!", " ", "How", " ", "are", " ", "you?"] + assert_eq!(high_prob_splits1.len(), 9); + assert_eq!(high_prob_splits1[0].0, "Hello"); + assert_eq!(high_prob_splits1[1].0, " "); + assert_eq!(high_prob_splits1[2].0, "world!"); + + // Test with low probability (<=0.5) in deterministic mode + // Should never split on whitespace + let low_prob_pretok = RandomWhitespaceSplit::new(0.3).with_deterministic(true); + let mut low_prob_pretokenized = PreTokenizedString::from(s); + low_prob_pretok.pre_tokenize(&mut low_prob_pretokenized).unwrap(); + + let low_prob_splits = low_prob_pretokenized + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s, o)) + .collect::>(); + + // With low probability, no whitespace should be split at all + assert_eq!(low_prob_splits.len(), 1); + assert_eq!(low_prob_splits[0].0, s); + } +} \ No newline at end of file diff --git a/tokenizers/src/processors/template.rs b/tokenizers/src/processors/template.rs index 6c9cf9a74..9611439fc 100644 --- a/tokenizers/src/processors/template.rs +++ b/tokenizers/src/processors/template.rs @@ -466,6 +466,7 @@ impl TemplateProcessingBuilder { } fn validate(&self) -> std::result::Result<(), String> { + use crate::utils::OptionExt; let pair_has_both = self.pair.as_ref().is_none_or(|pair| { let mut has_a = false; let mut has_b = false; diff --git a/tokenizers/src/utils/mod.rs b/tokenizers/src/utils/mod.rs index deda862eb..e46d49aa1 100644 --- a/tokenizers/src/utils/mod.rs +++ b/tokenizers/src/utils/mod.rs @@ -216,3 +216,23 @@ macro_rules! impl_serde_type{ // Re-export macro_rules_attribute pub use macro_rules_attribute::macro_rules_attribute; + +/// Extension trait for Option to add the is_none_or method +pub trait OptionExt { + /// Returns true if the option is None or if the predicate returns true for the contained value + fn is_none_or(&self, predicate: F) -> bool + where + F: FnOnce(&T) -> bool; +} + +impl OptionExt for Option { + fn is_none_or(&self, predicate: F) -> bool + where + F: FnOnce(&T) -> bool, + { + match self { + None => true, + Some(val) => predicate(val), + } + } +} diff --git a/tokenizers/tests/training.rs b/tokenizers/tests/training.rs index 37127d5ac..2e2c3e631 100644 --- a/tokenizers/tests/training.rs +++ b/tokenizers/tests/training.rs @@ -1,5 +1,7 @@ use tokenizers::models::bpe::BPE; use tokenizers::pre_tokenizers::whitespace::Whitespace; +use tokenizers::pre_tokenizers::random_chunk::RandomChunkSplit; +use tokenizers::pre_tokenizers::random_whitespace::RandomWhitespaceSplit; use tokenizers::{DecoderWrapper, NormalizerWrapper, PostProcessorWrapper, PreTokenizerWrapper}; use tokenizers::{Model, Tokenizer, TokenizerBuilder}; @@ -58,3 +60,138 @@ fn bpe_continuing_subword_prefix_error() { std::fs::remove_file("tokenizer.json").unwrap(); } + +#[test] +fn bpe_training_with_random_chunk_split() { + let mut tokenizer = TokenizerBuilder::< + BPE, + NormalizerWrapper, + PreTokenizerWrapper, + PostProcessorWrapper, + DecoderWrapper, + >::default() + .with_model( + BPE::builder() + .unk_token("[UNK]".to_string()) + .build() + .unwrap(), + ) + .with_pre_tokenizer(Some(PreTokenizerWrapper::RandomChunkSplit( + RandomChunkSplit::new(2, 5) + ))) + .build() + .unwrap(); + + let mut trainer = tokenizer.get_model().get_trainer(); + tokenizer + .train_from_files(&mut trainer, vec!["./data/small.txt".to_string()]) + .unwrap(); + + // Save and reload the tokenizer to test serialization + tokenizer.save("random_chunk_tokenizer.json", true).unwrap(); + let tokenizer = Tokenizer::from_file("random_chunk_tokenizer.json").unwrap(); + + // Verify the model works by encoding a text + let encoding = tokenizer.encode("Hello world", false).unwrap(); + assert!(encoding.get_tokens().len() > 0); + + std::fs::remove_file("random_chunk_tokenizer.json").unwrap(); +} + +#[test] +fn bpe_training_with_random_whitespace_split() { + let mut tokenizer = TokenizerBuilder::< + BPE, + NormalizerWrapper, + PreTokenizerWrapper, + PostProcessorWrapper, + DecoderWrapper, + >::default() + .with_model( + BPE::builder() + .unk_token("[UNK]".to_string()) + .build() + .unwrap(), + ) + .with_pre_tokenizer(Some(PreTokenizerWrapper::RandomWhitespaceSplit( + RandomWhitespaceSplit::new(0.3) + ))) + .build() + .unwrap(); + + let mut trainer = tokenizer.get_model().get_trainer(); + tokenizer + .train_from_files(&mut trainer, vec!["./data/small.txt".to_string()]) + .unwrap(); + + // Save and reload the tokenizer to test serialization + tokenizer.save("random_whitespace_tokenizer.json", true).unwrap(); + let tokenizer = Tokenizer::from_file("random_whitespace_tokenizer.json").unwrap(); + + // Verify the model works by encoding a text + let encoding = tokenizer.encode("Hello world", false).unwrap(); + assert!(encoding.get_tokens().len() > 0); + + std::fs::remove_file("random_whitespace_tokenizer.json").unwrap(); +} + +#[test] +fn bpe_training_with_deterministic_random_pretokenizers() { + // Test with RandomChunkSplit in deterministic mode + let mut tokenizer1 = TokenizerBuilder::< + BPE, + NormalizerWrapper, + PreTokenizerWrapper, + PostProcessorWrapper, + DecoderWrapper, + >::default() + .with_model( + BPE::builder() + .unk_token("[UNK]".to_string()) + .build() + .unwrap(), + ) + .with_pre_tokenizer(Some(PreTokenizerWrapper::RandomChunkSplit( + RandomChunkSplit::new(2, 4).with_deterministic(true) + ))) + .build() + .unwrap(); + + let mut trainer1 = tokenizer1.get_model().get_trainer(); + tokenizer1 + .train_from_files(&mut trainer1, vec!["./data/small.txt".to_string()]) + .unwrap(); + + // Test with RandomWhitespaceSplit in deterministic mode + let mut tokenizer2 = TokenizerBuilder::< + BPE, + NormalizerWrapper, + PreTokenizerWrapper, + PostProcessorWrapper, + DecoderWrapper, + >::default() + .with_model( + BPE::builder() + .unk_token("[UNK]".to_string()) + .build() + .unwrap(), + ) + .with_pre_tokenizer(Some(PreTokenizerWrapper::RandomWhitespaceSplit( + RandomWhitespaceSplit::new(0.3).with_deterministic(true) + ))) + .build() + .unwrap(); + + let mut trainer2 = tokenizer2.get_model().get_trainer(); + tokenizer2 + .train_from_files(&mut trainer2, vec!["./data/small.txt".to_string()]) + .unwrap(); + + // Encode the same text with both tokenizers to verify they work + let sample_text = "Hello world, this is a test for multi-word tokenization"; + let encoding1 = tokenizer1.encode(sample_text, false).unwrap(); + let encoding2 = tokenizer2.encode(sample_text, false).unwrap(); + + assert!(encoding1.get_tokens().len() > 0); + assert!(encoding2.get_tokens().len() > 0); +}