@@ -311,6 +311,11 @@ def convert_decoding_mode(decoding_mode: str):
311
311
f"decoding_mode value of '{ decoding_mode } ' is not supported." )
312
312
313
313
314
+ def convert_timestamp_to_seconds (timestamp : str ):
315
+ return int (
316
+ datetime .datetime .strptime (timestamp , "%m-%d-%Y %H:%M:%S" ).timestamp ())
317
+
318
+
314
319
class TritonPythonModel :
315
320
"""Your Python model must use the same class name. Every Python model
316
321
that is created must have "TritonPythonModel" as the class name.
@@ -422,6 +427,155 @@ def get_executor_config(self, model_config):
422
427
kwargs = {k : v for k , v in kwargs .items () if v is not None }
423
428
return trtllm .ExecutorConfig (** kwargs )
424
429
430
+ def create_metrics (self , model : str , version : str , is_v1_model : bool ):
431
+ self .request_metric_family = pb_utils .MetricFamily (
432
+ name = "nv_trt_llm_request_metrics" ,
433
+ description = "TRT LLM request metrics" ,
434
+ kind = pb_utils .MetricFamily .GAUGE ,
435
+ )
436
+ self .runtime_memory_metric_family = pb_utils .MetricFamily (
437
+ name = "nv_trt_llm_runtime_memory_metrics" ,
438
+ description = "TRT LLM runtime memory metrics" ,
439
+ kind = pb_utils .MetricFamily .GAUGE ,
440
+ )
441
+ self .kv_cache_metric_family = pb_utils .MetricFamily (
442
+ name = "nv_trt_llm_kv_cache_block_metrics" ,
443
+ description = "TRT LLM KV cache block metrics" ,
444
+ kind = pb_utils .MetricFamily .GAUGE ,
445
+ )
446
+ model_type = "v1" if is_v1_model else "inflight_batcher"
447
+ self .model_type_metric_family = pb_utils .MetricFamily (
448
+ name = f"nv_trt_llm_{ model_type } _metrics" ,
449
+ description = f"TRT LLM { model_type } -specific metrics" ,
450
+ kind = pb_utils .MetricFamily .GAUGE ,
451
+ )
452
+ self .general_metric_family = pb_utils .MetricFamily (
453
+ name = "nv_trt_llm_general_metrics" ,
454
+ description = "General TRT LLM metrics" ,
455
+ kind = pb_utils .MetricFamily .GAUGE ,
456
+ )
457
+ common_labels = {"model" : model , "version" : version }
458
+ self .all_metrics = {
459
+ # Request metrics
460
+ "num_active_requests" :
461
+ self .request_metric_family .Metric (labels = {
462
+ "request_type" : "active" ,
463
+ ** common_labels
464
+ }),
465
+ "max_num_active_requests" :
466
+ self .request_metric_family .Metric (labels = {
467
+ "request_type" : "max" ,
468
+ ** common_labels
469
+ }),
470
+ "num_scheduled_requests" :
471
+ self .request_metric_family .Metric (labels = {
472
+ "request_type" : "scheduled" ,
473
+ ** common_labels
474
+ }),
475
+ "num_context_requests" :
476
+ self .request_metric_family .Metric (labels = {
477
+ "request_type" : "context" ,
478
+ ** common_labels
479
+ }),
480
+ # Runtime metrics
481
+ "cpu_mem_usage" :
482
+ self .runtime_memory_metric_family .Metric (labels = {
483
+ "memory_type" : "cpu" ,
484
+ ** common_labels
485
+ }),
486
+ "gpu_mem_usage" :
487
+ self .runtime_memory_metric_family .Metric (labels = {
488
+ "memory_type" : "gpu" ,
489
+ ** common_labels
490
+ }),
491
+ "pinned_mem_usage" :
492
+ self .runtime_memory_metric_family .Metric (labels = {
493
+ "memory_type" : "pinned" ,
494
+ ** common_labels
495
+ }),
496
+ # KV cache metrics
497
+ "max_num_blocks" :
498
+ self .kv_cache_metric_family .Metric (labels = {
499
+ "kv_cache_block_type" : "max" ,
500
+ ** common_labels
501
+ }),
502
+ "free_num_blocks" :
503
+ self .kv_cache_metric_family .Metric (labels = {
504
+ "kv_cache_block_type" : "free" ,
505
+ ** common_labels
506
+ }),
507
+ "used_num_blocks" :
508
+ self .kv_cache_metric_family .Metric (labels = {
509
+ "kv_cache_block_type" : "used" ,
510
+ ** common_labels
511
+ }),
512
+ "tokens_per_block" :
513
+ self .kv_cache_metric_family .Metric (labels = {
514
+ "kv_cache_block_type" : "tokens_per" ,
515
+ ** common_labels
516
+ }),
517
+ # General metrics
518
+ "timestamp" :
519
+ self .general_metric_family .Metric (labels = {
520
+ "general_type" : "timestamp" ,
521
+ ** common_labels
522
+ }),
523
+ "iter" :
524
+ self .general_metric_family .Metric (labels = {
525
+ "general_type" : "iteration_counter" ,
526
+ ** common_labels
527
+ }),
528
+ }
529
+ if is_v1_model :
530
+ self .all_metrics .update ({
531
+ "num_ctx_tokens" :
532
+ self .model_type_metric_family .Metric (labels = {
533
+ "v1_specific_metric" : "total_context_tokens" ,
534
+ ** common_labels
535
+ }),
536
+ "num_gen_tokens" :
537
+ self .model_type_metric_family .Metric (
538
+ labels = {
539
+ "v1_specific_metric" : "total_generation_tokens" ,
540
+ ** common_labels
541
+ }),
542
+ "empty_gen_slots" :
543
+ self .model_type_metric_family .Metric (
544
+ labels = {
545
+ "v1_specific_metric" : "empty_generation_slots" ,
546
+ ** common_labels
547
+ }),
548
+ })
549
+ else :
550
+ self .all_metrics .update ({
551
+ "num_ctx_tokens" :
552
+ self .model_type_metric_family .Metric (
553
+ labels = {
554
+ "inflight_batcher_specific_metric" :
555
+ "total_context_tokens" ,
556
+ ** common_labels
557
+ }),
558
+ "num_gen_requests" :
559
+ self .model_type_metric_family .Metric (
560
+ labels = {
561
+ "inflight_batcher_specific_metric" :
562
+ "generation_requests" ,
563
+ ** common_labels
564
+ }),
565
+ "micro_batch_id" :
566
+ self .model_type_metric_family .Metric (
567
+ labels = {
568
+ "inflight_batcher_specific_metric" : "micro_batch_id" ,
569
+ ** common_labels
570
+ }),
571
+ "num_paused_requests" :
572
+ self .model_type_metric_family .Metric (
573
+ labels = {
574
+ "inflight_batcher_specific_metric" : "paused_requests" ,
575
+ ** common_labels
576
+ }),
577
+ })
578
+
425
579
def initialize (self , args ):
426
580
"""`initialize` is called only once when the model is being loaded.
427
581
Implementing `initialize` function is optional. This function allows
@@ -453,22 +607,30 @@ def initialize(self, args):
453
607
model_config )
454
608
self .cancellation_check_period_ms = get_parameter (
455
609
model_config , "cancellation_check_period_ms" , int ) or 100
610
+ self .stats_check_period_ms = get_parameter (
611
+ model_config , "stats_check_period_ms" , int ) or 100
456
612
457
613
if not self .decoupled :
458
614
raise pb_utils .TritonModelException (
459
615
"Please enable decoupled transaction policy in the model configuration to serve this model"
460
616
)
461
617
618
+ self .create_metrics (args ["model_name" ],
619
+ args ["model_version" ],
620
+ is_v1_model = executor_config .batching_type ==
621
+ trtllm .BatchingType .STATIC )
462
622
self .triton_id_to_req_id = {}
463
623
self .req_id_to_response_sender = {}
464
624
self .lock = Lock ()
465
625
self .running = False
466
626
self .awaiter_thread = Thread (target = self .awaiter_loop )
467
627
self .cancellation_thread = Thread (target = self .cancellation_loop )
628
+ self .metrics_thread = Thread (target = self .metrics_loop )
468
629
if self .executor .can_enqueue_requests ():
469
630
self .running = True
470
631
self .awaiter_thread .start ()
471
632
self .cancellation_thread .start ()
633
+ self .metrics_thread .start ()
472
634
else :
473
635
# In leader mode, worker ranks will wait here until leader is done.
474
636
self .executor .shutdown ()
@@ -564,7 +726,6 @@ def awaiter_loop(self):
564
726
del self .req_id_to_response_sender [req_id ]
565
727
# Remove local reference so response_sender can be cleaned properly.
566
728
del response_sender
567
- # TODO: Read stats: https://jirasw.nvidia.com/browse/TRTLLM-563
568
729
569
730
def cancellation_loop (self ):
570
731
"""Checks if any pending requests have been cancelled."""
@@ -578,6 +739,36 @@ def cancellation_loop(self):
578
739
# Remove local reference so response_sender can be cleaned properly.
579
740
del response_sender
580
741
742
+ def metrics_loop (self ):
743
+ """Updates triton metrics using stats from the executor."""
744
+ while self .running :
745
+ time .sleep (self .stats_check_period_ms / 1000.0 )
746
+ for stat in self .executor .get_latest_iteration_stats ():
747
+ try :
748
+ for key , metric in self .all_metrics .items ():
749
+ value = None
750
+ if hasattr (stat , key ):
751
+ value = getattr (stat , key )
752
+ elif stat .kv_cache_stats is not None and hasattr (
753
+ stat .kv_cache_stats , key ):
754
+ value = getattr (stat .kv_cache_stats , key )
755
+ elif stat .static_batching_stats is not None and hasattr (
756
+ stat .static_batching_stats , key ):
757
+ value = getattr (stat .static_batching_stats , key )
758
+ elif stat .inflight_batching_stats is not None and hasattr (
759
+ stat .inflight_batching_stats , key ):
760
+ value = getattr (stat .inflight_batching_stats , key )
761
+ if value is not None :
762
+ if key == "timestamp" :
763
+ value = convert_timestamp_to_seconds (value )
764
+ metric .set (value )
765
+ else :
766
+ pb_utils .Logger .log_warn (
767
+ f"Metric \" { key } \" not found." )
768
+ except Exception as e :
769
+ pb_utils .Logger .log_warn (
770
+ f"Error while processing metrics: { e } " )
771
+
581
772
def finalize (self ):
582
773
"""`finalize` is called only once when the model is being unloaded.
583
774
Implementing `finalize` function is optional. This function allows
@@ -587,4 +778,5 @@ def finalize(self):
587
778
self .running = False
588
779
self .awaiter_thread .join ()
589
780
self .cancellation_thread .join ()
781
+ self .metrics_thread .join ()
590
782
self .executor .shutdown ()
0 commit comments