Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generator-style run_batch function #2513

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
1545637
Add generator-style run_batch function
openhands-agent Dec 18, 2024
a8b28a2
Simplify generator-style run_batch to only yield results
openhands-agent Dec 18, 2024
779479d
Update python/sglang/lang/interpreter.py
xingyaoww Dec 18, 2024
8aa9f30
Update python/sglang/lang/interpreter.py
xingyaoww Dec 18, 2024
9e4f4ee
Merge branch 'main' into generator-run-batch
xingyaoww Dec 20, 2024
f726142
Maintain input order in generator_style=True mode and improve docstrings
openhands-agent Dec 26, 2024
2d21542
Remove docstrings to fix linting errors
openhands-agent Dec 26, 2024
d860e97
Fix test cases to handle both list and generator results
openhands-agent Dec 26, 2024
58fb561
Fix formatting
openhands-agent Dec 26, 2024
c1f44d3
Fix generator check in test cases
openhands-agent Dec 26, 2024
a20ed9d
remove unused future_to_arguments
xingyaoww Dec 26, 2024
eb27ec7
fix indentation
xingyaoww Dec 26, 2024
f509eee
Merge branch 'main' into generator-run-batch
xingyaoww Dec 26, 2024
c2a0feb
revert test change
xingyaoww Dec 26, 2024
c7d6573
Merge commit 'f509eee799a8bd06c1d81058fd1a85eb4eb89146' into generato…
xingyaoww Dec 26, 2024
d41e23f
Merge branch 'main' into generator-run-batch
xingyaoww Jan 2, 2025
03399f6
simplify generator
xingyaoww Jan 2, 2025
3b8a151
linter fix
xingyaoww Jan 2, 2025
13401ef
add test for generator style True
xingyaoww Jan 2, 2025
6e2750f
fix the issue where it start yield late
xingyaoww Jan 2, 2025
44ec998
fix yield for a large number of tasks
xingyaoww Jan 2, 2025
42c5ba7
Merge branch 'main' into generator-run-batch
xingyaoww Jan 2, 2025
acaf7ba
fix linter
xingyaoww Jan 2, 2025
194909d
Refactor run_program_batch to preserve original behavior and add gene…
openhands-agent Jan 3, 2025
d908645
fix linter
xingyaoww Jan 3, 2025
a06e4df
Merge branch 'main' into generator-run-batch
xingyaoww Jan 3, 2025
2159cdb
fix linter
xingyaoww Jan 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions python/sglang/lang/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def run_program_batch(
default_sampling_para,
num_threads,
progress_bar,
generator_style=False,
):
if hasattr(backend, "endpoint"):
backend = backend.endpoint
Expand All @@ -109,6 +110,17 @@ def run_program_batch(
num_threads = max(96, multiprocessing.cpu_count() * 16)
num_threads = min(num_threads, len(batch_arguments))

if generator_style:
return _run_program_batch_generator(
program,
backend,
batch_arguments,
default_sampling_para,
num_threads,
progress_bar,
)

# Original code path when generator_style=False
if num_threads == 1:
xingyaoww marked this conversation as resolved.
Show resolved Hide resolved
rets = []
if progress_bar:
Expand Down Expand Up @@ -168,6 +180,64 @@ def run_program_batch(
return rets


def _run_program_batch_generator(
program,
backend,
batch_arguments,
default_sampling_para,
num_threads,
progress_bar,
):
"""Helper function that yields results one by one using chunking to avoid overwhelming ThreadPoolExecutor."""
if num_threads == 1:
iterator = tqdm.tqdm(batch_arguments) if progress_bar else batch_arguments
for arguments in iterator:
yield run_program(
program,
backend,
(),
arguments,
default_sampling_para,
False,
True,
)
else:
pbar = tqdm.tqdm(total=len(batch_arguments)) if progress_bar else None

# Process in chunks to avoid overwhelming ThreadPoolExecutor
# Otherwise, ThreadPoolExecutor.submit will block after adding certain number of tasks
# so we will never reach "yield" until all tasks are done
chunk_size = 200

with ThreadPoolExecutor(num_threads) as executor:
for chunk_start in range(0, len(batch_arguments), chunk_size):
chunk_end = min(chunk_start + chunk_size, len(batch_arguments))
chunk_futures = []

# Submit chunk of tasks
for i in range(chunk_start, chunk_end):
future = executor.submit(
run_program,
program,
backend,
(),
batch_arguments[i],
default_sampling_para,
False,
True,
)
if pbar:
future.add_done_callback(lambda _: pbar.update())
chunk_futures.append(future)

# Yield results from this chunk as they complete
for future in chunk_futures:
yield future.result()

if pbar:
pbar.close()


def cache_program(program, backend):
from sglang.lang.tracer import extract_prefix_by_tracing

Expand Down
2 changes: 2 additions & 0 deletions python/sglang/lang/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def run_batch(
backend=None,
num_threads: Union[str, int] = "auto",
progress_bar: bool = False,
generator_style: bool = False,
):
from sglang.lang.interpreter import run_program_batch

Expand Down Expand Up @@ -277,6 +278,7 @@ def run_batch(
default_sampling_para,
num_threads,
progress_bar,
generator_style=generator_style,
)

def trace(self, *, backend=None, **kwargs):
Expand Down
24 changes: 23 additions & 1 deletion python/sglang/test/test_programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,13 +509,35 @@ def few_shot_hellaswag(s, question, choices):
temperature=0,
num_threads=64,
progress_bar=True,
generator_style=False,
)
preds = [choices[i].index(rets[i]["answer"]) for i in range(len(rets))]
preds = []
for i, ret in enumerate(rets):
preds.append(choices[i].index(ret["answer"]))
latency = time.time() - tic

# Compute accuracy
accuracy = np.mean(np.array(preds) == np.array(labels))

# Test generator style of run_batch
tic = time.time()
rets = few_shot_hellaswag.run_batch(
arguments,
temperature=0,
num_threads=64,
progress_bar=True,
generator_style=True,
)
preds_gen = []
for i, ret in enumerate(rets):
preds_gen.append(choices[i].index(ret["answer"]))
latency_gen = time.time() - tic

# Compute accuracy
accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
assert np.abs(accuracy_gen - accuracy) < 0.01
assert np.abs(latency_gen - latency) < 1

return accuracy, latency


Expand Down
25 changes: 12 additions & 13 deletions scripts/playground/reference_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@
import argparse

import requests
from PIL import Image

import torch
from PIL import Image
from transformers import (
AutoModelForCausalLM, AutoProcessor, AutoModelForImageTextToText
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoProcessor,
)

from sglang.srt.hf_transformers_utils import get_tokenizer
Expand All @@ -39,8 +40,7 @@
@torch.no_grad()
def vlm_text_with_image(args):
# Load the processor and model for ImageTextToText tasks
processor = AutoProcessor.from_pretrained(
args.model_path, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=True)
model = AutoModelForImageTextToText.from_pretrained(
args.model_path,
torch_dtype=args.dtype,
Expand All @@ -64,11 +64,8 @@ def vlm_text_with_image(args):
{
"type": "image",
},
{
"type": "text",
"text": "Describe this image."
}
]
{"type": "text", "text": "Describe this image."},
],
}
]

Expand All @@ -84,11 +81,13 @@ def vlm_text_with_image(args):
if not hasattr(processor, "apply_chat_template"):
raise ValueError("The processor does not support chat templates.")
text_prompt = processor.apply_chat_template(
conversation, add_generation_prompt=True)
conversation, add_generation_prompt=True
)

# Prepare inputs for the model
inputs = processor(text=[text_prompt], images=[image],
return_tensors="pt").to("cuda:0")
inputs = processor(text=[text_prompt], images=[image], return_tensors="pt").to(
"cuda:0"
)

# Generate output from the model
output_ids = model.generate(
Expand Down
Loading