Skip to content

Commit 0b2c6a8

Browse files
authored
Update TensorRT-LLM Backend (triton-inference-server#117)
* Update TensorRT-LLM Triton Backend ---------
1 parent 7a92137 commit 0b2c6a8

File tree

13 files changed

+630
-253
lines changed

13 files changed

+630
-253
lines changed

all_models/inflight_batcher_llm/ensemble/config.pbtxt

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
name: "ensemble"
2828
platform: "ensemble"
29-
max_batch_size: 128
29+
max_batch_size: ${triton_max_batch_size}
3030
input [
3131
{
3232
name: "text_input"
@@ -42,11 +42,13 @@ input [
4242
name: "bad_words"
4343
data_type: TYPE_STRING
4444
dims: [ -1 ]
45+
optional: true
4546
},
4647
{
4748
name: "stop_words"
4849
data_type: TYPE_STRING
4950
dims: [ -1 ]
51+
optional: true
5052
},
5153
{
5254
name: "end_id"
@@ -60,12 +62,6 @@ input [
6062
dims: [ 1 ]
6163
optional: true
6264
},
63-
{
64-
name: "embedding_bias"
65-
data_type: TYPE_FP16
66-
dims: [ -1 ]
67-
optional: true
68-
},
6965
{
7066
name: "top_k"
7167
data_type: TYPE_UINT32
@@ -137,6 +133,18 @@ input [
137133
data_type: TYPE_UINT32
138134
dims: [ 1 ]
139135
optional: true
136+
},
137+
{
138+
name: "embedding_bias_words"
139+
data_type: TYPE_STRING
140+
dims: [ -1 ]
141+
optional: true
142+
},
143+
{
144+
name: "embedding_bias_weights"
145+
data_type: TYPE_FP32
146+
dims: [ -1 ]
147+
optional: true
140148
}
141149
]
142150
output [
@@ -167,6 +175,14 @@ ensemble_scheduling {
167175
key: "STOP_WORDS_DICT"
168176
value: "stop_words"
169177
}
178+
input_map {
179+
key: "EMBEDDING_BIAS_WORDS"
180+
value: "embedding_bias_words"
181+
}
182+
input_map {
183+
key: "EMBEDDING_BIAS_WEIGHTS"
184+
value: "embedding_bias_weights"
185+
}
170186
output_map {
171187
key: "REQUEST_INPUT_LEN"
172188
value: "_REQUEST_INPUT_LEN"
@@ -187,6 +203,10 @@ ensemble_scheduling {
187203
key: "BAD_WORDS_IDS"
188204
value: "_BAD_WORDS_IDS"
189205
}
206+
output_map {
207+
key: "EMBEDDING_BIAS"
208+
value: "_EMBEDDING_BIAS"
209+
}
190210
},
191211
{
192212
model_name: "tensorrt_llm"
@@ -213,7 +233,7 @@ ensemble_scheduling {
213233
}
214234
input_map {
215235
key: "embedding_bias"
216-
value: "embedding_bias"
236+
value: "_EMBEDDING_BIAS"
217237
}
218238
input_map {
219239
key: "runtime_top_k"

all_models/inflight_batcher_llm/postprocessing/config.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
name: "postprocessing"
2828
backend: "python"
29-
max_batch_size: 128
29+
max_batch_size: ${triton_max_batch_size}
3030
input [
3131
{
3232
name: "TOKENS_BATCH"

all_models/inflight_batcher_llm/preprocessing/1/model.py

Lines changed: 86 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -78,17 +78,26 @@ def initialize(self, args):
7878
add_special_tokens=False)[0]
7979

8080
# Parse model output configs and convert Triton types to numpy types
81-
input_names = [
81+
output_names = [
8282
"INPUT_ID", "REQUEST_INPUT_LEN", "BAD_WORDS_IDS", "STOP_WORDS_IDS"
8383
]
84+
input_names = ["EMBEDDING_BIAS_WORDS", "EMBEDDING_BIAS_WEIGHTS"]
8485
for input_name in input_names:
8586
setattr(
8687
self,
8788
input_name.lower() + "_dtype",
8889
pb_utils.triton_string_to_numpy(
89-
pb_utils.get_output_config_by_name(
90+
pb_utils.get_input_config_by_name(
9091
model_config, input_name)['data_type']))
9192

93+
for output_name in output_names:
94+
setattr(
95+
self,
96+
output_name.lower() + "_dtype",
97+
pb_utils.triton_string_to_numpy(
98+
pb_utils.get_output_config_by_name(
99+
model_config, output_name)['data_type']))
100+
92101
def execute(self, requests):
93102
"""`execute` must be implemented in every Python model. `execute`
94103
function receives a list of pb_utils.InferenceRequest as the only
@@ -113,23 +122,54 @@ def execute(self, requests):
113122

114123
# Every Python backend must iterate over everyone of the requests
115124
# and create a pb_utils.InferenceResponse for each of them.
125+
logger = pb_utils.Logger
116126
for idx, request in enumerate(requests):
117127
# Get input tensors
118128
query = pb_utils.get_input_tensor_by_name(request,
119129
'QUERY').as_numpy()
130+
batch_dim = query.shape[0]
131+
if batch_dim != 1:
132+
133+
err_str = "Inflight batching backend expects requests with batch size of 1."
134+
logger.log_error(err_str)
135+
responses.append(
136+
pb_utils.InferenceResponse(
137+
output_tensors=[],
138+
error=pb_utils.TritonError(err_str)))
139+
continue
140+
120141
request_output_len = pb_utils.get_input_tensor_by_name(
121142
request, 'REQUEST_OUTPUT_LEN').as_numpy()
122143

123144
bad_words_dict = pb_utils.get_input_tensor_by_name(
124-
request, 'BAD_WORDS_DICT').as_numpy()
145+
request, 'BAD_WORDS_DICT')
146+
if bad_words_dict is not None:
147+
bad_words_dict = bad_words_dict.as_numpy()
148+
125149
stop_words_dict = pb_utils.get_input_tensor_by_name(
126-
request, 'STOP_WORDS_DICT').as_numpy()
150+
request, 'STOP_WORDS_DICT')
151+
if stop_words_dict is not None:
152+
stop_words_dict = stop_words_dict.as_numpy()
153+
154+
embedding_bias_words = pb_utils.get_input_tensor_by_name(
155+
request, 'EMBEDDING_BIAS_WORDS')
156+
if embedding_bias_words is not None:
157+
embedding_bias_words = embedding_bias_words.as_numpy()
158+
159+
embedding_bias_weights = pb_utils.get_input_tensor_by_name(
160+
request, 'EMBEDDING_BIAS_WEIGHTS')
161+
if embedding_bias_weights is not None:
162+
embedding_bias_weights = embedding_bias_weights.as_numpy()
127163

128164
# Preprocessing input data.
129165
input_id, request_input_len = self._create_request(query)
130166
bad_words = self._to_word_list_format(bad_words_dict)
131167
stop_words = self._to_word_list_format(stop_words_dict)
132168

169+
embedding_bias = self._get_embedding_bias(
170+
embedding_bias_words, embedding_bias_weights,
171+
self.embedding_bias_weights_dtype)
172+
133173
# Create output tensors. You need pb_utils.Tensor
134174
# objects to create pb_utils.InferenceResponse.
135175
input_id_tensor = pb_utils.Tensor(
@@ -142,17 +182,13 @@ def execute(self, requests):
142182
bad_words_ids_tensor = pb_utils.Tensor('BAD_WORDS_IDS', bad_words)
143183
stop_words_ids_tensor = pb_utils.Tensor('STOP_WORDS_IDS',
144184
stop_words)
185+
embedding_bias_tensor = pb_utils.Tensor('EMBEDDING_BIAS',
186+
embedding_bias)
145187

146-
# Create InferenceResponse. You can set an error here in case
147-
# there was a problem with handling this inference request.
148-
# Below is an example of how you can set errors in inference
149-
# response:
150-
#
151-
# pb_utils.InferenceResponse(
152-
# output_tensors=..., TritonError("An error occurred"))
153188
inference_response = pb_utils.InferenceResponse(output_tensors=[
154189
input_id_tensor, bad_words_ids_tensor, stop_words_ids_tensor,
155-
request_input_len_tensor, request_output_len_tensor
190+
request_input_len_tensor, request_output_len_tensor,
191+
embedding_bias_tensor
156192
])
157193
responses.append(inference_response)
158194

@@ -200,6 +236,10 @@ def _to_word_list_format(self, word_dict: List[List[str]]):
200236
'''
201237
assert self.tokenizer != None, "need to set tokenizer"
202238

239+
if word_dict is None:
240+
# Return an empty array of shape (1,2,0)
241+
return np.empty([1, 2, 0], dtype="int32")
242+
203243
flat_ids = []
204244
offsets = []
205245
for word_dict_item in word_dict:
@@ -232,3 +272,37 @@ def _to_word_list_format(self, word_dict: List[List[str]]):
232272

233273
return np.array([flat_ids, offsets], dtype="int32").transpose(
234274
(1, 0, 2))
275+
276+
def _get_embedding_bias(self, embedding_bias_words, embedding_bias_weights,
277+
bias_dtype):
278+
279+
assert self.tokenizer != None, "need to set tokenizer"
280+
281+
if embedding_bias_words is None or embedding_bias_weights is None:
282+
return np.empty([1, 0], dtype=self.embedding_bias_weights_dtype)
283+
284+
batch_embedding_bias = []
285+
for words, weights in zip(embedding_bias_words,
286+
embedding_bias_weights):
287+
288+
vocab_size = self.tokenizer.vocab_size
289+
embedding_bias = [0.] * vocab_size
290+
291+
assert len(words) == len(
292+
weights
293+
), "Embedding bias words must have same dimension as embedding bias weights"
294+
295+
for word, weight in zip(words, weights):
296+
if isinstance(word, bytes):
297+
word = word.decode()
298+
ids = self.tokenizer.encode(word)
299+
300+
if len(ids) == 0:
301+
continue
302+
303+
for id in ids:
304+
embedding_bias[id] += weight
305+
306+
batch_embedding_bias.append(np.array(embedding_bias))
307+
308+
return np.array(batch_embedding_bias, dtype=bias_dtype)

all_models/inflight_batcher_llm/preprocessing/config.pbtxt

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,27 +26,41 @@
2626

2727
name: "preprocessing"
2828
backend: "python"
29-
max_batch_size: 128
29+
max_batch_size: ${triton_max_batch_size}
3030
input [
3131
{
3232
name: "QUERY"
3333
data_type: TYPE_STRING
3434
dims: [ -1 ]
3535
},
36+
{
37+
name: "REQUEST_OUTPUT_LEN"
38+
data_type: TYPE_UINT32
39+
dims: [ -1 ]
40+
},
3641
{
3742
name: "BAD_WORDS_DICT"
3843
data_type: TYPE_STRING
3944
dims: [ -1 ]
45+
optional: true
4046
},
4147
{
4248
name: "STOP_WORDS_DICT"
4349
data_type: TYPE_STRING
4450
dims: [ -1 ]
51+
optional: true
4552
},
4653
{
47-
name: "REQUEST_OUTPUT_LEN"
48-
data_type: TYPE_UINT32
54+
name: "EMBEDDING_BIAS_WORDS"
55+
data_type: TYPE_STRING
56+
dims: [ -1 ]
57+
optional: true
58+
},
59+
{
60+
name: "EMBEDDING_BIAS_WEIGHTS"
61+
data_type: TYPE_FP32
4962
dims: [ -1 ]
63+
optional: true
5064
}
5165
]
5266
output [
@@ -70,6 +84,11 @@ output [
7084
data_type: TYPE_INT32
7185
dims: [ 2, -1 ]
7286
},
87+
{
88+
name: "EMBEDDING_BIAS"
89+
data_type: TYPE_FP32
90+
dims: [ -1 ]
91+
},
7392
{
7493
name: "REQUEST_OUTPUT_LEN"
7594
data_type: TYPE_UINT32

all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,23 @@
2626

2727
name: "tensorrt_llm"
2828
backend: "tensorrtllm"
29-
max_batch_size: 128
29+
max_batch_size: ${triton_max_batch_size}
3030

3131
model_transaction_policy {
3232
decoupled: ${decoupled_mode}
3333
}
3434

35+
dynamic_batching {
36+
preferred_batch_size: [ ${triton_max_batch_size} ]
37+
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
38+
}
39+
3540
input [
3641
{
3742
name: "input_ids"
3843
data_type: TYPE_INT32
3944
dims: [ -1 ]
45+
allow_ragged_batch: true
4046
},
4147
{
4248
name: "input_lengths"
@@ -68,18 +74,21 @@ input [
6874
data_type: TYPE_INT32
6975
dims: [ 2, -1 ]
7076
optional: true
77+
allow_ragged_batch: true
7178
},
7279
{
7380
name: "bad_words_list"
7481
data_type: TYPE_INT32
7582
dims: [ 2, -1 ]
7683
optional: true
84+
allow_ragged_batch: true
7785
},
7886
{
7987
name: "embedding_bias"
80-
data_type: TYPE_FP16
88+
data_type: TYPE_FP32
8189
dims: [ -1 ]
8290
optional: true
91+
allow_ragged_batch: true
8392
},
8493
{
8594
name: "beam_width"
@@ -161,6 +170,7 @@ input [
161170
data_type: TYPE_FP16
162171
dims: [ -1, -1 ]
163172
optional: true
173+
allow_ragged_batch: true
164174
},
165175
{
166176
name: "prompt_vocab_size"
@@ -191,7 +201,7 @@ instance_group [
191201
parameters: {
192202
key: "max_beam_width"
193203
value: {
194-
string_value: "1"
204+
string_value: "${max_beam_width}"
195205
}
196206
}
197207
parameters: {
@@ -218,6 +228,12 @@ parameters: {
218228
string_value: "${max_tokens_in_paged_kv_cache}"
219229
}
220230
}
231+
parameters: {
232+
key: "max_kv_cache_length"
233+
value: {
234+
string_value: "${max_kv_cache_length}"
235+
}
236+
}
221237
parameters: {
222238
key: "batch_scheduler_policy"
223239
value: {

0 commit comments

Comments
 (0)