diff --git a/cpp/grammar_matcher.cc b/cpp/grammar_matcher.cc index 44187173..33a0e722 100644 --- a/cpp/grammar_matcher.cc +++ b/cpp/grammar_matcher.cc @@ -8,8 +8,12 @@ #include +#include #include +#include +#include #include +#include #include #include "compiled_grammar_impl.h" @@ -19,6 +23,7 @@ #include "support/encoding.h" #include "support/int_set.h" #include "support/logging.h" +#include "support/thread_pool.h" #include "testing.h" namespace xgrammar { @@ -343,6 +348,52 @@ class GrammarMatcher::Impl : public EarleyParser { std::vector tmp_rejected_indices_delta_; }; +class BatchGrammarMatcher::Impl { + public: + Impl(std::variant max_threads) { + if (std::holds_alternative(max_threads)) { + int32_t num_threads = std::get(max_threads); + XGRAMMAR_CHECK(num_threads >= 1) + << "The num_threads should be at least 1, but got " << num_threads; + if (num_threads > 1) { + if (num_threads > static_cast(std::thread::hardware_concurrency())) { + XGRAMMAR_LOG(WARNING) << "The num_threads " << num_threads << " is larger than the " + << "number of hardware threads. Using " + << static_cast(std::thread::hardware_concurrency()) + << " instead."; + } + max_threads_ = + std::min(num_threads, static_cast(std::thread::hardware_concurrency())); + } + } else { + std::string str = std::get(max_threads); + XGRAMMAR_CHECK(str == "auto"); + max_threads_ = std::thread::hardware_concurrency() / 2; + } + } + + void BatchFillNextTokenBitmask( + std::vector* matchers, + DLTensor* next_token_bitmask, + const std::optional>& indices, + bool debug_print + ); + + static std::vector BatchAcceptToken( + std::vector* matchers, const std::vector& token_ids, bool debug_print + ); + + static std::vector BatchAcceptString( + std::vector* matchers, + const std::vector& input_strs, + bool debug_print + ); + + private: + std::optional thread_pool_ = std::nullopt; + int32_t max_threads_ = 1; +}; + bool GrammarMatcher::Impl::AcceptStopToken() { if (terminate_without_stop_token_) { return false; @@ -852,6 +903,75 @@ int GrammarMatcher::Impl::GetNextUncertainToken( } } +void BatchGrammarMatcher::Impl::BatchFillNextTokenBitmask( + std::vector* matchers, + DLTensor* next_token_bitmask, + const std::optional>& indices, + bool debug_print +) { + XGRAMMAR_CHECK(!indices.has_value() || indices->size() == matchers->size()) + << "The size of indices (" << (indices.has_value() ? indices->size() : 0) + << ") should be the same as the size of matchers (" << matchers->size() << ")."; + // Initialize the thread pool if needed. It should be initialized each time, + // because ThreadPool cannot be reused after Join(). + if (max_threads_ > 1) { + thread_pool_.emplace(max_threads_); + } + if (!thread_pool_.has_value()) { + for (int i = 0; i < static_cast(matchers->size()); i++) { + auto& matcher = (*matchers)[i]; + int index = indices.has_value() ? (*indices)[i] : i; + XGRAMMAR_CHECK(index >= 0 && index < next_token_bitmask->shape[0]) + << "The index " << index << " is out of range [0, " << next_token_bitmask->shape[0] + << ") for batch_id " << i << "."; + matcher->FillNextTokenBitmask(next_token_bitmask, index, debug_print); + } + } else { + auto fill_next_token_mask = [&](int32_t batch_id) { + auto& matcher = (*matchers)[batch_id]; + int index = indices.has_value() ? (*indices)[batch_id] : batch_id; + XGRAMMAR_CHECK(index >= 0 && index < next_token_bitmask->shape[0]) + << "The index " << index << " is out of range [0, " << next_token_bitmask->shape[0] + << ") for batch_id " << batch_id << "."; + matcher->FillNextTokenBitmask(next_token_bitmask, index, debug_print); + }; + for (int i = 0; i < static_cast(matchers->size()); i++) { + thread_pool_->Execute([fill_next_token_mask, i]() { fill_next_token_mask(i); }); + } + thread_pool_->Join(); + } +} + +std::vector BatchGrammarMatcher::Impl::BatchAcceptString( + std::vector* matchers, + const std::vector& input_strs, + bool debug_print +) { + XGRAMMAR_CHECK(matchers->size() == input_strs.size()) + << "The size of matchers (" << matchers->size() << ") and input_strs (" << input_strs.size() + << ") should be the same."; + std::vector accepted(matchers->size()); + for (int i = 0; i < static_cast(matchers->size()); i++) { + auto& matcher = (*matchers)[i]; + accepted[i] = matcher->AcceptString(input_strs[i], debug_print); + } + return accepted; +} + +std::vector BatchGrammarMatcher::Impl::BatchAcceptToken( + std::vector* matchers, const std::vector& token_ids, bool debug_print +) { + XGRAMMAR_CHECK(matchers->size() == token_ids.size()) + << "The size of matchers (" << matchers->size() << ") and token_ids (" << token_ids.size() + << ") should be the same."; + std::vector accepted(matchers->size()); + for (int i = 0; i < static_cast(matchers->size()); i++) { + auto& matcher = (*matchers)[i]; + accepted[i] = matcher->AcceptToken(token_ids[i], debug_print); + } + return accepted; +} + GrammarMatcher::GrammarMatcher( const CompiledGrammar& compiled_grammar, std::optional> override_stop_tokens, @@ -894,4 +1014,30 @@ std::string GrammarMatcher::_DebugPrintInternalState() const { return pimpl_->_DebugPrintInternalState(); } +void BatchGrammarMatcher::BatchFillNextTokenBitmask( + std::vector* matchers, + DLTensor* next_token_bitmask, + const std::optional>& indices, + bool debug_print +) { + return pimpl_->BatchFillNextTokenBitmask(matchers, next_token_bitmask, indices, debug_print); +} + +std::vector BatchGrammarMatcher::BatchAcceptString( + std::vector* matchers, + const std::vector& input_strs, + bool debug_print +) { + return Impl::BatchAcceptString(matchers, input_strs, debug_print); +} + +std::vector BatchGrammarMatcher::BatchAcceptToken( + std::vector* matchers, const std::vector& token_ids, bool debug_print +) { + return Impl::BatchAcceptToken(matchers, token_ids, debug_print); +} + +BatchGrammarMatcher::BatchGrammarMatcher(std::variant max_threads) + : pimpl_(std::make_shared(max_threads)) {} + } // namespace xgrammar diff --git a/cpp/nanobind/nanobind.cc b/cpp/nanobind/nanobind.cc index 69ba9d30..6d39ee11 100644 --- a/cpp/nanobind/nanobind.cc +++ b/cpp/nanobind/nanobind.cc @@ -13,12 +13,19 @@ #include #include +#include +#include +#include +#include +#include + #include "../grammar_functor.h" #include "../json_schema_converter.h" #include "../regex_converter.h" #include "../testing.h" #include "python_methods.h" #include "xgrammar/exception.h" +#include "xgrammar/matcher.h" namespace nb = nanobind; @@ -74,6 +81,48 @@ bool GrammarMatcher_FillNextTokenBitmask( return matcher.FillNextTokenBitmask(bitmask_dltensor_ptr, index, debug_print); } +void GrammarMatcher_BatchFillNextTokenMask( + BatchGrammarMatcher& batch_matcher, + std::vector* matchers, + nb::ndarray<> arr, + const std::optional>& indices, + std::variant max_threads, + bool debug_print +) { + if (arr.ndim() != 2) { + throw std::runtime_error("batch_token_bitmask tensor must be 2D"); + } + if (arr.device_type() != nb::device::cpu::value) { + throw std::runtime_error("token_bitmask array must be on CPU"); + } + if (arr.dtype() != nb::dtype()) { + throw std::runtime_error("token_bitmask array must be int32"); + } + static_assert(sizeof(arr) == sizeof(void*) + sizeof(nb::dlpack::dltensor)); + + DLTensor* bitmask_dltensor_ptr = + reinterpret_cast<::DLTensor*>(reinterpret_cast(&arr) + sizeof(void*)); + + batch_matcher.BatchFillNextTokenBitmask(matchers, bitmask_dltensor_ptr, indices, debug_print); +} + +std::vector GrammarMatcher_BatchAcceptString( + std::vector* matchers, + const std::vector>& input_strs, + bool debug_print +) { + std::vector input_strs_converted; + input_strs_converted.reserve(input_strs.size()); + for (const auto& str : input_strs) { + if (std::holds_alternative(str)) { + input_strs_converted.emplace_back(std::get(str)); + } else { + input_strs_converted.emplace_back(std::get(str).c_str()); + } + } + return BatchGrammarMatcher::BatchAcceptString(matchers, input_strs_converted); +} + std::vector TokenizerInfo_GetDecodedVocab(const TokenizerInfo& tokenizer) { const auto& decoded_vocab = tokenizer.GetDecodedVocab(); std::vector py_result; @@ -209,7 +258,29 @@ NB_MODULE(xgrammar_bindings, m) { .def("clear_cache", &GrammarCompiler::ClearCache) .def("get_cache_size_bytes", &GrammarCompiler::GetCacheSizeBytes) .def_prop_ro("cache_limit_bytes", &GrammarCompiler::CacheLimitBytes); - + auto pyBatchGrammarMatcher = nb::class_(m, "BatchGrammarMatcher"); + pyBatchGrammarMatcher + .def(nb::init>(), nb::arg("max_threads") = "auto") + .def( + "batch_fill_next_token_bitmask", + &GrammarMatcher_BatchFillNextTokenMask, + nb::arg("matchers"), + nb::arg("batch_token_bitmask"), + nb::arg("indices").none(), + nb::arg("max_thread") = "auto", + nb::arg("debug_print") = false, + nb::call_guard() + ) + .def_static( + "batch_accept_string", + &GrammarMatcher_BatchAcceptString, + nb::call_guard() + ) + .def_static( + "batch_accept_token", + &BatchGrammarMatcher::BatchAcceptToken, + nb::call_guard() + ); auto pyGrammarMatcher = nb::class_(m, "GrammarMatcher"); pyGrammarMatcher .def( diff --git a/docs/xgrammar_features/runtime_safeguards.md b/docs/xgrammar_features/runtime_safeguards.md index 5aaae0a0..fc0a7ae3 100644 --- a/docs/xgrammar_features/runtime_safeguards.md +++ b/docs/xgrammar_features/runtime_safeguards.md @@ -15,7 +15,7 @@ If the recursion depth exceeds the limit, the matcher operations (including [`xgr.GrammarMatcher.accept_token`](xgrammar.GrammarMatcher.accept_token), [`xgr.GrammarMatcher.accept_string`](xgrammar.GrammarMatcher.accept_string), [`xgr.GrammarMatcher.fill_next_token_bitmask`](xgrammar.GrammarMatcher.fill_next_token_bitmask), [`xgr.GrammarMatcher.find_jump_forward_string`](xgrammar.GrammarMatcher.find_jump_forward_string)) will raise -`RuntimeError`. +`RuntimeError`(before XGrammar v0.1.21). You can also use the [`xgr.max_recursion_depth`](xgrammar.max_recursion_depth) context manager to set the maximum recursion depth for a code block. @@ -27,6 +27,8 @@ with max_recursion_depth(10000): matcher.accept_token(token_id) ``` +After XGrammar v0.1.21, the pushdown automaton parser was replaced with an Earley parser, and there is no recursion involved. So the recursion will never be exceed during parsing, and the exception will not be raised. + ## Cache Size Limit The {py:class}`xgr.GrammarCompiler` class uses a cache to store the compiled grammars. diff --git a/include/xgrammar/matcher.h b/include/xgrammar/matcher.h index cdfee8ff..8bfcf16f 100644 --- a/include/xgrammar/matcher.h +++ b/include/xgrammar/matcher.h @@ -14,6 +14,7 @@ #include #include #include +#include #include namespace xgrammar { @@ -145,6 +146,64 @@ class GrammarMatcher { XGRAMMAR_DEFINE_PIMPL_METHODS(GrammarMatcher); }; +/*! + * \brief A batched version of GrammarMatcher for better efficiency. It supports batch processing + * of multiple GrammarMatcher objects in parallel. + * + * \details This class provides batched versions of the core methods of GrammarMatcher, including + * FillNextTokenBitmask, AcceptString, and AcceptToken. It utilizes multi-threading to process + * multiple GrammarMatcher objects simultaneously, significantly improving efficiency when dealing + * with a large number of matchers. + */ +class BatchGrammarMatcher { + public: + BatchGrammarMatcher(std::variant max_threads = "auto"); + + /*! + \brief A batched version of FillNextTokenBitmask for better efficiency. + \param matchers The array of GrammarMatcher objects. + \param next_token_bitmask The pre-allocated DLTensor to store the result bitmasks. + \param indices The optional array of indices to specify which matcher corresponds to which slice + of the bitmask tensor. If not provided, all matchers will write to the corresponding + indices(matchers[i] to next_token_bitmask[i]). + \param debug_print Whether to print debug information. Default is false. + */ + void BatchFillNextTokenBitmask( + std::vector* matchers, + DLTensor* next_token_bitmask, + const std::optional>& indices = std::nullopt, + bool debug_print = false + ); + + /*! + * \brief A batched version of AcceptString for better efficiency. + * \param matchers The array of GrammarMatcher objects. + * \param input_strs The array of input strings to be accepted. + * \param debug_print Whether to print debug information. Default is false. + * \return A vector of bytes indicating whether each string is accepted. + */ + static std::vector BatchAcceptString( + std::vector* matchers, + const std::vector& input_strs, + bool debug_print = false + ); + + /*! + * \brief A batched version of AcceptToken for better efficiency. + * \param matchers The array of GrammarMatcher objects. + * \param token_ids The array of token ids to be accepted. + * \param debug_print Whether to print debug information. Default is false. + * \return A vector of bytes indicating whether each token is accepted. + */ + static std::vector BatchAcceptToken( + std::vector* matchers, + const std::vector& token_ids, + bool debug_print = false + ); + + XGRAMMAR_DEFINE_PIMPL_METHODS(BatchGrammarMatcher); +}; + } // namespace xgrammar #endif // XGRAMMAR_MATCHER_H_ diff --git a/python/xgrammar/__init__.py b/python/xgrammar/__init__.py index 8e625467..cf81929f 100644 --- a/python/xgrammar/__init__.py +++ b/python/xgrammar/__init__.py @@ -15,6 +15,7 @@ ) from .grammar import Grammar, StructuralTagItem from .matcher import ( + BatchGrammarMatcher, GrammarMatcher, allocate_token_bitmask, apply_token_bitmask_inplace, diff --git a/python/xgrammar/matcher.py b/python/xgrammar/matcher.py index 3a135843..4a39cc55 100644 --- a/python/xgrammar/matcher.py +++ b/python/xgrammar/matcher.py @@ -4,7 +4,7 @@ import math import warnings -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import torch from numpy.typing import ArrayLike @@ -249,11 +249,6 @@ def accept_token(self, token_id: int, *, debug_print: bool = False) -> bool: ------- accepted : bool Whether the token is accepted. - - Raises - ------ - RuntimeError - If the recursion depth is exceeded. """ return self._handle.accept_token(token_id, debug_print) @@ -275,11 +270,6 @@ def accept_string(self, input_str: Union[str, bytes], *, debug_print: bool = Fal ------- accepted : bool Whether the string is accepted. - - Raises - ------ - RuntimeError - If the recursion depth is exceeded. """ return self._handle.accept_string(input_str, debug_print) @@ -314,8 +304,6 @@ def fill_next_token_bitmask( ------ RuntimeError If the bitmask is invalid (not on CPU, not int32, shape mismatch). - - If the recursion depth is exceeded. """ return self._handle.fill_next_token_bitmask(bitmask, index, debug_print) @@ -330,11 +318,6 @@ def find_jump_forward_string(self) -> str: ------- jump_forward_string : str The jump-forward string. - - Raises - ------ - RuntimeError - If the recursion depth is exceeded. """ return self._handle.find_jump_forward_string() @@ -400,3 +383,119 @@ def _debug_print_internal_state(self) -> str: The internal state of the matcher. """ return self._handle._debug_print_internal_state() + + +class BatchGrammarMatcher(XGRObject): + """A batch version of GrammarMatcher that can fill the next token bitmask for multiple + matchers in parallel. It utilizes multiple threads to speed up the computation. It is + especially useful when the batch size is large. + """ + + def __init__(self, max_threads: Union[int, Literal["auto"]] = "auto") -> None: + """Construct the batch grammar matcher. + Parameters + ---------- + max_threads : Union[int, Literal["auto"]], default: "auto" + The maximum number of threads to use for parallel processing. If set to "auto", the + max_threads will be set to std::thread::hardware_concurrency() / 2. + """ + + self._init_handle(_core.BatchGrammarMatcher(max_threads)) + + def batch_fill_next_token_bitmask( + self, + matchers: List["GrammarMatcher"], + bitmask: ArrayLike, + indices: Optional[List[int]] = None, + debug_print: bool = False, + ) -> None: + """Fill the next token bitmask for multiple matchers. + + Parameters + ---------- + matchers : List[GrammarMatcher] + The list of matchers to fill the bitmask for. + + bitmask : ArrayLike + Must be a 2-dimensional int32 tensor with shape (bitmask_batch_size, bitmask_size). + Bitmask_batch_size could be larger than the actual batch size to allow padding. + Bitmask_size equals to ceil(vocab_size/32), and could be computed through + xgrammar.allocate_token_bitmask. + + indices : Optional[List[int]], default: None + A list of indices to specify which rows in the bitmask to fill. If None, fill + the bitmask [0:len(matchers))]. + + debug_print : bool, default: False + Whether to print information about generated bitmask. Helpful for debugging. + + Raises + ------ + RuntimeError + If the bitmask is invalid (not on CPU, not int32, shape mismatch). + """ + matcher_handles = [matcher._handle for matcher in matchers] + + self._handle.batch_fill_next_token_bitmask(matcher_handles, bitmask, indices, debug_print) + + @staticmethod + def batch_accept_token( + matchers: List["GrammarMatcher"], tokens: List[int], debug_print: bool = False + ) -> List[bool]: + """Accept a batch of tokens for multiple matchers. + + Parameters + ---------- + matchers : List[GrammarMatcher] + The list of matchers to accept tokens for. + + tokens : List[int] + The list of tokens to accept. + + debug_print : bool, default: False + Whether to print information about generated bitmask. Helpful for debugging. + + Returns + ------- + accepted : List[bool] + A list of booleans indicating whether each token was accepted by its corresponding matcher. + + Raises + ------ + RuntimeError + If the sizes of matchers and tokens do not match. + """ + matcher_handles = [matcher._handle for matcher in matchers] + return _core.BatchGrammarMatcher.batch_accept_token(matcher_handles, tokens, debug_print) + + @staticmethod + def batch_accept_string( + matchers: List["GrammarMatcher"], + strings: List[Union[str, bytes]], + debug_print: bool = False, + ) -> List[bool]: + """Accept a batch of strings for multiple matchers. + + Parameters + ---------- + matchers : List[GrammarMatcher] + The list of matchers to accept tokens for. + + strings : List[Union[str, bytes]] + The list of strings to accept. + + debug_print : bool, default: False + Whether to print information about generated bitmask. Helpful for debugging. + + Returns + ------- + accepted : List[bool] + A list of booleans indicating whether each string was accepted by its corresponding matcher. + + Raises + ------ + RuntimeError + If the sizes of matchers and strings do not match. + """ + matcher_handles = [matcher._handle for matcher in matchers] + return _core.BatchGrammarMatcher.batch_accept_string(matcher_handles, strings, debug_print) diff --git a/tests/python/test_grammar_matcher_basic.py b/tests/python/test_grammar_matcher_basic.py index 15fab1d3..e2b2675f 100644 --- a/tests/python/test_grammar_matcher_basic.py +++ b/tests/python/test_grammar_matcher_basic.py @@ -1,6 +1,7 @@ """Test the basic functionality of GrammarMatcher.""" import math +import random import sys from typing import List, Optional, Union @@ -396,5 +397,220 @@ def test_fill_next_token_bitmask_errors(): matcher.fill_next_token_bitmask(bitmask_correct) +test_batch_accept_string_grammars_inputs_expecteds = [ + (['root ::= "a"', "root ::= [0-9]+", 'root ::= "ab"'], ["a", b"123", "ab"], [True, True, True]), + ( + ['root ::= "a"', "root ::= [0-9]+", 'root ::= "ab"'], + ["b", "123a", "d"], + [False, False, False], + ), + ( + ['root ::= "a"', "root ::= [0-9]+", 'root ::= "ab"'], + ["a", b"123a", b"ab"], + [True, False, True], + ), + (['root ::= "a"'], ["a"], [True]), + (['root ::= "a"'], ["b"], [False]), + ( + ['root ::= "你好"', 'root ::= "こんにちは"', 'root ::= "안녕하세요"'], + ["你好", "こんにちは", "안녕하세요"], + [True, True, True], + ), +] + + +@pytest.mark.parametrize( + "grammars, inputs, expecteds", test_batch_accept_string_grammars_inputs_expecteds +) +def test_batch_accept_string( + grammars: List[str], inputs: List[Union[str, bytes]], expecteds: List[bool] +): + matchers = [_get_matcher_from_grammar(grammar) for grammar in grammars] + results = xgr.BatchGrammarMatcher.batch_accept_string(matchers, inputs) + assert results == expecteds + + +test_batch_accept_token_grammars_inputs_expecteds = [ + (['root ::= "a"', "root ::= [0-9]+", 'root ::= "ab"'], [2, 5, 2], [True, True, True]), + (['root ::= "a"', "root ::= [0-9]+", 'root ::= "ab"'], [3, 2, 4], [False, False, False]), + (['root ::= "a"', "root ::= [0-9]+", 'root ::= "ab"'], [2, 8, 9], [True, False, True]), + (['root ::= "a"'], [2], [True]), + (['root ::= "a"'], [3], [False]), +] + + +@pytest.mark.parametrize( + "grammars, inputs, expecteds", test_batch_accept_token_grammars_inputs_expecteds +) +def test_batch_accept_token(grammars: List[str], inputs: List[int], expecteds: List[bool]): + vocab = [ + # fmt: off + "", "", "a", "b", "c", "1", "2", "3", "123a", "ab", + # fmt: on + ] + tokenizer_info = xgr.TokenizerInfo(vocab) + + matchers = [ + _get_matcher_from_grammar_and_tokenizer_info(xgr.Grammar.from_ebnf(grammar), tokenizer_info) + for grammar in grammars + ] + results = xgr.BatchGrammarMatcher.batch_accept_token(matchers, inputs) + assert results == expecteds + + +def test_batch_fill_next_token_bitmask(): + grammars = ['root ::= "a"', "root ::= [0-9]+", 'root ::= "ab"', "root ::= [a-z0-9]+"] + vocab = [ + # fmt: off + "ab", "", "a", "b", "c", "1", "2", "3", "123a" + # fmt: on + ] + tokenizer_info = xgr.TokenizerInfo(vocab) + + matchers = [ + _get_matcher_from_grammar_and_tokenizer_info(xgr.Grammar.from_ebnf(grammar), tokenizer_info) + for grammar in grammars + ] + + batch_size = len(matchers) + token_bitmask = xgr.allocate_token_bitmask(batch_size, tokenizer_info.vocab_size) + + input_str = ["a", "1", "a", "123a"] + + expected_accepted_tokens = [ + [[2], [5, 6, 7], [0, 2], [0, 2, 3, 4, 5, 6, 7, 8]], + [[1], [1, 5, 6, 7], [3], [0, 1, 2, 3, 4, 5, 6, 7, 8]], + ] + + batch_grammar_matcher = xgr.BatchGrammarMatcher(2) + batch_grammar_matcher.batch_fill_next_token_bitmask(matchers, token_bitmask) + + for i in range(batch_size): + rejected_token_ids = _get_masked_tokens_from_bitmask( + token_bitmask[i : i + 1], tokenizer_info.vocab_size + ) + accepted = list(set(range(len(vocab))) - set(rejected_token_ids)) + accepted.sort() + assert accepted == expected_accepted_tokens[0][i] + + assert xgr.BatchGrammarMatcher.batch_accept_string(matchers, input_str) == [ + True, + True, + True, + True, + ] + + batch_grammar_matcher.batch_fill_next_token_bitmask(matchers, token_bitmask) + + for i in range(batch_size): + rejected_token_ids = _get_masked_tokens_from_bitmask( + token_bitmask[i : i + 1], tokenizer_info.vocab_size + ) + accepted = list(set(range(len(vocab))) - set(rejected_token_ids)) + accepted.sort() + assert accepted == expected_accepted_tokens[1][i] + + +@pytest.mark.hf_token_required +def test_batch_fill_next_token_bitmask_pressure(): + tokenizer_path = "meta-llama/Llama-2-7b-chat-hf" + input_str = '{"id": 1,"name": "Example"}' + rejected_token_size = [ + # fmt: off + 31989, 31912, 270, 270, 270, 31973, 31846, 31846, 31948, 31915, 270, 270, 270, 270, + 270, 31973, 31846, 31846, 263, 263, 263, 263, 263, 263, 263, 263, 31974, 31999, + # fmt: on + ] + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True, trust_remote_code=True) + tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer) + matchers = [ + _get_matcher_from_grammar_and_tokenizer_info(json_grammar, tokenizer_info) + for _ in range(len(input_str) + 1) + ] + input_strs = [input_str[:i] for i in range(len(input_str))] + [input_str] + xgr.BatchGrammarMatcher.batch_accept_string(matchers, input_strs) + + bitmask_2d = xgr.allocate_token_bitmask(len(matchers), tokenizer_info.vocab_size) + batch_grammar_matcher = xgr.BatchGrammarMatcher(2) + batch_grammar_matcher.batch_fill_next_token_bitmask(matchers, bitmask_2d) + for i in range(len(matchers)): + rejected_token_ids = _get_masked_tokens_from_bitmask( + bitmask_2d[i], tokenizer_info.vocab_size + ) + assert len(rejected_token_ids) == rejected_token_size[i], ( + i, + len(rejected_token_ids), + rejected_token_size[i], + ) + + +@pytest.mark.hf_token_required +def test_batch_fill_next_token_bitmask_pressure_single_thread(): + tokenizer_path = "meta-llama/Llama-2-7b-chat-hf" + input_str = '{"id": 1,"name": "Example"}' + rejected_token_size = [ + # fmt: off + 31989, 31912, 270, 270, 270, 31973, 31846, 31846, 31948, 31915, 270, 270, 270, 270, + 270, 31973, 31846, 31846, 263, 263, 263, 263, 263, 263, 263, 263, 31974, 31999, + # fmt: on + ] + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True, trust_remote_code=True) + tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer) + matchers = [ + _get_matcher_from_grammar_and_tokenizer_info(json_grammar, tokenizer_info) + for _ in range(len(input_str) + 1) + ] + input_strs = [input_str[:i] for i in range(len(input_str))] + [input_str] + xgr.BatchGrammarMatcher.batch_accept_string(matchers, input_strs) + + bitmask_2d = xgr.allocate_token_bitmask(len(matchers), tokenizer_info.vocab_size) + batch_grammar_matcher = xgr.BatchGrammarMatcher(1) + batch_grammar_matcher.batch_fill_next_token_bitmask(matchers, bitmask_2d) + for i in range(len(matchers)): + rejected_token_ids = _get_masked_tokens_from_bitmask( + bitmask_2d[i], tokenizer_info.vocab_size + ) + assert len(rejected_token_ids) == rejected_token_size[i], ( + i, + len(rejected_token_ids), + rejected_token_size[i], + ) + + +@pytest.mark.hf_token_required +def test_batch_fill_next_token_bitmask_pressure_shuffled(): + tokenizer_path = "meta-llama/Llama-2-7b-chat-hf" + input_str = '{"id": 1,"name": "Example"}' + rejected_token_size = [ + # fmt: off + 31989, 31912, 270, 270, 270, 31973, 31846, 31846, 31948, 31915, 270, 270, 270, 270, + 270, 31973, 31846, 31846, 263, 263, 263, 263, 263, 263, 263, 263, 31974, 31999, + # fmt: on + ] + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True, trust_remote_code=True) + tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer) + matchers = [ + _get_matcher_from_grammar_and_tokenizer_info(json_grammar, tokenizer_info) + for _ in range(len(input_str) + 1) + ] + input_strs = [input_str[:i] for i in range(len(input_str))] + [input_str] + xgr.BatchGrammarMatcher.batch_accept_string(matchers, input_strs) + + shuffled_indices = list(range(len(matchers))) + random.shuffle(shuffled_indices) + bitmask_2d = xgr.allocate_token_bitmask(len(matchers), tokenizer_info.vocab_size) + batch_grammar_matcher = xgr.BatchGrammarMatcher() + batch_grammar_matcher.batch_fill_next_token_bitmask(matchers, bitmask_2d, shuffled_indices) + for i in range(len(matchers)): + rejected_token_ids = _get_masked_tokens_from_bitmask( + bitmask_2d[shuffled_indices[i]], tokenizer_info.vocab_size + ) + assert len(rejected_token_ids) == rejected_token_size[i], ( + i, + len(rejected_token_ids), + rejected_token_size[i], + ) + + if __name__ == "__main__": pytest.main(sys.argv)