Skip to content

Commit 3bc85f8

Browse files
committed
Coalesce text diffs in streaming requests.
Signed-off-by: Patrick Reiter Horn <[email protected]>
1 parent 37ac564 commit 3bc85f8

File tree

5 files changed

+29
-16
lines changed

5 files changed

+29
-16
lines changed

examples/apps/fastapi_server.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,10 @@ async def generate(self, request: Request) -> Response:
7171
sampling_params=sampling_params)
7272

7373
async def stream_results() -> AsyncGenerator[bytes, None]:
74+
last_text_len: int = 0
7475
async for output in promise:
75-
yield output.outputs[0].text_diff.encode("utf-8")
76+
text_diff, last_text_len = output.outputs[0].text_diff_safe(last_text_len)
77+
yield text_diff.encode("utf-8")
7678

7779
if streaming:
7880
return StreamingResponse(stream_results())

tensorrt_llm/executor/postproc_worker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
@dataclass(kw_only=True)
3232
class PostprocArgs:
3333
first_iteration: bool = True
34+
last_text_len: int = 0
35+
last_logprobs_len: int = 0
36+
last_token_ids_len: int = 0
3437
num_prompt_tokens: Optional[int] = None
3538
tokenizer: Optional[TransformersTokenizer] = None
3639

tensorrt_llm/executor/result.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from dataclasses import dataclass, field
55
from queue import Empty, Queue
66
from typing import (TYPE_CHECKING, Any, Callable, List, Literal, NamedTuple,
7-
Optional, TypeAlias, Union)
7+
Optional, TypeAlias, Union, Tuple)
88
from weakref import WeakMethod
99

1010
import torch
@@ -83,8 +83,11 @@ class CompletionOutput:
8383
Attributes:
8484
length (int): The number of generated tokens.
8585
token_ids_diff (List[int]): Newly generated token ids.
86-
logprobs_diff (List[float]): Logprobs of newly generated tokens.
87-
text_diff (str): Newly generated tokens.
86+
87+
Accessors:
88+
token_ids_diff_safe(int) -> Tuple[List[int], int]: Newly generated token ids since the given length.
89+
logprobs_diff_safe(int) -> Tuple[List[float], int]: Logprobs of newly generated tokens since the given length.
90+
text_diff_safe(int) -> Tuple[str, int]: Newly generated tokens since the given length.
8891
"""
8992
index: int
9093
text: str = ""
@@ -112,18 +115,19 @@ class CompletionOutput:
112115
def length(self) -> int:
113116
return len(self.token_ids)
114117

115-
@property
116-
def text_diff(self) -> str:
117-
return self.text[self._last_text_len:]
118+
def text_diff_safe(self, last_text_len) -> Tuple[str, int]:
119+
return self.text[last_text_len:], len(self.text)
120+
121+
def logprobs_diff_safe(self, last_logprobs_len) -> Tuple[List[float], int]:
122+
return self.logprobs[last_logprobs_len:], len(self.logprobs)
123+
124+
def token_ids_diff_safe(self, last_token_ids_len) -> Tuple[List[int], int]:
125+
return self.logprobs[last_token_ids_len:], len(self.logprobs)
118126

119127
@property
120128
def token_ids_diff(self) -> List[int]:
121129
return self.token_ids[self._last_token_ids_len:]
122130

123-
@property
124-
def logprobs_diff(self) -> List[float]:
125-
return self.logprobs[self._last_logprobs_len:]
126-
127131

128132
class GenerationResultBase:
129133
''' This holds the core logic of the GenerationResult class. '''

tensorrt_llm/serve/postprocess_handlers.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,9 @@ def yield_first_chat(num_tokens: int,
139139
if finish_reason_sent[i]:
140140
continue
141141

142-
delta_text = output.text_diff
142+
delta_text, args.last_text_len = output.text_diff_safe(args.last_text_len)
143+
if delta_text == '' and not output.finish_reason and not output.stop_reason:
144+
continue
143145

144146
in_reasoning, delta_text, reasoning_delta_text = apply_reasoning_parser(
145147
args, i, delta_text, True)
@@ -162,8 +164,8 @@ def yield_first_chat(num_tokens: int,
162164
delta=delta_message,
163165
finish_reason=None)
164166
if args.return_logprobs:
165-
logprobs = output.logprobs_diff
166-
token_ids = output.token_ids_diff
167+
logprobs, args.last_logprobs_len = output.logprobs_diff_safe(args.last_logprobs_len)
168+
token_ids, args.last_token_ids_len = output.token_ids_diff_safe(args.last_token_ids_len)
167169
choice.logprobs = create_logprobs(token_ids, args.tokenizer, logprobs)
168170
if output.finish_reason is not None:
169171
choice.finish_reason = output.finish_reason
@@ -282,9 +284,11 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args:
282284
include_continuous_usage = False
283285

284286
for output in rsp.outputs:
285-
delta_text = output.text_diff
287+
delta_text, args.last_text_len = output.text_diff_safe(args.last_text_len)
286288
if args.echo and args.first_iteration:
287289
delta_text = args.prompt + delta_text
290+
if delta_text == '' and not output.finish_reason and not output.stop_reason:
291+
continue
288292
choice = CompletionResponseStreamChoice(
289293
index=args.prompt_idx * args.num_choices + output.index,
290294
text=delta_text,

tests/unittest/llmapi/run_llm_with_postproc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def yield_first_chat(idx: int, role: str = None, content: str = None):
4646
if finish_reason_sent[i]:
4747
continue
4848

49-
delta_text = output.text_diff
49+
delta_text, args.last_text_len = output.text_diff_safe(args.last_text_len)
5050
delta_message = DeltaMessage(content=delta_text)
5151

5252
choice = ChatCompletionResponseStreamChoice(index=i,

0 commit comments

Comments
 (0)