Skip to content

Commit

Permalink
Update TensorRT-LLM backend (triton-inference-server#504)
Browse files Browse the repository at this point in the history
Co-authored-by: Kaiyu Xie <[email protected]>
  • Loading branch information
Shixiaowei02 and kaiyux authored Jun 18, 2024
1 parent 566b4ff commit 62cd00f
Show file tree
Hide file tree
Showing 29 changed files with 745 additions and 63 deletions.
43 changes: 43 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,49 @@ nv_inference_compute_output_duration_us{model="tensorrt_llm",version="1"} 0
nv_inference_pending_request_count{model="tensorrt_llm",version="1"} 0
```

## Multi-instance Support

TensorRT-LLM backend relies on MPI to coordinate the execution of a model across multiple GPUs
and nodes. Currently, there are two different modes supported to run a model across multiple GPUs:

1. [Leader mode](#leader-mode)
2. [Orchestrator mode](#orchestrator-mode)

### Leader Mode

In leader mode, TensorRT-LLM backend spawns one Triton Server process for every
GPU. The process with rank 0 is the leader process. Other Triton Server processes,
do not return from the `TRITONBACKEND_ModelInstanceInitialize` call to avoid
port collision and allowing the other processes to receive requests.

The overview of this mode is described in the diagram below:

![Leader Mode Overview](./images/leader-mode.png)

This mode is friendly with [slurm](https://slurm.schedmd.com) deployments since
it doesn't use
[MPI_Comm_spawn](https://www.open-mpi.org/doc/v4.1/man3/MPI_Comm_spawn.3.php).

### Orchestrator Mode

In orchestrator mode, the TensorRT-LLM backend spawns a single Triton Server process
that acts as an orchestrator and spawns one Triton Server process for every
GPU that each model requires. This mode is mainly used when serving multiple models
with TensorRT-LLM backend. In this mode, the `MPI` world size must be one as
TRT-LLM backend will automatically create new workers as needed. The overview
of this mode is described in the diagram below:

![Orchestrator Mode Overview](./images/orchestrator-mode.png)

Since this mode uses [MPI_Comm_spawn](https://www.open-mpi.org/doc/v4.1/man3/MPI_Comm_spawn.3.php),
it might not work properly with [slurm](https://slurm.schedmd.com) deployments.
Additionally, this currently only works for single node deployments.

### Running Multiple Instances of LLaMa Model

Please refer to [Running Multiple Instances of the LLaMa Model](docs/llama_multi_instance.md)
for more information on running multiple instances of LLaMa model in different configurations.

## Testing the TensorRT-LLM Backend
Please follow the guide in [`ci/README.md`](ci/README.md) to see how to run
the testing for TensorRT-LLM backend.
3 changes: 2 additions & 1 deletion all_models/gpt/postprocessing/1/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def initialize(self, args):
legacy=False,
padding_side="left",
trust_remote_code=True)
self.tokenizer.pad_token = self.tokenizer.eos_token
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = self.tokenizer.eos_token

# Parse model output configs
output_config = pb_utils.get_output_config_by_name(
Expand Down
4 changes: 3 additions & 1 deletion all_models/gpt/preprocessing/1/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def initialize(self, args):
padding_side='left',
legacy=False,
trust_remote_code=True)
self.tokenizer.pad_token = self.tokenizer.eos_token
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = self.tokenizer.eos_token

self.pad_id = self.tokenizer.encode(self.tokenizer.pad_token,
add_special_tokens=False)[0]

Expand Down
26 changes: 26 additions & 0 deletions all_models/inflight_batcher_llm/ensemble/config.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ input [
data_type: TYPE_STRING
dims: [ -1 ]
},
{
name: "decoder_text_input"
data_type: TYPE_STRING
dims: [ -1 ]
optional: true
},
{
name: "max_tokens"
data_type: TYPE_INT32
Expand Down Expand Up @@ -207,6 +213,10 @@ ensemble_scheduling {
key: "QUERY"
value: "text_input"
}
input_map {
key: "DECODER_QUERY"
value: "decoder_text_input"
}
input_map {
key: "REQUEST_OUTPUT_LEN"
value: "max_tokens"
Expand Down Expand Up @@ -243,6 +253,14 @@ ensemble_scheduling {
key: "INPUT_ID"
value: "_INPUT_ID"
}
output_map {
key: "REQUEST_DECODER_INPUT_LEN"
value: "_REQUEST_DECODER_INPUT_LEN"
}
output_map {
key: "DECODER_INPUT_ID"
value: "_DECODER_INPUT_ID"
}
output_map {
key: "REQUEST_OUTPUT_LEN"
value: "_REQUEST_OUTPUT_LEN"
Expand Down Expand Up @@ -275,10 +293,18 @@ ensemble_scheduling {
key: "input_ids"
value: "_INPUT_ID"
}
input_map {
key: "decoder_input_ids"
value: "_DECODER_INPUT_ID"
}
input_map {
key: "input_lengths"
value: "_REQUEST_INPUT_LEN"
}
input_map {
key: "decoder_input_lengths"
value: "_REQUEST_DECODER_INPUT_LEN"
}
input_map {
key: "request_output_len"
value: "_REQUEST_OUTPUT_LEN"
Expand Down
3 changes: 2 additions & 1 deletion all_models/inflight_batcher_llm/postprocessing/1/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def initialize(self, args):
legacy=False,
padding_side='left',
trust_remote_code=True)
self.tokenizer.pad_token = self.tokenizer.eos_token
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = self.tokenizer.eos_token

# Parse model output configs
output_config = pb_utils.get_output_config_by_name(
Expand Down
33 changes: 29 additions & 4 deletions all_models/inflight_batcher_llm/preprocessing/1/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def initialize(self, args):
trust_remote_code=True)
if isinstance(self.tokenizer, T5Tokenizer):
self.tokenizer_bos_id = self.tokenizer.sp_model.bos_id()
self.tokenizer.pad_token = self.tokenizer.eos_token

if not self.tokenizer.pad_token:
self.tokenizer.pad_token = self.tokenizer.eos_token

self.tokenizer_end_id = self.tokenizer.encode(
self.tokenizer.eos_token, add_special_tokens=False)[0]
Expand All @@ -93,7 +95,8 @@ def initialize(self, args):

# Parse model output configs and convert Triton types to numpy types
output_names = [
"INPUT_ID", "REQUEST_INPUT_LEN", "BAD_WORDS_IDS", "STOP_WORDS_IDS",
"INPUT_ID", "DECODER_INPUT_ID", "REQUEST_INPUT_LEN",
"REQUEST_DECODER_INPUT_LEN", "BAD_WORDS_IDS", "STOP_WORDS_IDS",
"OUT_END_ID", "OUT_PAD_ID"
]
input_names = ["EMBEDDING_BIAS_WORDS", "EMBEDDING_BIAS_WEIGHTS"]
Expand Down Expand Up @@ -142,6 +145,11 @@ def execute(self, requests):
# Get input tensors
query = pb_utils.get_input_tensor_by_name(request,
'QUERY').as_numpy()
decoder_query = pb_utils.get_input_tensor_by_name(
request, 'DECODER_QUERY')
if decoder_query is not None:
decoder_query = decoder_query.as_numpy()

batch_dim = query.shape[0]
if batch_dim != 1:

Expand Down Expand Up @@ -194,6 +202,15 @@ def execute(self, requests):

# Preprocessing input data.
input_id, request_input_len = self._create_request(query)
print(input_id)
print(request_input_len)
if decoder_query is not None:
decoder_input_id, request_decoder_input_len = self._create_request(
decoder_query)
else:
decoder_input_id = pad_id * np.ones((1, 1), np.int32)
request_decoder_input_len = 1 * np.ones((1, 1), np.int32)

bad_words = self._to_word_list_format(bad_words_dict)
stop_words = self._to_word_list_format(stop_words_dict)

Expand All @@ -208,6 +225,13 @@ def execute(self, requests):
request_input_len_tensor = pb_utils.Tensor(
'REQUEST_INPUT_LEN',
request_input_len.astype(self.request_input_len_dtype))
decoder_input_id_tensor = pb_utils.Tensor(
'DECODER_INPUT_ID',
decoder_input_id.astype(self.decoder_input_id_dtype))
request_decoder_input_len_tensor = pb_utils.Tensor(
'REQUEST_DECODER_INPUT_LEN',
request_decoder_input_len.astype(
self.request_decoder_input_len_dtype))
request_output_len_tensor = pb_utils.Tensor(
'REQUEST_OUTPUT_LEN', request_output_len)
bad_words_ids_tensor = pb_utils.Tensor('BAD_WORDS_IDS', bad_words)
Expand All @@ -221,8 +245,9 @@ def execute(self, requests):
np.array(pad_id, dtype=np.int32))

inference_response = pb_utils.InferenceResponse(output_tensors=[
input_id_tensor, bad_words_ids_tensor, stop_words_ids_tensor,
request_input_len_tensor, request_output_len_tensor,
input_id_tensor, decoder_input_id_tensor, bad_words_ids_tensor,
stop_words_ids_tensor, request_input_len_tensor,
request_decoder_input_len_tensor, request_output_len_tensor,
embedding_bias_tensor, end_id_tensor, pad_id_tensor
])
responses.append(inference_response)
Expand Down
16 changes: 16 additions & 0 deletions all_models/inflight_batcher_llm/preprocessing/config.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ input [
data_type: TYPE_STRING
dims: [ -1 ]
},
{
name: "DECODER_QUERY"
data_type: TYPE_STRING
dims: [ -1 ]
optional: true
},
{
name: "REQUEST_OUTPUT_LEN"
data_type: TYPE_INT32
Expand Down Expand Up @@ -86,6 +92,16 @@ output [
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "DECODER_INPUT_ID"
data_type: TYPE_INT32
dims: [ -1 ]
},
{
name: "REQUEST_DECODER_INPUT_LEN"
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "BAD_WORDS_IDS"
data_type: TYPE_INT32
Expand Down
20 changes: 20 additions & 0 deletions all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,20 @@ input [
optional: true
allow_ragged_batch: true
},
{
name: "decoder_input_ids"
data_type: TYPE_INT32
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "decoder_input_lengths"
data_type: TYPE_INT32
dims: [ 1 ]
optional: true
reshape: { shape: [ ] }
},
{
name: "draft_logits"
data_type: TYPE_FP32
Expand Down Expand Up @@ -368,6 +382,12 @@ parameters: {
string_value: "${engine_dir}"
}
}
parameters: {
key: "encoder_model_path"
value: {
string_value: "${encoder_engine_dir}"
}
}
parameters: {
key: "max_tokens_in_paged_kv_cache"
value: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def _single_value(data: Optional[np.ndarray]):
@dataclass
class Request:
text_input: np.ndarray = np.array([])
decoder_text_input: np.ndarray = None
max_tokens: np.ndarray = np.array([])
bad_words: Optional[np.ndarray] = None
stop_words: Optional[np.ndarray] = None
Expand Down Expand Up @@ -112,7 +113,9 @@ class DraftRequest:
@dataclass
class PreprocResponse:
input_ids: np.ndarray = np.array([])
decoder_input_ids: np.ndarray = None
input_lengths: np.ndarray = np.array([])
decoder_input_lengths: np.ndarray = None
bad_words_list: Optional[np.ndarray] = None
stop_words_list: Optional[np.ndarray] = None
embedding_bias: Optional[np.ndarray] = None
Expand All @@ -129,6 +132,8 @@ def with_new_inputs(cls,
if input_ids is not None else other.input_ids),
input_lengths=(input_lengths if input_lengths is not None else
other.input_lengths),
decoder_input_ids=other.decoder_input_ids,
decoder_input_lengths=other.decoder_input_lengths,
bad_words_list=other.bad_words_list,
stop_words_list=other.stop_words_list,
end_id=other.end_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def __init__(self,

self._preproc_outputs = [
"INPUT_ID",
"DECODER_INPUT_ID",
"REQUEST_INPUT_LEN",
"REQUEST_DECODER_INPUT_LEN",
"BAD_WORDS_IDS",
"STOP_WORDS_IDS",
"EMBEDDING_BIAS",
Expand All @@ -73,6 +75,7 @@ def __init__(self,

self.input_names = [
"text_input",
"decoder_text_input",
"max_tokens",
"bad_words",
"stop_words",
Expand Down Expand Up @@ -217,6 +220,7 @@ def preprocess(self, request: Request) -> PreprocResponse:
def _get_preproc_tensors(self, request: Request):
name_map = {
"text_input": "QUERY",
"decoder_text_input": "DECODER_QUERY",
"max_tokens": "REQUEST_OUTPUT_LEN",
"bad_words": "BAD_WORDS_DICT",
"stop_words": "STOP_WORDS_DICT",
Expand All @@ -230,7 +234,9 @@ def _get_preproc_tensors(self, request: Request):
def _get_preproc_response(self, triton_output):
name_map = {
"INPUT_ID": "input_ids",
"DECODER_INPUT_ID": "decoder_input_ids",
"REQUEST_INPUT_LEN": "input_lengths",
"REQUEST_DECODER_INPUT_LEN": "decoder_input_lengths",
"BAD_WORDS_IDS": "bad_words_list",
"STOP_WORDS_IDS": "stop_words_list",
"EMBEDDING_BIAS": "embedding_bias",
Expand Down Expand Up @@ -303,6 +309,7 @@ def _get_llm_tensors(self,
def _get_tensors_from_preproc(self, preproc: PreprocResponse):
name_map = {
"input_ids": "input_ids",
"decoder_input_ids": "decoder_input_ids",
"input_lengths": "input_lengths",
"bad_words_list": "bad_words_list",
"stop_words_list": "stop_words_list",
Expand Down
6 changes: 6 additions & 0 deletions all_models/inflight_batcher_llm/tensorrt_llm_bls/config.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ input [
data_type: TYPE_STRING
dims: [ -1 ]
},
{
name: "decoder_text_input"
data_type: TYPE_STRING
dims: [ -1 ]
optional: true
},
{
name: "max_tokens"
data_type: TYPE_INT32
Expand Down
Loading

0 comments on commit 62cd00f

Please sign in to comment.