Skip to content

Commit

Permalink
Update TensorRT-LLM backend (triton-inference-server#380)
Browse files Browse the repository at this point in the history
* Update TensorRT-LLM backend
  • Loading branch information
kaiyux authored Mar 19, 2024
1 parent 8d6748c commit da59830
Show file tree
Hide file tree
Showing 33 changed files with 1,561 additions and 330 deletions.
1 change: 1 addition & 0 deletions .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 60
PointerAlignment: Left
QualifierAlignment: Right
ReflowComments: true
SeparateDefinitionBlocks: Always
SortIncludes: CaseSensitive
Expand Down
38 changes: 23 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,21 +159,19 @@ cd tensorrt_llm/examples/gpt
rm -rf gpt2 && git clone https://huggingface.co/gpt2-medium gpt2
pushd gpt2 && rm pytorch_model.bin model.safetensors && wget -q https://huggingface.co/gpt2-medium/resolve/main/pytorch_model.bin && popd

# Convert weights from HF Tranformers to FT format
python3 hf_gpt_convert.py -p 8 -i gpt2 -o ./c-model/gpt2 --tensor-parallelism 4 --storage-type float16
# Convert weights from HF Tranformers to TensorRT-LLM checkpoint
python3 convert_checkpoint.py --model_dir gpt2 \
--dtype float16 \
--tp_size 4 \
--output_dir ./c-model/gpt2/fp16/4-gpu

# Build TensorRT engines
python3 build.py --model_dir=./c-model/gpt2/4-gpu/ \
--world_size=4 \
--dtype float16 \
--use_inflight_batching \
--use_gpt_attention_plugin float16 \
--paged_kv_cache \
--use_gemm_plugin float16 \
--remove_input_padding \
--hidden_act gelu \
--parallel_build \
--output_dir=engines/fp16/4-gpu
trtllm-build --checkpoint_dir ./c-model/gpt2/fp16/4-gpu \
--gpt_attention_plugin float16 \
--remove_input_padding enable \
--paged_kv_cache enable \
--gemm_plugin float16 \
--output_dir engines/fp16/4-gpu
```

### Create the model repository
Expand All @@ -191,7 +189,7 @@ and postprocessing models together.
- "tensorrt_llm_bls": This model can also be used to chain the preprocessing,
tensorrt_llm and postprocessing models together. The BLS model has an optional
parameter `accumulate_tokens` which can be used in streaming mode to call the
preprocessing model with all accumulated tokens, instead of only one token.
postprocessing model with all accumulated tokens, instead of only one token.
This might be necessary for certain tokenizers.

To learn more about ensemble and BLS models, please see the
Expand Down Expand Up @@ -236,6 +234,7 @@ The following table shows the fields that may to be modified before deployment:
| `exclude_input_in_output` | Optional (default=`false`). Set to `true` to only return completion tokens in a response. Set to `false` to return the prompt tokens concatenated with the generated tokens |
| `normalize_log_probs` | Optional (default=`true`). Set to `false` to skip normalization of `output_log_probs` |
| `enable_chunked_context` | Optional (default=`false`). Set to `true` to enable context chunking. |
| `gpu_device_ids` | Optional (default=unspecified). Comma-separated list of GPU IDs to use for this model. If not provided, the model will use all visible GPUs. |
| `decoding_mode` | Optional. Set to one of the following: `{top_k, top_p, top_k_top_p, beam_search}` to select the decoding mode. The `top_k` mode exclusively uses Top-K algorithm for sampling, The `top_p` mode uses exclusively Top-P algorithm for sampling. The top_k_top_p mode employs both Top-K and Top-P algorithms, depending on the runtime sampling params of the request. Note that the `top_k_top_p option` requires more memory and has a longer runtime than using `top_k` or `top_p` individually; therefore, it should be used only when necessary. `beam_search` uses beam search algorithm. If not specified, the default is to use `top_k_top_p` if `max_beam_width == 1`; otherwise, `beam_search` is used. |

*triton_model_repo/postprocessing/config.pbtxt*
Expand Down Expand Up @@ -275,6 +274,15 @@ cd /tensorrtllm_backend
python3 scripts/launch_triton_server.py --world_size=4 --model_repo=/tensorrtllm_backend/triton_model_repo
```

In order to use multiple TensorRT-LLM models, use the `--multi-model` option. The `--world_size` must be 1 as the TensorRT-LLM backend will dynamically launch TensorRT-LLM workers as needed.

```bash
cd /tensorrtllm_backend
python3 scripts/launch_triton_server.py --model_repo=/tensorrtllm_backend/triton_model_repo --multi-model
```

When using the `--multi-model` option, the Triton model repository can contain multiple TensorRT-LLM models. When running multiple TensorRT-LLM models, the `gpu_device_ids` parameter should be specified in the models `config.pbtxt` configuration files. It is up to you to ensure there is no overlap between allocated GPU IDs.

When successfully deployed, the server produces logs similar to the following ones.
```
I0919 14:52:10.475738 293 grpc_server.cc:2451] Started GRPCInferenceService at 0.0.0.0:8001
Expand Down Expand Up @@ -373,7 +381,7 @@ number of generated tokens is lower than 200. You can have a look at the
client code to see how early stopping is achieved.

#### Return context logits and/or generation logits
If you want to get context logits and/or generation logits, you need to enable `--gather_context_logits` and/or `--gather_generation_logits` when building the engine (or `--enable gather_all_token_logits` to enable both at the same time). For more setting details about these two flags, please refer to [build.py](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/gpt/build.py) or [gpt_runtime](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/gpt_runtime.md).
If you want to get context logits and/or generation logits, you need to enable `--gather_context_logits` and/or `--gather_generation_logits` when building the engine (or `--gather_all_token_logits` to enable both at the same time). For more setting details about these two flags, please refer to [build.py](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/commands/build.py) or [gpt_runtime](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/gpt_runtime.md).

After launching the server, you could get the output of logits by passing the corresponding parameters `--return-context-logits` and/or `--return-generation-logits` in the client scripts ([end_to_end_grpc_client.py](./inflight_batcher_llm/client/end_to_end_grpc_client.py) and [inflight_batcher_llm_client.py](./inflight_batcher_llm/client/inflight_batcher_llm_client.py)). For example:
```bash
Expand Down
41 changes: 41 additions & 0 deletions all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,17 @@ input [
reshape: { shape: [ ] }
optional: true
},
# the unique task ID for the given LoRA.
# To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given.
# The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`.
# If the cache is full the oldest LoRA will be evicted to make space for new ones. An error is returned if `lora_task_id` is not cached.
{
name: "lora_task_id"
data_type: TYPE_UINT64
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
# weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ]
# where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer
# each of the in / out tensors are first flattened and then concatenated together in the format above.
Expand Down Expand Up @@ -368,9 +379,39 @@ parameters: {
string_value: "${gpu_device_ids}"
}
}
parameters: {
key: "lora_cache_optimal_adapter_size"
value: {
string_value: "${lora_cache_optimal_adapter_size}"
}
}
parameters: {
key: "lora_cache_max_adapter_size"
value: {
string_value: "${lora_cache_max_adapter_size}"
}
}
parameters: {
key: "lora_cache_gpu_memory_fraction"
value: {
string_value: "${lora_cache_gpu_memory_fraction}"
}
}
parameters: {
key: "lora_cache_host_memory_bytes"
value: {
string_value: "${lora_cache_host_memory_bytes}"
}
}
parameters: {
key: "decoding_mode"
value: {
string_value: "${decoding_mode}"
}
}
parameters: {
key: "worker_path"
value: {
string_value: "/opt/tritonserver/backends/tensorrtllm/triton_tensorrtllm_worker"
}
}
20 changes: 8 additions & 12 deletions ci/L0_backend_trtllm/generate_engines.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function build_base_model {
cd ${GPT_DIR}
rm -rf gpt2 && git clone https://huggingface.co/gpt2-medium gpt2
pushd gpt2 && rm pytorch_model.bin model.safetensors && wget -q https://huggingface.co/gpt2-medium/resolve/main/pytorch_model.bin && popd
python3 hf_gpt_convert.py -p 8 -i gpt2 -o ./c-model/gpt2 --tensor-parallelism ${NUM_GPUS} --storage-type float16
python3 convert_checkpoint.py --model_dir gpt2 --dtype float16 --tp_size ${NUM_GPUS} --output_dir ./c-model/gpt2/${NUM_GPUS}-gpu/
cd ${BASE_DIR}
}

Expand All @@ -45,17 +45,13 @@ function build_tensorrt_engine_inflight_batcher {
local OUTPUT_DIR=inflight_${NUM_GPUS}_gpu/
# ./c-model/gpt2/ must already exist (it will if build_base_model
# has already been run)
python3 build.py --model_dir="${GPT_MODEL_DIR}" \
--world_size="${NUM_GPUS}" \
--dtype float16 \
--use_inflight_batching \
--use_gpt_attention_plugin float16 \
--paged_kv_cache \
--use_gemm_plugin float16 \
--remove_input_padding \
--hidden_act gelu \
--parallel_build \
--output_dir="${OUTPUT_DIR}"
trtllm-build --checkpoint_dir "${GPT_MODEL_DIR}" \
--gpt_attention_plugin float16 \
--remove_input_padding enable \
--paged_kv_cache enable \
--gemm_plugin float16 \
--workers "${NUM_GPUS}" \
--output_dir "${OUTPUT_DIR}"
cd ${BASE_DIR}
}

Expand Down
30 changes: 30 additions & 0 deletions ci/L0_backend_trtllm/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,36 @@ for NUM_GPU in "${NUM_GPUS_TO_TEST[@]}"; do
kill_server
wait_for_server_terminated ${SERVER_PID[@]}

# Multi-model
SERVER_LOG="./${NUM_GPU}gpu_multi_model.log"
run_server "${SERVER_ARGS} --multi-model"
wait_for_server_ready ${SERVER_TIMEOUT} ${SERVER_PID[@]}
if [ "$WAIT_RET" != "0" ]; then
# Cleanup
kill $SERVER_PID > /dev/null 2>&1 || true
echo -e "\n***\n*** Failed to start $SERVER\n***"
cat $SERVER_LOG
exit 1
fi
set -e

set -e
python3 ${TOOLS_DIR}/inflight_batcher_llm/end_to_end_test.py \
--max-input-len=500 \
--dataset=${DATASET}

if [ $? -ne 0 ]; then
cat $SERVER_LOG
echo -e "\n***\n*** Error executing inflight batching end-to-end test with ${NUM_GPU}GPU(s): line ${LINENO}\n***"
kill_server
wait_for_server_terminated ${SERVER_PID[@]}
RET=1
fi
set +e

curl localhost:8002/metrics -o ${NUM_GPU}gpu_IFB_no_stream_metrics.out
kill_server
wait_for_server_terminated ${SERVER_PID[@]}

done

Expand Down
4 changes: 2 additions & 2 deletions ci/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ instructions in [Option 3 Build via CMake](../README.md#option-3-build-via-cmake
Run the testing within the Triton container.

```bash
docker run --rm -it --net host --shm-size=2g --ulimit memlock=-1 --ulimit stack=67108864 --gpus all -v /path/to/tensorrtllm_backend:/tensorrtllm_backend nvcr.io/nvidia/tritonserver:23.10-trtllm-python-py3 bash
docker run --rm -it --net host --shm-size=2g --ulimit memlock=-1 --ulimit stack=67108864 --gpus all -v /path/to/tensorrtllm_backend:/opt/tritonserver/tensorrtllm_backend nvcr.io/nvidia/tritonserver:23.10-trtllm-python-py3 bash

# Change directory to the test and run the test.sh script
cd /tensorrtllm_backend/ci/<test directory>
cd /opt/tritonserver/tensorrtllm_backend/ci/<test directory>
bash -x ./test.sh
```

Expand Down
5 changes: 4 additions & 1 deletion dockerfile/Dockerfile.trt_llm_backend
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ARG BASE_IMAGE=nvcr.io/nvidia/tritonserver
ARG BASE_TAG=24.01-py3
ARG BASE_TAG=24.02-py3

FROM ${BASE_IMAGE}:${BASE_TAG} as base

Expand Down Expand Up @@ -61,4 +61,7 @@ RUN cd /app/tensorrt_llm/build && pip3 install *.whl

# Install TensorRT-LLM backend
RUN mkdir /opt/tritonserver/backends/tensorrtllm
ENV LD_LIBRARY_PATH=/opt/tritonserver/backends/tensorrtllm:${LD_LIBRARY_PATH}
COPY --from=trt_llm_backend_builder /app/inflight_batcher_llm/build/libtriton_tensorrtllm.so /opt/tritonserver/backends/tensorrtllm
COPY --from=trt_llm_backend_builder /app/inflight_batcher_llm/build/libtriton_tensorrtllm_common.so /opt/tritonserver/backends/tensorrtllm
COPY --from=trt_llm_backend_builder /app/inflight_batcher_llm/build/triton_tensorrtllm_worker /opt/tritonserver/backends/tensorrtllm
Loading

0 comments on commit da59830

Please sign in to comment.