Skip to content

Commit f51f50c

Browse files
authored
Update TensorRT-LLM backend main branch (triton-inference-server#264)
* Update TensorRT-LLM backend
1 parent 3a61c37 commit f51f50c

File tree

9 files changed

+121
-35
lines changed

9 files changed

+121
-35
lines changed

README.md

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ cp tensorrt_llm/examples/gpt/engines/fp16/4-gpu/* triton_model_repo/tensorrt_llm
196196
```
197197

198198
### Modify the model configuration
199-
The following table shows the fields that need to be modified before deployment:
199+
The following table shows the fields that may to be modified before deployment:
200200

201201
*triton_model_repo/preprocessing/config.pbtxt*
202202

@@ -209,17 +209,18 @@ The following table shows the fields that need to be modified before deployment:
209209

210210
| Name | Description
211211
| :----------------------: | :-----------------------------: |
212-
| `decoupled` | Controls streaming. Decoupled mode must be set to `True` if using the streaming option from the client. |
213-
| `max_beam_width` | The maximum beam width that any request may ask for when using beam search |
214-
| `gpt_model_type` | Set to `inflight_fused_batching` when enabling in-flight batching support. To disable in-flight batching, set to `V1` |
215-
| `gpt_model_path` | Path to the TensorRT-LLM engines for deployment. In this example, the path should be set to `/tensorrtllm_backend/triton_model_repo/tensorrt_llm/1` as the tensorrtllm_backend directory will be mounted to `/tensorrtllm_backend` within the container |
216-
| `max_tokens_in_paged_kv_cache` | The maximum size of the KV cache in number of tokens |
217-
| `max_attention_window_size` | When using techniques like sliding window attention, the maximum number of tokens that are attended to generate one token. Defaults to maximum sequence length |
218-
| `batch_scheduler_policy` | Set to `max_utilization` to greedily pack as many requests as possible in each current in-flight batching iteration. This maximizes the throughput but may result in overheads due to request pause/resume if KV cache limits are reached during execution. Set to `guaranteed_no_evict` to guarantee that a started request is never paused.|
219-
| `kv_cache_free_gpu_mem_fraction` | Set to a number between 0 and 1 to indicate the maximum fraction of GPU memory (after loading the model) that may be used for KV cache|
220-
| `max_num_sequences` | Maximum number of sequences that the in-flight batching scheme can maintain state for. Defaults to `max_batch_size` if `enable_trt_overlap` is `false` and to `2 * max_batch_size` if `enable_trt_overlap` is `true`, where `max_batch_size` is the TRT engine maximum batch size.
221-
| `enable_trt_overlap` | Set to `true` to partition available requests into 2 'microbatches' that can be run concurrently to hide exposed CPU runtime |
222-
| `exclude_input_in_output` | Set to `true` to only return completion tokens in a response. Set to `false` to return the prompt tokens concatenated with the generated tokens |
212+
| `gpt_model_type` | Mandatory. Set to `inflight_fused_batching` when enabling in-flight batching support. To disable in-flight batching, set to `V1` |
213+
| `gpt_model_path` | Mandatory. Path to the TensorRT-LLM engines for deployment. In this example, the path should be set to `/tensorrtllm_backend/triton_model_repo/tensorrt_llm/1` as the tensorrtllm_backend directory will be mounted to `/tensorrtllm_backend` within the container |
214+
| `batch_scheduler_policy` | Mandatory. Set to `max_utilization` to greedily pack as many requests as possible in each current in-flight batching iteration. This maximizes the throughput but may result in overheads due to request pause/resume if KV cache limits are reached during execution. Set to `guaranteed_no_evict` to guarantee that a started request is never paused.|
215+
| `decoupled` | Optional (default=`false`). Controls streaming. Decoupled mode must be set to `True` if using the streaming option from the client. |
216+
| `max_beam_width` | Optional (default=1). The maximum beam width that any request may ask for when using beam search.|
217+
| `max_tokens_in_paged_kv_cache` | Optional (default=unspecified). The maximum size of the KV cache in number of tokens. If unspecified, value is interpreted as 'infinite'. KV cache allocation is the min of max_tokens_in_paged_kv_cache and value derived from kv_cache_free_gpu_mem_fraction below. |
218+
| `max_attention_window_size` | Optional (default=max_sequence_length). When using techniques like sliding window attention, the maximum number of tokens that are attended to generate one token. Defaults attends to all tokens in sequence. |
219+
| `kv_cache_free_gpu_mem_fraction` | Optional (default=0.85). Set to a number between 0 and 1 to indicate the maximum fraction of GPU memory (after loading the model) that may be used for KV cache.|
220+
| `max_num_sequences` | Optional (default=`max_batch_size` if `enable_trt_overlap` is `false` and to `2 * max_batch_size` if `enable_trt_overlap` is `true`, where `max_batch_size` is the TRT engine maximum batch size). Maximum number of sequences that the in-flight batching scheme can maintain state for.
221+
| `enable_trt_overlap` | Optional (default=`true`). Set to `true` to partition available requests into 2 'microbatches' that can be run concurrently to hide exposed CPU runtime |
222+
| `exclude_input_in_output` | Optional (default=`false`). Set to `true` to only return completion tokens in a response. Set to `false` to return the prompt tokens concatenated with the generated tokens |
223+
| `normalize_log_probs` | Optional (default=`true`). Set to `false` to skip normalization of `output_log_probs` |
223224

224225
*triton_model_repo/postprocessing/config.pbtxt*
225226

all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,3 +294,9 @@ parameters: {
294294
string_value: "${enable_kv_cache_reuse}"
295295
}
296296
}
297+
parameters: {
298+
key: "normalize_log_probs"
299+
value: {
300+
string_value: "${normalize_log_probs}"
301+
}
302+
}

inflight_batcher_llm/client/README.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Sample TRT-LLM backend clients
2+
Three sample TRT-LLM Triton clients are provided with the TRT-LLM Triton backend implementation.
3+
* `e2e_grpc_speculative_decoding_client.py`: Demonstrates how to orchestrate between two independent TRT-LLM models - a draft model and a target model to achiever faster inferencing using speculative decoding. The high level design involves the client making a call to the draft model requesting a certain number of draft tokens, and then associating those draft tokens with a request to the target model. The target model returns some number of completion tokens internally leveraging the draft tokens to speed up inference. The client wraps these back-to-back calls to draft and target models in a loop to complete the full generation.
4+
Example command:
5+
```
6+
python3 speculative_decoding_test.py --max-input-len 200 \
7+
--dataset ${LOGDIR}/prompts.csv \
8+
--url-draft ${DRAFT_MODEL_URL} \
9+
--url-target ${TARGET_MODEL_URL}
10+
```
11+
12+
* `end_to_end_grpc_client.py`: Demonstrates sending a single request to a tritonserver running an ensemble including preprocessor (tokenizer), TRT-LLM model and postprocessor (detokenizer) and getting back a completion from it.
13+
Example command:
14+
```
15+
python3 end_to_end_grpc_client.py \
16+
--streaming --output-len 10 \
17+
--prompt "The only thing we have to fear is"
18+
19+
```
20+
* `inflight_batcher_llm_client.py`: Isolates queries and responses to the TRT-LLM model alone. Invokes tokenizer and detokenizer in the client script i.e. outside the server running inference.
21+
Example command:
22+
```
23+
python3 inflight_batcher_llm_client.py \
24+
--tokenizer-dir ${TOKENIZER_PATH} \
25+
--tokenizer-type ${TOKENIZER_TYPE} \
26+
--input-tokens-csv=${LOGDIR}/prompts.csv \
27+
--output-tokens-csv=${LOGDIR}/completions.csv
28+
```

inflight_batcher_llm/src/model_instance_state.cc

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727

2828
#include "model_instance_state.h"
2929

30+
#include "tensorrt_llm/common/mpiUtils.h"
31+
32+
namespace mpi = tensorrt_llm::mpi;
33+
3034
namespace triton::backend::inflight_batcher_llm
3135
{
3236

@@ -181,6 +185,17 @@ ModelInstanceState::ModelInstanceState(ModelState* model_state, TRITONBACKEND_Mo
181185
TLLM_LOG_WARNING("enable_trt_overlap is not specified, will be set to true");
182186
}
183187

188+
bool normalizeLogProbs = true;
189+
try
190+
{
191+
normalizeLogProbs = model_state_->GetParameter<bool>("normalize_log_probs");
192+
}
193+
catch (const std::exception& e)
194+
{
195+
// If parameter is not specified, just ignore
196+
TLLM_LOG_WARNING("normalize_log_probs is not specified, will be set to true");
197+
}
198+
184199
bool excludeInputInOutput = false;
185200
try
186201
{
@@ -223,6 +238,7 @@ ModelInstanceState::ModelInstanceState(ModelState* model_state, TRITONBACKEND_Mo
223238
optionalParams.kvCacheConfig.maxAttentionWindow = maxAttentionWindow;
224239
optionalParams.kvCacheConfig.enableBlockReuse = enableKVCacheReuse;
225240
optionalParams.enableTrtOverlap = enableTrtOverlap;
241+
optionalParams.normalizeLogProbs = normalizeLogProbs;
226242

227243
mBatchManager = std::make_shared<GptManager>(
228244
mModelPath, mTrtGptModelType, maxBeamWidth, schedulerPolicy,
@@ -232,7 +248,7 @@ ModelInstanceState::ModelInstanceState(ModelState* model_state, TRITONBACKEND_Mo
232248
[this]() { return pollStopSignals(); }, [this](const std::string& s) { return logStats(s); }, optionalParams,
233249
std::nullopt, std::nullopt, excludeInputInOutput);
234250

235-
if (getCommWorldRank() != 0)
251+
if (COMM_SESSION.getRank() != 0)
236252
{
237253
while (true)
238254
{
@@ -272,7 +288,7 @@ void ModelInstanceState::enqueue(TRITONBACKEND_Request** requests, const uint32_
272288
TRITONBACKEND_Request* request = requests[r];
273289
try
274290
{
275-
auto requestId = utils::getRequestId(request);
291+
auto requestId = utils::getRequestId(request, mRequestIdStrMap);
276292
bool stopRequest = utils::getRequestBooleanInputTensor(request, kStopInputTensorName);
277293

278294
if (stopRequest)
@@ -326,8 +342,10 @@ std::list<std::shared_ptr<InferenceRequest>> ModelInstanceState::get_inference_r
326342
return rval;
327343
}
328344

329-
auto world_size = getCommWorldSize();
330-
auto rank = getCommWorldRank();
345+
auto const& commSession = COMM_SESSION;
346+
347+
auto world_size = commSession.getSize();
348+
auto rank = commSession.getRank();
331349
if (rank == 0)
332350
{
333351
auto numPendingWorkItems = mWorkItemsQueue->numPendingWorkItems();
@@ -355,7 +373,7 @@ std::list<std::shared_ptr<InferenceRequest>> ModelInstanceState::get_inference_r
355373
if (world_size > 1)
356374
{
357375
int64_t num_new_work_items = rval.size();
358-
bcast(&num_new_work_items, 1, MPI_TYPE_INT64_T, 0, COMM_WORLD);
376+
commSession.bcast(num_new_work_items, 0);
359377

360378
if (num_new_work_items > 0)
361379
{
@@ -366,19 +384,19 @@ std::list<std::shared_ptr<InferenceRequest>> ModelInstanceState::get_inference_r
366384
packed.push_back(static_cast<int64_t>(vpacked.size()));
367385
packed.insert(packed.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end()));
368386
}
369-
bcast(packed, 0, COMM_WORLD);
387+
commSession.bcast(packed, 0);
370388
}
371389
}
372390
}
373391
else
374392
{
375393
// subordinate ranks hang until master rank sends work
376394
int64_t num_new_work_items;
377-
bcast(&num_new_work_items, 1, MPI_TYPE_INT64_T, 0, COMM_WORLD);
395+
commSession.bcast(num_new_work_items, 0);
378396
if (num_new_work_items > 0)
379397
{
380398
std::vector<int64_t> packed;
381-
bcast(packed, 0, COMM_WORLD);
399+
commSession.bcast(packed, 0);
382400
int64_t* packed_ptr = packed.data();
383401
for (int64_t count = 0; count < num_new_work_items; ++count)
384402
{
@@ -395,9 +413,14 @@ std::list<std::shared_ptr<InferenceRequest>> ModelInstanceState::get_inference_r
395413
void ModelInstanceState::sendResponse(
396414
uint64_t requestId, std::list<NamedTensor> const& response_tensors, bool final_response, const std::string& errMsg)
397415
{
398-
if (getCommWorldRank() == 0)
416+
if (COMM_SESSION.getRank() == 0)
399417
{
400-
std::string errStr = std::string("Failed to send Triton response for requestId: ") + std::to_string(requestId);
418+
std::string errStr = std::string("Failed to send Triton response for requestId: ")
419+
+ utils::getRequestIdStr(requestId, mRequestIdStrMap);
420+
if (final_response)
421+
{
422+
mRequestIdStrMap.erase(requestId);
423+
}
401424
try
402425
{
403426
auto workItem = mWorkItemsQueue->getInProgressWorkItem(requestId);
@@ -421,31 +444,34 @@ std::unordered_set<uint64_t> ModelInstanceState::pollStopSignals()
421444

422445
int64_t nStoppedReqIds = static_cast<int64_t>(stoppedReqIds.size());
423446

424-
if (getCommWorldSize() > 1)
447+
auto const& commSession = COMM_SESSION;
448+
449+
if (commSession.getSize() > 1)
425450
{
426451
// Broadcast number of stopped requests
427-
bcast(&nStoppedReqIds, 1, MPI_TYPE_INT64_T, 0, COMM_WORLD);
452+
commSession.bcast(nStoppedReqIds, 0);
428453

429454
if (nStoppedReqIds > 0)
430455
{
431456
// Broadcast stopped requests Ids
432-
if (getCommWorldRank() == 0)
457+
if (commSession.getRank() == 0)
433458
{
434459
// Store the requestIds in a contiguous vector
435460
std::vector<uint64_t> stoppedReqIdsVec(stoppedReqIds.begin(), stoppedReqIds.end());
436-
bcast(stoppedReqIdsVec.data(), stoppedReqIdsVec.size(), MPI_TYPE_UINT64_T, 0, COMM_WORLD);
461+
commSession.bcast(stoppedReqIdsVec.data(), stoppedReqIdsVec.size(), mpi::MpiType::kUINT64, 0);
437462
}
438463
else
439464
{
440465
std::vector<uint64_t> stoppedReqIdsVec(nStoppedReqIds);
441-
bcast(stoppedReqIdsVec.data(), stoppedReqIdsVec.size(), MPI_TYPE_UINT64_T, 0, COMM_WORLD);
466+
commSession.bcast(stoppedReqIdsVec.data(), stoppedReqIdsVec.size(), mpi::MpiType::kUINT64, 0);
442467
// Store the requestIds in the set
443468
stoppedReqIds.clear();
444469
std::copy(stoppedReqIdsVec.begin(), stoppedReqIdsVec.end(),
445470
std::inserter(stoppedReqIds, stoppedReqIds.end()));
446471
}
447472
}
448473
}
474+
449475
return stoppedReqIds;
450476
}
451477

inflight_batcher_llm/src/model_instance_state.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#define _GLIBCXX_USE_CXX11_ABI 0
2929

3030
#include <nlohmann/json.hpp>
31+
#include <unordered_map>
3132

3233
#include "triton/backend/backend_common.h"
3334
#include "triton/core/tritonbackend.h"
@@ -40,15 +41,13 @@
4041
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
4142
#include "tensorrt_llm/batch_manager/namedTensor.h"
4243
#include "tensorrt_llm/batch_manager/trtGptModelOptionalParams.h"
43-
#include "tensorrt_llm/common/mpiUtils.h"
4444

4545
#include "model_state.h"
4646
#include "work_item.h"
4747
#include "work_items_queue.h"
4848

4949
using namespace tensorrt_llm::batch_manager;
5050
using namespace tensorrt_llm::batch_manager::batch_scheduler;
51-
using namespace tensorrt_llm::mpi;
5251

5352
namespace triton::backend::inflight_batcher_llm
5453
{
@@ -126,6 +125,8 @@ class ModelInstanceState
126125

127126
std::shared_ptr<GptManager> mBatchManager;
128127
std::unique_ptr<WorkItemsQueue> mWorkItemsQueue;
128+
129+
std::unordered_map<uint64_t, std::string> mRequestIdStrMap;
129130
};
130131

131132
} // namespace triton::backend::inflight_batcher_llm

inflight_batcher_llm/src/utils.cc

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ TRITONSERVER_DataType to_triton_datatype(nvinfer1::DataType data_type)
139139
}
140140
}
141141

142-
uint64_t getRequestId(TRITONBACKEND_Request* request)
142+
uint64_t getRequestId(TRITONBACKEND_Request* request, std::unordered_map<uint64_t, std::string>& requestIdStrMap)
143143
{
144144
const char* charRequestId;
145145
TRITONBACKEND_RequestId(request, &charRequestId);
@@ -155,15 +155,34 @@ uint64_t getRequestId(TRITONBACKEND_Request* request)
155155
}
156156
catch (const std::exception& e)
157157
{
158-
std::string err = std::string("Invalid requestId, must be uint64_t. Got ") + strRequestId;
159-
throw std::runtime_error(err);
158+
std::hash<std::string> hasher;
159+
requestId = hasher(strRequestId);
160+
161+
// Check for hash collisions
162+
// If requestID already exists in the map with the same string, increment the ID and check again
163+
for (auto it = requestIdStrMap.find(requestId);
164+
it != requestIdStrMap.end() && it->second != strRequestId;)
165+
{
166+
requestId++;
167+
}
160168
}
169+
requestIdStrMap.insert({requestId, strRequestId});
161170
}
162171
}
163172

164173
return requestId;
165174
}
166175

176+
std::string getRequestIdStr(uint64_t requestId, std::unordered_map<uint64_t, std::string> const& requestIdStrMap)
177+
{
178+
auto it = requestIdStrMap.find(requestId);
179+
if (it != requestIdStrMap.end())
180+
{
181+
return it->second;
182+
}
183+
return std::to_string(requestId);
184+
}
185+
167186
std::unordered_set<std::string> getRequestOutputNames(TRITONBACKEND_Request* request)
168187
{
169188
std::unordered_set<std::string> outputNames;

inflight_batcher_llm/src/utils.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,13 @@ nvinfer1::DataType to_trt_datatype(TRITONSERVER_DataType data_type);
5050
/// @brief Convert TRT datatype to Triton datatype
5151
TRITONSERVER_DataType to_triton_datatype(nvinfer1::DataType data_type);
5252

53-
/// @brief get the requestId of the request
53+
/// @brief get the requestId of the request and update requestIdStrMap
5454
/// @return Returns 0 if not specified. Throws an error if request_id cannot be convert to uint64_t
55-
uint64_t getRequestId(TRITONBACKEND_Request* request);
55+
uint64_t getRequestId(TRITONBACKEND_Request* request, std::unordered_map<uint64_t, std::string>& requestIdStrMap);
56+
57+
/// @brief get the original requestId string from the uint64_t requestId
58+
/// @return If uint64_t id is not present in requestIdStrMap, returns std::to_string(requestId)
59+
std::string getRequestIdStr(uint64_t requestId, std::unordered_map<uint64_t, std::string> const& requestIdStrMap);
5660

5761
/// @brief Get the requested output names
5862
std::unordered_set<std::string> getRequestOutputNames(TRITONBACKEND_Request* request);

tensorrt_llm

Submodule tensorrt_llm updated 211 files

tools/version.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
e880735c6b44a0cf74ae1d37d23d529c86deb65d

0 commit comments

Comments
 (0)