Skip to content

Commit e1f3741

Browse files
committed
add some syntactic sugar to structural tags.
Signed-off-by: Yuchuan <[email protected]>
1 parent 605049e commit e1f3741

File tree

3 files changed

+60
-12
lines changed

3 files changed

+60
-12
lines changed

python/xgrammar/compiler.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
_convert_schema_to_str,
1313
_get_structural_tag_str_from_args,
1414
)
15-
from .structural_tag import StructuralTag
15+
from .structural_tag import Format, StructuralTag
1616
from .tokenizer_info import TokenizerInfo
1717

1818

@@ -220,7 +220,7 @@ def compile_regex(self, regex: str) -> CompiledGrammar:
220220

221221
@overload
222222
def compile_structural_tag(
223-
self, structural_tag: Union[StructuralTag, str, Dict[str, Any]]
223+
self, structural_tag: Union[StructuralTag, str, Dict[str, Any], Format]
224224
) -> CompiledGrammar: ...
225225

226226
@overload
@@ -246,8 +246,9 @@ def compile_structural_tag(self, *args, **kwargs) -> CompiledGrammar:
246246
247247
Parameters
248248
----------
249-
structural_tag : Union[StructuralTag, str, Dict[str, Any]]
249+
structural_tag : Union[StructuralTag, str, Dict[str, Any], Format]
250250
The structural tag either as a StructuralTag object, or a JSON string or a dictionary.
251+
If the input is a format enum, it will be converted to a StructuralTag object automatically.
251252
252253
tags : List[StructuralTagItem]
253254
(Deprecated) The structural tags. Use StructuralTag class instead.

python/xgrammar/grammar.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"""This module provides classes representing grammars."""
22

33
import json
4-
from typing import Any, Dict, List, Optional, Tuple, Type, Union, overload
4+
from typing import Any, Dict, List, Optional, Tuple, Type, Union, get_args, overload
55

66
from pydantic import BaseModel
77
from typing_extensions import deprecated
88

99
from .base import XGRObject, _core
10-
from .structural_tag import StructuralTag, StructuralTagItem
10+
from .structural_tag import Format, StructuralTag, StructuralTagItem
1111

1212

1313
def _convert_instance_to_str(instance: Union[str, Dict[str, Any], StructuralTag]) -> str:
@@ -111,8 +111,12 @@ def _get_structural_tag_str_from_args(args: List[Any], kwargs: Dict[str, Any]) -
111111
When the arguments are invalid.
112112
"""
113113
if len(args) == 1:
114+
possible_formats = get_args(Format)
114115
if isinstance(args[0], (str, dict, StructuralTag)):
115116
return _convert_instance_to_str(args[0])
117+
elif any(isinstance(args[0], fmt) for fmt in possible_formats):
118+
structural_tag = StructuralTag(format=args[0])
119+
return _convert_instance_to_str(structural_tag)
116120
else:
117121
raise TypeError("Invalid argument type for from_structural_tag")
118122
elif len(args) == 2 and isinstance(args[0], list) and isinstance(args[1], list):
@@ -285,7 +289,7 @@ def from_regex(regex_string: str, *, print_converted_ebnf: bool = False) -> "Gra
285289
@overload
286290
@staticmethod
287291
def from_structural_tag(
288-
structural_tag: Union[StructuralTag, str, Dict[str, Any]]
292+
structural_tag: Union[StructuralTag, str, Dict[str, Any], Format]
289293
) -> "Grammar": ...
290294

291295
@overload
@@ -311,8 +315,9 @@ def from_structural_tag(*args, **kwargs) -> "Grammar":
311315
312316
Parameters
313317
----------
314-
structural_tag : Union[StructuralTag, str, Dict[str, Any]]
318+
structural_tag : Union[StructuralTag, str, Dict[str, Any], Format]
315319
The structural tag either as a StructuralTag object, or a JSON string or a dictionary.
320+
If the input is a format enum, it will be converted to a StructuralTag object automatically.
316321
317322
tags : List[StructuralTagItem]
318323
(Deprecated) The structural tags. Use StructuralTag class instead.

tests/python/test_structural_tag_converter.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import sys
22
import time
3-
from typing import Any, Dict, List, Tuple
3+
from typing import Any, Dict, List, Tuple, Union
44

55
import pytest
66
from transformers import AutoTokenizer
77

88
import xgrammar as xgr
9+
from xgrammar.structural_tag import Format, StructuralTag
910
from xgrammar.testing import _is_grammar_accept_string
1011

1112
PROFILER_ON = True
12-
tokenizer_id = "meta-llama/Llama-3.1-8B-Instruct"
13+
tokenizer_id = "meta-llama/Meta-Llama-3-8B-Instruct"
1314

1415

1516
class Profiler:
@@ -22,13 +23,16 @@ def __init__(self, tokenizer_id: str):
2223
self.tokenizer_info, max_threads=16, cache_enabled=False
2324
)
2425

25-
def profile_stag(self, structural_tag_format: Dict[str, Any], instance: str):
26-
structural_tag = {"type": "structural_tag", "format": structural_tag_format}
26+
def profile_stag(
27+
self, structural_tag: Union[Dict[str, Any], StructuralTag, Format], instance: str
28+
):
29+
if isinstance(structural_tag, Dict):
30+
structural_tag = {"type": "structural_tag", "format": structural_tag}
2731
time_begin = time.monotonic_ns()
2832
compiled_grammar = self.compiler.compile_structural_tag(structural_tag)
2933
time_end = time.monotonic_ns()
3034
compiler_duration = time_end - time_begin
31-
print(f"Compiling structural tag {structural_tag_format}")
35+
print(f"Compiling structural tag {structural_tag}")
3236
print(f"Compile time: {compiler_duration / 1000 / 1000} ms")
3337
matcher = xgr.GrammarMatcher(compiled_grammar)
3438
token_bitmask = xgr.allocate_token_bitmask(1, self.tokenizer_info.vocab_size)
@@ -69,6 +73,33 @@ def check_stag_with_instance(
6973
profiler.profile_stag(structural_tag_format, instance)
7074

7175

76+
def check_stag_with_instance_in_stag_style(
77+
structural_tag: StructuralTag,
78+
instance: str,
79+
is_accepted: bool = True,
80+
debug_print: bool = False,
81+
):
82+
stag_grammar = xgr.Grammar.from_structural_tag(structural_tag)
83+
accepted = _is_grammar_accept_string(stag_grammar, instance, debug_print=debug_print)
84+
assert accepted == is_accepted
85+
if PROFILER_ON:
86+
profiler.profile_stag(structural_tag, instance)
87+
88+
89+
def check_stag_with_instance_in_format_style(
90+
structural_tag_format: Format,
91+
instance: str,
92+
is_accepted: bool = True,
93+
debug_print: bool = False,
94+
):
95+
96+
stag_grammar = xgr.Grammar.from_structural_tag(structural_tag_format)
97+
accepted = _is_grammar_accept_string(stag_grammar, instance, debug_print=debug_print)
98+
assert accepted == is_accepted
99+
if PROFILER_ON:
100+
profiler.profile_stag(structural_tag_format, instance)
101+
102+
72103
const_string_stag_grammar = [
73104
(
74105
{"type": "const_string", "value": "Hello!"},
@@ -1959,5 +1990,16 @@ def test_from_structural_tag_with_structural_tag_instance(
19591990
assert _is_grammar_accept_string(grammar, instance) == is_accepted
19601991

19611992

1993+
@pytest.mark.parametrize(
1994+
"stag_format, instance, is_accepted", basic_structural_tags_instance_is_accepted
1995+
)
1996+
def test_from_structural_tag_with_format_instance(
1997+
stag_format: xgr.structural_tag.Format, instance: str, is_accepted: bool
1998+
):
1999+
stag = xgr.StructuralTag(format=stag_format)
2000+
grammar = xgr.Grammar.from_structural_tag(stag)
2001+
assert _is_grammar_accept_string(grammar, instance) == is_accepted
2002+
2003+
19622004
if __name__ == "__main__":
19632005
pytest.main(sys.argv)

0 commit comments

Comments
 (0)