Skip to content

Commit

Permalink
Update TensorRT-LLM backend (triton-inference-server#512)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaiyux authored Jun 25, 2024
1 parent 62cd00f commit ada5799
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 17 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,11 @@ The below commands will build the same Triton TRT-LLM container as the one on th
```bash
# Prepare the TRT-LLM base image using the dockerfile from tensorrtllm_backend.
cd tensorrtllm_backend
git lfs install
git submodule update --init --recursive

# Specify the build args for the dockerfile.
BASE_IMAGE=nvcr.io/nvidia/pytorch:24.03-py3
BASE_IMAGE=nvcr.io/nvidia/pytorch:24.04-py3
TRT_VERSION=10.0.1.6
TRT_URL_x86=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.0.1/tars/TensorRT-10.0.1.6.Linux.x86_64-gnu.cuda-12.4.tar.gz
TRT_URL_ARM=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.0.1/tars/TensorRT-10.0.1.6.ubuntu-22.04.aarch64-gnu.cuda-12.4.tar.gz
Expand Down
194 changes: 193 additions & 1 deletion all_models/inflight_batcher_llm/tensorrt_llm/1/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,11 @@ def convert_decoding_mode(decoding_mode: str):
f"decoding_mode value of '{decoding_mode}' is not supported.")


def convert_timestamp_to_seconds(timestamp: str):
return int(
datetime.datetime.strptime(timestamp, "%m-%d-%Y %H:%M:%S").timestamp())


class TritonPythonModel:
"""Your Python model must use the same class name. Every Python model
that is created must have "TritonPythonModel" as the class name.
Expand Down Expand Up @@ -422,6 +427,155 @@ def get_executor_config(self, model_config):
kwargs = {k: v for k, v in kwargs.items() if v is not None}
return trtllm.ExecutorConfig(**kwargs)

def create_metrics(self, model: str, version: str, is_v1_model: bool):
self.request_metric_family = pb_utils.MetricFamily(
name="nv_trt_llm_request_metrics",
description="TRT LLM request metrics",
kind=pb_utils.MetricFamily.GAUGE,
)
self.runtime_memory_metric_family = pb_utils.MetricFamily(
name="nv_trt_llm_runtime_memory_metrics",
description="TRT LLM runtime memory metrics",
kind=pb_utils.MetricFamily.GAUGE,
)
self.kv_cache_metric_family = pb_utils.MetricFamily(
name="nv_trt_llm_kv_cache_block_metrics",
description="TRT LLM KV cache block metrics",
kind=pb_utils.MetricFamily.GAUGE,
)
model_type = "v1" if is_v1_model else "inflight_batcher"
self.model_type_metric_family = pb_utils.MetricFamily(
name=f"nv_trt_llm_{model_type}_metrics",
description=f"TRT LLM {model_type}-specific metrics",
kind=pb_utils.MetricFamily.GAUGE,
)
self.general_metric_family = pb_utils.MetricFamily(
name="nv_trt_llm_general_metrics",
description="General TRT LLM metrics",
kind=pb_utils.MetricFamily.GAUGE,
)
common_labels = {"model": model, "version": version}
self.all_metrics = {
# Request metrics
"num_active_requests":
self.request_metric_family.Metric(labels={
"request_type": "active",
**common_labels
}),
"max_num_active_requests":
self.request_metric_family.Metric(labels={
"request_type": "max",
**common_labels
}),
"num_scheduled_requests":
self.request_metric_family.Metric(labels={
"request_type": "scheduled",
**common_labels
}),
"num_context_requests":
self.request_metric_family.Metric(labels={
"request_type": "context",
**common_labels
}),
# Runtime metrics
"cpu_mem_usage":
self.runtime_memory_metric_family.Metric(labels={
"memory_type": "cpu",
**common_labels
}),
"gpu_mem_usage":
self.runtime_memory_metric_family.Metric(labels={
"memory_type": "gpu",
**common_labels
}),
"pinned_mem_usage":
self.runtime_memory_metric_family.Metric(labels={
"memory_type": "pinned",
**common_labels
}),
# KV cache metrics
"max_num_blocks":
self.kv_cache_metric_family.Metric(labels={
"kv_cache_block_type": "max",
**common_labels
}),
"free_num_blocks":
self.kv_cache_metric_family.Metric(labels={
"kv_cache_block_type": "free",
**common_labels
}),
"used_num_blocks":
self.kv_cache_metric_family.Metric(labels={
"kv_cache_block_type": "used",
**common_labels
}),
"tokens_per_block":
self.kv_cache_metric_family.Metric(labels={
"kv_cache_block_type": "tokens_per",
**common_labels
}),
# General metrics
"timestamp":
self.general_metric_family.Metric(labels={
"general_type": "timestamp",
**common_labels
}),
"iter":
self.general_metric_family.Metric(labels={
"general_type": "iteration_counter",
**common_labels
}),
}
if is_v1_model:
self.all_metrics.update({
"num_ctx_tokens":
self.model_type_metric_family.Metric(labels={
"v1_specific_metric": "total_context_tokens",
**common_labels
}),
"num_gen_tokens":
self.model_type_metric_family.Metric(
labels={
"v1_specific_metric": "total_generation_tokens",
**common_labels
}),
"empty_gen_slots":
self.model_type_metric_family.Metric(
labels={
"v1_specific_metric": "empty_generation_slots",
**common_labels
}),
})
else:
self.all_metrics.update({
"num_ctx_tokens":
self.model_type_metric_family.Metric(
labels={
"inflight_batcher_specific_metric":
"total_context_tokens",
**common_labels
}),
"num_gen_requests":
self.model_type_metric_family.Metric(
labels={
"inflight_batcher_specific_metric":
"generation_requests",
**common_labels
}),
"micro_batch_id":
self.model_type_metric_family.Metric(
labels={
"inflight_batcher_specific_metric": "micro_batch_id",
**common_labels
}),
"num_paused_requests":
self.model_type_metric_family.Metric(
labels={
"inflight_batcher_specific_metric": "paused_requests",
**common_labels
}),
})

def initialize(self, args):
"""`initialize` is called only once when the model is being loaded.
Implementing `initialize` function is optional. This function allows
Expand Down Expand Up @@ -453,22 +607,30 @@ def initialize(self, args):
model_config)
self.cancellation_check_period_ms = get_parameter(
model_config, "cancellation_check_period_ms", int) or 100
self.stats_check_period_ms = get_parameter(
model_config, "stats_check_period_ms", int) or 100

if not self.decoupled:
raise pb_utils.TritonModelException(
"Please enable decoupled transaction policy in the model configuration to serve this model"
)

self.create_metrics(args["model_name"],
args["model_version"],
is_v1_model=executor_config.batching_type ==
trtllm.BatchingType.STATIC)
self.triton_id_to_req_id = {}
self.req_id_to_response_sender = {}
self.lock = Lock()
self.running = False
self.awaiter_thread = Thread(target=self.awaiter_loop)
self.cancellation_thread = Thread(target=self.cancellation_loop)
self.metrics_thread = Thread(target=self.metrics_loop)
if self.executor.can_enqueue_requests():
self.running = True
self.awaiter_thread.start()
self.cancellation_thread.start()
self.metrics_thread.start()
else:
# In leader mode, worker ranks will wait here until leader is done.
self.executor.shutdown()
Expand Down Expand Up @@ -564,7 +726,6 @@ def awaiter_loop(self):
del self.req_id_to_response_sender[req_id]
# Remove local reference so response_sender can be cleaned properly.
del response_sender
# TODO: Read stats: https://jirasw.nvidia.com/browse/TRTLLM-563

def cancellation_loop(self):
"""Checks if any pending requests have been cancelled."""
Expand All @@ -578,6 +739,36 @@ def cancellation_loop(self):
# Remove local reference so response_sender can be cleaned properly.
del response_sender

def metrics_loop(self):
"""Updates triton metrics using stats from the executor."""
while self.running:
time.sleep(self.stats_check_period_ms / 1000.0)
for stat in self.executor.get_latest_iteration_stats():
try:
for key, metric in self.all_metrics.items():
value = None
if hasattr(stat, key):
value = getattr(stat, key)
elif stat.kv_cache_stats is not None and hasattr(
stat.kv_cache_stats, key):
value = getattr(stat.kv_cache_stats, key)
elif stat.static_batching_stats is not None and hasattr(
stat.static_batching_stats, key):
value = getattr(stat.static_batching_stats, key)
elif stat.inflight_batching_stats is not None and hasattr(
stat.inflight_batching_stats, key):
value = getattr(stat.inflight_batching_stats, key)
if value is not None:
if key == "timestamp":
value = convert_timestamp_to_seconds(value)
metric.set(value)
else:
pb_utils.Logger.log_warn(
f"Metric \"{key}\" not found.")
except Exception as e:
pb_utils.Logger.log_warn(
f"Error while processing metrics: {e}")

def finalize(self):
"""`finalize` is called only once when the model is being unloaded.
Implementing `finalize` function is optional. This function allows
Expand All @@ -587,4 +778,5 @@ def finalize(self):
self.running = False
self.awaiter_thread.join()
self.cancellation_thread.join()
self.metrics_thread.join()
self.executor.shutdown()
5 changes: 5 additions & 0 deletions all_models/tests/test_python_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,3 +572,8 @@ def test_get_executor_config_minimal():
assert config.iter_stats_max_iterations == 1000
assert config.request_stats_max_iterations == 0
assert config.logits_post_processor_map is None


def test_convert_timestamp_to_seconds():
assert convert_timestamp_to_seconds("01-01-1970 00:00:00") == 0
assert convert_timestamp_to_seconds("05-17-2024 23:28:39") == 1715988519
14 changes: 12 additions & 2 deletions ci/L0_backend_trtllm/custom_metrics_verification_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,18 @@ def _parse_log_file(self, filename):
with open(filename) as log_file:
for line in reversed(list(log_file)):
if "Active Request Count" in line:
json_format = re.sub(r"^.*?{", "{", line)
return json.loads(json_format)
match = re.search(r'({.*})', line)
if match:
json_string = match.group(1)
try:
json_string = json_string.replace('\\"', '"')
except json.JSONDecodeError as e:
raise Exception("Error parsing the JSON string: ",
e)
else:
raise Exception("No JSON found in the log file")

return json.loads(json_string)

def _parse_triton_metrics(self, filename, is_v1):
curl_counts = {}
Expand Down
13 changes: 13 additions & 0 deletions ci/L0_backend_trtllm/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ BASE_METRICS_VERIFICATION_LOG="base_metrics_verification.log"
CUSTOM_METRICS_VERIFICATION_TEST=custom_metrics_verification_tests.py
CUSTOM_METRICS_VERIFICATION_LOG="custom_metrics_verification.log"
SERVER_PID=0
SLEEP_DURATION=3

# Force environment to use python version 3
apt update -q=2 \
Expand Down Expand Up @@ -237,7 +238,10 @@ for NUM_GPU in "${NUM_GPUS_TO_TEST[@]}"; do
fi
set +e

# Make sure the metrics is retrieved after the server has updated the metrics internally
sleep ${SLEEP_DURATION}
curl localhost:8002/metrics -o ${NUM_GPU}gpu_v1_no_stream_metrics.out

kill_server
wait_for_server_terminated ${SERVER_PID[@]}

Expand Down Expand Up @@ -285,7 +289,10 @@ for NUM_GPU in "${NUM_GPUS_TO_TEST[@]}"; do
fi
set +e

# Make sure the metrics is retrieved after the server has updated the metrics internally
sleep ${SLEEP_DURATION}
curl localhost:8002/metrics -o ${NUM_GPU}gpu_IFB_no_stream_metrics.out

kill_server
wait_for_server_terminated ${SERVER_PID[@]}

Expand Down Expand Up @@ -342,7 +349,10 @@ for NUM_GPU in "${NUM_GPUS_TO_TEST[@]}"; do
fi
set +e

# Make sure the metrics is retrieved after the server has updated the metrics internally
sleep ${SLEEP_DURATION}
curl localhost:8002/metrics -o ${NUM_GPU}gpu_multi_model_metrics.out

kill_server
wait_for_server_terminated ${SERVER_PID[@]}
fi
Expand Down Expand Up @@ -375,7 +385,10 @@ for NUM_GPU in "${NUM_GPUS_TO_TEST[@]}"; do
fi
set +e

# Make sure the metrics is retrieved after the server has updated the metrics internally
sleep ${SLEEP_DURATION}
curl localhost:8002/metrics -o ${NUM_GPU}gpu_IFB_stream_metrics.out

kill_server
wait_for_server_terminated ${SERVER_PID[@]}

Expand Down
15 changes: 5 additions & 10 deletions dockerfile/Dockerfile.triton.trt_llm_backend
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,19 @@ ENV LD_LIBRARY_PATH=/usr/local/tensorrt/lib:${LD_LIBRARY_PATH}
ENV TRT_ROOT=/usr/local/tensorrt

FROM install_dependencies as tensorrt_llm_build

ARG TENSORRT_LLM_REPO=https://github.com/NVIDIA/TensorRT-LLM.git
ARG TENSORRT_LLM_REPO_TAG=main

RUN pip3 install --no-cache-dir \
cmake \
polygraphy==0.49.0 \
mpi4py==3.1.5

WORKDIR /workspace/
RUN git clone --recurse-submodules --branch ${TENSORRT_LLM_REPO_TAG} ${TENSORRT_LLM_REPO} tenosrrt_llm

WORKDIR /workspace/tenosrrt_llm
RUN python3 scripts/build_wheel.py --trt_root /usr/local/tensorrt
WORKDIR /workspace
COPY scripts scripts
COPY tensorrt_llm tensorrt_llm
RUN cd tensorrt_llm && python3 scripts/build_wheel.py --trt_root="${TRT_ROOT}" --clean --job_count 18 && cd ..

FROM install_dependencies as base

WORKDIR /tmp
COPY --from=tensorrt_llm_build /workspace/tenosrrt_llm/build/tensorrt_llm*whl .
COPY --from=tensorrt_llm_build /workspace/tensorrt_llm/build/tensorrt_llm*whl .

RUN pip3 install --no-cache-dir --extra-index-url https://pypi.nvidia.com tensorrt_llm*.whl
2 changes: 1 addition & 1 deletion dockerfile/Dockerfile.trt_llm_backend
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ARG BASE_IMAGE=nvcr.io/nvidia/tritonserver
ARG BASE_TAG=24.04-py3
ARG BASE_TAG=24.05-py3

FROM ${BASE_IMAGE}:${BASE_TAG} as base

Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm
Submodule tensorrt_llm updated 94 files
+44 −0 benchmarks/cpp/gptManagerBenchmark.cpp
+1 −0 benchmarks/python/build.py
+2 −2 cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a
+2 −2 cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a
+3 −3 cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt
+2 −2 cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a
+2 −2 cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a
+2 −2 cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib
+2 −2 cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a
+2 −2 cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a
+3 −3 cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt
+2 −2 cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a
+2 −2 cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a
+2 −2 cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib
+3 −2 cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h
+1 −1 ...rt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt
+1 −1 ...rMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll
+0 −0 ...nsorrt_llm/kernels/decoderMaskedMultiheadAttention/instantiation/decoderMaskedMultiheadAttention104_bf16.cu
+53 −13 cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
+4 −2 cpp/tensorrt_llm/pybind/executor/bindings.cpp
+3 −0 cpp/tensorrt_llm/runtime/medusaModule.cpp
+40 −10 cpp/tests/kernels/mixtureOfExpertsTest.cu
+1 −1 cpp/tests/resources/scripts/build_medusa_engines.py
+1 −1 docker/Dockerfile.multi
+2 −2 docker/common/install_pytorch.sh
+4 −5 docker/common/install_tensorrt.sh
+2 −2 docs/source/reference/support-matrix.md
+4 −3 docs/source/release-notes.md
+161 −1 docs/source/speculative_decoding.md
+1 −1 examples/baichuan/requirements.txt
+1 −1 examples/bloom/requirements.txt
+1 −1 examples/chatglm/requirements.txt
+2 −2 examples/cogvlm/convert_checkpoint.py
+1 −1 examples/dbrx/requirements.txt
+1 −1 examples/falcon/requirements.txt
+2 −2 examples/gemma/convert_checkpoint.py
+1 −1 examples/gemma/requirements.txt
+1 −1 examples/gpt/requirements.txt
+1 −1 examples/gptj/requirements.txt
+3 −3 examples/gptneox/README.md
+3 −4 examples/gptneox/convert_checkpoint.py
+1 −1 examples/gptneox/requirements.txt
+1 −1 examples/grok/requirements.txt
+1 −1 examples/high-level-api/requirements.txt
+1 −1 examples/internlm/requirements.txt
+2 −2 examples/llama/README.md
+1 −1 examples/llama/requirements.txt
+5 −13 examples/mamba/README.md
+2 −1 examples/mamba/requirements.txt
+30 −5 examples/medusa/README.md
+15 −98 examples/medusa/convert_checkpoint.py
+1 −1 examples/medusa/requirements.txt
+1 −1 examples/mixtral/requirements.txt
+5 −3 examples/mmlu.py
+1 −1 examples/mpt/requirements.txt
+1 −1 examples/nemotron/requirements.txt
+1 −1 examples/opt/requirements.txt
+13 −17 examples/phi/README.md
+7 −10 examples/phi/convert_checkpoint.py
+0 −63 examples/phi/postprocess_quant_checkpoint.py
+1 −1 examples/phi/requirements.txt
+18 −1 examples/quantization/quantize.py
+1 −1 examples/quantization/requirements.txt
+1 −1 examples/qwen/requirements.txt
+1 −1 examples/qwenvl/requirements.txt
+1 −1 examples/recurrentgemma/requirements.txt
+1 −1 examples/run.py
+1 −1 examples/skywork/requirements.txt
+1 −1 examples/smaug/requirements.txt
+1 −1 examples/whisper/requirements.txt
+5 −4 requirements.txt
+2 −2 tensorrt_llm/auto_parallel/parallelization.py
+5 −0 tensorrt_llm/auto_parallel/tensor_parallel/plugin_node.py
+3 −2 tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/gpt_attention_node.py
+1 −15 tensorrt_llm/commands/build.py
+1 −4 tensorrt_llm/models/__init__.py
+1 −0 tensorrt_llm/models/gemma/model.py
+93 −53 tensorrt_llm/models/generation_mixin.py
+39 −30 tensorrt_llm/models/llama/convert.py
+62 −35 tensorrt_llm/models/mamba/model.py
+66 −29 tensorrt_llm/models/medusa/weight.py
+26 −7 tensorrt_llm/models/modeling_utils.py
+76 −5 tensorrt_llm/models/phi3/convert.py
+107 −30 tensorrt_llm/models/phi3/model.py
+0 −14 tensorrt_llm/models/phi3/phi3small/__init__.py
+0 −257 tensorrt_llm/models/phi3/phi3small/model.py
+2 −100 tensorrt_llm/models/phi3/split_weights.py
+23 −44 tensorrt_llm/models/recurrentgemma/model.py
+10 −2 tensorrt_llm/quantization/layers.py
+13 −7 tensorrt_llm/quantization/quantize.py
+101 −21 tensorrt_llm/quantization/quantize_by_modelopt.py
+1 −1 tensorrt_llm/version.py
+6 −4 tests/model/test_mamba.py
+1 −1 tests/test_llama_conversion.sh
2 changes: 1 addition & 1 deletion tools/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
bb75970fe2f21b2cb9a7d231010540397f6dfd79
73b896d12a81662027fa6746ab3ed99450150e18

0 comments on commit ada5799

Please sign in to comment.