Skip to content

Commit 1f24b4e

Browse files
committed
feat(scaffolding): add streaming scaffolding_llm.generate_async support
Signed-off-by: Zhenhuan Chen <[email protected]>
1 parent ddfe4fc commit 1f24b4e

File tree

13 files changed

+334
-199
lines changed

13 files changed

+334
-199
lines changed

examples/scaffolding/contrib/AsyncGeneration/stream_generation_controller.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from typing import List
44

55
from tensorrt_llm.scaffolding import Controller, GenerationTask, Task
6-
from tensorrt_llm.scaffolding.contrib import StreamGenerationTask
6+
from tensorrt_llm.scaffolding.contrib.AsyncGeneration import \
7+
StreamGenerationTask
78

89

910
class NativeStreamGenerationController(Controller):

examples/scaffolding/contrib/Dynasor/scaffolding_dynasor_run.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import argparse
2+
import asyncio
23

34
from tensorrt_llm.scaffolding import (MajorityVoteController, ScaffoldingLlm,
45
TRTLLMWorker)
5-
from tensorrt_llm.scaffolding.contrib import DynasorGenerationController
6+
from tensorrt_llm.scaffolding.contrib.Dynasor import DynasorGenerationController
67

78

89
def parse_arguments():
@@ -16,13 +17,16 @@ def parse_arguments():
1617
parser.add_argument("--max_num_tokens", type=int, default=7000)
1718
parser.add_argument("--majority_vote", action='store_true')
1819
parser.add_argument('--sample_num', type=int, default=3)
20+
parser.add_argument('--streaming', action='store_true')
1921
args = parser.parse_args()
2022
return args
2123

2224

2325
def test(prompts, proposer_worker, args):
2426
dynasor_generation_controller = DynasorGenerationController(
25-
generation_dir=args.model_dir, max_tokens=args.max_num_tokens)
27+
generation_dir=args.model_dir,
28+
max_tokens=args.max_num_tokens,
29+
streaming=args.streaming)
2630

2731
# If majority voting is requested, wrap the controller in MajorityVoteController
2832
if args.majority_vote:
@@ -47,9 +51,25 @@ def test(prompts, proposer_worker, args):
4751
},
4852
)
4953

50-
results = llm.generate(prompts)
51-
for result in results:
52-
print(result.output.output_str)
54+
if args.streaming:
55+
56+
async def task(prompt: str):
57+
i = 0
58+
async for result in llm.generate_async(prompt):
59+
i += 1
60+
print(">>>", i, result)
61+
async for output in result.output:
62+
print(i, len(output.outputs[0].text))
63+
print(f">>> final output {len(output.outputs[0].text)}\n",
64+
output.outputs[0].text)
65+
# print(f">>> final result.output {len(result.output.outputs[0].text)} {result.output}\n", result.output.outputs[0].text)
66+
67+
asyncio.run_coroutine_threadsafe(task(prompts[0]), llm.loop).result()
68+
else:
69+
results = llm.generate(prompts)
70+
for result in results:
71+
print(result.output.outputs[0].text)
72+
5373
print(f"main shutting down...")
5474
llm.shutdown()
5575
print(f"worker shutting down...")
@@ -62,8 +82,8 @@ def main():
6282

6383
prompts = [
6484
"Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\r\n\r\n",
65-
"There exist real numbers $x$ and $y$, both greater than 1, such that $\\log_x\\left(y^x\\right)=\\log_y\\left(x^{4y}\\right)=10$. Find $xy$.",
66-
"Find the largest possible real part of \\[(75+117i)z+\\frac{96+144i}{z}\\]where $z$ is a complex number with $|z|=4$.",
85+
# "There exist real numbers $x$ and $y$, both greater than 1, such that $\\log_x\\left(y^x\\right)=\\log_y\\left(x^{4y}\\right)=10$. Find $xy$.",
86+
# "Find the largest possible real part of \\[(75+117i)z+\\frac{96+144i}{z}\\]where $z$ is a complex number with $|z|=4$.",
6787
]
6888

6989
llm_worker = TRTLLMWorker.init_with_new_llm(

examples/scaffolding/run_basic_generation.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_sync(prompts, proposer_worker):
2828
)
2929
results = llm.generate(prompts)
3030
for result in results:
31-
print(result.output.output_str)
31+
print(result.output.outputs[0].text)
3232
print(f'main shutting down...')
3333
llm.shutdown()
3434
print(f'worker shutting down...')
@@ -40,16 +40,24 @@ def test_async(prompt, proposer_worker):
4040

4141
async def test_async_func(prompt, proposer_worker):
4242
prototype_controller = NativeGenerationController(
43-
sampling_params={"temperature": 0.9})
43+
sampling_params={"temperature": 0.9}, streaming=True)
4444
llm = ScaffoldingLlm(
4545
prototype_controller,
4646
{NativeGenerationController.WorkerTag.GENERATION: proposer_worker},
4747
)
48-
49-
future = llm.generate_async(prompt)
50-
51-
result = await future.aresult()
52-
print(result.output.output_str)
48+
i = 0
49+
50+
async for result in llm.generate_async(prompt):
51+
i += 1
52+
print(">>>", i, result)
53+
async for output in result.output:
54+
print(len(output.outputs[0].text))
55+
# print(result.output,
56+
# end='\n' if result.finished else '\r',
57+
# flush=True)
58+
59+
# result = await future.aresult()
60+
# print(result.output.output_str)
5361

5462
print(f'main shutting down...')
5563
llm.shutdown()

examples/scaffolding/token_budget_majority_vote.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,17 @@ def main():
7777
args = parse_arguments()
7878
workers = {}
7979

80-
llm_worker = TRTLLMWorker.init_with_new_llm(args.model_dir,
81-
backend="pytorch",
82-
max_batch_size=32,
83-
max_num_tokens=4096,
84-
temperature=0.9)
80+
llm_worker = TRTLLMWorker.init_with_new_llm(
81+
args.model_dir,
82+
max_batch_size=32,
83+
max_num_tokens=4096,
84+
)
8585

8686
prototype_generation_controller = NativeGenerationController(
87-
custom_sampling_params={
87+
sampling_params={
8888
"max_tokens": 4096,
8989
"top_p": 0.9,
90+
"temperature": 0.9,
9091
})
9192
workers[NativeGenerationController.WorkerTag.GENERATION] = llm_worker
9293

tensorrt_llm/scaffolding/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .benchmark import ScaffoldingBenchRequest, async_scaffolding_benchmark
22
from .controller import (BestOfNController, Controller, MajorityVoteController,
33
NativeGenerationController, NativeRewardController,
4-
ParallelProcess, ScaffoldingOutput)
4+
ParallelProcess)
55
from .math_utils import (extract_answer_from_boxed, extract_answer_with_regex,
66
get_digit_majority_vote_result)
77
from .scaffolding_llm import ScaffoldingLlm

tensorrt_llm/scaffolding/contrib/AsyncGeneration/stream_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ async def get_step_or_more_tokens(task: StreamGenerationTask):
5151
if task.request_handle._done:
5252
task.end_flag = True
5353

54-
sampling_params = worker.combine_sampling_params_with_generation_task(task)
54+
sampling_params = worker.convert_task_params(task)
5555
if task.request_handle is None:
5656
task.request_handle = worker.llm.generate_async(
5757
task.input_str, sampling_params=sampling_params, streaming=True)

tensorrt_llm/scaffolding/contrib/Dynasor/dynasor_controller.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@ class WorkerTag(Enum):
1515

1616
# Certainty_threshold and chunk_size controls the compute saving level
1717
# Decreasing the certainty_threshold and chunk_size will save tokens but may risk at compromising accuracy.
18-
def __init__(self,
19-
generation_dir,
20-
max_tokens=8192,
21-
certainty_threshold=3,
22-
chunk_size=64):
18+
def __init__(
19+
self,
20+
generation_dir,
21+
max_tokens=8192,
22+
certainty_threshold=3,
23+
chunk_size=64,
24+
streaming=False,
25+
):
2326
"""
2427
Initializes the controller with parameters controlling token limits and certainty thresholds.
2528
@@ -46,6 +49,7 @@ def __init__(self,
4649
trust_remote_code=False,
4750
use_fast=True,
4851
)
52+
self.streaming = streaming
4953

5054
def process(self, tasks: List[GenerationTask], **kwargs):
5155
"""
@@ -70,12 +74,14 @@ def process(self, tasks: List[GenerationTask], **kwargs):
7074
proposer_task.temperature = 0.6
7175
proposer_task.top_p = 0.95
7276
proposer_task.worker_tag = self.WorkerTag.GENERATION
77+
proposer_task.streaming = self.streaming
7378

7479
probe_task = GenerationTask()
7580
probe_task.max_tokens = 20
7681
probe_task.temperature = 0.6
7782
probe_task.top_p = 0.95
7883
probe_task.worker_tag = self.WorkerTag.GENERATION
84+
probe_task.streaming = self.streaming
7985

8086
probe_answers = []
8187
probe_responses = []
@@ -96,9 +102,13 @@ def process(self, tasks: List[GenerationTask], **kwargs):
96102
probe_task.input_str = current_prompt + self.probe_suffix
97103

98104
# For the probe task, append the suffix to force a chain-of-thought leading to an answer.
105+
print("[DynasorGenerationController] probe_task")
99106
yield [probe_task]
100107

101108
# Retrieve the output from the probe task.
109+
# if probe_task.streaming:
110+
# print("[DynasorGenerationController] wait result for probe_task")
111+
# probe_task.result.result()
102112
probe_text = probe_task.output_str
103113

104114
# Extract the potential answer from the probe response.
@@ -120,6 +130,7 @@ def process(self, tasks: List[GenerationTask], **kwargs):
120130
probe_answers[-self.certainty_threshold:])
121131
== self.certainty_threshold
122132
and sum(probe_certain_count) == self.certainty_threshold):
133+
tasks[0].result = probe_task.result
123134
# If the current prompt indicates the chain-of-thought phase has ended, use one type of suffix.
124135
if "</think>" in current_prompt:
125136
tasks[0].output_str = (current_prompt + self.answer_suffix +
@@ -133,13 +144,18 @@ def process(self, tasks: List[GenerationTask], **kwargs):
133144
return
134145

135146
# if not confident, do another round of generation
147+
print("[DynasorGenerationController] proposer_task")
136148
yield [proposer_task]
137149

138150
# Append the newly generated text from the proposer to the current prompt for the next iteration.
151+
# if proposer_task.streaming:
152+
# print("[DynasorGenerationController] wait result for proposer_task")
153+
# proposer_task.result.result()
139154
current_prompt += proposer_task.output_str
140155

141156
# If the maximum token limit is reached without satisfying the certainty condition,
142157
# output the accumulated prompt as the final output.
158+
tasks[0].result = proposer_task.result
143159
tasks[0].output_str = current_prompt
144160
return
145161

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +0,0 @@
1-
from tensorrt_llm.scaffolding import * # noqa
2-
3-
from .AsyncGeneration import StreamGenerationTask, stream_generation_handler
4-
from .Dynasor import DynasorGenerationController
5-
from .mcp import (ChatTask, MCPCallTask, MCPController, MCPListTask, MCPWorker,
6-
chat_handler)
7-
8-
__all__ = [
9-
# AsyncGeneration
10-
"stream_generation_handler",
11-
"StreamGenerationTask",
12-
# Dynasor
13-
"DynasorGenerationController",
14-
#mcp
15-
"MCPController",
16-
"MCPWorker",
17-
"MCPCallTask",
18-
"MCPListTask",
19-
"ChatTask",
20-
"chat_handler"
21-
]

tensorrt_llm/scaffolding/controller.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,12 @@
66
import torch
77
from torch.nn import functional as F
88

9+
from tensorrt_llm.executor.result import GenerationResult
910
from tensorrt_llm.logger import logger
1011
from tensorrt_llm.scaffolding.math_utils import get_digit_majority_vote_result
11-
from tensorrt_llm.scaffolding.task import (GenerationTask, ScaffoldingOutput,
12-
Task)
12+
from tensorrt_llm.scaffolding.task import GenerationTask, Task
1313

14-
15-
class ScaffoldingOutput:
16-
17-
def __init__(self):
18-
self.output_str = None
19-
# reserved for customized controller
20-
self.customized_output = None
14+
# from .result import ScaffoldingOutput
2115

2216

2317
class Controller(ABC):
@@ -28,11 +22,12 @@ def __init__(self):
2822
def clone(self):
2923
return copy.deepcopy(self)
3024

31-
def generate(self, prompt: str, **kwargs) -> ScaffoldingOutput:
25+
def generate(self, prompt: str, **kwargs) -> GenerationResult:
3226
task = GenerationTask.create_from_prompt(prompt)
3327

3428
yield from self.process([task], **kwargs)
3529

30+
# print("[Controller.generate] task.output in generate", task.result)
3631
return task.create_scaffolding_output()
3732

3833
def process(self, tasks: List[Task], **kwargs):
@@ -57,7 +52,7 @@ class NativeGenerationController(Controller):
5752
class WorkerTag(Enum):
5853
GENERATION = "generation"
5954

60-
def __init__(self, sampling_params: dict = None):
55+
def __init__(self, sampling_params: dict = None, streaming: bool = False):
6156
super().__init__()
6257
if sampling_params is None:
6358
sampling_params = {}
@@ -67,13 +62,15 @@ def __init__(self, sampling_params: dict = None):
6762
f"{key} is not a supported field for GenerationTask")
6863
sampling_params.pop(key)
6964
self.sampling_params = sampling_params
65+
self.streaming = streaming
7066

7167
def process(self, tasks: List[Task], **kwargs):
7268
for task in tasks:
7369
task.worker_tag = self.WorkerTag.GENERATION
7470
for key, value in self.sampling_params.items():
7571
if getattr(task, key) is None:
7672
setattr(task, key, value)
73+
task.streaming = self.streaming
7774

7875
yield tasks
7976

0 commit comments

Comments
 (0)