Skip to content

Commit 06f63fe

Browse files
authored
* Update * Update doc for lfs usage * Update TensorRT-LLM submodule
1 parent 99de6ed commit 06f63fe

File tree

21 files changed

+721
-167
lines changed

21 files changed

+721
-167
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,8 @@ don't need by removing the corresponding flags.
101101
```bash
102102
# Update the submodules
103103
cd tensorrtllm_backend
104-
git submodule update --init --recursive
105104
git lfs install
106-
git lfs pull
105+
git submodule update --init --recursive
107106

108107
# Use the Dockerfile to build the backend in a container
109108
# For x86_64

all_models/inflight_batcher_llm/ensemble/config.pbtxt

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ output [
125125
{
126126
name: "text_output"
127127
data_type: TYPE_STRING
128-
dims: [ -1, -1 ]
128+
dims: [ -1 ]
129129
}
130130
]
131131
ensemble_scheduling {
@@ -229,6 +229,10 @@ ensemble_scheduling {
229229
key: "output_ids"
230230
value: "_TOKENS_BATCH"
231231
}
232+
output_map {
233+
key: "sequence_length"
234+
value: "_SEQUENCE_LENGTH"
235+
}
232236
},
233237
{
234238
model_name: "postprocessing"
@@ -237,6 +241,10 @@ ensemble_scheduling {
237241
key: "TOKENS_BATCH"
238242
value: "_TOKENS_BATCH"
239243
}
244+
input_map {
245+
key: "SEQUENCE_LENGTH"
246+
value: "_SEQUENCE_LENGTH"
247+
}
240248
output_map {
241249
key: "OUTPUT"
242250
value: "text_output"

all_models/inflight_batcher_llm/postprocessing/1/model.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,16 @@ def execute(self, requests):
109109
tokens_batch = pb_utils.get_input_tensor_by_name(
110110
request, 'TOKENS_BATCH').as_numpy()
111111

112+
# Get sequence length
113+
sequence_lengths = pb_utils.get_input_tensor_by_name(
114+
request, 'SEQUENCE_LENGTH').as_numpy()
115+
112116
# Reshape Input
113117
# tokens_batch = tokens_batch.reshape([-1, tokens_batch.shape[0]])
114118
# tokens_batch = tokens_batch.T
115119

116120
# Postprocessing output data.
117-
outputs = self._postprocessing(tokens_batch)
121+
outputs = self._postprocessing(tokens_batch, sequence_lengths)
118122

119123
# Create output tensors. You need pb_utils.Tensor
120124
# objects to create pb_utils.InferenceResponse.
@@ -144,10 +148,11 @@ def finalize(self):
144148
"""
145149
print('Cleaning up...')
146150

147-
def _postprocessing(self, tokens_batch):
151+
def _postprocessing(self, tokens_batch, sequence_lengths):
148152
outputs = []
149-
for beam_tokens in tokens_batch:
150-
for tokens in beam_tokens:
151-
output = self.tokenizer.decode(tokens)
153+
for batch_idx, beam_tokens in enumerate(tokens_batch):
154+
for beam_idx, tokens in enumerate(beam_tokens):
155+
seq_len = sequence_lengths[batch_idx][beam_idx]
156+
output = self.tokenizer.decode(tokens[:seq_len])
152157
outputs.append(output.encode('utf8'))
153158
return outputs

all_models/inflight_batcher_llm/postprocessing/config.pbtxt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,18 @@ input [
3232
name: "TOKENS_BATCH"
3333
data_type: TYPE_INT32
3434
dims: [ -1, -1 ]
35+
},
36+
{
37+
name: "SEQUENCE_LENGTH"
38+
data_type: TYPE_INT32
39+
dims: [ -1 ]
3540
}
3641
]
3742
output [
3843
{
3944
name: "OUTPUT"
4045
data_type: TYPE_STRING
41-
dims: [ -1, -1 ]
46+
dims: [ -1 ]
4247
}
4348
]
4449

all_models/inflight_batcher_llm/preprocessing/1/model.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@
2929
from typing import List
3030

3131
import numpy as np
32-
import torch
3332
import triton_python_backend_utils as pb_utils
34-
from torch.nn.utils.rnn import pad_sequence
3533
from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer
3634

3735

@@ -135,12 +133,10 @@ def execute(self, requests):
135133
# Create output tensors. You need pb_utils.Tensor
136134
# objects to create pb_utils.InferenceResponse.
137135
input_id_tensor = pb_utils.Tensor(
138-
'INPUT_ID',
139-
np.array(input_id).astype(self.input_id_dtype))
136+
'INPUT_ID', input_id.astype(self.input_id_dtype))
140137
request_input_len_tensor = pb_utils.Tensor(
141138
'REQUEST_INPUT_LEN',
142-
np.array(request_input_len).astype(
143-
self.request_input_len_dtype))
139+
request_input_len.astype(self.request_input_len_dtype))
144140
request_output_len_tensor = pb_utils.Tensor(
145141
'REQUEST_OUTPUT_LEN', request_output_len)
146142
bad_words_ids_tensor = pb_utils.Tensor('BAD_WORDS_IDS', bad_words)
@@ -176,16 +172,19 @@ def _create_request(self, query):
176172
query : batch string (2D numpy array)
177173
"""
178174
start_ids = [
179-
torch.IntTensor(self.tokenizer.encode(s[0].decode()))
175+
np.array(self.tokenizer.encode(s[0].decode())).astype(int)
180176
for s in query
181177
]
182-
start_lengths = torch.IntTensor([[len(ids)] for ids in start_ids])
183-
184-
start_ids = pad_sequence(start_ids,
185-
batch_first=True,
186-
padding_value=self.pad_id)
187-
# input_len = min(start_lengths)
188-
#attn_mask = torch.ones((batch_size, input_len, input_len)).tril()
178+
start_lengths = np.array([[len(ids)] for ids in start_ids]).astype(int)
179+
180+
max_len = 0
181+
for seq in start_ids:
182+
max_len = max(max_len, seq.shape[0])
183+
start_ids = np.stack([
184+
np.pad(seq, (0, max_len - seq.shape[0]),
185+
'constant',
186+
constant_values=(0, self.pad_id)) for seq in start_ids
187+
])
189188

190189
return start_ids, start_lengths
191190

all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@ output [
144144
name: "output_ids"
145145
data_type: TYPE_INT32
146146
dims: [ -1, -1 ]
147+
},
148+
{
149+
name: "sequence_length"
150+
data_type: TYPE_INT32
151+
dims: [ -1 ]
147152
}
148153
]
149154
instance_group [
@@ -167,7 +172,7 @@ parameters: {
167172
parameters: {
168173
key: "gpt_model_type"
169174
value: {
170-
string_value: "inflight_fused_batching"
175+
string_value: "${batching_strategy}"
171176
}
172177
}
173178
parameters: {

dockerfile/Dockerfile.trt_llm_backend

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ RUN pip uninstall -y tensorrt
1515

1616
FROM base as dev
1717

18+
ENV SHINIT_FILE=${BASH_ENV}
19+
1820
# Download & install internal TRT release
1921
COPY tensorrt_llm/docker/common/install_tensorrt.sh /tmp/
2022
RUN bash /tmp/install_tensorrt.sh && rm /tmp/install_tensorrt.sh

inflight_batcher_llm/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,6 @@ set_ifndef(TRT_INCLUDE_DIR /usr/include/${CMAKE_SYSTEM_PROCESSOR}-linux-gnu)
190190

191191
set(TRT_LIB nvinfer)
192192
find_library_create_target(${TRT_LIB} nvinfer SHARED ${TRT_LIB_DIR})
193-
find_library_create_target(nvuffparser nvparsers SHARED ${TRT_LIB_DIR})
194193

195194
file(STRINGS "${TRT_INCLUDE_DIR}/NvInferVersion.h" VERSION_STRINGS
196195
REGEX "#define NV_TENSORRT_.*")
@@ -311,6 +310,7 @@ target_link_libraries(
311310
triton-core-serverstub # from repo-core
312311
triton-backend-utils # from repo-backend
313312
${MPI_LIBRARIES}
313+
${CUDA_LIBRARIES}
314314
nvinfer
315315
nvinfer_plugin_tensorrt_llm)
316316

inflight_batcher_llm/client/inflight_batcher_llm_client.py

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ def prepare_tensor(name, input, protocol):
7474

7575

7676
def prepare_inputs(input_ids_data, input_lengths_data, request_output_len_data,
77-
beam_width_data, temperature_data, streaming_data):
77+
beam_width_data, temperature_data, streaming_data, end_id,
78+
pad_id):
7879
protocol = 'grpc'
7980
inputs = [
8081
prepare_tensor("input_ids", input_ids_data, protocol),
@@ -84,6 +85,8 @@ def prepare_inputs(input_ids_data, input_lengths_data, request_output_len_data,
8485
prepare_tensor("beam_width", beam_width_data, protocol),
8586
prepare_tensor("temperature", temperature_data, protocol),
8687
prepare_tensor("streaming", streaming_data, protocol),
88+
prepare_tensor("end_id", end_id, protocol),
89+
prepare_tensor("pad_id", pad_id, protocol),
8790
]
8891

8992
return inputs
@@ -118,8 +121,9 @@ def callback(user_data, result, error):
118121
user_data._completed_requests.put(result)
119122
if (FLAGS.streaming):
120123
output_ids = result.as_numpy('output_ids')
121-
tokens = list(output_ids[0][0])
122-
print(tokens, flush=True)
124+
if output_ids != None:
125+
tokens = list(output_ids[0][0])
126+
print(tokens, flush=True)
123127

124128

125129
if __name__ == "__main__":
@@ -275,6 +279,8 @@ def callback(user_data, result, error):
275279
tokenizer.pad_token = tokenizer.eos_token
276280
pad_id = tokenizer.encode(tokenizer.pad_token, add_special_tokens=False)[0]
277281
end_id = tokenizer.encode(tokenizer.eos_token, add_special_tokens=False)[0]
282+
end_id_data = np.array([[end_id]], dtype=np.uint32)
283+
pad_id_data = np.array([[pad_id]], dtype=np.uint32)
278284

279285
input_ids = [tokenizer.encode(FLAGS.text)]
280286
input_ids_data = np.array(input_ids, dtype=np.int32)
@@ -291,7 +297,8 @@ def callback(user_data, result, error):
291297

292298
inputs = prepare_inputs(input_ids_data, input_lengths_data,
293299
request_output_len_data, beam_width_data,
294-
temperature_data, streaming_data)
300+
temperature_data, streaming_data, end_id_data,
301+
pad_id_data)
295302

296303
if FLAGS.stop_after_ms > 0:
297304
stop_inputs = prepare_stop_signals()
@@ -300,17 +307,18 @@ def callback(user_data, result, error):
300307

301308
request_id = FLAGS.request_id
302309

303-
expected_output_ids = [
304-
input_ids[0] + [
305-
21221, 290, 257, 4255, 379, 262, 1957, 7072, 11, 4689, 347, 2852,
306-
2564, 494, 13, 679
307-
]
310+
expected_output_ids = input_ids[0] + [
311+
21221, 290, 257, 4255, 379, 262, 1957, 7072, 11, 4689, 347, 2852, 2564,
312+
494, 13, 679
308313
]
314+
309315
if FLAGS.streaming:
310316
actual_output_ids = [input_ids[0]]
311317
else:
312318
actual_output_ids = []
313319

320+
sequence_lengths = []
321+
314322
user_data = UserData()
315323
with grpcclient.InferenceServerClient(
316324
url=FLAGS.url,
@@ -361,17 +369,12 @@ def callback(user_data, result, error):
361369
print(result)
362370
else:
363371
output_ids = result.as_numpy('output_ids')
364-
372+
sequence_lengths = result.as_numpy('sequence_length')
365373
if output_ids is not None:
366-
if (FLAGS.streaming):
367-
# Only one beam is supported
368-
tokens = list(output_ids[0][0])
369-
actual_output_ids[
370-
0] = actual_output_ids[0] + tokens
371-
else:
372-
for beam_output_ids in output_ids[0]:
373-
tokens = list(beam_output_ids)
374-
actual_output_ids.append(tokens)
374+
# Only one beam is supported
375+
tokens = list(output_ids[0][0])
376+
actual_output_ids[
377+
0] = actual_output_ids[0] + tokens
375378
else:
376379
print("Got cancellation response from server")
377380
else:
@@ -408,12 +411,13 @@ def callback(user_data, result, error):
408411
print(result)
409412
else:
410413
output_ids = result.as_numpy('output_ids')
414+
sequence_lengths = result.as_numpy('sequence_length')
411415
if output_ids is not None:
412416
for beam_output_ids in output_ids[0]:
413417
tokens = list(beam_output_ids)
414418
actual_output_ids.append(tokens)
415419
else:
416-
print("Got response for cancellation request")
420+
print("Got cancellation response from server")
417421

418422
processed_count = processed_count + 1
419423
except Exception as e:
@@ -422,18 +426,21 @@ def callback(user_data, result, error):
422426

423427
passed = True
424428

425-
print("output_ids = ", actual_output_ids)
426-
output_ids = np.array(actual_output_ids)
427-
output_ids = output_ids.reshape(
428-
(output_ids.size, )).tolist()[input_ids_data.shape[1]:]
429-
output_text = tokenizer.decode(output_ids)
430-
print(f'Input: {FLAGS.text}')
431-
print(f'Output: {output_text}')
432-
if (FLAGS.check_output):
433-
passed = (actual_output_ids == expected_output_ids)
434-
print("expected_output_ids = ", expected_output_ids)
435-
print("\n=====")
436-
print("PASS!" if passed else "FAIL!")
437-
print("=====")
429+
for beam in range(FLAGS.beam_width):
430+
seq_len = sequence_lengths[0][
431+
beam] if not FLAGS.streaming else len(actual_output_ids[beam])
432+
output_ids_w_prompt = actual_output_ids[beam][:seq_len]
433+
output_ids_wo_prompt = output_ids_w_prompt[input_ids_data.
434+
shape[1]:]
435+
output_text = tokenizer.decode(output_ids_wo_prompt)
436+
print(f'Input: {FLAGS.text}')
437+
print(f'Output beam {beam}: {output_text}')
438+
if (FLAGS.check_output and beam == 0):
439+
passed = (output_ids_w_prompt == expected_output_ids)
440+
print("output_ids = ", output_ids_w_prompt)
441+
print("expected_output_ids = ", expected_output_ids)
442+
print("\n=====")
443+
print("PASS!" if passed else "FAIL!")
444+
print("=====")
438445

439446
sys.exit(not passed)

0 commit comments

Comments
 (0)