Skip to content

Commit cdc202a

Browse files
authored
Update TensorRT-LLM backend (#638)
1 parent c104768 commit cdc202a

File tree

15 files changed

+292
-69
lines changed

15 files changed

+292
-69
lines changed

all_models/inflight_batcher_llm/ensemble/config.pbtxt

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,47 @@ input [
187187
data_type: TYPE_FP32
188188
dims: [ -1 ]
189189
optional: true
190+
},
191+
# the unique task ID for the given LoRA.
192+
# To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given.
193+
# The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`.
194+
# If the cache is full the oldest LoRA will be evicted to make space for new ones. An error is returned if `lora_task_id` is not cached.
195+
{
196+
name: "lora_task_id"
197+
data_type: TYPE_UINT64
198+
dims: [ 1 ]
199+
optional: true
200+
},
201+
# weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ]
202+
# where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer
203+
# each of the in / out tensors are first flattened and then concatenated together in the format above.
204+
# D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out.
205+
{
206+
name: "lora_weights"
207+
data_type: TYPE_FP16
208+
dims: [ -1, -1 ]
209+
optional: true
210+
allow_ragged_batch: true
211+
},
212+
# module identifier (same size a first dimension of lora_weights)
213+
# See LoraModule::ModuleType for model id mapping
214+
#
215+
# "attn_qkv": 0 # compbined qkv adapter
216+
# "attn_q": 1 # q adapter
217+
# "attn_k": 2 # k adapter
218+
# "attn_v": 3 # v adapter
219+
# "attn_dense": 4 # adapter for the dense layer in attention
220+
# "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection
221+
# "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection
222+
# "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate
223+
#
224+
# last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ]
225+
{
226+
name: "lora_config"
227+
data_type: TYPE_INT32
228+
dims: [ -1, 3 ]
229+
optional: true
230+
allow_ragged_batch: true
190231
}
191232
]
192233
output [
@@ -430,7 +471,19 @@ ensemble_scheduling {
430471
input_map {
431472
key: "prompt_table_extra_ids"
432473
value: "_OUT_PROMPT_TABLE_EXTRA_IDS"
433-
}
474+
},
475+
input_map {
476+
key: "lora_task_id",
477+
value: "lora_task_id"
478+
},
479+
input_map {
480+
key: "lora_weights",
481+
value: "lora_weights"
482+
},
483+
input_map {
484+
key: "lora_config",
485+
value: "lora_config"
486+
},
434487
output_map {
435488
key: "output_ids"
436489
value: "_TOKENS_BATCH"

all_models/inflight_batcher_llm/tensorrt_llm_bls/1/lib/decode.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ class Request:
9090
random_seed: Optional[np.ndarray] = None
9191
presence_penalty: Optional[np.ndarray] = None
9292
frequency_penalty: Optional[np.ndarray] = None
93+
lora_task_id: Optional[np.ndarray] = None
94+
lora_weights: Optional[np.ndarray] = None
95+
lora_config: Optional[np.ndarray] = None
9396

9497
def validate(self):
9598
_validate_non_empty(self.text_input, "text_input is required")
@@ -263,6 +266,8 @@ def _spec_generate(
263266

264267
draft_request = None
265268
if num_draft_tokens > 0:
269+
request.min_length = np.array([num_draft_tokens],
270+
dtype=np.int32)
266271
draft_response: GenerationResponse = self._draft_generate_non_streaming(
267272
cur_preproc, request, num_draft_tokens)
268273
seq_len: int = draft_response.sequence_length[0][0]
@@ -275,12 +280,16 @@ def _spec_generate(
275280
draft_logits = draft_response.generation_logits[0][0]
276281

277282
input_draft_tokens = draft_output_ids[len(input_ids):seq_len]
278-
draft_request = DraftRequest(
279-
draft_input_ids=np.expand_dims(input_draft_tokens, 0))
280-
if request.use_draft_logits is not None and request.use_draft_logits[
281-
0]:
282-
draft_request.draft_logits = np.expand_dims(
283-
draft_logits[-len(input_draft_tokens):], 0)
283+
if len(input_draft_tokens) > 0:
284+
draft_request = DraftRequest(
285+
draft_input_ids=np.expand_dims(input_draft_tokens, 0))
286+
if request.use_draft_logits is not None and request.use_draft_logits[
287+
0]:
288+
draft_request.draft_logits = np.expand_dims(
289+
draft_logits[-len(input_draft_tokens):], 0)
290+
else:
291+
draft_request = DraftRequest()
292+
request.min_length = None
284293
else:
285294
draft_request = DraftRequest()
286295
target_response = self._generate_non_streaming(

all_models/inflight_batcher_llm/tensorrt_llm_bls/1/lib/triton_decoder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ def __init__(self,
108108
"embedding_bias_weights",
109109
"num_draft_tokens",
110110
"use_draft_logits",
111+
"lora_task_id",
112+
"lora_weights",
113+
"lora_config",
111114
]
112115

113116
self.__undo_reshape_whitelist = {
@@ -409,6 +412,9 @@ def _get_llm_tensors_from_request(
409412
"stream": "streaming",
410413
"prompt_embedding_table": "prompt_embedding_table",
411414
"prompt_vocab_size": "prompt_vocab_size",
415+
"lora_task_id": "lora_task_id",
416+
"lora_weights": "lora_weights",
417+
"lora_config": "lora_config",
412418
}
413419
batch_size = request.text_input.shape[0]
414420
tensors = self.create_triton_tensors(request, name_map)

all_models/inflight_batcher_llm/tensorrt_llm_bls/config.pbtxt

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,48 @@ input [
215215
dims: [ 1 ]
216216
reshape: { shape: [ ] }
217217
optional: true
218+
},
219+
# the unique task ID for the given LoRA.
220+
# To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given.
221+
# The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`.
222+
# If the cache is full the oldest LoRA will be evicted to make space for new ones. An error is returned if `lora_task_id` is not cached.
223+
{
224+
name: "lora_task_id"
225+
data_type: TYPE_UINT64
226+
dims: [ 1 ]
227+
reshape: { shape: [ ] }
228+
optional: true
229+
},
230+
# weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ]
231+
# where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer
232+
# each of the in / out tensors are first flattened and then concatenated together in the format above.
233+
# D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out.
234+
{
235+
name: "lora_weights"
236+
data_type: TYPE_FP16
237+
dims: [ -1, -1 ]
238+
optional: true
239+
allow_ragged_batch: true
240+
},
241+
# module identifier (same size a first dimension of lora_weights)
242+
# See LoraModule::ModuleType for model id mapping
243+
#
244+
# "attn_qkv": 0 # compbined qkv adapter
245+
# "attn_q": 1 # q adapter
246+
# "attn_k": 2 # k adapter
247+
# "attn_v": 3 # v adapter
248+
# "attn_dense": 4 # adapter for the dense layer in attention
249+
# "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection
250+
# "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection
251+
# "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate
252+
#
253+
# last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ]
254+
{
255+
name: "lora_config"
256+
data_type: TYPE_INT32
257+
dims: [ -1, 3 ]
258+
optional: true
259+
allow_ragged_batch: true
218260
}
219261
]
220262
output [

ci/L0_backend_trtllm/base_metrics_verification_tests.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,16 @@
2525
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2626
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2727
import json
28+
import os
2829
import sys
2930
from collections import defaultdict
3031

3132
import numpy as np
3233
import requests
3334

34-
sys.path.append("/opt/tritonserver/tensorrtllm_backend/tools/utils")
35+
BACKEND_ROOT = os.environ.get('BACKEND_ROOT',
36+
"/opt/tritonserver/tensorrtllm_backend")
37+
sys.path.append(os.path.join(BACKEND_ROOT, "tools/utils"))
3538
import unittest
3639

3740
import utils
@@ -75,9 +78,14 @@ def _run_infer(self, client, prompts, output_lens):
7578
utils.prepare_tensor("bad_words", bad_words_list, "http"),
7679
utils.prepare_tensor("stop_words", stop_words_list, "http"),
7780
]
81+
# Request minimal outputs
82+
outputs = utils.prepare_outputs("http")
7883

7984
async_requests.append(
80-
client.async_infer(model_name, inputs, request_id=str(i)))
85+
client.async_infer(model_name,
86+
inputs,
87+
outputs=outputs,
88+
request_id=str(i)))
8189

8290
try:
8391
utils.get_http_results(async_requests)

ci/L0_backend_trtllm/generate_engines.sh

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@
2525
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2626
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2727

28-
BASE_DIR=/opt/tritonserver/tensorrtllm_backend/ci/L0_backend_trtllm
29-
GPT_DIR=/opt/tritonserver/tensorrtllm_backend/tensorrt_llm/examples/gpt
30-
TRTLLM_DIR=/opt/tritonserver/tensorrtllm_backend/tensorrt_llm/
28+
BACKEND_ROOT=${BACKEND_ROOT:='/opt/tritonserver/tensorrtllm_backend'}
29+
BASE_DIR=${BACKEND_ROOT}/ci/L0_backend_trtllm
30+
GPT_DIR=${BACKEND_ROOT}/tensorrt_llm/examples/gpt
31+
TRTLLM_DIR=${BACKEND_ROOT}/tensorrt_llm/
3132

3233
function build_base_model {
3334
local NUM_GPUS=$1

0 commit comments

Comments
 (0)