Skip to content

Conversation

Seven-Streams
Copy link
Collaborator

@Seven-Streams Seven-Streams commented Sep 29, 2025

This PR provides a series of methods to batched fill next token mask and batched accept string/token for GrammarMatcher. In general, there are three new methods:

  • batch_fill_next_token_bitmask(matchers: List["GrammarMatcher"], bitmask: ArrayLike, index: int = 0, max_threads: int = 16, debug_print: bool = False)
  • batch_accept_string(matchers: List["GrammarMatcher"], strings: List[str], debug_print: bool = False)
  • batch_accept_token(matchers: List["GrammarMatcher"], tokens: List[int], debug_print: bool = False)

These methods allow users to fill multiple GrammarMatchers' token masks in a batch, and reduce the overhead of the transversion between cpp and python.

Signed-off-by: Yuchuan <[email protected]>
Signed-off-by: Yuchuan <[email protected]>
Signed-off-by: Yuchuan <[email protected]>
Signed-off-by: Yuchuan <[email protected]>
Signed-off-by: Yuchuan <[email protected]>
Signed-off-by: Yuchuan <[email protected]>
Signed-off-by: Yuchuan <[email protected]>
Signed-off-by: Yuchuan <[email protected]>
Signed-off-by: Yuchuan <[email protected]>
Signed-off-by: Yuchuan <[email protected]>
Signed-off-by: Yuchuan <[email protected]>
Copy link
Collaborator

@Ubospica Ubospica left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x grammar, latency

Before (sequential); After(batch)

Use one example in JSB

Signed-off-by: Yuchuan <[email protected]>
Signed-off-by: Yuchuan <[email protected]>
Signed-off-by: Yuchuan <[email protected]>
Signed-off-by: Yuchuan <[email protected]>
Signed-off-by: Yuchuan <[email protected]>
Signed-off-by: Yuchuan <[email protected]>
Signed-off-by: Yuchuan <[email protected]>
@Seven-Streams Seven-Streams marked this pull request as draft September 30, 2025 03:08
Signed-off-by: Yuchuan <[email protected]>
Signed-off-by: Yuchuan <[email protected]>
Signed-off-by: Yuchuan <[email protected]>
@Seven-Streams
Copy link
Collaborator Author

comparison_bar_chart

@Seven-Streams Seven-Streams marked this pull request as ready for review September 30, 2025 04:12
Copy link
Collaborator

@Ubospica Ubospica left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Is y axis the average time or total time? And why is batch=1 so slow? Didn't you warm it up? And can you also show speedup with batch is large, like 300

@Seven-Streams
Copy link
Collaborator Author

LGTM. Is y axis the average time or total time? And why is batch=1 so slow? Didn't you warm it up? And can you also show speedup with batch is large, like 300

I did warmed it up. I'll test it more.

Signed-off-by: Yuchuan <[email protected]>
Signed-off-by: Yuchuan <[email protected]>
@Seven-Streams
Copy link
Collaborator Author

@Ubospica image

The testing code is here:

import xgrammar as xgr
from transformers import AutoTokenizer
import time
import matplotlib.pyplot as plt
import numpy as np

tokenizer_path = "meta-llama/Llama-2-7b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True, trust_remote_code=True)
tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer)
compiler = xgr.GrammarCompiler(tokenizer_info)
regex = "(\\s\\S){24}"
compiled_grammar = compiler.compile_regex(regex)

batch_time_list = dict()
naive_time_list = dict()

batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256]

# Warm up
for batch_size in batch_sizes:
    matchers = [
        xgr.GrammarMatcher(compiled_grammar)
        for _ in range(batch_size)
    ]
    dummy_next_token_bitmask = xgr.allocate_token_bitmask(1, tokenizer.vocab_size)
    for matcher in matchers:
        matcher.fill_next_token_bitmask(dummy_next_token_bitmask)
    
    matchers = [
        xgr.GrammarMatcher(compiled_grammar)
        for _ in range(batch_size)
    ]
    dummy_next_batch_token_bitmask = xgr.allocate_token_bitmask(batch_size, tokenizer.vocab_size)
    xgr.GrammarMatcher.batch_fill_next_token_bitmask(matchers, dummy_next_batch_token_bitmask)

# Actual benchmarking
for _ in range(10):
    for batch_size in batch_sizes:
        matchers = [
            xgr.GrammarMatcher(compiled_grammar)
            for _ in range(batch_size)
        ]
        
        next_token_bitmask = xgr.allocate_token_bitmask(batch_size, tokenizer.vocab_size)
        
        # warm up
        for i, matcher in enumerate(matchers):
            matcher.fill_next_token_bitmask(next_token_bitmask, i)

        start_time = time.time_ns()
        for i, matcher in enumerate(matchers):
            matcher.fill_next_token_bitmask(next_token_bitmask, i)
        end_time = time.time_ns()
        elapsed_time = end_time - start_time
        if batch_size not in naive_time_list:
            naive_time_list[batch_size] = []
        naive_time_list[batch_size].append(elapsed_time)
        
        matchers = [
            xgr.GrammarMatcher(compiled_grammar)
            for _ in range(batch_size)
        ]

        next_batch_token_bitmask = xgr.allocate_token_bitmask(batch_size, tokenizer.vocab_size)
        
        # warm up
        xgr.GrammarMatcher.batch_fill_next_token_bitmask(matchers, next_batch_token_bitmask)
        
        start_time = time.time_ns()
        xgr.GrammarMatcher.batch_fill_next_token_bitmask(matchers, next_batch_token_bitmask)
        end_time = time.time_ns()
        elapsed_time = end_time - start_time
        if batch_size not in batch_time_list:
            batch_time_list[batch_size] = []
        batch_time_list[batch_size].append(elapsed_time)
        
# Draw results
naive_avg_time = [np.mean(naive_time_list[bs]) for bs in batch_sizes]
single_times = []
for bs in batch_sizes:
    single_times.extend([t / bs for t in naive_time_list[bs]])
naive_avg_time.insert(0, np.mean(single_times))
batch_avg_time = [np.mean(batch_time_list[bs]) for bs in batch_sizes]
single_times = []
for bs in batch_sizes:
    single_times.extend([t / bs for t in batch_time_list[bs]])
batch_avg_time.insert(0, np.mean(single_times))
    
x = np.arange(len(batch_sizes) + 1)
width = 0.2 
fig, ax = plt.subplots()
rects1 = ax.bar(x - width/2, naive_avg_time, width, label='Naive')
rects2 = ax.bar(x + width/2, batch_avg_time, width, label='Batch')
ax.set_ylabel('Time (ns)')
ax.set_xlabel('Batch Size')
ax.set_title('Time by Batch Size and Method')
ax.set_xticks(x)

ax.set_xticklabels(["avg"] + batch_sizes)
ax.legend()
ax.bar_label(rects1, padding=3)
ax.bar_label(rects2, padding=-3)
ax.set_yscale('log')
plt.savefig("batch_vs_naive.png")
plt.show()

Signed-off-by: Yuchuan <[email protected]>

refactor nanobind.

Signed-off-by: Yuchuan <[email protected]>

refactor python files.

Signed-off-by: Yuchuan <[email protected]>

finish.

Signed-off-by: Yuchuan <[email protected]>

fix batch filling.

Signed-off-by: Yuchuan <[email protected]>
@Seven-Streams Seven-Streams force-pushed the main-dev/2025-09-28/batch branch from 2a63ee5 to 6374909 Compare October 3, 2025 02:02
@Seven-Streams Seven-Streams merged commit 01678b1 into mlc-ai:main Oct 8, 2025
38 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants