3
3
import os
4
4
import sys
5
5
import time
6
+ from dataclasses import dataclass
6
7
from random import randint
7
8
from threading import Lock , Thread
9
+ from typing import Any , List
8
10
9
11
import numpy as np
10
12
import torch
13
15
from torch .utils .dlpack import from_dlpack
14
16
15
17
import tensorrt_llm .bindings .executor as trtllm
18
+
19
+ METRIC_TOTAL_OUTPUT_TOKENS = "total_output_tokens"
20
+ METRIC_TOTAL_INPUT_TOKENS = "total_input_tokens"
16
21
import tensorrt_llm .logger as logger
17
22
18
23
24
+ @dataclass
25
+ class RequestData :
26
+ triton_req_id : int
27
+ triton_user_id : str
28
+ batch_index : int
29
+ batch_size : int
30
+ num_return_sequences : int
31
+ num_input_tokens : int
32
+ num_output_tokens : int
33
+ response_sender : Any
34
+
35
+
19
36
def mpi_comm ():
20
37
from mpi4py import MPI
21
38
return MPI .COMM_WORLD
@@ -136,6 +153,10 @@ def parse_medusa_choices(medusa_choices):
136
153
return result
137
154
138
155
156
+ def parse_eagle_choices (eagle_choices ):
157
+ return parse_medusa_choices (eagle_choices )
158
+
159
+
139
160
def get_sampling_config_from_request (request , batch_size = 1 , batch_index = 0 ):
140
161
kwargs = {}
141
162
kwargs ['beam_width' ] = get_input_scalar_by_name (
@@ -254,6 +275,29 @@ def get_lora_config_from_request(request, batch_size=1, batch_index=0):
254
275
return None
255
276
256
277
278
+ def build_1_2_5_buckets (max_value : int ) -> List [int ]:
279
+ """
280
+ Builds a list of buckets with increasing powers of 10 multiplied by
281
+ mantissa values (1, 5), starting from 10 until the value exceeds
282
+ the specified maximum.
283
+
284
+ Example:
285
+ >>> build_1_2_5_buckets(1000)
286
+ [10, 50, 100, 500, 1000]
287
+ """
288
+ mantissa_lst = [1 , 5 ]
289
+ exponent = 1 # Start from exponent 1 instead of 0
290
+ buckets : List [int ] = []
291
+ while True :
292
+ for m in mantissa_lst :
293
+ value = m * 10 ** exponent
294
+ if value <= max_value :
295
+ buckets .append (value )
296
+ else :
297
+ return buckets
298
+ exponent += 1
299
+
300
+
257
301
def convert_request (request , exclude_input_from_output , decoupled ):
258
302
inputs = {}
259
303
input_token_ids = get_input_tensor_by_name (request , 'input_ids' )
@@ -281,7 +325,6 @@ def convert_request(request, exclude_input_from_output, decoupled):
281
325
input_length = len (input_token_ids )
282
326
# Trim input token ids with input_lengths
283
327
inputs ['input_token_ids' ] = input_token_ids [0 :input_length ]
284
-
285
328
inputs ['max_new_tokens' ] = get_input_scalar_by_name (
286
329
request , 'request_output_len' , batch_size , batch_index )
287
330
if inputs ['max_new_tokens' ] is None :
@@ -377,7 +420,7 @@ def convert_response(response, batch_index, batch_size, num_return_sequences):
377
420
if response .has_error ():
378
421
return pb_utils .InferenceResponse (output_tensors = [],
379
422
error = pb_utils .TritonError (
380
- response .error_msg )), True
423
+ response .error_msg )), True , 0
381
424
result = response .result
382
425
beam_lengths = np .expand_dims (
383
426
np .array ([len (beam ) for beam in result .output_token_ids ], np .int32 ), 0 )
@@ -387,6 +430,7 @@ def convert_response(response, batch_index, batch_size, num_return_sequences):
387
430
for idx , beam in enumerate (result .output_token_ids ):
388
431
output_ids [0 , idx , :len (beam )] = beam
389
432
433
+ output_lengths = output_ids .size
390
434
output_tensors = [
391
435
pb_utils .Tensor ("output_ids" , output_ids ),
392
436
pb_utils .Tensor ("sequence_length" , beam_lengths ),
@@ -431,7 +475,8 @@ def convert_response(response, batch_index, batch_size, num_return_sequences):
431
475
np .expand_dims (np .array ([result .sequence_index ], np .int32 ),
432
476
0 )))
433
477
434
- return pb_utils .InferenceResponse (output_tensors ), result .is_final
478
+ return pb_utils .InferenceResponse (
479
+ output_tensors ), result .is_final , output_lengths
435
480
436
481
437
482
def convert_scheduler_policy (batch_scheduler_policy : str ):
@@ -472,6 +517,12 @@ def convert_decoding_mode(decoding_mode: str):
472
517
return trtllm .DecodingMode .BeamSearch ()
473
518
elif decoding_mode == "medusa" :
474
519
return trtllm .DecodingMode .Medusa ()
520
+ elif decoding_mode == "redrafter" :
521
+ return trtllm .DecodingMode .ExplicitDraftTokens ()
522
+ elif decoding_mode == "lookahead" :
523
+ return trtllm .DecodingMode .Lookahead ()
524
+ elif decoding_mode == "eagle" :
525
+ return trtllm .DecodingMode .Eagle ()
475
526
raise pb_utils .TritonModelException (
476
527
f"decoding_mode value of '{ decoding_mode } ' is not supported." )
477
528
@@ -569,10 +620,15 @@ def get_peft_cache_config(self, model_config):
569
620
return trtllm .PeftCacheConfig (** kwargs )
570
621
571
622
def get_decoding_config (self , model_config ):
623
+ eagle_choices = parse_eagle_choices (
624
+ get_parameter (model_config , "eagle_choices" ))
572
625
kwargs = {
573
626
"medusa_choices" :
574
627
parse_medusa_choices (get_parameter (model_config ,
575
628
"medusa_choices" )),
629
+ "eagle_config" :
630
+ None
631
+ if eagle_choices is None else trtllm .EagleConfig (eagle_choices ),
576
632
"decoding_mode" :
577
633
convert_decoding_mode (get_parameter (model_config ,
578
634
"decoding_mode" )),
@@ -653,6 +709,17 @@ def create_metrics(self, model: str, version: str, is_v1_model: bool):
653
709
description = "General TRT LLM metrics" ,
654
710
kind = pb_utils .MetricFamily .GAUGE ,
655
711
)
712
+ # Set the metric using self.general_metric_output_family.observe(string_size)
713
+ self .request_tokens_metric_family = pb_utils .MetricFamily (
714
+ name = "nv_llm_input_token_len" ,
715
+ description = "TRT LLM response metrics" ,
716
+ kind = pb_utils .MetricFamily .HISTOGRAM ,
717
+ )
718
+ self .response_tokens_metric_family = pb_utils .MetricFamily (
719
+ name = "nv_llm_output_token_len" ,
720
+ description = "TRT LLM response metrics" ,
721
+ kind = pb_utils .MetricFamily .HISTOGRAM ,
722
+ )
656
723
common_labels = {"model" : model , "version" : version }
657
724
self .all_metrics = {
658
725
# Request metrics
@@ -724,6 +791,20 @@ def create_metrics(self, model: str, version: str, is_v1_model: bool):
724
791
"general_type" : "iteration_counter" ,
725
792
** common_labels
726
793
}),
794
+ METRIC_TOTAL_OUTPUT_TOKENS :
795
+ self .response_tokens_metric_family .Metric (
796
+ labels = {
797
+ "response_metric_type" : METRIC_TOTAL_OUTPUT_TOKENS ,
798
+ ** common_labels
799
+ },
800
+ buckets = build_1_2_5_buckets (1000 )),
801
+ METRIC_TOTAL_INPUT_TOKENS :
802
+ self .request_tokens_metric_family .Metric (
803
+ labels = {
804
+ "response_metric_type" : METRIC_TOTAL_INPUT_TOKENS ,
805
+ ** common_labels
806
+ },
807
+ buckets = build_1_2_5_buckets (1000 )),
727
808
}
728
809
if is_v1_model :
729
810
self .all_metrics .update ({
@@ -917,12 +998,21 @@ def execute(self, requests):
917
998
request_ids , triton_req_ids , triton_user_ids ,
918
999
executor_requests , triton_requests , batch_indices ):
919
1000
920
- self .req_id_to_request_data [
921
- req_id ] = triton_req_id , triton_user_id , batch_index , len (
922
- batch_indices
923
- ), executor_request .num_return_sequences , triton_request .get_response_sender (
924
- )
1001
+ self .req_id_to_request_data [req_id ] = RequestData (
1002
+ triton_req_id , triton_user_id , batch_index ,
1003
+ len (batch_indices ), executor_request .num_return_sequences ,
1004
+ 0 , 0 , triton_request .get_response_sender ())
925
1005
self .triton_req_id_to_req_ids [triton_req_id ].add (req_id )
1006
+ input_len = len (
1007
+ executor_request .input_token_ids
1008
+ ) if executor_request .input_token_ids is not None else 0
1009
+ self .req_id_to_request_data [
1010
+ req_id ].num_input_tokens += input_len
1011
+ # This checks both request level and instance config level
1012
+ if executor_request .output_config .exclude_input_from_output == False and executor_request .streaming == False :
1013
+ self .req_id_to_request_data [
1014
+ req_id ].num_output_tokens -= self .req_id_to_request_data [
1015
+ req_id ].num_input_tokens * executor_request .sampling_config .beam_width
926
1016
if triton_user_id is not None and triton_user_id != "" :
927
1017
self .triton_user_id_to_req_ids [triton_user_id ].add (req_id )
928
1018
@@ -934,53 +1024,60 @@ def awaiter_loop(self):
934
1024
for response in self .executor .await_responses (
935
1025
timeout = datetime .timedelta (milliseconds = 1 )):
936
1026
req_id = response .request_id
1027
+ request_data = None
937
1028
with self .lock :
938
1029
if req_id not in self .req_id_to_request_data :
939
1030
continue
940
- triton_req_id , triton_user_id , batch_index , batch_size , num_return_sequences , response_sender = self .req_id_to_request_data [
941
- req_id ]
942
-
943
- triton_response , is_final = convert_response (
944
- response , batch_index , batch_size , num_return_sequences )
1031
+ request_data = self .req_id_to_request_data [req_id ]
945
1032
1033
+ triton_response , is_final , output_length = convert_response (
1034
+ response , request_data .batch_index ,
1035
+ request_data .batch_size , request_data .num_return_sequences )
1036
+ with self .lock :
1037
+ self .req_id_to_request_data [
1038
+ req_id ].num_output_tokens += output_length
946
1039
triton_request_final = False
947
1040
if is_final :
948
1041
with self .lock :
949
1042
# Check if all executor requests part of that triton request are finished
950
- self .triton_req_id_to_req_ids [triton_req_id ]. remove (
951
- req_id )
952
- if len (self .triton_req_id_to_req_ids [triton_req_id ]
953
- ) == 0 :
1043
+ self .triton_req_id_to_req_ids [
1044
+ request_data . triton_req_id ]. remove ( req_id )
1045
+ if len (self .triton_req_id_to_req_ids [
1046
+ request_data . triton_req_id ] ) == 0 :
954
1047
pb_utils .Logger .log_info (
955
- f"DELETING Req id { req_id } , triton_req_id { triton_req_id } "
1048
+ f"DELETING Req id { req_id } , triton_req_id { request_data . triton_req_id } "
956
1049
)
957
1050
triton_request_final = True
958
- del self .triton_req_id_to_req_ids [triton_req_id ]
959
- if triton_user_id is not None and triton_user_id != "" :
1051
+ del self .triton_req_id_to_req_ids [
1052
+ request_data .triton_req_id ]
1053
+ if request_data .triton_user_id is not None and request_data .triton_user_id != "" :
960
1054
del self .triton_user_id_to_req_ids [
961
- triton_user_id ]
1055
+ request_data .triton_user_id ]
1056
+ self .update_metrics_per_request (req_id )
962
1057
del self .req_id_to_request_data [req_id ]
963
1058
964
- response_sender .send (
1059
+ request_data . response_sender .send (
965
1060
triton_response ,
966
1061
flags = pb_utils .TRITONSERVER_RESPONSE_COMPLETE_FINAL
967
1062
if triton_request_final else 0 )
968
1063
969
- # Remove local reference so response_sender can be cleaned properly.
970
- del response_sender
971
-
972
1064
def cancellation_loop (self ):
973
1065
"""Checks if any pending requests have been cancelled."""
974
1066
while self .running :
975
1067
time .sleep (self .cancellation_check_period_ms / 1000.0 )
976
1068
with self .lock :
977
- for req_id , (triton_req_id , triton_user_id , batch_index ,
978
- batch_size , num_return_sequences , response_sender
979
- ) in self .req_id_to_request_data .items ():
980
- if response_sender .is_cancelled ():
1069
+ for req_id , request_data in self .req_id_to_request_data .items (
1070
+ ):
1071
+ if request_data .response_sender .is_cancelled ():
981
1072
self .executor .cancel_request (req_id )
982
- # Remove local reference so response_sender can be cleaned properly.
983
- del response_sender
1073
+
1074
+ def update_metrics_per_request (self , req_id ):
1075
+ """Updates triton metrics after completing one request"""
1076
+ output_tokens = self .req_id_to_request_data [req_id ].num_output_tokens
1077
+ input_tokens = self .req_id_to_request_data [req_id ].num_input_tokens
1078
+
1079
+ self .all_metrics [METRIC_TOTAL_OUTPUT_TOKENS ].observe (output_tokens )
1080
+ self .all_metrics [METRIC_TOTAL_INPUT_TOKENS ].observe (input_tokens )
984
1081
985
1082
def metrics_loop (self ):
986
1083
"""Updates triton metrics using stats from the executor."""
@@ -989,6 +1086,12 @@ def metrics_loop(self):
989
1086
for stat in self .executor .get_latest_iteration_stats ():
990
1087
try :
991
1088
for key , metric in self .all_metrics .items ():
1089
+ # Skip processing for both histogram metrics
1090
+ if isinstance (key , str ) and key in [
1091
+ METRIC_TOTAL_OUTPUT_TOKENS ,
1092
+ METRIC_TOTAL_INPUT_TOKENS
1093
+ ]:
1094
+ continue
992
1095
value = None
993
1096
if hasattr (stat , key ):
994
1097
value = getattr (stat , key )
0 commit comments