Skip to content

Commit

Permalink
Add generator-style run_batch function
Browse files Browse the repository at this point in the history
This change adds a new generator_style parameter to run_batch that allows
yielding results as they become available, while maintaining the performance
benefits of batch processing. This is particularly useful when you want to
process results as soon as they are ready, for example to save them to disk.

When generator_style=True, run_batch yields tuples of (arguments, result)
as they become available, instead of returning a list at the end.

Fixes sgl-project#303
  • Loading branch information
openhands-agent committed Dec 18, 2024
1 parent 21e9e63 commit 1545637
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 45 deletions.
92 changes: 47 additions & 45 deletions python/sglang/lang/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def run_program_batch(
default_sampling_para,
num_threads,
progress_bar,
generator_style=False,
):
if hasattr(backend, "endpoint"):
backend = backend.endpoint
Expand All @@ -110,62 +111,63 @@ def run_program_batch(
num_threads = min(num_threads, len(batch_arguments))

if num_threads == 1:
rets = []
if progress_bar:
for arguments in tqdm.tqdm(batch_arguments):
rets.append(
run_program(
program,
backend,
(),
arguments,
default_sampling_para,
False,
True,
)
)
if progress_bar and not generator_style:
iterator = tqdm.tqdm(batch_arguments)
else:
for arguments in batch_arguments:
rets.append(
run_program(
program,
backend,
(),
arguments,
default_sampling_para,
False,
True,
)
)
iterator = batch_arguments

for arguments in iterator:
result = run_program(
program,
backend,
(),
arguments,
default_sampling_para,
False,
True,
)
if generator_style:
yield arguments, result
else:
if not 'rets' in locals():
rets = []
rets.append(result)

else:
if progress_bar:
if progress_bar and not generator_style:
pbar = tqdm.tqdm(total=len(batch_arguments))

with ThreadPoolExecutor(num_threads) as executor:
futures = []
future_to_arguments = {}
for arguments in batch_arguments:
futures.append(
executor.submit(
run_program,
program,
backend,
(),
arguments,
default_sampling_para,
False,
True,
)
future = executor.submit(
run_program,
program,
backend,
(),
arguments,
default_sampling_para,
False,
True,
)
if progress_bar:
futures[-1].add_done_callback(lambda _: pbar.update())

rets = [f.result() for f in futures]
rets[-1].sync()
futures.append(future)
future_to_arguments[future] = arguments
if progress_bar and not generator_style:
future.add_done_callback(lambda _: pbar.update())

if generator_style:
for future in concurrent.futures.as_completed(futures):
yield future_to_arguments[future], future.result()
else:
rets = [f.result() for f in futures]
rets[-1].sync()

if progress_bar:
if progress_bar and not generator_style:
pbar.close()

return rets
if not generator_style:
return rets


def cache_program(program, backend):
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/lang/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def run_batch(
backend=None,
num_threads: Union[str, int] = "auto",
progress_bar: bool = False,
generator_style: bool = False,
):
from sglang.lang.interpreter import run_program_batch

Expand Down Expand Up @@ -277,6 +278,7 @@ def run_batch(
default_sampling_para,
num_threads,
progress_bar,
generator_style=generator_style,
)

def trace(self, *, backend=None, **kwargs):
Expand Down

0 comments on commit 1545637

Please sign in to comment.