Skip to content

Commit 87100b0

Browse files
authored
Update TensorRT-LLM backend (#655)
1 parent 91c07d3 commit 87100b0

File tree

26 files changed

+229
-64
lines changed

26 files changed

+229
-64
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ DECOUPLED_MODE=false
219219

220220
python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/ensemble/config.pbtxt triton_max_batch_size:${TRITON_MAX_BATCH_SIZE}
221221
python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/preprocessing/config.pbtxt tokenizer_dir:${TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},preprocessing_instance_count:${INSTANCE_COUNT}
222-
python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},engine_dir:${ENGINE_DIR},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MS},batching_strategy:inflight_fused_batching,max_queue_size:${MAX_QUEUE_SIZE}
222+
python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},engine_dir:${ENGINE_DIR},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MS},batching_strategy:inflight_fused_batching,max_queue_size:${MAX_QUEUE_SIZE},encoder_input_features_data_type:TYPE_FP16
223223
python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/postprocessing/config.pbtxt tokenizer_dir:${TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},postprocessing_instance_count:${INSTANCE_COUNT},max_queue_size:${MAX_QUEUE_SIZE}
224224
python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},bls_instance_count:${INSTANCE_COUNT}
225225
```

all_models/inflight_batcher_llm/postprocessing/1/model.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,7 @@ def execute(self, requests):
132132
for batch_idx, beam_tokens in enumerate(token_batch):
133133
for beam_idx, tokens in enumerate(beam_tokens):
134134
seq_len = sequence_lengths[idx][batch_idx][beam_idx]
135-
# Exclude fake ids in multimodal models
136-
fake_id_len = 0
137-
for i in range(seq_len):
138-
if tokens[i] < self.tokenizer.vocab_size:
139-
fake_id_len = i
140-
break
141-
list_of_tokens.append(tokens[fake_id_len:seq_len])
135+
list_of_tokens.append(tokens[:seq_len])
142136
req_idx_offset += 1
143137

144138
req_idx_offsets.append(req_idx_offset)

all_models/inflight_batcher_llm/preprocessing/1/model.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,16 @@
2424
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

27+
import base64
28+
import io
2729
import json
2830
import os
2931
from typing import List
3032

3133
import numpy as np
34+
import requests
3235
import triton_python_backend_utils as pb_utils
36+
from PIL import Image
3337
from transformers import AutoProcessor, AutoTokenizer, T5Tokenizer
3438

3539

@@ -659,18 +663,28 @@ def __init__(self,
659663
vision_model_processor,
660664
preprocessor_model_config={}):
661665
# import libraries that are only relevant for multimodal models
662-
import requests
663666
import torch
664-
from PIL import Image
665667
from torch.utils.dlpack import from_dlpack
666668

667-
from tensorrt_llm._utils import str_dtype_to_torch
669+
# NOTE: Due to the behavior of MPI initialization, it is recommended to avoid using import tensorrt_llm
670+
# except for the specific modules tensorrt_llm and multimodal_encoders.
671+
# As a result, the function str_dtype_to_torch has been copied directly from tensorrt_llm._utils.
672+
_str_to_torch_dtype_dict = dict(
673+
bfloat16=torch.bfloat16,
674+
float16=torch.float16,
675+
float32=torch.float32,
676+
int64=torch.int64,
677+
int32=torch.int32,
678+
int8=torch.int8,
679+
bool=torch.bool,
680+
fp8=torch.float8_e4m3fn,
681+
)
682+
683+
def str_dtype_to_torch(dtype):
684+
ret = _str_to_torch_dtype_dict.get(dtype)
685+
assert ret is not None, f'Unsupported dtype: {dtype}'
686+
return ret
668687

669-
# create method for loading image from urls
670-
self.load_images_from_urls = lambda img_urls: [
671-
Image.open(requests.get(img_url.decode(), stream=True).raw)
672-
for img_url in img_urls
673-
]
674688
self.load_images_tensor = lambda tensor: tensor if not hasattr(
675689
tensor, 'to_dlpack') else from_dlpack(tensor.to_dlpack())
676690

@@ -695,6 +709,22 @@ def __init__(self,
695709
self.vision_model_processor = vision_model_processor
696710
self.vision_model_type = vision_model_type
697711

712+
def load_images_from_urls(self, img_urls):
713+
images = []
714+
for img_url in img_urls:
715+
img_url = img_url.decode()
716+
if img_url.startswith("data:image/jpeg;base64,"):
717+
image_base64 = img_url.split(",")[1]
718+
# Decode the base64 string
719+
image_data = base64.b64decode(image_base64)
720+
# Create a BytesIO object from the decoded data
721+
image_buffer = io.BytesIO(image_data)
722+
images.append(Image.open(image_buffer))
723+
else:
724+
images.append(
725+
Image.open(requests.get(img_url, stream=True).raw))
726+
return images
727+
698728
def process(self, queries, img_urls=None, image_bytes=None):
699729
vision_processed_tensors = {}
700730
if img_urls is not None or image_bytes is not None:

all_models/inflight_batcher_llm/tensorrt_llm/1/model.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ def get_sampling_config_from_request(request, batch_size=1, batch_index=0):
193193
request, 'beam_search_diversity_rate', batch_size, batch_index)
194194
kwargs['early_stopping'] = get_input_scalar_by_name(
195195
request, 'early_stopping', batch_size, batch_index)
196+
kwargs['num_return_sequences'] = get_input_scalar_by_name(
197+
request, 'num_return_sequences', batch_size, batch_index) or 1
196198
kwargs = {k: v for k, v in kwargs.items() if v is not None}
197199
return trtllm.SamplingConfig(**kwargs)
198200

@@ -336,9 +338,6 @@ def convert_request(request, exclude_input_from_output, decoupled):
336338
raise pb_utils.TritonModelException(
337339
"Streaming is only supported in decoupled mode.")
338340

339-
inputs['num_return_sequences'] = get_input_scalar_by_name(
340-
request, 'num_return_sequences', batch_size, batch_index) or 1
341-
342341
inputs['end_id'] = get_input_scalar_by_name(request, 'end_id',
343342
batch_size, batch_index)
344343
inputs['pad_id'] = get_input_scalar_by_name(request, 'pad_id',
@@ -364,7 +363,7 @@ def convert_request(request, exclude_input_from_output, decoupled):
364363
# if request doesn't specify exclude_input_from_output, try to use the parameter
365364
output_config.exclude_input_from_output = (
366365
exclude_input_from_output
367-
if exclude_input_from_output is not None else false)
366+
if exclude_input_from_output is not None else False)
368367
else:
369368
output_config.exclude_input_from_output = req_exclude_input_from_output
370369

@@ -642,7 +641,11 @@ def get_extended_runtime_perf_knob_config(self, model_config):
642641
"multi_block_mode":
643642
get_parameter(model_config, "multi_block_mode", bool),
644643
"enable_context_fmha_fp32_acc":
645-
get_parameter(model_config, "enable_context_fmha_fp32_acc", bool)
644+
get_parameter(model_config, "enable_context_fmha_fp32_acc", bool),
645+
"cuda_graph_mode":
646+
get_parameter(model_config, "cuda_graph_mode", bool),
647+
"cuda_graph_cache_size":
648+
get_parameter(model_config, "cuda_graph_cache_size", int),
646649
}
647650
kwargs = {k: v for k, v in kwargs.items() if v is not None}
648651
return trtllm.ExtendedRuntimePerfKnobConfig(**kwargs)
@@ -1000,8 +1003,9 @@ def execute(self, requests):
10001003

10011004
self.req_id_to_request_data[req_id] = RequestData(
10021005
triton_req_id, triton_user_id, batch_index,
1003-
len(batch_indices), executor_request.num_return_sequences,
1004-
0, 0, triton_request.get_response_sender())
1006+
len(batch_indices),
1007+
executor_request.sampling_config.num_return_sequences, 0,
1008+
0, triton_request.get_response_sender())
10051009
self.triton_req_id_to_req_ids[triton_req_id].add(req_id)
10061010
input_len = len(
10071011
executor_request.input_token_ids

all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ input [
4848
},
4949
{
5050
name: "encoder_input_features"
51-
data_type: TYPE_FP16
51+
data_type: ${encoder_input_features_data_type}
5252
dims: [ -1, -1 ]
5353
allow_ragged_batch: true
5454
optional: true
@@ -648,6 +648,18 @@ parameters: {
648648
string_value: "${multi_block_mode}"
649649
}
650650
}
651+
parameters: {
652+
key: "cuda_graph_mode"
653+
value: {
654+
string_value: "${cuda_graph_mode}"
655+
}
656+
}
657+
parameters: {
658+
key: "cuda_graph_cache_size"
659+
value: {
660+
string_value: "${cuda_graph_cache_size}"
661+
}
662+
}
651663
parameters: {
652664
key: "speculative_decoding_fast_logits"
653665
value: {

all_models/multimodal/multimodal_encoders/1/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def triton_string_to_torch(dtype):
4949
"TYPE_FP16": torch.float16,
5050
"TYPE_FP32": torch.float32,
5151
"TYPE_FP64": torch.float64,
52+
"TYPE_BF16": torch.bfloat16
5253
}
5354
return type_map[dtype]
5455

all_models/multimodal/multimodal_encoders/config.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ output [
8585
# Output for visual encoders of type mllama
8686
{
8787
name: "ENCODER_INPUT_FEATURES"
88-
data_type: TYPE_FP16
88+
data_type: ${encoder_input_features_data_type}
8989
dims: [ -1, -1 ]
9090
},
9191
{

ci/L0_backend_trtllm/test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ for NUM_GPU in "${NUM_GPUS_TO_TEST[@]}"; do
195195
replace_config_tags '${engine_dir}' "${MODEL_DIR}/tensorrt_llm/1/inflight_${NUM_GPU}_gpu/" "${MODEL_DIR}/tensorrt_llm/config.pbtxt"
196196
replace_config_tags '${max_queue_delay_microseconds}' "50000" "${MODEL_DIR}/tensorrt_llm/config.pbtxt"
197197
replace_config_tags '${triton_backend}' "tensorrtllm" "${MODEL_DIR}/tensorrt_llm/config.pbtxt"
198+
replace_config_tags '${encoder_input_features_data_type}' "TYPE_FP16" "${MODEL_DIR}/tensorrt_llm/config.pbtxt"
198199
replace_config_tags '${triton_max_batch_size}' "128" "${MODEL_DIR}/postprocessing/config.pbtxt"
199200
replace_config_tags '${tokenizer_dir}' "${TOKENIZER_DIR}/" "${MODEL_DIR}/postprocessing/config.pbtxt"
200201
replace_config_tags '${postprocessing_instance_count}' '1' "${MODEL_DIR}/postprocessing/config.pbtxt"

dockerfile/Dockerfile.triton.trt_llm_backend

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
ARG BASE_IMAGE=nvcr.io/nvidia/tritonserver:24.10-py3-min
22
ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.10-py3
3-
ARG NVRTC_VER=12.6.68-1
3+
ARG NVRTC_VER=12.6.77-1
44
ARG TRT_VER=10.6.0.26
55
ARG RELEASE_URL_TRT_x86=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.6.0/tars/TensorRT-${TRT_VER}.Linux.x86_64-gnu.cuda-12.6.tar.gz
66
ARG RELEASE_URL_TRT_ARM=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.6.0/tars/TensorRT-${TRT_VER}.ubuntu-24.04.aarch64-gnu.cuda-12.6.tar.gz
@@ -29,6 +29,8 @@ COPY --from=pytorch_image /usr/local/lib/python3.10/dist-packages/torch-2.5.0a0+
2929
COPY --from=pytorch_image /usr/local/lib/python3.10/dist-packages/torchgen /usr/local/lib/python3.10/dist-packages/torchgen
3030
COPY --from=pytorch_image /usr/local/lib/python3.10/dist-packages/torchvision /usr/local/lib/python3.10/dist-packages/torchvision
3131
COPY --from=pytorch_image /usr/local/lib/python3.10/dist-packages/torchvision-0.20.0a0.dist-info /usr/local/lib/python3.10/dist-packages/torchvision-0.20.0a0.dist-info
32+
COPY --from=pytorch_image /usr/local/lib/python3.10/dist-packages/setuptools /usr/local/lib/python3.10/dist-packages/setuptools
33+
COPY --from=pytorch_image /usr/local/lib/python3.10/dist-packages/setuptools-70.3.0.dist-info /usr/local/lib/python3.10/dist-packages/setuptools-70.3.0.dist-info
3234

3335
# Might not need to copy cusparseLt in the future once it's included in DLFW cuda container
3436
COPY --from=pytorch_image /usr/local/cuda/lib64/libcusparseLt* /usr/local/cuda/lib64/
@@ -109,6 +111,8 @@ COPY --from=pytorch_image /usr/local/lib/python3.10/dist-packages/torch-2.5.0a0+
109111
COPY --from=pytorch_image /usr/local/lib/python3.10/dist-packages/torchgen /usr/local/lib/python3.10/dist-packages/torchgen
110112
COPY --from=pytorch_image /usr/local/lib/python3.10/dist-packages/torchvision /usr/local/lib/python3.10/dist-packages/torchvision
111113
COPY --from=pytorch_image /usr/local/lib/python3.10/dist-packages/torchvision-0.20.0a0.dist-info /usr/local/lib/python3.10/dist-packages/torchvision-0.20.0a0.dist-info
114+
COPY --from=pytorch_image /usr/local/lib/python3.10/dist-packages/setuptools /usr/local/lib/python3.10/dist-packages/setuptools
115+
COPY --from=pytorch_image /usr/local/lib/python3.10/dist-packages/setuptools-70.3.0.dist-info /usr/local/lib/python3.10/dist-packages/setuptools-70.3.0.dist-info
112116

113117
# Might not need to copy cusparseLt in the future once it's included in DLFW cuda container
114118
COPY --from=pytorch_image /usr/local/cuda/lib64/libcusparseLt* /usr/local/cuda/lib64/

docs/baichuan.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ python3 tools/fill_template.py -i baichuan_ifb/preprocessing/config.pbtxt tokeni
4444
python3 tools/fill_template.py -i baichuan_ifb/postprocessing/config.pbtxt tokenizer_dir:${HF_BAICHUAN_MODEL},triton_max_batch_size:64,postprocessing_instance_count:1
4545
python3 tools/fill_template.py -i baichuan_ifb/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:64,decoupled_mode:False,bls_instance_count:1,accumulate_tokens:False
4646
python3 tools/fill_template.py -i baichuan_ifb/ensemble/config.pbtxt triton_max_batch_size:64
47-
python3 tools/fill_template.py -i baichuan_ifb/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:64,decoupled_mode:False,max_beam_width:1,engine_dir:/tmp/baichuan/13B/trt_engines/fp16/1-gpu/,max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:0
47+
python3 tools/fill_template.py -i baichuan_ifb/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:64,decoupled_mode:False,max_beam_width:1,engine_dir:/tmp/baichuan/13B/trt_engines/fp16/1-gpu/,max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:0,encoder_input_features_data_type:TYPE_FP16
4848
```
4949

5050
* Launch server
@@ -178,7 +178,7 @@ python3 tools/fill_template.py -i baichuan_ifb/preprocessing/config.pbtxt tokeni
178178
python3 tools/fill_template.py -i baichuan_ifb/postprocessing/config.pbtxt tokenizer_dir:${HF_BAICHUAN_MODEL},triton_max_batch_size:64,postprocessing_instance_count:1
179179
python3 tools/fill_template.py -i baichuan_ifb/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:64,decoupled_mode:True,bls_instance_count:1,accumulate_tokens:True
180180
python3 tools/fill_template.py -i baichuan_ifb/ensemble/config.pbtxt triton_max_batch_size:64
181-
python3 tools/fill_template.py -i baichuan_ifb/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:64,decoupled_mode:True,max_beam_width:1,engine_dir:/tmp/baichuan/13B/trt_engines/fp16/1-gpu/,max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:0
181+
python3 tools/fill_template.py -i baichuan_ifb/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:64,decoupled_mode:True,max_beam_width:1,engine_dir:/tmp/baichuan/13B/trt_engines/fp16/1-gpu/,max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:0,encoder_input_features_data_type:TYPE_FP16
182182
183183
pip install SentencePiece
184184
# please add `trust_remote_code=True` in tokenizer of preprocessing and postprocessing. Considering the security, we don't add it by default.

0 commit comments

Comments
 (0)