Skip to content

Commit f70899b

Browse files
authored
Update TensorRT-LLM backend (#643)
1 parent cdc202a commit f70899b

File tree

21 files changed

+1338
-131
lines changed

21 files changed

+1338
-131
lines changed
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
<!--
2+
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions
6+
# are met:
7+
# * Redistributions of source code must retain the above copyright
8+
# notice, this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of NVIDIA CORPORATION nor the names of its
13+
# contributors may be used to endorse or promote products derived
14+
# from this software without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
17+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
20+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
24+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27+
!-->
28+
29+
# Running Disaggregated Serving with Triton TensorRT-LLM Backend
30+
31+
## Overview
32+
33+
Disaggregated serving refers to a technique that uses separate GPUs for
34+
running the context and generation phases of LLM inference.
35+
36+
For Triton integration, a BLS model named
37+
[_disaggregated\_serving\_bls_](./disaggregated_serving_bls/1/model.py)
38+
has been created that orchestrates the disaggregated serving pipeline. This
39+
BLS model requires the TRT-LLM model names that are going to be used for
40+
context and generation phases.
41+
42+
This example assumes access to a two GPU device systems with CUDA_VISIBLE_DEVICES
43+
set to `0,1`.
44+
45+
## Model Repository Setup and Start Server
46+
47+
1. Setup the model repository as instructed in the [LLaMa](../docs/llama.md)
48+
guide.
49+
50+
2. Create context and generation models with the desired tensor-parallel
51+
configuration. We will be using `context` and `generation` model names for
52+
context and generation models respectively. The context and generation models
53+
should be copying the config
54+
[tensorrt_llm](../inflight_batcher_llm/tensorrt_llm/) model.
55+
56+
3. Set the `participant_ids` for context and generation models to `1` and `2` respectively.
57+
58+
4. Set the `gpu_device_ids` for context and generation models to `0` and `1` respectively.
59+
60+
5. Set the `context_model_name` and `generation_model_name` to `context` and `generation` in the
61+
[disaggregated_serving_bls](./disaggregated_serving_bls/config.pbtxt) model configuration.
62+
63+
Your model repository should look like below:
64+
65+
```
66+
disaggreagted_serving/
67+
|-- context
68+
| |-- 1
69+
| `-- config.pbtxt
70+
|-- disaggregated_serving_bls
71+
| |-- 1
72+
| | `-- model.py
73+
| `-- config.pbtxt
74+
|-- ensemble
75+
| |-- 1
76+
| `-- config.pbtxt
77+
|-- generation
78+
| |-- 1
79+
| `-- config.pbtxt
80+
|-- postprocessing
81+
| |-- 1
82+
| | `-- model.py
83+
| `-- config.pbtxt
84+
`-- preprocessing
85+
|-- 1
86+
| `-- model.py
87+
`-- config.pbtxt
88+
```
89+
90+
6. Rename the `tensorrt_llm` model in the `ensemble` config.pbtxt file to `disaggregated_serving_bls`.
91+
92+
7. Launch the Triton Server:
93+
94+
```
95+
python3 scripts/launch_triton_server.py --world_size 3 --tensorrt_llm_model_name context,generation --multi-model --disable-spawn-processes
96+
```
97+
98+
> ![NOTE]
99+
>
100+
> The world size should be equal to `tp*pp` of context model + `tp*pp` of generation model + 1.
101+
> The additional process is required for the orchestrator.
102+
103+
6. Send a request to the server.
104+
105+
```
106+
python3 inflight_batcher_llm/client/end_to_end_grpc_client.py -S -p "Machine learning is"
107+
```
108+
109+
## Creating Multiple Copies of the Context and Generation Models (Data Parallelism)
110+
111+
You can also create multiple copies of the context and generation models. This can be
112+
achieved by setting the `participant_ids` and `gpu_device_ids` for each instance.
113+
114+
For example, if you have a context model with `tp=2` and you want to create 2
115+
copies of it, you can set the `participant_ids` to `1,2;3,4`,
116+
`gpu_device_ids` to `0,1;2,3` (assuming a 4-GPU system), and set the `count`
117+
in `instance_groups` section of the model configuration to 2. This will create 2
118+
copies of the context model where the first copy will be on GPU 0 and 1, and the
119+
second copy will be on GPU 2 and 3.
120+
121+
## Known Issues
122+
123+
1. Only C++ version of the backend is supported right now.
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import json
2+
3+
import triton_python_backend_utils as pb_utils
4+
5+
6+
def read_parameter_as_type(value, name, pytype=str):
7+
if value == "":
8+
return None
9+
if value.startswith("${") and value.endswith("}"):
10+
return None
11+
if pytype is bool:
12+
return value.lower() in ["1", "true"]
13+
try:
14+
result = pytype(value)
15+
return result
16+
except:
17+
pb_utils.Logger.log_warning(
18+
f"Could not read parameter '{name}' with value '{value}', will use default."
19+
)
20+
return None
21+
22+
23+
def get_parameter(model_config, name, pytype=str):
24+
if name not in model_config['parameters']:
25+
return None
26+
return read_parameter_as_type(
27+
model_config['parameters'][name]['string_value'], name, pytype)
28+
29+
30+
class TritonPythonModel:
31+
"""Your Python model must use the same class name. Every Python model
32+
that is created must have "TritonPythonModel" as the class name.
33+
"""
34+
35+
def initialize(self, args):
36+
"""`initialize` is called only once when the model is being loaded.
37+
Implementing `initialize` function is optional. This function allows
38+
the model to initialize any state associated with this model.
39+
40+
Parameters
41+
----------
42+
args : dict
43+
Both keys and values are strings. The dictionary keys and values are:
44+
* model_config: A JSON string containing the model configuration
45+
* model_instance_kind: A string containing model instance kind
46+
* model_instance_device_id: A string containing model instance device ID
47+
* model_repository: Model repository path
48+
* model_version: Model version
49+
* model_name: Model name
50+
"""
51+
model_config = json.loads(args['model_config'])
52+
self.context_model_name = get_parameter(model_config,
53+
"context_model_name")
54+
self.generation_model_name = get_parameter(model_config,
55+
"generation_model_name")
56+
self.decoupled = pb_utils.using_decoupled_model_transaction_policy(
57+
model_config)
58+
59+
def create_context_request(self, request):
60+
inputs = request.inputs()
61+
triton_request = pb_utils.InferenceRequest(
62+
model_name=self.context_model_name,
63+
inputs=inputs,
64+
parameters={"request_type": "context_only"},
65+
requested_output_names=[])
66+
return triton_request
67+
68+
def create_generation_request(self, request, context_response):
69+
inputs = request.inputs()
70+
context_phase_params = pb_utils.get_output_tensor_by_name(
71+
context_response, "context_phase_params")
72+
if context_phase_params is None:
73+
raise pb_utils.TritonModelException(
74+
"Context response must have an output named context phase params"
75+
)
76+
inputs.append(context_phase_params)
77+
triton_request = pb_utils.InferenceRequest(
78+
model_name=self.generation_model_name,
79+
inputs=inputs,
80+
parameters={"request_type": "generation_only"},
81+
requested_output_names=[])
82+
return triton_request
83+
84+
def execute(self, requests):
85+
"""`execute` must be implemented in every Python model. `execute`
86+
function receives a list of pb_utils.InferenceRequest as the only
87+
argument. This function is called when an inference is requested
88+
for this model.
89+
90+
Parameters
91+
----------
92+
requests : list
93+
A list of pb_utils.InferenceRequest
94+
95+
Returns
96+
-------
97+
list
98+
A list of pb_utils.InferenceResponse. The length of this list must
99+
be the same as `requests`
100+
"""
101+
for request in requests:
102+
context_request = self.create_context_request(request)
103+
context_responses = context_request.exec(decoupled=self.decoupled)
104+
if self.decoupled:
105+
context_responses = list(context_responses)
106+
assert len(
107+
context_responses) == 1, "Expected 1 context response"
108+
109+
if self.decoupled:
110+
context_response = context_responses[0]
111+
else:
112+
context_response = context_responses
113+
if context_response.has_error():
114+
raise pb_utils.TritonModelException(
115+
f"Context model {self.context_model_name} failed with error: {context_response.error().message()}"
116+
)
117+
generation_request = self.create_generation_request(
118+
request, context_response)
119+
120+
# TODO(itabrizian): Send the context response to reduce TTFT in decoupled case.
121+
# It requires adding the generated token to the generation request
122+
# to avoid sending the first token multiple times.
123+
responses = generation_request.exec(decoupled=self.decoupled)
124+
125+
if self.decoupled:
126+
for response in responses:
127+
if response.has_error():
128+
raise pb_utils.TritonModelException(
129+
f"Generation model {self.generation_model_name} failed with error: {response.error().message()}"
130+
)
131+
request.get_response_sender().send(response)
132+
133+
request.get_response_sender().send(
134+
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
135+
else:
136+
request.get_response_sender().send(
137+
responses,
138+
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)

0 commit comments

Comments
 (0)