Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions cpp/grammar_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@

#include <xgrammar/matcher.h>

#include <algorithm>
#include <cstdint>
#include <optional>
#include <thread>
#include <utility>
#include <variant>
#include <vector>

#include "compiled_grammar_impl.h"
Expand All @@ -19,6 +23,7 @@
#include "support/encoding.h"
#include "support/int_set.h"
#include "support/logging.h"
#include "support/thread_pool.h"
#include "testing.h"

namespace xgrammar {
Expand Down Expand Up @@ -343,6 +348,52 @@ class GrammarMatcher::Impl : public EarleyParser {
std::vector<int32_t> tmp_rejected_indices_delta_;
};

class BatchGrammarMatcher::Impl {
public:
Impl(std::variant<std::string, int32_t> max_threads) {
if (std::holds_alternative<int32_t>(max_threads)) {
int32_t num_threads = std::get<int32_t>(max_threads);
XGRAMMAR_CHECK(num_threads >= 1)
<< "The num_threads should be at least 1, but got " << num_threads;
if (num_threads > 1) {
if (num_threads > static_cast<int32_t>(std::thread::hardware_concurrency())) {
XGRAMMAR_LOG(WARNING) << "The num_threads " << num_threads << " is larger than the "
<< "number of hardware threads. Using "
<< static_cast<int32_t>(std::thread::hardware_concurrency())
<< " instead.";
}
max_threads_ =
std::min(num_threads, static_cast<int32_t>(std::thread::hardware_concurrency()));
}
} else {
std::string str = std::get<std::string>(max_threads);
XGRAMMAR_CHECK(str == "auto");
max_threads_ = std::thread::hardware_concurrency() / 2;
}
}

void BatchFillNextTokenBitmask(
std::vector<GrammarMatcher>* matchers,
DLTensor* next_token_bitmask,
const std::optional<std::vector<int32_t>>& indices,
bool debug_print
);

static std::vector<uint8_t> BatchAcceptToken(
std::vector<GrammarMatcher>* matchers, const std::vector<int32_t>& token_ids, bool debug_print
);

static std::vector<uint8_t> BatchAcceptString(
std::vector<GrammarMatcher>* matchers,
const std::vector<std::string>& input_strs,
bool debug_print
);

private:
std::optional<ThreadPool> thread_pool_ = std::nullopt;
int32_t max_threads_ = 1;
};

bool GrammarMatcher::Impl::AcceptStopToken() {
if (terminate_without_stop_token_) {
return false;
Expand Down Expand Up @@ -852,6 +903,75 @@ int GrammarMatcher::Impl::GetNextUncertainToken(
}
}

void BatchGrammarMatcher::Impl::BatchFillNextTokenBitmask(
std::vector<GrammarMatcher>* matchers,
DLTensor* next_token_bitmask,
const std::optional<std::vector<int32_t>>& indices,
bool debug_print
) {
XGRAMMAR_CHECK(!indices.has_value() || indices->size() == matchers->size())
<< "The size of indices (" << (indices.has_value() ? indices->size() : 0)
<< ") should be the same as the size of matchers (" << matchers->size() << ").";
// Initialize the thread pool if needed. It should be initialized each time,
// because ThreadPool cannot be reused after Join().
if (max_threads_ > 1) {
thread_pool_.emplace(max_threads_);
}
if (!thread_pool_.has_value()) {
for (int i = 0; i < static_cast<int32_t>(matchers->size()); i++) {
auto& matcher = (*matchers)[i];
int index = indices.has_value() ? (*indices)[i] : i;
XGRAMMAR_CHECK(index >= 0 && index < next_token_bitmask->shape[0])
<< "The index " << index << " is out of range [0, " << next_token_bitmask->shape[0]
<< ") for batch_id " << i << ".";
matcher->FillNextTokenBitmask(next_token_bitmask, index, debug_print);
}
} else {
auto fill_next_token_mask = [&](int32_t batch_id) {
auto& matcher = (*matchers)[batch_id];
int index = indices.has_value() ? (*indices)[batch_id] : batch_id;
XGRAMMAR_CHECK(index >= 0 && index < next_token_bitmask->shape[0])
<< "The index " << index << " is out of range [0, " << next_token_bitmask->shape[0]
<< ") for batch_id " << batch_id << ".";
matcher->FillNextTokenBitmask(next_token_bitmask, index, debug_print);
};
for (int i = 0; i < static_cast<int32_t>(matchers->size()); i++) {
thread_pool_->Execute([fill_next_token_mask, i]() { fill_next_token_mask(i); });
}
thread_pool_->Join();
}
}

std::vector<uint8_t> BatchGrammarMatcher::Impl::BatchAcceptString(
std::vector<GrammarMatcher>* matchers,
const std::vector<std::string>& input_strs,
bool debug_print
) {
XGRAMMAR_CHECK(matchers->size() == input_strs.size())
<< "The size of matchers (" << matchers->size() << ") and input_strs (" << input_strs.size()
<< ") should be the same.";
std::vector<uint8_t> accepted(matchers->size());
for (int i = 0; i < static_cast<int32_t>(matchers->size()); i++) {
auto& matcher = (*matchers)[i];
accepted[i] = matcher->AcceptString(input_strs[i], debug_print);
}
return accepted;
}

std::vector<uint8_t> BatchGrammarMatcher::Impl::BatchAcceptToken(
std::vector<GrammarMatcher>* matchers, const std::vector<int32_t>& token_ids, bool debug_print
) {
XGRAMMAR_CHECK(matchers->size() == token_ids.size())
<< "The size of matchers (" << matchers->size() << ") and token_ids (" << token_ids.size()
<< ") should be the same.";
std::vector<uint8_t> accepted(matchers->size());
for (int i = 0; i < static_cast<int32_t>(matchers->size()); i++) {
auto& matcher = (*matchers)[i];
accepted[i] = matcher->AcceptToken(token_ids[i], debug_print);
}
return accepted;
}

GrammarMatcher::GrammarMatcher(
const CompiledGrammar& compiled_grammar,
std::optional<std::vector<int>> override_stop_tokens,
Expand Down Expand Up @@ -894,4 +1014,30 @@ std::string GrammarMatcher::_DebugPrintInternalState() const {
return pimpl_->_DebugPrintInternalState();
}

void BatchGrammarMatcher::BatchFillNextTokenBitmask(
std::vector<GrammarMatcher>* matchers,
DLTensor* next_token_bitmask,
const std::optional<std::vector<int32_t>>& indices,
bool debug_print
) {
return pimpl_->BatchFillNextTokenBitmask(matchers, next_token_bitmask, indices, debug_print);
}

std::vector<uint8_t> BatchGrammarMatcher::BatchAcceptString(
std::vector<GrammarMatcher>* matchers,
const std::vector<std::string>& input_strs,
bool debug_print
) {
return Impl::BatchAcceptString(matchers, input_strs, debug_print);
}

std::vector<uint8_t> BatchGrammarMatcher::BatchAcceptToken(
std::vector<GrammarMatcher>* matchers, const std::vector<int32_t>& token_ids, bool debug_print
) {
return Impl::BatchAcceptToken(matchers, token_ids, debug_print);
}

BatchGrammarMatcher::BatchGrammarMatcher(std::variant<std::string, int32_t> max_threads)
: pimpl_(std::make_shared<BatchGrammarMatcher::Impl>(max_threads)) {}

} // namespace xgrammar
73 changes: 72 additions & 1 deletion cpp/nanobind/nanobind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,19 @@
#include <nanobind/stl/vector.h>
#include <xgrammar/xgrammar.h>

#include <cstdint>
#include <optional>
#include <thread>
#include <variant>
#include <vector>

#include "../grammar_functor.h"
#include "../json_schema_converter.h"
#include "../regex_converter.h"
#include "../testing.h"
#include "python_methods.h"
#include "xgrammar/exception.h"
#include "xgrammar/matcher.h"

namespace nb = nanobind;

Expand Down Expand Up @@ -74,6 +81,48 @@ bool GrammarMatcher_FillNextTokenBitmask(
return matcher.FillNextTokenBitmask(bitmask_dltensor_ptr, index, debug_print);
}

void GrammarMatcher_BatchFillNextTokenMask(
BatchGrammarMatcher& batch_matcher,
std::vector<GrammarMatcher>* matchers,
nb::ndarray<> arr,
const std::optional<std::vector<int32_t>>& indices,
std::variant<int32_t, std::string> max_threads,
bool debug_print
) {
if (arr.ndim() != 2) {
throw std::runtime_error("batch_token_bitmask tensor must be 2D");
}
if (arr.device_type() != nb::device::cpu::value) {
throw std::runtime_error("token_bitmask array must be on CPU");
}
if (arr.dtype() != nb::dtype<int32_t>()) {
throw std::runtime_error("token_bitmask array must be int32");
}
static_assert(sizeof(arr) == sizeof(void*) + sizeof(nb::dlpack::dltensor));

DLTensor* bitmask_dltensor_ptr =
reinterpret_cast<::DLTensor*>(reinterpret_cast<char*>(&arr) + sizeof(void*));

batch_matcher.BatchFillNextTokenBitmask(matchers, bitmask_dltensor_ptr, indices, debug_print);
}

std::vector<uint8_t> GrammarMatcher_BatchAcceptString(
std::vector<GrammarMatcher>* matchers,
const std::vector<std::variant<nb::bytes, std::string>>& input_strs,
bool debug_print
) {
std::vector<std::string> input_strs_converted;
input_strs_converted.reserve(input_strs.size());
for (const auto& str : input_strs) {
if (std::holds_alternative<std::string>(str)) {
input_strs_converted.emplace_back(std::get<std::string>(str));
} else {
input_strs_converted.emplace_back(std::get<nb::bytes>(str).c_str());
}
}
return BatchGrammarMatcher::BatchAcceptString(matchers, input_strs_converted);
}

std::vector<nanobind::bytes> TokenizerInfo_GetDecodedVocab(const TokenizerInfo& tokenizer) {
const auto& decoded_vocab = tokenizer.GetDecodedVocab();
std::vector<nanobind::bytes> py_result;
Expand Down Expand Up @@ -209,7 +258,29 @@ NB_MODULE(xgrammar_bindings, m) {
.def("clear_cache", &GrammarCompiler::ClearCache)
.def("get_cache_size_bytes", &GrammarCompiler::GetCacheSizeBytes)
.def_prop_ro("cache_limit_bytes", &GrammarCompiler::CacheLimitBytes);

auto pyBatchGrammarMatcher = nb::class_<BatchGrammarMatcher>(m, "BatchGrammarMatcher");
pyBatchGrammarMatcher
.def(nb::init<std::variant<std::string, int32_t>>(), nb::arg("max_threads") = "auto")
.def(
"batch_fill_next_token_bitmask",
&GrammarMatcher_BatchFillNextTokenMask,
nb::arg("matchers"),
nb::arg("batch_token_bitmask"),
nb::arg("indices").none(),
nb::arg("max_thread") = "auto",
nb::arg("debug_print") = false,
nb::call_guard<nb::gil_scoped_release>()
)
.def_static(
"batch_accept_string",
&GrammarMatcher_BatchAcceptString,
nb::call_guard<nb::gil_scoped_release>()
)
.def_static(
"batch_accept_token",
&BatchGrammarMatcher::BatchAcceptToken,
nb::call_guard<nb::gil_scoped_release>()
);
auto pyGrammarMatcher = nb::class_<GrammarMatcher>(m, "GrammarMatcher");
pyGrammarMatcher
.def(
Expand Down
4 changes: 3 additions & 1 deletion docs/xgrammar_features/runtime_safeguards.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ If the recursion depth exceeds the limit,
the matcher operations (including [`xgr.GrammarMatcher.accept_token`](xgrammar.GrammarMatcher.accept_token),
[`xgr.GrammarMatcher.accept_string`](xgrammar.GrammarMatcher.accept_string), [`xgr.GrammarMatcher.fill_next_token_bitmask`](xgrammar.GrammarMatcher.fill_next_token_bitmask),
[`xgr.GrammarMatcher.find_jump_forward_string`](xgrammar.GrammarMatcher.find_jump_forward_string)) will raise
`RuntimeError`.
`RuntimeError`(before XGrammar v0.1.21).

You can also use the [`xgr.max_recursion_depth`](xgrammar.max_recursion_depth) context manager to set the maximum
recursion depth for a code block.
Expand All @@ -27,6 +27,8 @@ with max_recursion_depth(10000):
matcher.accept_token(token_id)
```

After XGrammar v0.1.21, the pushdown automaton parser was replaced with an Earley parser, and there is no recursion involved. So the recursion will never be exceed during parsing, and the exception will not be raised.

## Cache Size Limit

The {py:class}`xgr.GrammarCompiler` class uses a cache to store the compiled grammars.
Expand Down
59 changes: 59 additions & 0 deletions include/xgrammar/matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <cstdint>
#include <optional>
#include <string>
#include <variant>
#include <vector>

namespace xgrammar {
Expand Down Expand Up @@ -145,6 +146,64 @@ class GrammarMatcher {
XGRAMMAR_DEFINE_PIMPL_METHODS(GrammarMatcher);
};

/*!
* \brief A batched version of GrammarMatcher for better efficiency. It supports batch processing
* of multiple GrammarMatcher objects in parallel.
*
* \details This class provides batched versions of the core methods of GrammarMatcher, including
* FillNextTokenBitmask, AcceptString, and AcceptToken. It utilizes multi-threading to process
* multiple GrammarMatcher objects simultaneously, significantly improving efficiency when dealing
* with a large number of matchers.
*/
class BatchGrammarMatcher {
public:
BatchGrammarMatcher(std::variant<std::string, int32_t> max_threads = "auto");

/*!
\brief A batched version of FillNextTokenBitmask for better efficiency.
\param matchers The array of GrammarMatcher objects.
\param next_token_bitmask The pre-allocated DLTensor to store the result bitmasks.
\param indices The optional array of indices to specify which matcher corresponds to which slice
of the bitmask tensor. If not provided, all matchers will write to the corresponding
indices(matchers[i] to next_token_bitmask[i]).
\param debug_print Whether to print debug information. Default is false.
*/
void BatchFillNextTokenBitmask(
std::vector<GrammarMatcher>* matchers,
DLTensor* next_token_bitmask,
const std::optional<std::vector<int32_t>>& indices = std::nullopt,
bool debug_print = false
);

/*!
* \brief A batched version of AcceptString for better efficiency.
* \param matchers The array of GrammarMatcher objects.
* \param input_strs The array of input strings to be accepted.
* \param debug_print Whether to print debug information. Default is false.
* \return A vector of bytes indicating whether each string is accepted.
*/
static std::vector<uint8_t> BatchAcceptString(
std::vector<GrammarMatcher>* matchers,
const std::vector<std::string>& input_strs,
bool debug_print = false
);

/*!
* \brief A batched version of AcceptToken for better efficiency.
* \param matchers The array of GrammarMatcher objects.
* \param token_ids The array of token ids to be accepted.
* \param debug_print Whether to print debug information. Default is false.
* \return A vector of bytes indicating whether each token is accepted.
*/
static std::vector<uint8_t> BatchAcceptToken(
std::vector<GrammarMatcher>* matchers,
const std::vector<int32_t>& token_ids,
bool debug_print = false
);

XGRAMMAR_DEFINE_PIMPL_METHODS(BatchGrammarMatcher);
};

} // namespace xgrammar

#endif // XGRAMMAR_MATCHER_H_
1 change: 1 addition & 0 deletions python/xgrammar/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from .grammar import Grammar, StructuralTagItem
from .matcher import (
BatchGrammarMatcher,
GrammarMatcher,
allocate_token_bitmask,
apply_token_bitmask_inplace,
Expand Down
Loading
Loading