Skip to content

Commit

Permalink
[Feature] Support regex and repetition range (#144)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubospica authored Jan 8, 2025
1 parent fdf8d6d commit 6cb2db1
Show file tree
Hide file tree
Showing 19 changed files with 398 additions and 14 deletions.
2 changes: 1 addition & 1 deletion cpp/compiled_grammar_data_structure.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include <vector>

// matcher_data_structure.h is included to use RulePosition
#include "matcher_data_structure.h"
#include "grammar_matcher_data_structure.h"
#include "support/dynamic_bitset.h"
#include "support/utils.h"

Expand Down
2 changes: 1 addition & 1 deletion cpp/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

#include "compiled_grammar_data_structure.h"
#include "grammar_data_structure.h"
#include "matcher_base.h"
#include "grammar_matcher_base.h"
#include "support/thread_pool.h"
#include "support/thread_safe_cache.h"

Expand Down
3 changes: 3 additions & 0 deletions cpp/grammar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "grammar_parser.h"
#include "grammar_serializer.h"
#include "json_schema_converter.h"
#include "regex_converter.h"

namespace xgrammar {

Expand All @@ -32,6 +33,8 @@ Grammar Grammar::FromJSONSchema(
return FromEBNF(ebnf_string);
}

Grammar Grammar::FromRegex(const std::string& regex) { return FromEBNF(RegexToEBNF(regex)); }

// Optimized json grammar for the speed of the grammar matcher
const std::string kJSONGrammarString = R"(
root ::= (
Expand Down
7 changes: 5 additions & 2 deletions cpp/matcher.cc → cpp/grammar_matcher.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
/*!
* Copyright (c) 2024 by Contributors
* \file xgrammar/matcher.cc
* \brief This source file implement the matcher class, especially the logic related to LLM tokens,
* like accepting tokens, leveraging the token mask cache to generate the mask, etc. matcher_base.cc
* implements the basic matching algorithm from strings to grammar.
*/

#include <xgrammar/matcher.h>
Expand All @@ -10,9 +13,9 @@

#include "compiled_grammar_data_structure.h"
#include "grammar_data_structure.h"
#include "grammar_matcher_base.h"
#include "grammar_matcher_data_structure.h"
#include "grammar_serializer.h"
#include "matcher_base.h"
#include "matcher_data_structure.h"
#include "support/dynamic_bitset.h"
#include "support/encoding.h"
#include "support/int_set.h"
Expand Down
9 changes: 6 additions & 3 deletions cpp/matcher_base.cc → cpp/grammar_matcher_base.cc
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
/*!
* Copyright (c) 2024 by Contributors
* \file xgrammar/regex_converter.cc
* \file xgrammar/matcher_base.cc
* \brief This source file implements the basic matching algorithm from strings to grammar.
* matcher.cc will handle the logic related to LLM tokens, like accepting tokens, leveraging the
* token mask cache to generate the mask, etc.
*/

#include "matcher_base.h"
#include "grammar_matcher_base.h"

#include <algorithm>
#include <vector>

#include "grammar_data_structure.h"
#include "matcher_data_structure.h"
#include "grammar_matcher_data_structure.h"
#include "support/encoding.h"

namespace xgrammar {
Expand Down
2 changes: 1 addition & 1 deletion cpp/matcher_base.h → cpp/grammar_matcher_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <vector>

#include "grammar_data_structure.h"
#include "matcher_data_structure.h"
#include "grammar_matcher_data_structure.h"

namespace xgrammar {

Expand Down
File renamed without changes.
122 changes: 119 additions & 3 deletions cpp/grammar_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class EBNFParser {
int32_t ParseString();
int32_t ParseRuleRef();
int32_t ParseElement();
int64_t ParseInteger();
std::pair<int64_t, int64_t> ParseRepetitionRange();
int32_t ParseQuantifier();
int32_t ParseLookaheadAssertion();
int32_t ParseSequence();
Expand All @@ -40,6 +42,7 @@ class EBNFParser {
int32_t HandleStarQuantifier(int32_t rule_expr_id);
int32_t HandlePlusQuantifier(int32_t rule_expr_id);
int32_t HandleQuestionQuantifier(int32_t rule_expr_id);
int32_t HandleRepetitionRange(int32_t rule_expr_id, int64_t lower, int64_t upper);

// When parsing, we first find the names of all rules, and build the mapping from name to rule id.
void BuildRuleNameToId();
Expand Down Expand Up @@ -85,6 +88,8 @@ class EBNFParser {
// Whether the current element is in parentheses.
// A sequence expression cannot contain newline, unless it is in parentheses.
bool in_parentheses_ = false;

inline static constexpr int64_t MAX_INTEGER_IN_GRAMMAR = 1e10;
};

void EBNFParser::ConsumeSpace(bool allow_newline) {
Expand Down Expand Up @@ -289,7 +294,7 @@ int32_t EBNFParser::HandleStarQuantifier(int32_t rule_expr_id) {
auto new_rule_id = builder_.AddEmptyRule(new_rule_name);
auto ref_to_new_rule = builder_.AddRuleRef(new_rule_id);
auto new_rule_expr_id = builder_.AddChoices(
{builder_.AddSequence({rule_expr_id, ref_to_new_rule}), builder_.AddEmptyStr()}
{builder_.AddEmptyStr(), builder_.AddSequence({rule_expr_id, ref_to_new_rule})}
);
builder_.UpdateRuleBody(new_rule_id, new_rule_expr_id);

Expand All @@ -314,17 +319,128 @@ int32_t EBNFParser::HandlePlusQuantifier(int32_t rule_expr_id) {
int32_t EBNFParser::HandleQuestionQuantifier(int32_t rule_expr_id) {
// a? --> rule ::= a | empty
auto new_rule_name = builder_.GetNewRuleName(cur_rule_name_);
auto new_rule_expr_id = builder_.AddChoices({rule_expr_id, builder_.AddEmptyStr()});
auto new_rule_expr_id = builder_.AddChoices({builder_.AddEmptyStr(), rule_expr_id});
auto new_rule_id = builder_.AddRule({new_rule_name, new_rule_expr_id});
return builder_.AddRuleRef(new_rule_id);
}

int64_t EBNFParser::ParseInteger() {
if (!isdigit(Peek())) {
ReportParseError("Expect integer");
}
int64_t num = 0;
while (Peek() && isdigit(Peek())) {
num = num * 10 + (Peek() - '0');
Consume();
if (num > MAX_INTEGER_IN_GRAMMAR) {
ReportParseError(
"Integer is too large: parsed " + std::to_string(num) + ", max allowed is " +
std::to_string(MAX_INTEGER_IN_GRAMMAR)
);
}
}
return num;
}

// {x}: Match exactly x occurrences
// {x,}: Match at least x occurrences
// {x,y}: Match at least x occurrences, at most y occurrences
std::pair<int64_t, int64_t> EBNFParser::ParseRepetitionRange() {
Consume();
ConsumeSpace();
int64_t lower = ParseInteger();
ConsumeSpace();
if (Peek() == ',') {
Consume();
ConsumeSpace();
if (Peek() == '}') {
Consume();
return {lower, -1};
}
int64_t upper = ParseInteger();
if (upper < lower) {
ReportParseError(
"Lower bound is larger than upper bound: " + std::to_string(lower) + " > " +
std::to_string(upper)
);
}
Consume();
return {lower, upper};
} else if (Peek() == '}') {
Consume();
return {lower, lower};
}
ReportParseError("Expect ',' or '}' in repetition range");
}

int32_t EBNFParser::HandleRepetitionRange(int32_t rule_expr_id, int64_t lower, int64_t upper) {
// Construct expr expr ... expr (l times)
std::vector<int32_t> elements;
for (int64_t i = 0; i < lower; ++i) {
elements.push_back(rule_expr_id);
}

// Case 1: {l}:
// expr expr ... expr (l times)
if (upper == lower) {
return builder_.AddSequence(elements);
}

// Case 2: {l,}:
// expr expr ... expr (l times) rest
// rest ::= "" | expr rest
if (upper == -1) {
auto new_rule_name = builder_.GetNewRuleName(cur_rule_name_);
auto new_rule_id = builder_.AddEmptyRule(new_rule_name);
auto ref_to_new_rule = builder_.AddRuleRef(new_rule_id);
auto new_rule_expr_id = builder_.AddChoices(
{builder_.AddEmptyStr(), builder_.AddSequence({rule_expr_id, ref_to_new_rule})}
);
builder_.UpdateRuleBody(new_rule_id, new_rule_expr_id);
elements.push_back(builder_.AddRuleRef(new_rule_id));
return builder_.AddSequence(elements);
}

// Case 3: {l, r} (r - l >= 1)
// expr expr ... expr (l times) rest1
// rest1 ::= "" | expr rest2
// rest2 ::= "" | expr rest3
// ...
// rest(r - l) ::= "" | expr
std::vector<int32_t> rest_rule_ids;

for (int64_t i = 0; i < upper - lower; ++i) {
auto new_rule_name = builder_.GetNewRuleName(cur_rule_name_);
rest_rule_ids.push_back(builder_.AddEmptyRule(new_rule_name));
}
for (int64_t i = 0; i < upper - lower - 1; ++i) {
auto ref_to_next_rule = builder_.AddRuleRef(rest_rule_ids[i + 1]);
auto new_rule_expr_id = builder_.AddChoices(
{builder_.AddEmptyStr(), builder_.AddSequence({rule_expr_id, ref_to_next_rule})}
);
builder_.UpdateRuleBody(rest_rule_ids[i], new_rule_expr_id);
}
auto last_rule_expr_id = builder_.AddChoices({builder_.AddEmptyStr(), rule_expr_id});
builder_.UpdateRuleBody(rest_rule_ids.back(), last_rule_expr_id);

elements.push_back(builder_.AddRuleRef(rest_rule_ids[0]));
return builder_.AddSequence(elements);
}

int32_t EBNFParser::ParseQuantifier() {
int32_t rule_expr_id = ParseElement();
ConsumeSpace(in_parentheses_);
if (Peek() != '*' && Peek() != '+' && Peek() != '?') {
if (Peek() != '*' && Peek() != '+' && Peek() != '?' && Peek() != '{') {
return rule_expr_id;
}

// Handle repetition range
if (Peek() == '{') {
auto [lower, upper] = ParseRepetitionRange();
return HandleRepetitionRange(rule_expr_id, lower, upper);
}

// Handle quantifiers
Consume();

// We will transform a*, a+, a? into a rule, and return the reference to this rule
Expand Down
1 change: 1 addition & 0 deletions cpp/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ PYBIND11_MODULE(xgrammar_bindings, m) {
pyGrammar.def("to_string", &Grammar::ToString)
.def_static("from_ebnf", &Grammar::FromEBNF)
.def_static("from_json_schema", &Grammar::FromJSONSchema)
.def_static("from_regex", &Grammar::FromRegex)
.def_static("builtin_json_grammar", &Grammar::BuiltinJSONGrammar);

auto pyCompiledGrammar = py::class_<CompiledGrammar>(m, "CompiledGrammar");
Expand Down
1 change: 1 addition & 0 deletions cpp/regex_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ std::string RegexConverter::HandleCharacterClass() {
// {x,y}
std::string RegexConverter::HandleRepetitionRange() {
std::string result = "{";
++current_;
if (!isdigit(*current_)) {
RaiseError("Invalid repetition count.");
}
Expand Down
6 changes: 6 additions & 0 deletions include/xgrammar/grammar.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ class Grammar {
bool strict_mode = true
);

/*!
* \brief Construct a grammar from a regular expression string.
* \param regex The regular expression string.
*/
static Grammar FromRegex(const std::string& regex);

/*!
* \brief Get the grammar of standard JSON format. We have built-in support for JSON.
*/
Expand Down
31 changes: 31 additions & 0 deletions python/xgrammar/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ def from_ebnf(ebnf_string: str, *, root_rule_name: str = "root") -> "Grammar":
root_rule_name : str, default: "root"
The name of the root rule in the grammar.
Raises
------
RuntimeError
When converting the regex pattern fails, with details about the parsing error.
"""
return Grammar._create_from_handle(_core.Grammar.from_ebnf(ebnf_string, root_rule_name))

Expand Down Expand Up @@ -100,6 +105,11 @@ def from_json_schema(
-------
grammar : Grammar
The constructed grammar.
Raises
------
RuntimeError
When converting the json schema fails, with details about the parsing error.
"""
if isinstance(schema, type) and issubclass(schema, BaseModel):
if hasattr(schema, "model_json_schema"):
Expand All @@ -117,6 +127,27 @@ def from_json_schema(
_core.Grammar.from_json_schema(schema, any_whitespace, indent, separators, strict_mode),
)

@staticmethod
def from_regex(regex_string: str) -> "Grammar":
"""Create a grammar from a regular expression string.
Parameters
----------
regex_string : str
The regular expression pattern to create the grammar from.
Returns
-------
grammar : Grammar
The constructed grammar from the regex pattern.
Raises
------
RuntimeError
When parsing the regex pattern fails, with details about the parsing error.
"""
return Grammar._create_from_handle(_core.Grammar.from_regex(regex_string))

@staticmethod
def builtin_json_grammar() -> "Grammar":
"""Get the grammar of standard JSON. This is compatible with the official JSON grammar
Expand Down
37 changes: 35 additions & 2 deletions python/xgrammar/testing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Testing utilities."""

import time
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -84,14 +85,46 @@ def _regex_to_ebnf(regex: str, with_rule_name: bool = True) -> str:


def _is_grammar_accept_string(
grammar: Union[Grammar, str], input_str: str, debug_print: bool = False
grammar: Union[Grammar, str],
input_str: str,
*,
debug_print: bool = False,
print_time: bool = False,
) -> bool:
"""Check if a grammar accepts a string. For test purposes.
Parameters
----------
grammar : Union[Grammar, str]
The grammar to check. Can be either a Grammar object or a BNF grammar string.
input_str : str
The input string to check.
debug_print : bool, default: False
Whether to print debug information during matching.
print_time : bool, default: False
Whether to print timing information.
Returns
-------
bool
True if the grammar accepts the string, False otherwise.
"""

if isinstance(grammar, str):
grammar = Grammar.from_ebnf(grammar)
grammar_compiler = GrammarCompiler(TokenizerInfo([]), cache_enabled=False)
compiled_grammar = grammar_compiler.compile_grammar(grammar)
matcher = GrammarMatcher(compiled_grammar, terminate_without_stop_token=True)
if not matcher._debug_accept_string(input_str, debug_print=debug_print):

if print_time:
start = time.monotonic_ns()
accepted = matcher._debug_accept_string(input_str, debug_print=debug_print)

if print_time:
end = time.monotonic_ns()
print(f"Accepting {input_str}, result: {accepted}, time: {(end - start) / 1e3} us")

if not accepted:
return False
return matcher.is_terminated()

Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_grammar_matcher.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""This test uses the optimized JSON grammar provided by the grammar library."""
"""This test tests the token-based operations for the grammar matcher."""

import sys
from typing import List, Optional
Expand Down
Loading

0 comments on commit 6cb2db1

Please sign in to comment.