Skip to content

Commit 75b0964

Browse files
authored
Update TensorRT-LLM backend (triton-inference-server#480)
* Update TensorRT-LLM backend
1 parent b93b948 commit 75b0964

File tree

12 files changed

+255
-183
lines changed

12 files changed

+255
-183
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ The following table shows the fields that may to be modified before deployment:
311311
| `lora_cache_max_adapter_size` | Optional (default=64) Used to set the minimum size of a cache page. Pages must be at least large enough to fit a single module, single later adapter_size `maxAdapterSize` row of weights. |
312312
| `lora_cache_gpu_memory_fraction` | Optional (default=0.05) Fraction of GPU memory used for LoRA cache. Computed as a fraction of left over memory after engine load, and after KV cache is loaded |
313313
| `lora_cache_host_memory_bytes` | Optional (default=1G) Size of host LoRA cache in bytes |
314+
| `gpu_weights_percent` | Optional (default=1.0). Set to a number between 0.0 and 1.0 to specify the percentage of weights that reside on GPU instead of CPU and streaming load during runtime. Values less than 1.0 are only supported for an engine built with `weight_streaming` on. |
314315

315316
*triton_model_repo/postprocessing/config.pbtxt*
316317

all_models/inflight_batcher_llm/tensorrt_llm/1/model.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def get_output_config_from_request(request, exclude_input_from_output):
134134
return trtllm.OutputConfig(**kwargs)
135135

136136

137-
def get_speculative_decoding_config_from_request(request):
137+
def get_external_draft_tokens_config_from_request(request):
138138
kwargs = {}
139139
draft_input_ids = get_input_tensor_by_name(request, 'draft_input_ids')
140140
if draft_input_ids is not None:
@@ -146,7 +146,7 @@ def get_speculative_decoding_config_from_request(request):
146146
request, 'draft_acceptance_threshold')
147147
kwargs = {k: v for k, v in kwargs.items() if v is not None}
148148
if len(kwargs) > 0:
149-
return trtllm.SpeculativeDecodingConfig(**kwargs)
149+
return trtllm.ExternalDraftTokensConfig(**kwargs)
150150
return None
151151

152152

@@ -211,7 +211,7 @@ def convert_request(request, exclude_input_from_output, decoupled):
211211
sampling_config = get_sampling_config_from_request(request)
212212
output_config = get_output_config_from_request(request,
213213
exclude_input_from_output)
214-
speculative_decoding_config = get_speculative_decoding_config_from_request(
214+
external_draft_tokens_config = get_external_draft_tokens_config_from_request(
215215
request)
216216
prompt_tuning_config = get_prompt_tuning_config_from_request(request)
217217
lora_config = get_lora_config_from_request(request)
@@ -220,7 +220,7 @@ def convert_request(request, exclude_input_from_output, decoupled):
220220
**inputs,
221221
sampling_config=sampling_config,
222222
output_config=output_config,
223-
speculative_decoding_config=speculative_decoding_config,
223+
external_draft_tokens_config=external_draft_tokens_config,
224224
prompt_tuning_config=prompt_tuning_config,
225225
lora_config=lora_config,
226226
)
@@ -295,18 +295,18 @@ def convert_batching_type(gpt_model_type: str):
295295
def convert_decoding_mode(decoding_mode: str):
296296
if decoding_mode is None:
297297
return None
298-
elif decoding_mode == "none":
299-
return trtllm.DecodingMode.NONE
298+
elif decoding_mode == "auto":
299+
return trtllm.DecodingMode.Auto()
300300
elif decoding_mode == "top_k":
301-
return trtllm.DecodingMode.TOP_K
301+
return trtllm.DecodingMode.TopK()
302302
elif decoding_mode == "top_p":
303-
return trtllm.DecodingMode.TOP_P
303+
return trtllm.DecodingMode.TopP()
304304
elif decoding_mode == "top_k_top_p":
305-
return trtllm.DecodingMode.TOP_K_TOP_P
305+
return trtllm.DecodingMode.TopKTopP()
306306
elif decoding_mode == "beam_search":
307-
return trtllm.DecodingMode.BEAM_SEARCH
307+
return trtllm.DecodingMode.BeamSearch()
308308
elif decoding_mode == "medusa":
309-
return trtllm.DecodingMode.MEDUSA
309+
return trtllm.DecodingMode.Medusa()
310310
raise pb_utils.TritonModelException(
311311
f"decoding_mode value of '{decoding_mode}' is not supported.")
312312

@@ -384,6 +384,19 @@ def get_peft_cache_config(self, model_config):
384384
kwargs = {k: v for k, v in kwargs.items() if v is not None}
385385
return trtllm.PeftCacheConfig(**kwargs)
386386

387+
def get_decoding_config(self, model_config):
388+
kwargs = {
389+
"medusa_choices":
390+
parse_medusa_choices(get_parameter(model_config,
391+
"medusa_choices")),
392+
"decoding_mode":
393+
convert_decoding_mode(get_parameter(model_config,
394+
"decoding_mode")),
395+
}
396+
print(kwargs)
397+
kwargs = {k: v for k, v in kwargs.items() if v is not None}
398+
return trtllm.DecodingConfig(**kwargs)
399+
387400
def get_executor_config(self, model_config):
388401
kwargs = {
389402
"max_beam_width":
@@ -403,12 +416,8 @@ def get_executor_config(self, model_config):
403416
self.get_parallel_config(model_config),
404417
"peft_cache_config":
405418
self.get_peft_cache_config(model_config),
406-
"medusa_choices":
407-
parse_medusa_choices(get_parameter(model_config,
408-
"medusa_choices")),
409-
"decoding_mode":
410-
convert_decoding_mode(get_parameter(model_config,
411-
"decoding_mode")),
419+
"decoding_config":
420+
self.get_decoding_config(model_config),
412421
}
413422
kwargs = {k: v for k, v in kwargs.items() if v is not None}
414423
return trtllm.ExecutorConfig(**kwargs)
@@ -553,22 +562,21 @@ def awaiter_loop(self):
553562
with self.lock:
554563
del self.triton_id_to_req_id[triton_id]
555564
del self.req_id_to_response_sender[req_id]
565+
# Remove local reference so response_sender can be cleaned properly.
566+
del response_sender
556567
# TODO: Read stats: https://jirasw.nvidia.com/browse/TRTLLM-563
557568

558569
def cancellation_loop(self):
559570
"""Checks if any pending requests have been cancelled."""
560571
while self.running:
561572
time.sleep(self.cancellation_check_period_ms / 1000.0)
562573
with self.lock:
563-
cancelled_ids = []
564574
for req_id, (triton_id, response_sender
565575
) in self.req_id_to_response_sender.items():
566576
if response_sender.is_cancelled():
567577
self.executor.cancel_request(req_id)
568-
cancelled_ids.append((req_id, triton_id))
569-
for req_id, triton_id in cancelled_ids:
570-
del self.triton_id_to_req_id[triton_id]
571-
del self.req_id_to_response_sender[req_id]
578+
# Remove local reference so response_sender can be cleaned properly.
579+
del response_sender
572580

573581
def finalize(self):
574582
"""`finalize` is called only once when the model is being unloaded.

all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,3 +513,9 @@ parameters: {
513513
string_value: "${medusa_choices}"
514514
}
515515
}
516+
parameters: {
517+
key: "gpu_weights_percent"
518+
value: {
519+
string_value: "${gpu_weights_percent}"
520+
}
521+
}

all_models/inflight_batcher_llm/tensorrt_llm_bls/1/lib/decode.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,17 +91,13 @@ def validate(self):
9191

9292
num_draft_tokens = _single_value(self.num_draft_tokens)
9393
stream = _single_value(self.stream)
94-
gen_logits = _single_value(self.return_generation_logits)
94+
_single_value(self.return_generation_logits)
9595
context_logits = _single_value(self.return_context_logits)
9696

9797
if num_draft_tokens:
9898
_validate_that(
9999
not stream,
100100
"streaming is not supported with speculative decoding")
101-
_validate_that(
102-
not gen_logits,
103-
"generation logits are not supported with speculative decoding"
104-
)
105101
_validate_that(
106102
not context_logits,
107103
"context logits are not supported with speculative decoding")

all_models/tests/test_python_backend.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -291,12 +291,12 @@ def test_convert_request(triton_request: MockTritonRequest):
291291
assert (converted.embedding_bias == torch.tensor([0., 0., 0.])).all()
292292
assert converted.logits_post_processor_name is None
293293

294-
assert isinstance(converted.speculative_decoding_config,
295-
trtllm.SpeculativeDecodingConfig)
296-
assert converted.speculative_decoding_config.tokens == [0, 1]
297-
assert (converted.speculative_decoding_config.logits == torch.tensor(
294+
assert isinstance(converted.external_draft_tokens_config,
295+
trtllm.ExternalDraftTokensConfig)
296+
assert converted.external_draft_tokens_config.tokens == [0, 1]
297+
assert (converted.external_draft_tokens_config.logits == torch.tensor(
298298
[[1.0, 2.0], [3.0, 4.0]])).all()
299-
assert converted.speculative_decoding_config.acceptance_threshold == 1.0
299+
assert converted.external_draft_tokens_config.acceptance_threshold == 1.0
300300

301301
assert isinstance(converted.prompt_tuning_config,
302302
trtllm.PromptTuningConfig)
@@ -345,7 +345,7 @@ def test_convert_request_minimal(triton_request_minimal: MockTritonRequest):
345345
assert converted.bad_words is None
346346
assert converted.embedding_bias is None
347347
assert converted.logits_post_processor_name is None
348-
assert converted.speculative_decoding_config is None
348+
assert converted.external_draft_tokens_config is None
349349
assert converted.prompt_tuning_config is None
350350
assert converted.lora_config is None
351351

@@ -468,14 +468,12 @@ def test_convert_batching_type():
468468

469469
def test_convert_decoding_mode():
470470
assert convert_decoding_mode(None) is None
471-
assert convert_decoding_mode("none") == trtllm.DecodingMode.NONE
472-
assert convert_decoding_mode("top_k") == trtllm.DecodingMode.TOP_K
473-
assert convert_decoding_mode("top_p") == trtllm.DecodingMode.TOP_P
474-
assert convert_decoding_mode(
475-
"top_k_top_p") == trtllm.DecodingMode.TOP_K_TOP_P
476-
assert convert_decoding_mode(
477-
"beam_search") == trtllm.DecodingMode.BEAM_SEARCH
478-
assert convert_decoding_mode("medusa") == trtllm.DecodingMode.MEDUSA
471+
assert convert_decoding_mode("auto").isAuto()
472+
assert convert_decoding_mode("top_k").isTopK()
473+
assert convert_decoding_mode("top_p").isTopP()
474+
assert convert_decoding_mode("top_k_top_p").isTopKandTopP()
475+
assert convert_decoding_mode("beam_search").isBeamSearch()
476+
assert convert_decoding_mode("medusa").isMedusa()
479477
with pytest.raises(
480478
Exception,
481479
match="decoding_mode value of 'other' is not supported"):
@@ -490,7 +488,7 @@ def model_config() -> Dict:
490488
"normalize_log_probs": "false",
491489
"gpt_model_type": "inflight_batching",
492490
"medusa_choices": "{1, 2, 3, 4}, {5, 6, 7}",
493-
"decoding_mode": "top_k_top_p",
491+
"decoding_mode": "medusa",
494492
"batch_scheduler_policy": "max_utilization",
495493
"enable_kv_cache_reuse": "true",
496494
"max_tokens_in_paged_kv_cache": "1",
@@ -516,8 +514,8 @@ def test_get_executor_config(model_config: Dict):
516514
assert config.enable_chunked_context == True
517515
assert config.normalize_log_probs == False
518516
assert config.batching_type == trtllm.BatchingType.INFLIGHT
519-
assert config.medusa_choices == [[1, 2, 3, 4], [5, 6, 7]]
520-
assert config.decoding_mode == trtllm.DecodingMode.TOP_K_TOP_P
517+
assert config.decoding_config.medusa_choices == [[1, 2, 3, 4], [5, 6, 7]]
518+
assert config.decoding_config.decoding_mode.isMedusa()
521519
assert config.scheduler_config.capacity_scheduler_policy == trtllm.CapacitySchedulerPolicy.MAX_UTILIZATION
522520
assert config.kv_cache_config.enable_block_reuse == True
523521
assert config.kv_cache_config.max_tokens == 1
@@ -556,8 +554,8 @@ def test_get_executor_config_minimal():
556554
assert config.enable_chunked_context == False
557555
assert config.normalize_log_probs == True
558556
assert config.batching_type == trtllm.BatchingType.INFLIGHT
559-
assert config.medusa_choices is None
560-
assert config.decoding_mode is None
557+
assert config.decoding_config.decoding_mode is None
558+
assert config.decoding_config.medusa_choices is None
561559
assert config.scheduler_config.capacity_scheduler_policy == trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT
562560
assert config.kv_cache_config.enable_block_reuse == False
563561
assert config.kv_cache_config.max_tokens is None

dockerfile/Dockerfile.trt_llm_backend

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
ARG BASE_IMAGE=nvcr.io/nvidia/tritonserver
2-
ARG BASE_TAG=24.03-py3
2+
ARG BASE_TAG=24.04-py3
33

44
FROM ${BASE_IMAGE}:${BASE_TAG} as base
55

@@ -37,9 +37,7 @@ RUN bash /tmp/install_mpi4py.sh && rm /tmp/install_mpi4py.sh
3737
# Use "pypi" (default) for x86_64 arch and "src_non_cxx11_abi" for aarch64 arch
3838
ARG TORCH_INSTALL_TYPE="pypi"
3939
COPY tensorrt_llm/docker/common/install_pytorch.sh install_pytorch.sh
40-
# Apply PyTorch patch for supporting compiling with CUDA 12.4 from source codes.
41-
COPY tensorrt_llm/docker/common/pytorch_pr_116072.patch /tmp/pytorch_pr_116072.patch
42-
RUN bash ./install_pytorch.sh $TORCH_INSTALL_TYPE && rm install_pytorch.sh /tmp/pytorch_pr_116072.patch
40+
RUN bash ./install_pytorch.sh $TORCH_INSTALL_TYPE && rm install_pytorch.sh
4341

4442
FROM dev as trt_llm_builder
4543

0 commit comments

Comments
 (0)