diff --git a/README.md b/README.md index dcc12f92..9d9e0088 100644 --- a/README.md +++ b/README.md @@ -67,8 +67,11 @@ The below commands will build the same Triton TRT-LLM container as the one on th ```bash # Prepare the TRT-LLM base image using the dockerfile from tensorrtllm_backend. cd tensorrtllm_backend +git lfs install +git submodule update --init --recursive + # Specify the build args for the dockerfile. -BASE_IMAGE=nvcr.io/nvidia/pytorch:24.03-py3 +BASE_IMAGE=nvcr.io/nvidia/pytorch:24.04-py3 TRT_VERSION=10.0.1.6 TRT_URL_x86=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.0.1/tars/TensorRT-10.0.1.6.Linux.x86_64-gnu.cuda-12.4.tar.gz TRT_URL_ARM=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.0.1/tars/TensorRT-10.0.1.6.ubuntu-22.04.aarch64-gnu.cuda-12.4.tar.gz diff --git a/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py b/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py index 1488d99d..3bbf86d1 100644 --- a/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py +++ b/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py @@ -311,6 +311,11 @@ def convert_decoding_mode(decoding_mode: str): f"decoding_mode value of '{decoding_mode}' is not supported.") +def convert_timestamp_to_seconds(timestamp: str): + return int( + datetime.datetime.strptime(timestamp, "%m-%d-%Y %H:%M:%S").timestamp()) + + class TritonPythonModel: """Your Python model must use the same class name. Every Python model that is created must have "TritonPythonModel" as the class name. @@ -422,6 +427,155 @@ def get_executor_config(self, model_config): kwargs = {k: v for k, v in kwargs.items() if v is not None} return trtllm.ExecutorConfig(**kwargs) + def create_metrics(self, model: str, version: str, is_v1_model: bool): + self.request_metric_family = pb_utils.MetricFamily( + name="nv_trt_llm_request_metrics", + description="TRT LLM request metrics", + kind=pb_utils.MetricFamily.GAUGE, + ) + self.runtime_memory_metric_family = pb_utils.MetricFamily( + name="nv_trt_llm_runtime_memory_metrics", + description="TRT LLM runtime memory metrics", + kind=pb_utils.MetricFamily.GAUGE, + ) + self.kv_cache_metric_family = pb_utils.MetricFamily( + name="nv_trt_llm_kv_cache_block_metrics", + description="TRT LLM KV cache block metrics", + kind=pb_utils.MetricFamily.GAUGE, + ) + model_type = "v1" if is_v1_model else "inflight_batcher" + self.model_type_metric_family = pb_utils.MetricFamily( + name=f"nv_trt_llm_{model_type}_metrics", + description=f"TRT LLM {model_type}-specific metrics", + kind=pb_utils.MetricFamily.GAUGE, + ) + self.general_metric_family = pb_utils.MetricFamily( + name="nv_trt_llm_general_metrics", + description="General TRT LLM metrics", + kind=pb_utils.MetricFamily.GAUGE, + ) + common_labels = {"model": model, "version": version} + self.all_metrics = { + # Request metrics + "num_active_requests": + self.request_metric_family.Metric(labels={ + "request_type": "active", + **common_labels + }), + "max_num_active_requests": + self.request_metric_family.Metric(labels={ + "request_type": "max", + **common_labels + }), + "num_scheduled_requests": + self.request_metric_family.Metric(labels={ + "request_type": "scheduled", + **common_labels + }), + "num_context_requests": + self.request_metric_family.Metric(labels={ + "request_type": "context", + **common_labels + }), + # Runtime metrics + "cpu_mem_usage": + self.runtime_memory_metric_family.Metric(labels={ + "memory_type": "cpu", + **common_labels + }), + "gpu_mem_usage": + self.runtime_memory_metric_family.Metric(labels={ + "memory_type": "gpu", + **common_labels + }), + "pinned_mem_usage": + self.runtime_memory_metric_family.Metric(labels={ + "memory_type": "pinned", + **common_labels + }), + # KV cache metrics + "max_num_blocks": + self.kv_cache_metric_family.Metric(labels={ + "kv_cache_block_type": "max", + **common_labels + }), + "free_num_blocks": + self.kv_cache_metric_family.Metric(labels={ + "kv_cache_block_type": "free", + **common_labels + }), + "used_num_blocks": + self.kv_cache_metric_family.Metric(labels={ + "kv_cache_block_type": "used", + **common_labels + }), + "tokens_per_block": + self.kv_cache_metric_family.Metric(labels={ + "kv_cache_block_type": "tokens_per", + **common_labels + }), + # General metrics + "timestamp": + self.general_metric_family.Metric(labels={ + "general_type": "timestamp", + **common_labels + }), + "iter": + self.general_metric_family.Metric(labels={ + "general_type": "iteration_counter", + **common_labels + }), + } + if is_v1_model: + self.all_metrics.update({ + "num_ctx_tokens": + self.model_type_metric_family.Metric(labels={ + "v1_specific_metric": "total_context_tokens", + **common_labels + }), + "num_gen_tokens": + self.model_type_metric_family.Metric( + labels={ + "v1_specific_metric": "total_generation_tokens", + **common_labels + }), + "empty_gen_slots": + self.model_type_metric_family.Metric( + labels={ + "v1_specific_metric": "empty_generation_slots", + **common_labels + }), + }) + else: + self.all_metrics.update({ + "num_ctx_tokens": + self.model_type_metric_family.Metric( + labels={ + "inflight_batcher_specific_metric": + "total_context_tokens", + **common_labels + }), + "num_gen_requests": + self.model_type_metric_family.Metric( + labels={ + "inflight_batcher_specific_metric": + "generation_requests", + **common_labels + }), + "micro_batch_id": + self.model_type_metric_family.Metric( + labels={ + "inflight_batcher_specific_metric": "micro_batch_id", + **common_labels + }), + "num_paused_requests": + self.model_type_metric_family.Metric( + labels={ + "inflight_batcher_specific_metric": "paused_requests", + **common_labels + }), + }) + def initialize(self, args): """`initialize` is called only once when the model is being loaded. Implementing `initialize` function is optional. This function allows @@ -453,22 +607,30 @@ def initialize(self, args): model_config) self.cancellation_check_period_ms = get_parameter( model_config, "cancellation_check_period_ms", int) or 100 + self.stats_check_period_ms = get_parameter( + model_config, "stats_check_period_ms", int) or 100 if not self.decoupled: raise pb_utils.TritonModelException( "Please enable decoupled transaction policy in the model configuration to serve this model" ) + self.create_metrics(args["model_name"], + args["model_version"], + is_v1_model=executor_config.batching_type == + trtllm.BatchingType.STATIC) self.triton_id_to_req_id = {} self.req_id_to_response_sender = {} self.lock = Lock() self.running = False self.awaiter_thread = Thread(target=self.awaiter_loop) self.cancellation_thread = Thread(target=self.cancellation_loop) + self.metrics_thread = Thread(target=self.metrics_loop) if self.executor.can_enqueue_requests(): self.running = True self.awaiter_thread.start() self.cancellation_thread.start() + self.metrics_thread.start() else: # In leader mode, worker ranks will wait here until leader is done. self.executor.shutdown() @@ -564,7 +726,6 @@ def awaiter_loop(self): del self.req_id_to_response_sender[req_id] # Remove local reference so response_sender can be cleaned properly. del response_sender - # TODO: Read stats: https://jirasw.nvidia.com/browse/TRTLLM-563 def cancellation_loop(self): """Checks if any pending requests have been cancelled.""" @@ -578,6 +739,36 @@ def cancellation_loop(self): # Remove local reference so response_sender can be cleaned properly. del response_sender + def metrics_loop(self): + """Updates triton metrics using stats from the executor.""" + while self.running: + time.sleep(self.stats_check_period_ms / 1000.0) + for stat in self.executor.get_latest_iteration_stats(): + try: + for key, metric in self.all_metrics.items(): + value = None + if hasattr(stat, key): + value = getattr(stat, key) + elif stat.kv_cache_stats is not None and hasattr( + stat.kv_cache_stats, key): + value = getattr(stat.kv_cache_stats, key) + elif stat.static_batching_stats is not None and hasattr( + stat.static_batching_stats, key): + value = getattr(stat.static_batching_stats, key) + elif stat.inflight_batching_stats is not None and hasattr( + stat.inflight_batching_stats, key): + value = getattr(stat.inflight_batching_stats, key) + if value is not None: + if key == "timestamp": + value = convert_timestamp_to_seconds(value) + metric.set(value) + else: + pb_utils.Logger.log_warn( + f"Metric \"{key}\" not found.") + except Exception as e: + pb_utils.Logger.log_warn( + f"Error while processing metrics: {e}") + def finalize(self): """`finalize` is called only once when the model is being unloaded. Implementing `finalize` function is optional. This function allows @@ -587,4 +778,5 @@ def finalize(self): self.running = False self.awaiter_thread.join() self.cancellation_thread.join() + self.metrics_thread.join() self.executor.shutdown() diff --git a/all_models/tests/test_python_backend.py b/all_models/tests/test_python_backend.py index 35054653..2a4756a7 100644 --- a/all_models/tests/test_python_backend.py +++ b/all_models/tests/test_python_backend.py @@ -572,3 +572,8 @@ def test_get_executor_config_minimal(): assert config.iter_stats_max_iterations == 1000 assert config.request_stats_max_iterations == 0 assert config.logits_post_processor_map is None + + +def test_convert_timestamp_to_seconds(): + assert convert_timestamp_to_seconds("01-01-1970 00:00:00") == 0 + assert convert_timestamp_to_seconds("05-17-2024 23:28:39") == 1715988519 diff --git a/ci/L0_backend_trtllm/custom_metrics_verification_tests.py b/ci/L0_backend_trtllm/custom_metrics_verification_tests.py index b163a030..ad6c539f 100644 --- a/ci/L0_backend_trtllm/custom_metrics_verification_tests.py +++ b/ci/L0_backend_trtllm/custom_metrics_verification_tests.py @@ -64,8 +64,18 @@ def _parse_log_file(self, filename): with open(filename) as log_file: for line in reversed(list(log_file)): if "Active Request Count" in line: - json_format = re.sub(r"^.*?{", "{", line) - return json.loads(json_format) + match = re.search(r'({.*})', line) + if match: + json_string = match.group(1) + try: + json_string = json_string.replace('\\"', '"') + except json.JSONDecodeError as e: + raise Exception("Error parsing the JSON string: ", + e) + else: + raise Exception("No JSON found in the log file") + + return json.loads(json_string) def _parse_triton_metrics(self, filename, is_v1): curl_counts = {} diff --git a/ci/L0_backend_trtllm/test.sh b/ci/L0_backend_trtllm/test.sh index 47b44b92..b947971a 100644 --- a/ci/L0_backend_trtllm/test.sh +++ b/ci/L0_backend_trtllm/test.sh @@ -38,6 +38,7 @@ BASE_METRICS_VERIFICATION_LOG="base_metrics_verification.log" CUSTOM_METRICS_VERIFICATION_TEST=custom_metrics_verification_tests.py CUSTOM_METRICS_VERIFICATION_LOG="custom_metrics_verification.log" SERVER_PID=0 +SLEEP_DURATION=3 # Force environment to use python version 3 apt update -q=2 \ @@ -237,7 +238,10 @@ for NUM_GPU in "${NUM_GPUS_TO_TEST[@]}"; do fi set +e + # Make sure the metrics is retrieved after the server has updated the metrics internally + sleep ${SLEEP_DURATION} curl localhost:8002/metrics -o ${NUM_GPU}gpu_v1_no_stream_metrics.out + kill_server wait_for_server_terminated ${SERVER_PID[@]} @@ -285,7 +289,10 @@ for NUM_GPU in "${NUM_GPUS_TO_TEST[@]}"; do fi set +e + # Make sure the metrics is retrieved after the server has updated the metrics internally + sleep ${SLEEP_DURATION} curl localhost:8002/metrics -o ${NUM_GPU}gpu_IFB_no_stream_metrics.out + kill_server wait_for_server_terminated ${SERVER_PID[@]} @@ -342,7 +349,10 @@ for NUM_GPU in "${NUM_GPUS_TO_TEST[@]}"; do fi set +e + # Make sure the metrics is retrieved after the server has updated the metrics internally + sleep ${SLEEP_DURATION} curl localhost:8002/metrics -o ${NUM_GPU}gpu_multi_model_metrics.out + kill_server wait_for_server_terminated ${SERVER_PID[@]} fi @@ -375,7 +385,10 @@ for NUM_GPU in "${NUM_GPUS_TO_TEST[@]}"; do fi set +e + # Make sure the metrics is retrieved after the server has updated the metrics internally + sleep ${SLEEP_DURATION} curl localhost:8002/metrics -o ${NUM_GPU}gpu_IFB_stream_metrics.out + kill_server wait_for_server_terminated ${SERVER_PID[@]} diff --git a/dockerfile/Dockerfile.triton.trt_llm_backend b/dockerfile/Dockerfile.triton.trt_llm_backend index 0d69122e..524ca41a 100644 --- a/dockerfile/Dockerfile.triton.trt_llm_backend +++ b/dockerfile/Dockerfile.triton.trt_llm_backend @@ -57,24 +57,19 @@ ENV LD_LIBRARY_PATH=/usr/local/tensorrt/lib:${LD_LIBRARY_PATH} ENV TRT_ROOT=/usr/local/tensorrt FROM install_dependencies as tensorrt_llm_build - -ARG TENSORRT_LLM_REPO=https://github.com/NVIDIA/TensorRT-LLM.git -ARG TENSORRT_LLM_REPO_TAG=main - RUN pip3 install --no-cache-dir \ cmake \ polygraphy==0.49.0 \ mpi4py==3.1.5 -WORKDIR /workspace/ -RUN git clone --recurse-submodules --branch ${TENSORRT_LLM_REPO_TAG} ${TENSORRT_LLM_REPO} tenosrrt_llm - -WORKDIR /workspace/tenosrrt_llm -RUN python3 scripts/build_wheel.py --trt_root /usr/local/tensorrt +WORKDIR /workspace +COPY scripts scripts +COPY tensorrt_llm tensorrt_llm +RUN cd tensorrt_llm && python3 scripts/build_wheel.py --trt_root="${TRT_ROOT}" --clean --job_count 18 && cd .. FROM install_dependencies as base WORKDIR /tmp -COPY --from=tensorrt_llm_build /workspace/tenosrrt_llm/build/tensorrt_llm*whl . +COPY --from=tensorrt_llm_build /workspace/tensorrt_llm/build/tensorrt_llm*whl . RUN pip3 install --no-cache-dir --extra-index-url https://pypi.nvidia.com tensorrt_llm*.whl diff --git a/dockerfile/Dockerfile.trt_llm_backend b/dockerfile/Dockerfile.trt_llm_backend index 9cbf3761..de520c6d 100644 --- a/dockerfile/Dockerfile.trt_llm_backend +++ b/dockerfile/Dockerfile.trt_llm_backend @@ -1,5 +1,5 @@ ARG BASE_IMAGE=nvcr.io/nvidia/tritonserver -ARG BASE_TAG=24.04-py3 +ARG BASE_TAG=24.05-py3 FROM ${BASE_IMAGE}:${BASE_TAG} as base diff --git a/tensorrt_llm b/tensorrt_llm index 2a115dae..9691e12b 160000 --- a/tensorrt_llm +++ b/tensorrt_llm @@ -1 +1 @@ -Subproject commit 2a115dae84f13daaa54727534daa837c534eceb4 +Subproject commit 9691e12bce7ae1c126c435a049eb516eb119486c diff --git a/tools/version.txt b/tools/version.txt index 61104ab7..dfc4d6d2 100644 --- a/tools/version.txt +++ b/tools/version.txt @@ -1 +1 @@ -bb75970fe2f21b2cb9a7d231010540397f6dfd79 +73b896d12a81662027fa6746ab3ed99450150e18