Skip to content

Commit

Permalink
sentence piece tokenizer support for TokenizerInfo (#120)
Browse files Browse the repository at this point in the history
  • Loading branch information
zanderjiang authored Dec 22, 2024
1 parent 1509eac commit 8bc7b1d
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 12 deletions.
3 changes: 2 additions & 1 deletion cpp/tokenizer_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,8 @@ class HFTokenizerAnalyzer {

bool TokenizerInfo::Impl::IsSpecialToken(const std::string& token) {
// gemma treats [@BOS@] as a special token
return (token[0] == '<' && token.back() == '>' && token.size() >= 3) || token == "[@BOS@]";
return (token[0] == '<' && token.back() == '>' && token.size() >= 3) || token == "[@BOS@]" ||
token == "";
}

TokenizerInfo::Impl::Impl(
Expand Down
2 changes: 2 additions & 0 deletions python/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
pybind11
pydantic
pytest
sentencepiece
tiktoken
torch
transformers
91 changes: 82 additions & 9 deletions python/xgrammar/tokenizer_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from enum import Enum
from typing import List, Optional, Union

import sentencepiece
import tiktoken
from transformers import PreTrainedTokenizerBase, PreTrainedTokenizerFast

from .base import XGRObject, _core
Expand Down Expand Up @@ -94,6 +96,35 @@ def __init__(
)
)

@staticmethod
def _is_tiktoken_tokenizer(tokenizer: PreTrainedTokenizerBase) -> bool:
# helper to check if tokenizer is a tiktoken tokenizer
has_tiktoken_encoding = hasattr(tokenizer, "tokenizer") and isinstance(
tokenizer.tokenizer, tiktoken.Encoding
)

filename_pattern = (
"vocab_file" in tokenizer.vocab_files_names
and "tiktoken" in tokenizer.vocab_files_names["vocab_file"]
)

return has_tiktoken_encoding or filename_pattern

@staticmethod
def _is_sentencepiece_tokenizer(tokenizer: PreTrainedTokenizerBase) -> bool:
# helper to check if tokenizer is a sentence piece tokenizer
has_sp_model_attr = hasattr(tokenizer, "sp_model") and isinstance(
tokenizer.sp_model, sentencepiece.SentencePieceProcessor
)

has_nested_sp_model_attr = (
hasattr(tokenizer, "tokenizer")
and hasattr(tokenizer.tokenizer, "sp_model")
and isinstance(tokenizer.tokenizer.sp_model, sentencepiece.SentencePieceProcessor)
)

return has_sp_model_attr or has_nested_sp_model_attr

@staticmethod
def from_huggingface(
tokenizer: PreTrainedTokenizerBase,
Expand Down Expand Up @@ -139,17 +170,28 @@ def from_huggingface(
raise ValueError("stop_token_ids cannot be empty")

try:
encoded_vocab = tokenizer.get_vocab()
encoded_vocab = [
token for token, _ in sorted(encoded_vocab.items(), key=lambda x: x[1])
]
vocab_dict = tokenizer.get_vocab()
except AttributeError as e:
msg = (
f"Cannot get the vocabulary of the tokenizer {type(tokenizer)}. The tokenizer "
"should have a get_vocab method."
)
raise ValueError(msg) from e

max_id = max(vocab_dict.values()) if vocab_dict else -1
detected_vocab_size = max(len(vocab_dict), max_id + 1)
if vocab_size is None:
vocab_size = detected_vocab_size
else:
if vocab_size < detected_vocab_size:
msg = f"Input vocab_size less than minimum viable vocab size for tokenizer {type(tokenizer)}."
raise ValueError(msg)

# maintain tokenizer's indexing
encoded_vocab = ["" for _ in range(vocab_size)]
for token, idx in vocab_dict.items():
encoded_vocab[idx] = token

if isinstance(tokenizer, PreTrainedTokenizerFast):
# huggingface fast tokenizer
# - the vocabulary is directly obtained from tokenizer.get_vocab()
Expand All @@ -174,10 +216,7 @@ def from_huggingface(
encoded_vocab, backend_str, vocab_size, stop_token_ids
)
)
elif (
"vocab_file" in tokenizer.vocab_files_names
and "tiktoken" in tokenizer.vocab_files_names["vocab_file"]
):
elif TokenizerInfo._is_tiktoken_tokenizer(tokenizer):
# tiktoken tokenizer
# e.g. Phi-3-small-8k-instruct, Qwen-7B-Chat, stablelm-2-12b-chat (previously)
if stop_token_ids is None:
Expand All @@ -196,8 +235,42 @@ def from_huggingface(
stop_token_ids=stop_token_ids,
prepend_space_in_tokenization=False,
)
elif TokenizerInfo._is_sentencepiece_tokenizer(tokenizer):
# sentencepiece tokenizer
# e.g. Chatglm3-6b
if hasattr(tokenizer, "sp_model"):
sp_model = tokenizer.sp_model
elif hasattr(tokenizer, "tokenizer") and hasattr(tokenizer.tokenizer, "sp_model"):
sp_model = tokenizer.tokenizer.sp_model

if stop_token_ids is None:
if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None:
stop_token_ids = [tokenizer.eos_token_id]
else:
eos_id = sp_model.eos_id()
if eos_id != -1:
stop_token_ids = [eos_id]
else:
logger.warning(
"When constructing TokenizerInfo from a huggingface tokenizer, "
"stop_token_ids is neither provided by user nor found from the tokenizer. "
"It will be automatically detected."
)
# detect vocab_type of tokenizer
if "<0x0A>" in vocab_dict:
vocab_type = VocabType.BYTE_FALLBACK
else:
vocab_type = VocabType.RAW

return TokenizerInfo(
encoded_vocab,
vocab_type=vocab_type,
vocab_size=vocab_size,
stop_token_ids=stop_token_ids,
prepend_space_in_tokenization=True,
)
else:
# TODO(yixin): sentencepiece tokenizer
# TODO(yixin): unsupported tokenizer
raise ValueError(f"Unsupported tokenizer type: {type(tokenizer)}")

@property
Expand Down
38 changes: 36 additions & 2 deletions tests/python/test_tokenizer_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def tokenizer_info_storage() -> Dict[str, Tuple[PreTrainedTokenizerBase, xgr.Tok
("Qwen/Qwen2.5-1.5B", xgr.VocabType.BYTE_LEVEL, False),
("internlm/internlm2_5-7b-chat", xgr.VocabType.BYTE_FALLBACK, False),
("mistralai/Mixtral-8x22B-Instruct-v0.1", xgr.VocabType.BYTE_FALLBACK, True),
("THUDM/glm-4-9b-chat", xgr.VocabType.RAW, False),
("THUDM/chatglm3-6b", xgr.VocabType.BYTE_FALLBACK, True),
]

tokenizer_paths = [path for path, *_ in tokenizer_paths_metadata]
Expand Down Expand Up @@ -72,7 +74,9 @@ def test_properties(
tokenizer_info_storage: Dict[str, Tuple[PreTrainedTokenizerBase, xgr.TokenizerInfo]],
):
tokenizer, tokenizer_info = tokenizer_info_storage[tokenizer_path]
assert tokenizer_info.vocab_size == len(tokenizer.get_vocab())
vocab_dict = tokenizer.get_vocab()
max_id = max(vocab_dict.values()) if vocab_dict else -1
assert tokenizer_info.vocab_size == max(len(vocab_dict), max_id + 1)
assert tokenizer_info.vocab_type == vocab_type
assert tokenizer_info.prepend_space_in_tokenization == prepend_space_in_tokenization

Expand All @@ -84,9 +88,11 @@ def test_decoded_vocab(
):
tokenizer, tokenizer_info = tokenizer_info_storage[tokenizer_path]
decoded_vocab = tokenizer_info.decoded_vocab
vocab_dict = tokenizer.get_vocab()
max_id = max(vocab_dict.values()) if vocab_dict else -1
assert isinstance(decoded_vocab, list)
assert all(isinstance(token, bytes) for token in decoded_vocab)
assert len(decoded_vocab) == len(tokenizer.get_vocab())
assert len(decoded_vocab) == max(len(vocab_dict), max_id + 1)
assert len(decoded_vocab) == tokenizer_info.vocab_size


Expand Down Expand Up @@ -217,5 +223,33 @@ def test_customized_tokenizer_info(tokenizer_path: str):
assert tokenizer_info.special_token_ids[-5:] == [original_vocab_size + i for i in range(5)]


@pytest.mark.parametrize("tokenizer_path", ["meta-llama/Llama-2-7b-chat-hf"])
def test_special_token_detection(
tokenizer_path: str,
tokenizer_info_storage: Dict[str, Tuple[PreTrainedTokenizerBase, xgr.TokenizerInfo]],
):
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path,
use_fast=True,
trust_remote_code=True,
)
vocab_dict = {
"": 0,
"<s>": 1,
"</s>": 2,
"[@BOS@]": 3,
"regular": 4,
"<test_token>": 5,
"not<special>": 6,
"<>": 7,
}
tokenizer_info = xgr.TokenizerInfo.from_vocab_and_metadata(
list(vocab_dict.keys()),
'{"vocab_type":"BYTE_FALLBACK","vocab_size":8,"prepend_space_in_tokenization":true,"stop_token_ids":[2]}',
)
expected_special_tokens = {0, 1, 2, 3, 5}
assert set(tokenizer_info.special_token_ids) == expected_special_tokens


if __name__ == "__main__":
pytest.main(sys.argv)

0 comments on commit 8bc7b1d

Please sign in to comment.