Skip to content

Commit 91c07d3

Browse files
authored
Update TensorRT-LLM backend (#652)
1 parent 869c2e0 commit 91c07d3

File tree

16 files changed

+606
-84
lines changed

16 files changed

+606
-84
lines changed

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ repo. If you don't find your answer there you can ask questions on the
7373
- [Scheduling](#scheduling)
7474
- [Key-Value Cache](#key-value-cache)
7575
- [Decoding](#decoding)
76-
- [Decoding Modes - Top-k, Top-p, Top-k Top-p, Beam Search and Medusa](#decoding-modes---top-k-top-p-top-k-top-p-beam-search-and-medusa)
76+
- [Decoding Modes - Top-k, Top-p, Top-k Top-p, Beam Search, Medusa, ReDrafter, Lookahead and Eagle](#decoding-modes---top-k-top-p-top-k-top-p-beam-search-medusa-redrafter-lookahead-and-eagle)
7777
- [Speculative Decoding](#speculative-decoding)
7878
- [Chunked Context](#chunked-context)
7979
- [Quantization](#quantization)
@@ -606,15 +606,15 @@ TRT-LLM engine. Parameters for KV cache can be found in the
606606

607607
### Decoding
608608

609-
#### Decoding Modes - Top-k, Top-p, Top-k Top-p, Beam Search and Medusa
609+
#### Decoding Modes - Top-k, Top-p, Top-k Top-p, Beam Search, Medusa, ReDrafter, Lookahead and Eagle
610610

611611
TensorRT-LLM supports various decoding modes, including top-k, top-p,
612-
top-k top-p, beam search and Medusa. See the
612+
top-k top-p, beam search Medusa, ReDrafter, Lookahead and Eagle. See the
613613
[Sampling Parameters](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/advanced/gpt-runtime.md#sampling-parameters)
614614
section to learn more about top-k, top-p, top-k top-p and beam search decoding.
615-
For more details on Medusa, please refer to the
616-
[Medusa Decoding](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/medusa)
617-
documentation.
615+
Please refer to the
616+
[speculative decoding documentation](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/advanced/speculative-decoding.md)
617+
for more details on Medusa, ReDrafter, Lookahead and Eagle.
618618

619619
Parameters for decoding modes can be found in the
620620
[model config](./docs/model_config.md#tensorrt_llm_model) of tensorrt_llm model.

all_models/inflight_batcher_llm/tensorrt_llm/1/model.py

Lines changed: 134 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import os
44
import sys
55
import time
6+
from dataclasses import dataclass
67
from random import randint
78
from threading import Lock, Thread
9+
from typing import Any, List
810

911
import numpy as np
1012
import torch
@@ -13,9 +15,24 @@
1315
from torch.utils.dlpack import from_dlpack
1416

1517
import tensorrt_llm.bindings.executor as trtllm
18+
19+
METRIC_TOTAL_OUTPUT_TOKENS = "total_output_tokens"
20+
METRIC_TOTAL_INPUT_TOKENS = "total_input_tokens"
1621
import tensorrt_llm.logger as logger
1722

1823

24+
@dataclass
25+
class RequestData:
26+
triton_req_id: int
27+
triton_user_id: str
28+
batch_index: int
29+
batch_size: int
30+
num_return_sequences: int
31+
num_input_tokens: int
32+
num_output_tokens: int
33+
response_sender: Any
34+
35+
1936
def mpi_comm():
2037
from mpi4py import MPI
2138
return MPI.COMM_WORLD
@@ -136,6 +153,10 @@ def parse_medusa_choices(medusa_choices):
136153
return result
137154

138155

156+
def parse_eagle_choices(eagle_choices):
157+
return parse_medusa_choices(eagle_choices)
158+
159+
139160
def get_sampling_config_from_request(request, batch_size=1, batch_index=0):
140161
kwargs = {}
141162
kwargs['beam_width'] = get_input_scalar_by_name(
@@ -254,6 +275,29 @@ def get_lora_config_from_request(request, batch_size=1, batch_index=0):
254275
return None
255276

256277

278+
def build_1_2_5_buckets(max_value: int) -> List[int]:
279+
"""
280+
Builds a list of buckets with increasing powers of 10 multiplied by
281+
mantissa values (1, 5), starting from 10 until the value exceeds
282+
the specified maximum.
283+
284+
Example:
285+
>>> build_1_2_5_buckets(1000)
286+
[10, 50, 100, 500, 1000]
287+
"""
288+
mantissa_lst = [1, 5]
289+
exponent = 1 # Start from exponent 1 instead of 0
290+
buckets: List[int] = []
291+
while True:
292+
for m in mantissa_lst:
293+
value = m * 10**exponent
294+
if value <= max_value:
295+
buckets.append(value)
296+
else:
297+
return buckets
298+
exponent += 1
299+
300+
257301
def convert_request(request, exclude_input_from_output, decoupled):
258302
inputs = {}
259303
input_token_ids = get_input_tensor_by_name(request, 'input_ids')
@@ -281,7 +325,6 @@ def convert_request(request, exclude_input_from_output, decoupled):
281325
input_length = len(input_token_ids)
282326
# Trim input token ids with input_lengths
283327
inputs['input_token_ids'] = input_token_ids[0:input_length]
284-
285328
inputs['max_new_tokens'] = get_input_scalar_by_name(
286329
request, 'request_output_len', batch_size, batch_index)
287330
if inputs['max_new_tokens'] is None:
@@ -377,7 +420,7 @@ def convert_response(response, batch_index, batch_size, num_return_sequences):
377420
if response.has_error():
378421
return pb_utils.InferenceResponse(output_tensors=[],
379422
error=pb_utils.TritonError(
380-
response.error_msg)), True
423+
response.error_msg)), True, 0
381424
result = response.result
382425
beam_lengths = np.expand_dims(
383426
np.array([len(beam) for beam in result.output_token_ids], np.int32), 0)
@@ -387,6 +430,7 @@ def convert_response(response, batch_index, batch_size, num_return_sequences):
387430
for idx, beam in enumerate(result.output_token_ids):
388431
output_ids[0, idx, :len(beam)] = beam
389432

433+
output_lengths = output_ids.size
390434
output_tensors = [
391435
pb_utils.Tensor("output_ids", output_ids),
392436
pb_utils.Tensor("sequence_length", beam_lengths),
@@ -431,7 +475,8 @@ def convert_response(response, batch_index, batch_size, num_return_sequences):
431475
np.expand_dims(np.array([result.sequence_index], np.int32),
432476
0)))
433477

434-
return pb_utils.InferenceResponse(output_tensors), result.is_final
478+
return pb_utils.InferenceResponse(
479+
output_tensors), result.is_final, output_lengths
435480

436481

437482
def convert_scheduler_policy(batch_scheduler_policy: str):
@@ -472,6 +517,12 @@ def convert_decoding_mode(decoding_mode: str):
472517
return trtllm.DecodingMode.BeamSearch()
473518
elif decoding_mode == "medusa":
474519
return trtllm.DecodingMode.Medusa()
520+
elif decoding_mode == "redrafter":
521+
return trtllm.DecodingMode.ExplicitDraftTokens()
522+
elif decoding_mode == "lookahead":
523+
return trtllm.DecodingMode.Lookahead()
524+
elif decoding_mode == "eagle":
525+
return trtllm.DecodingMode.Eagle()
475526
raise pb_utils.TritonModelException(
476527
f"decoding_mode value of '{decoding_mode}' is not supported.")
477528

@@ -569,10 +620,15 @@ def get_peft_cache_config(self, model_config):
569620
return trtllm.PeftCacheConfig(**kwargs)
570621

571622
def get_decoding_config(self, model_config):
623+
eagle_choices = parse_eagle_choices(
624+
get_parameter(model_config, "eagle_choices"))
572625
kwargs = {
573626
"medusa_choices":
574627
parse_medusa_choices(get_parameter(model_config,
575628
"medusa_choices")),
629+
"eagle_config":
630+
None
631+
if eagle_choices is None else trtllm.EagleConfig(eagle_choices),
576632
"decoding_mode":
577633
convert_decoding_mode(get_parameter(model_config,
578634
"decoding_mode")),
@@ -653,6 +709,17 @@ def create_metrics(self, model: str, version: str, is_v1_model: bool):
653709
description="General TRT LLM metrics",
654710
kind=pb_utils.MetricFamily.GAUGE,
655711
)
712+
# Set the metric using self.general_metric_output_family.observe(string_size)
713+
self.request_tokens_metric_family = pb_utils.MetricFamily(
714+
name="nv_llm_input_token_len",
715+
description="TRT LLM response metrics",
716+
kind=pb_utils.MetricFamily.HISTOGRAM,
717+
)
718+
self.response_tokens_metric_family = pb_utils.MetricFamily(
719+
name="nv_llm_output_token_len",
720+
description="TRT LLM response metrics",
721+
kind=pb_utils.MetricFamily.HISTOGRAM,
722+
)
656723
common_labels = {"model": model, "version": version}
657724
self.all_metrics = {
658725
# Request metrics
@@ -724,6 +791,20 @@ def create_metrics(self, model: str, version: str, is_v1_model: bool):
724791
"general_type": "iteration_counter",
725792
**common_labels
726793
}),
794+
METRIC_TOTAL_OUTPUT_TOKENS:
795+
self.response_tokens_metric_family.Metric(
796+
labels={
797+
"response_metric_type": METRIC_TOTAL_OUTPUT_TOKENS,
798+
**common_labels
799+
},
800+
buckets=build_1_2_5_buckets(1000)),
801+
METRIC_TOTAL_INPUT_TOKENS:
802+
self.request_tokens_metric_family.Metric(
803+
labels={
804+
"response_metric_type": METRIC_TOTAL_INPUT_TOKENS,
805+
**common_labels
806+
},
807+
buckets=build_1_2_5_buckets(1000)),
727808
}
728809
if is_v1_model:
729810
self.all_metrics.update({
@@ -917,12 +998,21 @@ def execute(self, requests):
917998
request_ids, triton_req_ids, triton_user_ids,
918999
executor_requests, triton_requests, batch_indices):
9191000

920-
self.req_id_to_request_data[
921-
req_id] = triton_req_id, triton_user_id, batch_index, len(
922-
batch_indices
923-
), executor_request.num_return_sequences, triton_request.get_response_sender(
924-
)
1001+
self.req_id_to_request_data[req_id] = RequestData(
1002+
triton_req_id, triton_user_id, batch_index,
1003+
len(batch_indices), executor_request.num_return_sequences,
1004+
0, 0, triton_request.get_response_sender())
9251005
self.triton_req_id_to_req_ids[triton_req_id].add(req_id)
1006+
input_len = len(
1007+
executor_request.input_token_ids
1008+
) if executor_request.input_token_ids is not None else 0
1009+
self.req_id_to_request_data[
1010+
req_id].num_input_tokens += input_len
1011+
# This checks both request level and instance config level
1012+
if executor_request.output_config.exclude_input_from_output == False and executor_request.streaming == False:
1013+
self.req_id_to_request_data[
1014+
req_id].num_output_tokens -= self.req_id_to_request_data[
1015+
req_id].num_input_tokens * executor_request.sampling_config.beam_width
9261016
if triton_user_id is not None and triton_user_id != "":
9271017
self.triton_user_id_to_req_ids[triton_user_id].add(req_id)
9281018

@@ -934,53 +1024,60 @@ def awaiter_loop(self):
9341024
for response in self.executor.await_responses(
9351025
timeout=datetime.timedelta(milliseconds=1)):
9361026
req_id = response.request_id
1027+
request_data = None
9371028
with self.lock:
9381029
if req_id not in self.req_id_to_request_data:
9391030
continue
940-
triton_req_id, triton_user_id, batch_index, batch_size, num_return_sequences, response_sender = self.req_id_to_request_data[
941-
req_id]
942-
943-
triton_response, is_final = convert_response(
944-
response, batch_index, batch_size, num_return_sequences)
1031+
request_data = self.req_id_to_request_data[req_id]
9451032

1033+
triton_response, is_final, output_length = convert_response(
1034+
response, request_data.batch_index,
1035+
request_data.batch_size, request_data.num_return_sequences)
1036+
with self.lock:
1037+
self.req_id_to_request_data[
1038+
req_id].num_output_tokens += output_length
9461039
triton_request_final = False
9471040
if is_final:
9481041
with self.lock:
9491042
# Check if all executor requests part of that triton request are finished
950-
self.triton_req_id_to_req_ids[triton_req_id].remove(
951-
req_id)
952-
if len(self.triton_req_id_to_req_ids[triton_req_id]
953-
) == 0:
1043+
self.triton_req_id_to_req_ids[
1044+
request_data.triton_req_id].remove(req_id)
1045+
if len(self.triton_req_id_to_req_ids[
1046+
request_data.triton_req_id]) == 0:
9541047
pb_utils.Logger.log_info(
955-
f"DELETING Req id {req_id}, triton_req_id {triton_req_id} "
1048+
f"DELETING Req id {req_id}, triton_req_id {request_data.triton_req_id} "
9561049
)
9571050
triton_request_final = True
958-
del self.triton_req_id_to_req_ids[triton_req_id]
959-
if triton_user_id is not None and triton_user_id != "":
1051+
del self.triton_req_id_to_req_ids[
1052+
request_data.triton_req_id]
1053+
if request_data.triton_user_id is not None and request_data.triton_user_id != "":
9601054
del self.triton_user_id_to_req_ids[
961-
triton_user_id]
1055+
request_data.triton_user_id]
1056+
self.update_metrics_per_request(req_id)
9621057
del self.req_id_to_request_data[req_id]
9631058

964-
response_sender.send(
1059+
request_data.response_sender.send(
9651060
triton_response,
9661061
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
9671062
if triton_request_final else 0)
9681063

969-
# Remove local reference so response_sender can be cleaned properly.
970-
del response_sender
971-
9721064
def cancellation_loop(self):
9731065
"""Checks if any pending requests have been cancelled."""
9741066
while self.running:
9751067
time.sleep(self.cancellation_check_period_ms / 1000.0)
9761068
with self.lock:
977-
for req_id, (triton_req_id, triton_user_id, batch_index,
978-
batch_size, num_return_sequences, response_sender
979-
) in self.req_id_to_request_data.items():
980-
if response_sender.is_cancelled():
1069+
for req_id, request_data in self.req_id_to_request_data.items(
1070+
):
1071+
if request_data.response_sender.is_cancelled():
9811072
self.executor.cancel_request(req_id)
982-
# Remove local reference so response_sender can be cleaned properly.
983-
del response_sender
1073+
1074+
def update_metrics_per_request(self, req_id):
1075+
"""Updates triton metrics after completing one request"""
1076+
output_tokens = self.req_id_to_request_data[req_id].num_output_tokens
1077+
input_tokens = self.req_id_to_request_data[req_id].num_input_tokens
1078+
1079+
self.all_metrics[METRIC_TOTAL_OUTPUT_TOKENS].observe(output_tokens)
1080+
self.all_metrics[METRIC_TOTAL_INPUT_TOKENS].observe(input_tokens)
9841081

9851082
def metrics_loop(self):
9861083
"""Updates triton metrics using stats from the executor."""
@@ -989,6 +1086,12 @@ def metrics_loop(self):
9891086
for stat in self.executor.get_latest_iteration_stats():
9901087
try:
9911088
for key, metric in self.all_metrics.items():
1089+
# Skip processing for both histogram metrics
1090+
if isinstance(key, str) and key in [
1091+
METRIC_TOTAL_OUTPUT_TOKENS,
1092+
METRIC_TOTAL_INPUT_TOKENS
1093+
]:
1094+
continue
9921095
value = None
9931096
if hasattr(stat, key):
9941097
value = getattr(stat, key)

all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,12 @@ parameters: {
624624
string_value: "${medusa_choices}"
625625
}
626626
}
627+
parameters: {
628+
key: "eagle_choices"
629+
value: {
630+
string_value: "${eagle_choices}"
631+
}
632+
}
627633
parameters: {
628634
key: "gpu_weights_percent"
629635
value: {

all_models/tests/test_python_backend.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -541,8 +541,8 @@ def test_convert_response(trtllm_response: trtllm.Response):
541541
batch_index = 2
542542
batch_size = 3
543543
num_return_sequences = 1
544-
response, is_final = convert_response(trtllm_response, batch_index,
545-
batch_size, num_return_sequences)
544+
response, is_final, output_length = convert_response(
545+
trtllm_response, batch_index, batch_size, num_return_sequences)
546546
assert is_final == True
547547
assert (response.tensors["output_ids"].as_numpy() == np.array([[1, 2, 3]
548548
])).all()
@@ -564,8 +564,8 @@ def test_convert_response_minimal(trtllm_response_minimal: trtllm.Response):
564564
batch_index = 2
565565
batch_size = 3
566566
num_return_sequences = 1
567-
response, is_final = convert_response(trtllm_response_minimal, batch_index,
568-
batch_size, num_return_sequences)
567+
response, is_final, output_length = convert_response(
568+
trtllm_response_minimal, batch_index, batch_size, num_return_sequences)
569569
assert is_final == False
570570
assert (response.tensors["output_ids"].as_numpy() == np.array([[1, 2, 3]
571571
])).all()
@@ -584,8 +584,8 @@ def test_convert_response_error(trtllm_response_error: trtllm.Response):
584584
batch_index = 2
585585
batch_size = 3
586586
num_return_sequences = 1
587-
response, is_final = convert_response(trtllm_response_error, batch_index,
588-
batch_size, num_return_sequences)
587+
response, is_final, output_length = convert_response(
588+
trtllm_response_error, batch_index, batch_size, num_return_sequences)
589589
assert is_final == True
590590
assert response.has_error() and response.error.message == "internal error"
591591

@@ -622,6 +622,9 @@ def test_convert_decoding_mode():
622622
assert convert_decoding_mode("top_k_top_p").isTopKandTopP()
623623
assert convert_decoding_mode("beam_search").isBeamSearch()
624624
assert convert_decoding_mode("medusa").isMedusa()
625+
assert convert_decoding_mode("redrafter").isExplicitDraftTokens()
626+
assert convert_decoding_mode("lookahead").isLookahead()
627+
assert convert_decoding_mode("eagle").isEagle()
625628
with pytest.raises(
626629
Exception,
627630
match="decoding_mode value of 'other' is not supported"):
@@ -709,6 +712,8 @@ def test_get_executor_config_minimal():
709712
assert config.batching_type == trtllm.BatchingType.INFLIGHT
710713
assert config.decoding_config.decoding_mode is None
711714
assert config.decoding_config.medusa_choices is None
715+
assert config.decoding_config.eagle_config is None
716+
assert config.decoding_config.lookahead_decoding_config is None
712717
assert config.scheduler_config.capacity_scheduler_policy == trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT
713718
assert config.kv_cache_config.enable_block_reuse == False
714719
assert config.kv_cache_config.max_tokens is None

0 commit comments

Comments
 (0)