Skip to content

Commit 37ed967

Browse files
authored
Update TensorRT-LLM backend (triton-inference-server#142)
* Update TensorRT-LLM backend
1 parent 0b2c6a8 commit 37ed967

File tree

15 files changed

+324
-295
lines changed

15 files changed

+324
-295
lines changed

all_models/gpt/preprocessing/1/model.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# -*- coding: utf-8 -*-
2-
import csv
32
import json
43
from typing import List
54

@@ -164,30 +163,29 @@ def _create_request(self, query):
164163

165164
return start_ids, start_lengths
166165

167-
def _to_word_list_format(self, word_dict: List[List[str]]):
166+
def _to_word_list_format(self, word_lists: List[List[str | bytes]]):
168167
'''
169-
format of word_dict
170-
len(word_dict) should be same to batch_size
171-
word_dict[i] means the words for batch i
172-
len(word_dict[i]) must be 1, which means it only contains 1 string
173-
This string can contains several sentences and split by ",".
174-
For example, if word_dict[2] = " I am happy, I am sad", then this function will return
175-
the ids for two short sentences " I am happy" and " I am sad".
168+
word_lists format:
169+
len(word_lists) == batch_size
170+
word_lists[i] means the words associated to batch item i. A "word" may actually be any string. Like "lorem" or "lorem ipsum".
176171
'''
177172
assert self.tokenizer != None, "need to set tokenizer"
178173

174+
if word_lists is None:
175+
# Return an empty array of shape (1,2,0)
176+
return np.empty([1, 2, 0], dtype="int32")
177+
179178
flat_ids = []
180179
offsets = []
181-
for word_dict_item in word_dict:
180+
for word_list in word_lists:
182181
item_flat_ids = []
183182
item_offsets = []
184183

185-
if isinstance(word_dict_item[0], bytes):
186-
word_dict_item = [word_dict_item[0].decode()]
184+
for word in word_list:
185+
if isinstance(word, bytes):
186+
word = word.decode()
187187

188-
words = list(csv.reader(word_dict_item))[0]
189-
for word in words:
190-
ids = self.tokenizer.encode(word)
188+
ids = self.tokenizer.encode(word, add_special_tokens=False)
191189

192190
if len(ids) == 0:
193191
continue

all_models/inflight_batcher_llm/preprocessing/1/model.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

27-
import csv
2827
import json
2928
from typing import List
3029

@@ -224,35 +223,29 @@ def _create_request(self, query):
224223

225224
return start_ids, start_lengths
226225

227-
def _to_word_list_format(self, word_dict: List[List[str]]):
226+
def _to_word_list_format(self, word_lists: List[List[str | bytes]]):
228227
'''
229-
format of word_dict
230-
len(word_dict) should be same to batch_size
231-
word_dict[i] means the words for batch i
232-
len(word_dict[i]) must be 1, which means it only contains 1 string
233-
This string can contains several sentences and split by ",".
234-
For example, if word_dict[2] = " I am happy, I am sad", then this function will return
235-
the ids for two short sentences " I am happy" and " I am sad".
228+
word_lists format:
229+
len(word_lists) == batch_size
230+
word_lists[i] means the words associated to batch item i. A "word" may actually be any string. Like "lorem" or "lorem ipsum".
236231
'''
237232
assert self.tokenizer != None, "need to set tokenizer"
238233

239-
if word_dict is None:
234+
if word_lists is None:
240235
# Return an empty array of shape (1,2,0)
241236
return np.empty([1, 2, 0], dtype="int32")
242237

243238
flat_ids = []
244239
offsets = []
245-
for word_dict_item in word_dict:
240+
for word_list in word_lists:
246241
item_flat_ids = []
247242
item_offsets = []
248243

249-
if isinstance(word_dict_item[0], bytes):
250-
word_dict_item = [word_dict_item[0].decode()]
251-
252-
words = list(csv.reader(word_dict_item))[0]
253-
for word in words:
254-
ids = self.tokenizer.encode(word)
244+
for word in word_list:
245+
if isinstance(word, bytes):
246+
word = word.decode()
255247

248+
ids = self.tokenizer.encode(word, add_special_tokens=False)
256249
if len(ids) == 0:
257250
continue
258251

ci/L0_backend_trtllm/test.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,15 +136,15 @@ if [ "$WAIT_RET" != "0" ]; then
136136
fi
137137

138138
set -e
139-
python3 ${TOOLS_DIR}/inflight_batcher_llm/identity_test.py \
139+
python3 ${TOOLS_DIR}/inflight_batcher_llm/benchmark_core_model.py \
140140
--max-input-len=500 \
141141
dataset \
142142
--dataset=${DATASET} \
143143
--tokenizer-dir=${TOKENIZER_DIR}
144144

145145
if [ $? -ne 0 ]; then
146146
cat $SERVER_LOG
147-
echo -e "\n***\n*** Error executing inflight batching identity test: line ${LINENO}\n***"
147+
echo -e "\n***\n*** Error executing inflight batching benchmark_core_model: line ${LINENO}\n***"
148148
RET=1
149149
fi
150150
set +e
@@ -180,14 +180,14 @@ if [ "$WAIT_RET" != "0" ]; then
180180
fi
181181

182182
set -e
183-
python3 ${TOOLS_DIR}/inflight_batcher_llm/identity_test.py \
183+
python3 ${TOOLS_DIR}/inflight_batcher_llm/benchmark_core_model.py \
184184
--max-input-len=500 \
185185
--dataset=${DATASET} \
186186
--tokenizer-dir=${TOKENIZER_DIR}
187187

188188
if [ $? -ne 0 ]; then
189189
cat $SERVER_LOG
190-
echo -e "\n***\n*** Error executing inflight batching identity test: line ${LINENO}\n***"
190+
echo -e "\n***\n*** Error executing inflight batching benchmark_core_model: line ${LINENO}\n***"
191191
RET=1
192192
fi
193193
set +e

ci/README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ cd /tensorrtllm_backend/ci/<test directory>
4747
bash -x ./test.sh
4848
```
4949

50-
## Run the e2e/identity test to benchmark
50+
## Run the e2e/benchmark_core_model to benchmark
5151

5252
These two tests are ran in the [L0_backend_trtllm](./L0_backend_trtllm/)
5353
test. Below are the instructions to run the tests manually.
@@ -89,17 +89,17 @@ Expected outputs
8989
[INFO] Total Latency: 11099.243 ms
9090
```
9191

92-
### Identity test
92+
### benchmark_core_model
9393

94-
[Identity test script](../tools/inflight_batcher_llm/identity_test.py)
95-
sends requests directly to the deployed `tensorrt_llm` model, the identity test
94+
[benchmark_core_model script](../tools/inflight_batcher_llm/benchmark_core_model.py)
95+
sends requests directly to the deployed `tensorrt_llm` model, the benchmark_core_model
9696
latency indicates the inference latency of TensorRT-LLM, not including the
9797
pre/post-processing latency which is usually handled by a third-party library
9898
such as HuggingFace.
9999

100100
```bash
101101
cd tools/inflight_batcher_llm
102-
python3 identity_test.py dataset --dataset <dataset path>
102+
python3 benchmark_core_model.py dataset --dataset <dataset path>
103103
```
104104

105105
Expected outputs

dockerfile/Dockerfile.trt_llm_backend

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,16 @@ RUN pip uninstall -y tensorrt
1515

1616
FROM base as dev
1717

18-
ENV SHINIT_FILE=${BASH_ENV}
19-
18+
ARG TRT_VER="9.1.0.4"
19+
ENV TRT_VER=$TRT_VER
20+
ARG CUDA_VER="12.2"
21+
ENV CUDA_VER=$CUDA_VER
22+
ARG CUDNN_VER="8.9.4.25-1+cuda12.2"
23+
ENV CUDNN_VER=$CUDNN_VER
24+
ARG NCCL_VER="2.18.3-1+cuda12.2"
25+
ENV NCCL_VER=$NCCL_VER
26+
ARG CUBLAS_VER="12.2.5.6-1"
27+
ENV CUBLAS_VER=$CUBLAS_VER
2028
# Download & install internal TRT release
2129
COPY tensorrt_llm/docker/common/install_tensorrt.sh /tmp/
2230
RUN bash /tmp/install_tensorrt.sh && rm /tmp/install_tensorrt.sh

inflight_batcher_llm/README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ You will find that the generation process is stopped early and therefore the num
132132

133133
You can have a look at the client code to see how early stopping is achieved.
134134

135-
## Run the e2e/identity test to benchmark
135+
## Run the e2e/benchmark_core_model to benchmark
136136

137137
### End to end test
138138
End to end test script sends requests to deployed ensemble model.
@@ -156,11 +156,11 @@ Expected outputs
156156
[INFO] Total Latency: 11099.243 ms
157157
```
158158

159-
### Identity test
159+
### benchmark core model
160160

161-
Identity test script sends requests directly to deployed tensorrt_llm model, the identity test latency indicates the inference latency of TensorRT-LLM, not including the pre/post-processing latency which is usually handled by a third-party library such as HuggingFace.
161+
benchmark_core_model script sends requests directly to deployed tensorrt_llm model, the benchmark core model latency indicates the inference latency of TensorRT-LLM, not including the pre/post-processing latency which is usually handled by a third-party library such as HuggingFace.
162162

163-
Identity test can generate traffic from 2 sources.
163+
benchmark_core_model can generate traffic from 2 sources.
164164
1 - dataset (json file containning prompts and optional responses)
165165
2 - token normal distribution (user specified input, output seqlen)
166166

@@ -171,11 +171,11 @@ cd tools/inflight_batcher_llm
171171
```
172172
Example: Run dataset with 10 req/sec requested rate with provided tokenizer.
173173
```
174-
python3 identity_test.py -i grpc --request_rate 10 dataset --dataset <dataset path> --tokenizer_dir <> --tokenizer_type <>
174+
python3 benchmark_core_model.py -i grpc --request_rate 10 dataset --dataset <dataset path> --tokenizer_dir <> --tokenizer_type <> --num_requests 5000
175175
```
176176
Example: Generate I/O seqlen tokens with input normal distribution with mean_seqlen=128, stdev=10. Output normal distribution with mean_seqlen=20, stdev=2. Set stdev=0 to get constant seqlens.
177177
```
178-
python3 identity_test.py -i grpc --request_rate 10 token_norm_dist --input_mean 128 --input_stdev 5 --output_mean 20 --output_stdev 2 --num_requests 5000
178+
python3 benchmark_core_model.py -i grpc --request_rate 10 token_norm_dist --input_mean 128 --input_stdev 5 --output_mean 20 --output_stdev 2 --num_requests 5000
179179
```
180180
Expected outputs
181181
```

inflight_batcher_llm/client/end_to_end_grpc_client.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@ def test(triton_client, prompt, request_id, repetition_penalty,
4747
input0 = [[prompt]]
4848
input0_data = np.array(input0).astype(object)
4949
output0_len = np.ones_like(input0).astype(np.uint32) * FLAGS.output_len
50-
bad_words_list = np.array([bad_words], dtype=object)
51-
stop_words_list = np.array([stop_words], dtype=object)
5250
streaming = [[FLAGS.streaming]]
5351
streaming_data = np.array(streaming, dtype=bool)
5452
beam_width = [[FLAGS.beam_width]]

0 commit comments

Comments
 (0)