11import sys
22import time
3- from typing import Any , Dict , List , Tuple
3+ from typing import Any , Dict , List , Tuple , Union
44
55import pytest
66from transformers import AutoTokenizer
77
88import xgrammar as xgr
9+ from xgrammar .structural_tag import Format , StructuralTag
910from xgrammar .testing import _is_grammar_accept_string
1011
1112PROFILER_ON = True
12- tokenizer_id = "meta-llama/Llama-3.1 -8B-Instruct"
13+ tokenizer_id = "meta-llama/Meta- Llama-3-8B-Instruct"
1314
1415
1516class Profiler :
@@ -22,13 +23,16 @@ def __init__(self, tokenizer_id: str):
2223 self .tokenizer_info , max_threads = 16 , cache_enabled = False
2324 )
2425
25- def profile_stag (self , structural_tag_format : Dict [str , Any ], instance : str ):
26- structural_tag = {"type" : "structural_tag" , "format" : structural_tag_format }
26+ def profile_stag (
27+ self , structural_tag : Union [Dict [str , Any ], StructuralTag , Format ], instance : str
28+ ):
29+ if isinstance (structural_tag , Dict ):
30+ structural_tag = {"type" : "structural_tag" , "format" : structural_tag }
2731 time_begin = time .monotonic_ns ()
2832 compiled_grammar = self .compiler .compile_structural_tag (structural_tag )
2933 time_end = time .monotonic_ns ()
3034 compiler_duration = time_end - time_begin
31- print (f"Compiling structural tag { structural_tag_format } " )
35+ print (f"Compiling structural tag { structural_tag } " )
3236 print (f"Compile time: { compiler_duration / 1000 / 1000 } ms" )
3337 matcher = xgr .GrammarMatcher (compiled_grammar )
3438 token_bitmask = xgr .allocate_token_bitmask (1 , self .tokenizer_info .vocab_size )
@@ -69,6 +73,33 @@ def check_stag_with_instance(
6973 profiler .profile_stag (structural_tag_format , instance )
7074
7175
76+ def check_stag_with_instance_in_stag_style (
77+ structural_tag : StructuralTag ,
78+ instance : str ,
79+ is_accepted : bool = True ,
80+ debug_print : bool = False ,
81+ ):
82+ stag_grammar = xgr .Grammar .from_structural_tag (structural_tag )
83+ accepted = _is_grammar_accept_string (stag_grammar , instance , debug_print = debug_print )
84+ assert accepted == is_accepted
85+ if PROFILER_ON :
86+ profiler .profile_stag (structural_tag , instance )
87+
88+
89+ def check_stag_with_instance_in_format_style (
90+ structural_tag_format : Format ,
91+ instance : str ,
92+ is_accepted : bool = True ,
93+ debug_print : bool = False ,
94+ ):
95+
96+ stag_grammar = xgr .Grammar .from_structural_tag (structural_tag_format )
97+ accepted = _is_grammar_accept_string (stag_grammar , instance , debug_print = debug_print )
98+ assert accepted == is_accepted
99+ if PROFILER_ON :
100+ profiler .profile_stag (structural_tag_format , instance )
101+
102+
72103const_string_stag_grammar = [
73104 (
74105 {"type" : "const_string" , "value" : "Hello!" },
@@ -1959,5 +1990,16 @@ def test_from_structural_tag_with_structural_tag_instance(
19591990 assert _is_grammar_accept_string (grammar , instance ) == is_accepted
19601991
19611992
1993+ @pytest .mark .parametrize (
1994+ "stag_format, instance, is_accepted" , basic_structural_tags_instance_is_accepted
1995+ )
1996+ def test_from_structural_tag_with_format_instance (
1997+ stag_format : xgr .structural_tag .Format , instance : str , is_accepted : bool
1998+ ):
1999+ stag = xgr .StructuralTag (format = stag_format )
2000+ grammar = xgr .Grammar .from_structural_tag (stag )
2001+ assert _is_grammar_accept_string (grammar , instance ) == is_accepted
2002+
2003+
19622004if __name__ == "__main__" :
19632005 pytest .main (sys .argv )
0 commit comments