diff --git a/examples/apps/fastapi_server.py b/examples/apps/fastapi_server.py index b2aa0baf2a..e254e2b018 100755 --- a/examples/apps/fastapi_server.py +++ b/examples/apps/fastapi_server.py @@ -72,8 +72,10 @@ async def generate(self, request: Request) -> Response: sampling_params=sampling_params) async def stream_results() -> AsyncGenerator[bytes, None]: + last_text_len: int = 0 async for output in promise: - yield output.outputs[0].text_diff.encode("utf-8") + text_diff, last_text_len = output.outputs[0].text_diff_safe(last_text_len) + yield text_diff.encode("utf-8") if streaming: return StreamingResponse(stream_results()) diff --git a/tensorrt_llm/executor/postproc_worker.py b/tensorrt_llm/executor/postproc_worker.py index 2e5a3cd296..53240a04f2 100644 --- a/tensorrt_llm/executor/postproc_worker.py +++ b/tensorrt_llm/executor/postproc_worker.py @@ -29,6 +29,9 @@ @dataclass(kw_only=True) class PostprocArgs: first_iteration: bool = True + last_text_len: int = 0 + last_logprobs_len: int = 0 + last_token_ids_len: int = 0 num_prompt_tokens: Optional[int] = None tokenizer: Optional[TransformersTokenizer] = None diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index 67c1d3d120..d3476a6122 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field from queue import Empty, Queue from typing import (TYPE_CHECKING, Any, Callable, List, Literal, NamedTuple, - Optional, TypeAlias, Union) + Optional, TypeAlias, Union, Tuple) from weakref import WeakMethod import torch @@ -89,8 +89,11 @@ class CompletionOutput: Attributes: length (int): The number of generated tokens. token_ids_diff (List[int]): Newly generated token ids. - logprobs_diff (List[float]): Logprobs of newly generated tokens. - text_diff (str): Newly generated tokens. + + Accessors: + token_ids_diff_safe(int) -> Tuple[List[int], int]: Newly generated token ids since the given length. + logprobs_diff_safe(int) -> Tuple[List[float], int]: Logprobs of newly generated tokens since the given length. + text_diff_safe(int) -> Tuple[str, int]: Newly generated tokens since the given length. """ index: int text: str = "" @@ -119,17 +122,26 @@ class CompletionOutput: def length(self) -> int: return len(self.token_ids) - @property - def text_diff(self) -> str: - return self.text[self._last_text_len:] + def text_diff_safe(self, last_text_len) -> Tuple[str, int]: + return self.text[last_text_len:], len(self.text) + + def logprobs_diff_safe(self, last_logprobs_len) -> Tuple[List[float], int]: + return self.logprobs[last_logprobs_len:], len(self.logprobs) + + def token_ids_diff_safe(self, last_token_ids_len) -> Tuple[List[int], int]: + return self.logprobs[last_token_ids_len:], len(self.logprobs) + + #@property + #def text_diff(self) -> str: + # return self.text[self._last_text_len:] @property def token_ids_diff(self) -> List[int]: return self.token_ids[self._last_token_ids_len:] - @property - def logprobs_diff(self) -> List[float]: - return self.logprobs[self._last_logprobs_len:] + #@property + #def logprobs_diff(self) -> List[float]: + # return self.logprobs[self._last_logprobs_len:] class GenerationResultBase: diff --git a/tensorrt_llm/serve/postprocess_handlers.py b/tensorrt_llm/serve/postprocess_handlers.py index 321ff6cc90..03fe4977e1 100644 --- a/tensorrt_llm/serve/postprocess_handlers.py +++ b/tensorrt_llm/serve/postprocess_handlers.py @@ -139,7 +139,9 @@ def yield_first_chat(num_tokens: int, if finish_reason_sent[i]: continue - delta_text = output.text_diff + delta_text, args.last_text_len = output.text_diff_safe(args.last_text_len) + if delta_text == '' and not output.finish_reason and not output.stop_reason: + continue in_reasoning, delta_text, reasoning_delta_text = apply_reasoning_parser( args, i, delta_text, True) @@ -162,8 +164,8 @@ def yield_first_chat(num_tokens: int, delta=delta_message, finish_reason=None) if args.return_logprobs: - logprobs = output.logprobs_diff - token_ids = output.token_ids_diff + logprobs, args.last_logprobs_len = output.logprobs_diff_safe(args.last_logprobs_len) + token_ids, args.last_token_ids_len = output.token_ids_diff_safe(args.last_token_ids_len) choice.logprobs = create_logprobs(token_ids, args.tokenizer, logprobs) if output.finish_reason is not None: choice.finish_reason = output.finish_reason @@ -282,9 +284,11 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args: include_continuous_usage = False for output in rsp.outputs: - delta_text = output.text_diff + delta_text, args.last_text_len = output.text_diff_safe(args.last_text_len) if args.echo and args.first_iteration: delta_text = args.prompt + delta_text + if delta_text == '' and not output.finish_reason and not output.stop_reason: + continue choice = CompletionResponseStreamChoice( index=args.prompt_idx * args.num_choices + output.index, text=delta_text, diff --git a/tests/unittest/llmapi/run_llm_with_postproc.py b/tests/unittest/llmapi/run_llm_with_postproc.py index 6ee365c952..1b20fefe4f 100644 --- a/tests/unittest/llmapi/run_llm_with_postproc.py +++ b/tests/unittest/llmapi/run_llm_with_postproc.py @@ -47,7 +47,7 @@ def yield_first_chat(idx: int, role: str = None, content: str = None): if finish_reason_sent[i]: continue - delta_text = output.text_diff + delta_text, args.last_text_len = output.text_diff_safe(args.last_text_len) delta_message = DeltaMessage(content=delta_text) choice = ChatCompletionResponseStreamChoice(index=i,