diff --git a/.clang-format b/.clang-format index 1983a9ca..12bb2f11 100644 --- a/.clang-format +++ b/.clang-format @@ -59,6 +59,7 @@ PenaltyBreakString: 1000 PenaltyExcessCharacter: 1000000 PenaltyReturnTypeOnItsOwnLine: 60 PointerAlignment: Left +QualifierAlignment: Right ReflowComments: true SeparateDefinitionBlocks: Always SortIncludes: CaseSensitive diff --git a/README.md b/README.md index d52ec0fa..6b49f7d4 100644 --- a/README.md +++ b/README.md @@ -159,21 +159,19 @@ cd tensorrt_llm/examples/gpt rm -rf gpt2 && git clone https://huggingface.co/gpt2-medium gpt2 pushd gpt2 && rm pytorch_model.bin model.safetensors && wget -q https://huggingface.co/gpt2-medium/resolve/main/pytorch_model.bin && popd -# Convert weights from HF Tranformers to FT format -python3 hf_gpt_convert.py -p 8 -i gpt2 -o ./c-model/gpt2 --tensor-parallelism 4 --storage-type float16 +# Convert weights from HF Tranformers to TensorRT-LLM checkpoint +python3 convert_checkpoint.py --model_dir gpt2 \ + --dtype float16 \ + --tp_size 4 \ + --output_dir ./c-model/gpt2/fp16/4-gpu # Build TensorRT engines -python3 build.py --model_dir=./c-model/gpt2/4-gpu/ \ - --world_size=4 \ - --dtype float16 \ - --use_inflight_batching \ - --use_gpt_attention_plugin float16 \ - --paged_kv_cache \ - --use_gemm_plugin float16 \ - --remove_input_padding \ - --hidden_act gelu \ - --parallel_build \ - --output_dir=engines/fp16/4-gpu +trtllm-build --checkpoint_dir ./c-model/gpt2/fp16/4-gpu \ + --gpt_attention_plugin float16 \ + --remove_input_padding enable \ + --paged_kv_cache enable \ + --gemm_plugin float16 \ + --output_dir engines/fp16/4-gpu ``` ### Create the model repository @@ -191,7 +189,7 @@ and postprocessing models together. - "tensorrt_llm_bls": This model can also be used to chain the preprocessing, tensorrt_llm and postprocessing models together. The BLS model has an optional parameter `accumulate_tokens` which can be used in streaming mode to call the -preprocessing model with all accumulated tokens, instead of only one token. +postprocessing model with all accumulated tokens, instead of only one token. This might be necessary for certain tokenizers. To learn more about ensemble and BLS models, please see the @@ -236,6 +234,7 @@ The following table shows the fields that may to be modified before deployment: | `exclude_input_in_output` | Optional (default=`false`). Set to `true` to only return completion tokens in a response. Set to `false` to return the prompt tokens concatenated with the generated tokens | | `normalize_log_probs` | Optional (default=`true`). Set to `false` to skip normalization of `output_log_probs` | | `enable_chunked_context` | Optional (default=`false`). Set to `true` to enable context chunking. | +| `gpu_device_ids` | Optional (default=unspecified). Comma-separated list of GPU IDs to use for this model. If not provided, the model will use all visible GPUs. | | `decoding_mode` | Optional. Set to one of the following: `{top_k, top_p, top_k_top_p, beam_search}` to select the decoding mode. The `top_k` mode exclusively uses Top-K algorithm for sampling, The `top_p` mode uses exclusively Top-P algorithm for sampling. The top_k_top_p mode employs both Top-K and Top-P algorithms, depending on the runtime sampling params of the request. Note that the `top_k_top_p option` requires more memory and has a longer runtime than using `top_k` or `top_p` individually; therefore, it should be used only when necessary. `beam_search` uses beam search algorithm. If not specified, the default is to use `top_k_top_p` if `max_beam_width == 1`; otherwise, `beam_search` is used. | *triton_model_repo/postprocessing/config.pbtxt* @@ -275,6 +274,15 @@ cd /tensorrtllm_backend python3 scripts/launch_triton_server.py --world_size=4 --model_repo=/tensorrtllm_backend/triton_model_repo ``` +In order to use multiple TensorRT-LLM models, use the `--multi-model` option. The `--world_size` must be 1 as the TensorRT-LLM backend will dynamically launch TensorRT-LLM workers as needed. + +```bash +cd /tensorrtllm_backend +python3 scripts/launch_triton_server.py --model_repo=/tensorrtllm_backend/triton_model_repo --multi-model +``` + +When using the `--multi-model` option, the Triton model repository can contain multiple TensorRT-LLM models. When running multiple TensorRT-LLM models, the `gpu_device_ids` parameter should be specified in the models `config.pbtxt` configuration files. It is up to you to ensure there is no overlap between allocated GPU IDs. + When successfully deployed, the server produces logs similar to the following ones. ``` I0919 14:52:10.475738 293 grpc_server.cc:2451] Started GRPCInferenceService at 0.0.0.0:8001 @@ -373,7 +381,7 @@ number of generated tokens is lower than 200. You can have a look at the client code to see how early stopping is achieved. #### Return context logits and/or generation logits -If you want to get context logits and/or generation logits, you need to enable `--gather_context_logits` and/or `--gather_generation_logits` when building the engine (or `--enable gather_all_token_logits` to enable both at the same time). For more setting details about these two flags, please refer to [build.py](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/gpt/build.py) or [gpt_runtime](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/gpt_runtime.md). +If you want to get context logits and/or generation logits, you need to enable `--gather_context_logits` and/or `--gather_generation_logits` when building the engine (or `--gather_all_token_logits` to enable both at the same time). For more setting details about these two flags, please refer to [build.py](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/commands/build.py) or [gpt_runtime](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/gpt_runtime.md). After launching the server, you could get the output of logits by passing the corresponding parameters `--return-context-logits` and/or `--return-generation-logits` in the client scripts ([end_to_end_grpc_client.py](./inflight_batcher_llm/client/end_to_end_grpc_client.py) and [inflight_batcher_llm_client.py](./inflight_batcher_llm/client/inflight_batcher_llm_client.py)). For example: ```bash diff --git a/all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt b/all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt index 4e60a6fc..62caa47c 100644 --- a/all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt +++ b/all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt @@ -214,6 +214,17 @@ input [ reshape: { shape: [ ] } optional: true }, + # the unique task ID for the given LoRA. + # To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given. + # The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`. + # If the cache is full the oldest LoRA will be evicted to make space for new ones. An error is returned if `lora_task_id` is not cached. + { + name: "lora_task_id" + data_type: TYPE_UINT64 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, # weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ] # where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer # each of the in / out tensors are first flattened and then concatenated together in the format above. @@ -368,9 +379,39 @@ parameters: { string_value: "${gpu_device_ids}" } } +parameters: { + key: "lora_cache_optimal_adapter_size" + value: { + string_value: "${lora_cache_optimal_adapter_size}" + } +} +parameters: { + key: "lora_cache_max_adapter_size" + value: { + string_value: "${lora_cache_max_adapter_size}" + } +} +parameters: { + key: "lora_cache_gpu_memory_fraction" + value: { + string_value: "${lora_cache_gpu_memory_fraction}" + } +} +parameters: { + key: "lora_cache_host_memory_bytes" + value: { + string_value: "${lora_cache_host_memory_bytes}" + } +} parameters: { key: "decoding_mode" value: { string_value: "${decoding_mode}" } } +parameters: { + key: "worker_path" + value: { + string_value: "/opt/tritonserver/backends/tensorrtllm/triton_tensorrtllm_worker" + } +} diff --git a/ci/L0_backend_trtllm/generate_engines.sh b/ci/L0_backend_trtllm/generate_engines.sh index baf31965..8527f95e 100644 --- a/ci/L0_backend_trtllm/generate_engines.sh +++ b/ci/L0_backend_trtllm/generate_engines.sh @@ -34,7 +34,7 @@ function build_base_model { cd ${GPT_DIR} rm -rf gpt2 && git clone https://huggingface.co/gpt2-medium gpt2 pushd gpt2 && rm pytorch_model.bin model.safetensors && wget -q https://huggingface.co/gpt2-medium/resolve/main/pytorch_model.bin && popd - python3 hf_gpt_convert.py -p 8 -i gpt2 -o ./c-model/gpt2 --tensor-parallelism ${NUM_GPUS} --storage-type float16 + python3 convert_checkpoint.py --model_dir gpt2 --dtype float16 --tp_size ${NUM_GPUS} --output_dir ./c-model/gpt2/${NUM_GPUS}-gpu/ cd ${BASE_DIR} } @@ -45,17 +45,13 @@ function build_tensorrt_engine_inflight_batcher { local OUTPUT_DIR=inflight_${NUM_GPUS}_gpu/ # ./c-model/gpt2/ must already exist (it will if build_base_model # has already been run) - python3 build.py --model_dir="${GPT_MODEL_DIR}" \ - --world_size="${NUM_GPUS}" \ - --dtype float16 \ - --use_inflight_batching \ - --use_gpt_attention_plugin float16 \ - --paged_kv_cache \ - --use_gemm_plugin float16 \ - --remove_input_padding \ - --hidden_act gelu \ - --parallel_build \ - --output_dir="${OUTPUT_DIR}" + trtllm-build --checkpoint_dir "${GPT_MODEL_DIR}" \ + --gpt_attention_plugin float16 \ + --remove_input_padding enable \ + --paged_kv_cache enable \ + --gemm_plugin float16 \ + --workers "${NUM_GPUS}" \ + --output_dir "${OUTPUT_DIR}" cd ${BASE_DIR} } diff --git a/ci/L0_backend_trtllm/test.sh b/ci/L0_backend_trtllm/test.sh index 5a3be53d..3ed1e8be 100644 --- a/ci/L0_backend_trtllm/test.sh +++ b/ci/L0_backend_trtllm/test.sh @@ -338,6 +338,36 @@ for NUM_GPU in "${NUM_GPUS_TO_TEST[@]}"; do kill_server wait_for_server_terminated ${SERVER_PID[@]} + # Multi-model + SERVER_LOG="./${NUM_GPU}gpu_multi_model.log" + run_server "${SERVER_ARGS} --multi-model" + wait_for_server_ready ${SERVER_TIMEOUT} ${SERVER_PID[@]} + if [ "$WAIT_RET" != "0" ]; then + # Cleanup + kill $SERVER_PID > /dev/null 2>&1 || true + echo -e "\n***\n*** Failed to start $SERVER\n***" + cat $SERVER_LOG + exit 1 + fi + set -e + + set -e + python3 ${TOOLS_DIR}/inflight_batcher_llm/end_to_end_test.py \ + --max-input-len=500 \ + --dataset=${DATASET} + + if [ $? -ne 0 ]; then + cat $SERVER_LOG + echo -e "\n***\n*** Error executing inflight batching end-to-end test with ${NUM_GPU}GPU(s): line ${LINENO}\n***" + kill_server + wait_for_server_terminated ${SERVER_PID[@]} + RET=1 + fi + set +e + + curl localhost:8002/metrics -o ${NUM_GPU}gpu_IFB_no_stream_metrics.out + kill_server + wait_for_server_terminated ${SERVER_PID[@]} done diff --git a/ci/README.md b/ci/README.md index aec1595d..8a20dfb0 100644 --- a/ci/README.md +++ b/ci/README.md @@ -40,10 +40,10 @@ instructions in [Option 3 Build via CMake](../README.md#option-3-build-via-cmake Run the testing within the Triton container. ```bash -docker run --rm -it --net host --shm-size=2g --ulimit memlock=-1 --ulimit stack=67108864 --gpus all -v /path/to/tensorrtllm_backend:/tensorrtllm_backend nvcr.io/nvidia/tritonserver:23.10-trtllm-python-py3 bash +docker run --rm -it --net host --shm-size=2g --ulimit memlock=-1 --ulimit stack=67108864 --gpus all -v /path/to/tensorrtllm_backend:/opt/tritonserver/tensorrtllm_backend nvcr.io/nvidia/tritonserver:23.10-trtllm-python-py3 bash # Change directory to the test and run the test.sh script -cd /tensorrtllm_backend/ci/ +cd /opt/tritonserver/tensorrtllm_backend/ci/ bash -x ./test.sh ``` diff --git a/dockerfile/Dockerfile.trt_llm_backend b/dockerfile/Dockerfile.trt_llm_backend index 552689ad..4f707778 100644 --- a/dockerfile/Dockerfile.trt_llm_backend +++ b/dockerfile/Dockerfile.trt_llm_backend @@ -1,5 +1,5 @@ ARG BASE_IMAGE=nvcr.io/nvidia/tritonserver -ARG BASE_TAG=24.01-py3 +ARG BASE_TAG=24.02-py3 FROM ${BASE_IMAGE}:${BASE_TAG} as base @@ -61,4 +61,7 @@ RUN cd /app/tensorrt_llm/build && pip3 install *.whl # Install TensorRT-LLM backend RUN mkdir /opt/tritonserver/backends/tensorrtllm +ENV LD_LIBRARY_PATH=/opt/tritonserver/backends/tensorrtllm:${LD_LIBRARY_PATH} COPY --from=trt_llm_backend_builder /app/inflight_batcher_llm/build/libtriton_tensorrtllm.so /opt/tritonserver/backends/tensorrtllm +COPY --from=trt_llm_backend_builder /app/inflight_batcher_llm/build/libtriton_tensorrtllm_common.so /opt/tritonserver/backends/tensorrtllm +COPY --from=trt_llm_backend_builder /app/inflight_batcher_llm/build/triton_tensorrtllm_worker /opt/tritonserver/backends/tensorrtllm diff --git a/inflight_batcher_llm/CMakeLists.txt b/inflight_batcher_llm/CMakeLists.txt index 21491dc0..5eeb4a37 100644 --- a/inflight_batcher_llm/CMakeLists.txt +++ b/inflight_batcher_llm/CMakeLists.txt @@ -102,19 +102,19 @@ FetchContent_MakeAvailable(repo-common repo-core repo-backend) configure_file(src/libtriton_tensorrtllm.ldscript libtriton_tensorrtllm.ldscript COPYONLY) -set(SRCS - src/libtensorrtllm.cc - src/work_item.cc - src/work_items_queue.cc - src/model_instance_state.cc - src/model_state.cc - src/utils.cc - src/inference_answer.cc) +set(COMMON_SRCS + src/work_item.cc src/work_items_queue.cc src/model_instance_state.cc + src/model_state.cc src/utils.cc src/inference_answer.cc) -add_library(triton-tensorrt-llm-backend SHARED ${SRCS}) +add_library(triton-tensorrt-llm-common SHARED ${COMMON_SRCS}) -add_library(TritonTensorRTLLMBackend::triton-tensorrt-llm-backend ALIAS - triton-tensorrt-llm-backend) +set(BACKEND_SRCS src/libtensorrtllm.cc src/orchestrator.cc) + +add_library(triton-tensorrt-llm-backend SHARED ${BACKEND_SRCS}) + +set(WORKER_SRCS src/worker.cc) + +add_executable(triton-tensorrt-llm-worker ${WORKER_SRCS}) enable_language(CUDA) @@ -191,21 +191,22 @@ list(APPEND COMMON_HEADER_DIRS ${TORCH_INCLUDE_DIRS} ${TRT_INCLUDE_DIR}) include_directories(${COMMON_HEADER_DIRS}) target_include_directories( - triton-tensorrt-llm-backend - PRIVATE ${TRTLLM_DIR}/cpp - ${TRTLLM_DIR}/cpp/include - ${CMAKE_CURRENT_SOURCE_DIR}/src - ${CUDA_INCLUDE_DIRS} - ${CUDNN_ROOT_DIR}/include - ${NCCL_INCLUDE_DIR} - ${3RDPARTY_DIR}/cutlass/include - ${MPI_INCLUDE_PATH} - ${COMMON_HEADER_DIR}) - + triton-tensorrt-llm-common + PUBLIC ${TRTLLM_DIR}/cpp + ${TRTLLM_DIR}/cpp/include + ${CMAKE_CURRENT_SOURCE_DIR}/src + ${CUDA_INCLUDE_DIRS} + ${CUDNN_ROOT_DIR}/include + ${NCCL_INCLUDE_DIR} + ${3RDPARTY_DIR}/cutlass/include + ${MPI_INCLUDE_PATH} + ${COMMON_HEADER_DIR}) + +target_compile_features(triton-tensorrt-llm-common PRIVATE cxx_std_17) target_compile_features(triton-tensorrt-llm-backend PRIVATE cxx_std_17) -target_compile_options( - triton-tensorrt-llm-backend - PRIVATE +target_compile_features(triton-tensorrt-llm-worker PRIVATE cxx_std_17) + +set(COMPILE_OPTIONS $<$,$,$>: -Wall -Wextra @@ -216,6 +217,10 @@ target_compile_options( /D_WIN32_WINNT=0x0A00 /EHsc>) +target_compile_options(triton-tensorrt-llm-common PRIVATE ${COMPILE_OPTIONS}) +target_compile_options(triton-tensorrt-llm-backend PRIVATE ${COMPILE_OPTIONS}) +target_compile_options(triton-tensorrt-llm-worker PRIVATE ${COMPILE_OPTIONS}) + add_library(tensorrt_llm SHARED IMPORTED) set_property( TARGET tensorrt_llm @@ -262,9 +267,9 @@ if(TRITON_ENABLE_METRICS) triton-backend-utils # from repo-backend tensorrt_llm) - target_compile_definitions(triton-tensorrt-llm-backend + target_compile_definitions(triton-tensorrt-llm-common PRIVATE TRITON_ENABLE_METRICS=1) - target_link_libraries(triton-tensorrt-llm-backend + target_link_libraries(triton-tensorrt-llm-common PRIVATE triton-custom-metrics-reporter-library) endif() @@ -290,16 +295,21 @@ if(TRITON_BUILD) endif() # TRITON_BUILD target_link_libraries( - triton-tensorrt-llm-backend - PRIVATE tensorrt_llm - triton-core-serverapi # from repo-core - triton-core-backendapi # from repo-core - triton-core-serverstub # from repo-core - triton-backend-utils # from repo-backend - ${MPI_LIBRARIES} - ${CUDA_LIBRARIES} - nvinfer - nvinfer_plugin_tensorrt_llm) + triton-tensorrt-llm-common + PUBLIC tensorrt_llm + triton-core-serverapi # from repo-core + triton-core-backendapi # from repo-core + triton-core-serverstub # from repo-core + triton-backend-utils # from repo-backend + ${MPI_LIBRARIES} + ${CUDA_LIBRARIES} + nvinfer + nvinfer_plugin_tensorrt_llm) + +target_link_libraries(triton-tensorrt-llm-backend + PRIVATE triton-tensorrt-llm-common) +target_link_libraries(triton-tensorrt-llm-worker + PRIVATE triton-tensorrt-llm-common) FetchContent_Declare( json @@ -308,13 +318,19 @@ FetchContent_Declare( FetchContent_MakeAvailable(json) -target_link_libraries(triton-tensorrt-llm-backend +target_link_libraries(triton-tensorrt-llm-common PRIVATE nlohmann_json::nlohmann_json) if(WIN32) set_target_properties( triton-tensorrt-llm-backend PROPERTIES POSITION_INDEPENDENT_CODE ON OUTPUT_NAME triton_tensorrtllm) + set_target_properties( + triton-tensorrt-llm-worker PROPERTIES POSITION_INDEPENDENT_CODE ON + OUTPUT_NAME triton_tensorrtllm_worker) + set_target_properties( + triton-tensorrt-llm-common PROPERTIES POSITION_INDEPENDENT_CODE ON + OUTPUT_NAME triton_tensorrtllm_common) else() set_target_properties( triton-tensorrt-llm-backend @@ -325,62 +341,14 @@ else() LINK_FLAGS "-Wl,--version-script libtriton_tensorrtllm.ldscript -Wl,-rpath,'$ORIGIN' -Wl,--no-undefined" ) + set_target_properties(triton-tensorrt-llm-worker + PROPERTIES OUTPUT_NAME triton_tensorrtllm_worker) + set_target_properties( + triton-tensorrt-llm-common PROPERTIES POSITION_INDEPENDENT_CODE ON + OUTPUT_NAME triton_tensorrtllm_common) endif() if(BUILD_TESTS) enable_testing() add_subdirectory(tests) endif() - -# -# Install -# -include(GNUInstallDirs) -set(INSTALL_CONFIGDIR ${CMAKE_INSTALL_LIBDIR}/cmake/TritonTensorRTLLMBackend) - -install( - TARGETS triton-tensorrt-llm-backend - EXPORT triton-tensorrt-llm-backend-targets - LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/tensorrtllm - RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/tensorrtllm) - -if(TRITON_BUILD) - file( - GLOB - LIBINFER_PLUGIN_TENSORRT_LLM - "${TRTLLM_DIR}/cpp/build/tensorrt_llm/plugins/libnvinfer_plugin_tensorrt_llm.so*" - FOLLOW_SYMLINKS) - install(FILES ${LIBINFER_PLUGIN_TENSORRT_LLM} - DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/tensorrtllm) - - file(GLOB LIBINFER_PLUGIN_TENSORRT_LLM - "${TRTLLM_DIR}/cpp/build/tensorrt_llm/libtensorrt_llm.so*" - FOLLOW_SYMLINKS) - install(FILES ${LIBINFER_PLUGIN_TENSORRT_LLM} - DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/tensorrtllm) -endif() # TRITON_BUILD - -install( - EXPORT triton-tensorrt-llm-backend-targets - FILE TritonTensorRTLLMBackendTargets.cmake - NAMESPACE TritonTensorRTLLMBackend:: - DESTINATION ${INSTALL_CONFIGDIR}) - -include(CMakePackageConfigHelpers) -configure_package_config_file( - ${CMAKE_CURRENT_LIST_DIR}/cmake/TritonTensorRTLLMBackendConfig.cmake.in - ${CMAKE_CURRENT_BINARY_DIR}/TritonTensorRTLLMBackendConfig.cmake - INSTALL_DESTINATION ${INSTALL_CONFIGDIR}) - -install(FILES ${CMAKE_CURRENT_BINARY_DIR}/TritonTensorRTLLMBackendConfig.cmake - DESTINATION ${INSTALL_CONFIGDIR}) - -# -# Export from build tree -# -export( - EXPORT triton-tensorrt-llm-backend-targets - FILE ${CMAKE_CURRENT_BINARY_DIR}/TritonTensorRTLLMBackendTargets.cmake - NAMESPACE TritonTensorRTLLMBackend::) - -export(PACKAGE TritonTensorRTLLMBackend) diff --git a/inflight_batcher_llm/README.md b/inflight_batcher_llm/README.md index c7b385d0..0052a61e 100644 --- a/inflight_batcher_llm/README.md +++ b/inflight_batcher_llm/README.md @@ -187,6 +187,52 @@ python3 tensorrt_llm/examples/hf_lora_convert.py -i Japanese-Alpaca-LoRA-7b-v0 - python3 tensorrt_llm/examples/hf_lora_convert.py -i luotuo-lora-7b-0.1 -o luotuo-lora-7b-0.1-weights --storage-type float16 ``` +### LoRA Cache + +As LoRA weights are passed to the backend they will be cached in a host cache. As requests are scheduled, those weights with be prefetched to a gpu cache. After a LoRA is loaded into the cache, only `lora_task_id` is needed for inference. + + +Optimal adapter size used to size cache pages (default: 8) +``` +parameters: { + key: "lora_cache_optimal_adapter_size" + value: { + string_value: "${lora_cache_optimal_adapter_size}" + } +} +``` + + +Maximum supported adapter size (default: 64) +``` +parameters: { + key: "lora_cache_max_adapter_size" + value: { + string_value: "${lora_cache_max_adapter_size}" + } +} +``` + +Fraction of GPU memory used for LoRA cache. Computed as a fraction of left over memory after engine load, and after KV cache is loaded (default: 0.05) +``` +parameters: { + key: "lora_cache_gpu_memory_fraction" + value: { + string_value: "${lora_cache_gpu_memory_fraction}" + } +} +``` + +Size of host LoRA cache in bytes (default: 1G) +``` +parameters: { + key: "lora_cache_host_memory_bytes" + value: { + string_value: "${lora_cache_host_memory_bytes}" + } +} +``` + Launch tritonserver as describe above Run Multi-LoRA example by issuing multiple concurrent requests. diff --git a/inflight_batcher_llm/client/inflight_batcher_llm_client.py b/inflight_batcher_llm/client/inflight_batcher_llm_client.py index c869217b..6311464f 100755 --- a/inflight_batcher_llm/client/inflight_batcher_llm_client.py +++ b/inflight_batcher_llm/client/inflight_batcher_llm_client.py @@ -119,10 +119,10 @@ def prepare_inputs(input_ids_data, input_lengths_data, request_output_len_data, beam_width_data, temperature_data, repetition_penalty_data, presence_penalty_data, frequency_penalty_data, streaming_data, end_id, pad_id, prompt_embedding_table_data, - prompt_vocab_size_data, lora_weights_data, lora_config_data, - return_log_probs_data, top_k_data, top_p_data, - draft_ids_data, return_context_logits_data, - return_generation_logits_data): + prompt_vocab_size_data, lora_task_id_data, + lora_weights_data, lora_config_data, return_log_probs_data, + top_k_data, top_p_data, draft_ids_data, + return_context_logits_data, return_generation_logits_data): inputs = [ prepare_tensor("input_ids", input_ids_data), prepare_tensor("input_lengths", input_lengths_data), @@ -142,6 +142,8 @@ def prepare_inputs(input_ids_data, input_lengths_data, request_output_len_data, prompt_embedding_table_data), prepare_tensor("prompt_vocab_size", prompt_vocab_size_data) ] + if lora_task_id_data is not None: + inputs += [prepare_tensor("lora_task_id", lora_task_id_data)] if lora_weights_data is not None: inputs += [ prepare_tensor("lora_weights", lora_weights_data), @@ -421,6 +423,11 @@ def callback(user_data, result, error): default='', required=False, help="LoRA weights") + parser.add_argument("--lora-task-id", + type=int, + default=None, + required=False, + help="LoRA task id") parser.add_argument( "--exclude-input-in-output", action="store_true", @@ -488,6 +495,12 @@ def callback(user_data, result, error): default=[], help='The requested output tensors') + parser.add_argument('--model-name', + type=str, + required=False, + default='tensorrt_llm', + help='Specify model name') + FLAGS = parser.parse_args() tokenizer = None @@ -562,6 +575,9 @@ def callback(user_data, result, error): except Exception: lora_config_data = np.load( os.path.join(FLAGS.lora_path, "model.lora_keys.npy")) + lora_task_id_data = None + if FLAGS.lora_task_id is not None and FLAGS.lora_task_id != 0: + lora_task_id_data = np.array([[FLAGS.lora_task_id]], dtype=np.uint64) input_ids_data = np.array(input_ids, dtype=np.int32) input_lengths = [[len(ii)] for ii in input_ids] @@ -614,9 +630,10 @@ def callback(user_data, result, error): beam_width_data, temperature_data, repetition_penalty_data, presence_penalty_data, frequency_penalty_data, streaming_data, end_id_data, pad_id_data, prompt_embedding_table_data, - prompt_vocab_size_data, lora_weights_data, lora_config_data, - return_log_probs_data, top_k_data, top_p_data, draft_ids_data, - return_context_logits_data, return_generation_logits_data) + prompt_vocab_size_data, lora_task_id_data, lora_weights_data, + lora_config_data, return_log_probs_data, top_k_data, top_p_data, + draft_ids_data, return_context_logits_data, + return_generation_logits_data) if FLAGS.requested_outputs: # Must have at least output_ids in requested outputs @@ -679,7 +696,7 @@ def callback(user_data, result, error): ) # Send request triton_client.async_stream_infer( - 'tensorrt_llm', + FLAGS.model_name, inputs, outputs=outputs, request_id=request_id, @@ -690,7 +707,7 @@ def callback(user_data, result, error): if not FLAGS.stop_via_request_cancel: triton_client.async_stream_infer( - 'tensorrt_llm', + FLAGS.model_name, stop_inputs, request_id=request_id, parameters={'Streaming': FLAGS.streaming}) @@ -729,7 +746,7 @@ def callback(user_data, result, error): else: # Send request infer_future = triton_client.async_infer( - 'tensorrt_llm', + FLAGS.model_name, inputs, outputs=outputs, request_id=request_id, @@ -746,7 +763,7 @@ def callback(user_data, result, error): infer_future.cancel() else: triton_client.async_infer( - 'tensorrt_llm', + FLAGS.model_name, stop_inputs, request_id=request_id, callback=partial(callback, user_data), diff --git a/inflight_batcher_llm/src/custom_metrics_reporter/custom_metrics_reporter.cc b/inflight_batcher_llm/src/custom_metrics_reporter/custom_metrics_reporter.cc index 764c1ec7..7285a623 100644 --- a/inflight_batcher_llm/src/custom_metrics_reporter/custom_metrics_reporter.cc +++ b/inflight_batcher_llm/src/custom_metrics_reporter/custom_metrics_reporter.cc @@ -59,7 +59,7 @@ const std::vector CustomMetricsReporter::IFB_specific_labels_{ const std::vector CustomMetricsReporter::general_metric_keys_{"Timestamp", "Iteration Counter"}; const std::vector CustomMetricsReporter::general_metric_labels_{"timestamp", "iteration_counter"}; -uint64_t convertTimestampToSeconds(const std::string& ts) +uint64_t convertTimestampToSeconds(std::string const& ts) { std::tm tm = {}; std::stringstream ss(ts); @@ -70,9 +70,9 @@ uint64_t convertTimestampToSeconds(const std::string& ts) return time_in_seconds; } -TritonMetricGroup::TritonMetricGroup(const std::string& metric_family_label, - const std::string& metric_family_description, const std::string& category_label, - const std::vector& json_keys, const std::vector& sub_labels) +TritonMetricGroup::TritonMetricGroup(std::string const& metric_family_label, + std::string const& metric_family_description, std::string const& category_label, + std::vector const& json_keys, std::vector const& sub_labels) : metric_family_label_(metric_family_label) , metric_family_description_(metric_family_description) , category_label_(category_label) @@ -81,14 +81,14 @@ TritonMetricGroup::TritonMetricGroup(const std::string& metric_family_label, { } -TRITONSERVER_Error* TritonMetricGroup::CreateGroup(const std::string& model_name, const uint64_t version) +TRITONSERVER_Error* TritonMetricGroup::CreateGroup(std::string const& model_name, const uint64_t version) { TRITONSERVER_MetricFamily* metric_family = nullptr; RETURN_IF_ERROR(TRITONSERVER_MetricFamilyNew(&metric_family, TRITONSERVER_METRIC_KIND_GAUGE, metric_family_label_.c_str(), metric_family_description_.c_str())); metric_family_.reset(metric_family); - std::vector labels; + std::vector labels; std::unique_ptr model_label( TRITONSERVER_ParameterNew("model", TRITONSERVER_PARAMETER_STRING, model_name.c_str())); std::unique_ptr model_version( @@ -120,13 +120,13 @@ TRITONSERVER_Error* TritonMetricGroup::UpdateGroup(std::vector& values return nullptr; // success } -const std::vector& TritonMetricGroup::JsonKeys() const +std::vector const& TritonMetricGroup::JsonKeys() const { return json_keys_; } TRITONSERVER_Error* CustomMetricsReporter::InitializeReporter( - const std::string& model_name, const uint64_t version, const bool is_v1_model) + std::string const& model_name, const uint64_t version, bool const is_v1_model) { /* REQUEST METRIC GROUP */ request_metric_family_ = std::make_unique( @@ -179,18 +179,18 @@ TRITONSERVER_Error* CustomMetricsReporter::InitializeReporter( return nullptr; // success } -TRITONSERVER_Error* CustomMetricsReporter::UpdateCustomMetrics(const std::string& custom_metrics) +TRITONSERVER_Error* CustomMetricsReporter::UpdateCustomMetrics(std::string const& custom_metrics) { triton::common::TritonJson::Value metrics; std::vector members; metrics.Parse(custom_metrics); metrics.Members(&members); - for (const auto& metric_group : metric_groups_) + for (auto const& metric_group : metric_groups_) { std::vector metric_group_keys = metric_group->JsonKeys(); std::vector metric_group_values; - for (const auto& key : metric_group_keys) + for (auto const& key : metric_group_keys) { triton::common::TritonJson::Value value_json; uint64_t value; diff --git a/inflight_batcher_llm/src/custom_metrics_reporter/custom_metrics_reporter.h b/inflight_batcher_llm/src/custom_metrics_reporter/custom_metrics_reporter.h index 1d8eb0fa..d0960178 100644 --- a/inflight_batcher_llm/src/custom_metrics_reporter/custom_metrics_reporter.h +++ b/inflight_batcher_llm/src/custom_metrics_reporter/custom_metrics_reporter.h @@ -44,9 +44,9 @@ namespace triton::backend::inflight_batcher_llm::custom_metrics_reporter class TritonMetricGroup { public: - TritonMetricGroup(const std::string& metric_family_label, const std::string& metric_family_description, - const std::string& category_label, const std::vector& json_keys, - const std::vector& labels); + TritonMetricGroup(std::string const& metric_family_label, std::string const& metric_family_description, + std::string const& category_label, std::vector const& json_keys, + std::vector const& labels); ~TritonMetricGroup(){}; /// Create a new Triton metric family with corresponding metric @@ -57,7 +57,7 @@ class TritonMetricGroup /// \param version The version of the model to provide a metrics /// group for. /// \return a TRITONSERVER_Error indicating success or failure. - TRITONSERVER_Error* CreateGroup(const std::string& model_name, const uint64_t version); + TRITONSERVER_Error* CreateGroup(std::string const& model_name, const uint64_t version); /// Update the Triton metrics associated with this group using /// the parsed TRT LLM backend statistics values. @@ -72,7 +72,7 @@ class TritonMetricGroup /// /// \return A const reference to vector of strings corresponding /// to the JSON keys associated with this group. - const std::vector& JsonKeys() const; + std::vector const& JsonKeys() const; /// Custom deleter for a unique TRITONSERVER_MetricFamily pointer struct MetricFamilyDeleter @@ -139,7 +139,7 @@ class CustomMetricsReporter /// \param is_v1_model Whether the model type is v1 or an inflight /// batching model. /// \return a TRITONSERVER_Error indicating success or failure. - TRITONSERVER_Error* InitializeReporter(const std::string& model, const uint64_t version, const bool is_v1_model); + TRITONSERVER_Error* InitializeReporter(std::string const& model, const uint64_t version, bool const is_v1_model); /// Updates the vector of TritonMetricGroup objects with a /// JSON-formatted statistics string. @@ -147,7 +147,7 @@ class CustomMetricsReporter /// \param statistics A JSON-formatted string of TRT LLM backend /// statistics. /// \return a TRITONSERVER_Error indicating success or failure. - TRITONSERVER_Error* UpdateCustomMetrics(const std::string& custom_metrics); + TRITONSERVER_Error* UpdateCustomMetrics(std::string const& custom_metrics); static const std::vector request_keys_; static const std::vector request_labels_; diff --git a/inflight_batcher_llm/src/inference_answer.cc b/inflight_batcher_llm/src/inference_answer.cc index 33b72311..da8e73a8 100644 --- a/inflight_batcher_llm/src/inference_answer.cc +++ b/inflight_batcher_llm/src/inference_answer.cc @@ -33,36 +33,45 @@ static int kBitsinByte = 8; std::vector InferenceAnswer::serialize() const { - std::list packed; + // request ID + // num tensors + // final answer + // err msg + size_t totalSize = 4; - packed.push_back(static_cast(request_id_)); - - packed.push_back(static_cast(response_tensors_.size())); for (auto const& tensor : response_tensors_) { - auto packed_tensor = tensor.serialize(); - packed.push_back(static_cast(packed_tensor.size())); - packed.insert(packed.end(), packed_tensor.begin(), packed_tensor.end()); + totalSize += tensor.serializedSize(); + ++totalSize; } - packed.push_back(final_response_ ? 1 : 0); - - const auto num_elements = (err_msg_.size() + sizeof(int64_t) - 1) / sizeof(int64_t); + const int64_t num_elements = (err_msg_.size() + sizeof(int64_t) - 1) / sizeof(int64_t); + totalSize += num_elements; - packed.push_back(static_cast(err_msg_.size())); + std::vector vpacked(totalSize); + int64_t* ptr = vpacked.data(); + *ptr++ = request_id_; + *ptr++ = static_cast(response_tensors_.size()); + for (auto const& tensor : response_tensors_) + { + auto size = tensor.serializedSize(); + *ptr++ = size; + tensor.serialize(ptr, size); + ptr += size; + } + *ptr++ = final_response_ ? 1 : 0; + *ptr++ = err_msg_.size(); - for (size_t i = 0; i < num_elements; ++i) + for (int64_t i = 0; i < num_elements; ++i) { int64_t buffer = 0; for (size_t j = 0; j < sizeof(int64_t) && (i * sizeof(int64_t) + j) < err_msg_.size(); ++j) { buffer |= static_cast(err_msg_[i * sizeof(int64_t) + j]) << (j * kBitsinByte); } - packed.push_back(buffer); + *ptr++ = buffer; } - std::vector vpacked{ - std::make_move_iterator(std::begin(packed)), std::make_move_iterator(std::end(packed))}; return vpacked; } @@ -82,7 +91,7 @@ std::shared_ptr InferenceAnswer::deserialize(int64_t const* pac answer->final_response_ = *packed_ptr++ != 0; - const auto num_chars = *packed_ptr++; + auto const num_chars = *packed_ptr++; answer->err_msg_.reserve(num_chars); int64_t i = 0; diff --git a/inflight_batcher_llm/src/inference_answer.h b/inflight_batcher_llm/src/inference_answer.h index 77b25717..671cc6fc 100644 --- a/inflight_batcher_llm/src/inference_answer.h +++ b/inflight_batcher_llm/src/inference_answer.h @@ -46,7 +46,7 @@ class InferenceAnswer } InferenceAnswer(uint64_t request_id, std::list const& response_tensors, bool final_response, - const std::string& err_msg) + std::string const& err_msg) : request_id_(request_id) , response_tensors_(response_tensors) , final_response_(final_response) @@ -69,16 +69,16 @@ class InferenceAnswer return response_tensors_; } - const std::string& GetErrorMessage() const + std::string const& GetErrorMessage() const { return err_msg_; } [[nodiscard]] std::vector serialize() const; - static std::shared_ptr deserialize(const std::vector& packed); + static std::shared_ptr deserialize(std::vector const& packed); - static std::shared_ptr deserialize(const int64_t* packed_ptr); + static std::shared_ptr deserialize(int64_t const* packed_ptr); private: uint64_t request_id_; diff --git a/inflight_batcher_llm/src/libtensorrtllm.cc b/inflight_batcher_llm/src/libtensorrtllm.cc index 98c6d897..f726ba6a 100644 --- a/inflight_batcher_llm/src/libtensorrtllm.cc +++ b/inflight_batcher_llm/src/libtensorrtllm.cc @@ -32,6 +32,8 @@ #include #include +#include "tensorrt_llm/common/mpiUtils.h" + // Triton headers #include "triton/backend/backend_common.h" #include "triton/core/tritonbackend.h" @@ -40,6 +42,7 @@ // trtllm backend headers #include "model_instance_state.h" #include "model_state.h" +#include "orchestrator.h" #include "work_item.h" #include "work_items_queue.h" @@ -52,6 +55,44 @@ namespace triton::backend::inflight_batcher_llm extern "C" { + + // Global backend state creation + TRITONSERVER_Error* TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend) + { + tensorrt_llm::mpi::initialize(tensorrt_llm::mpi::MpiThreadSupport::THREAD_MULTIPLE); + + char const* str = std::getenv("TRTLLM_ORCHESTRATOR"); + + if (str && std::atoi(str) != 0) + { + TLLM_LOG_INFO( + "Detected TRTLLM_ORCHESTRATOR environment variable, TRTLLM backend will operator in orchestrator " + "mode."); + auto* orchestrator = new Orchestrator(); + RETURN_IF_ERROR(TRITONBACKEND_BackendSetState(backend, reinterpret_cast(orchestrator))); + } + else + { + RETURN_IF_ERROR(TRITONBACKEND_BackendSetState(backend, nullptr)); + } + + return nullptr; // success + } + + TRITONSERVER_Error* TRITONBACKEND_Finalize(TRITONBACKEND_Backend* backend) + { + void* vstate; + RETURN_IF_ERROR(TRITONBACKEND_BackendState(backend, &vstate)); + + if (vstate) + { + auto* orchestrator = reinterpret_cast(vstate); + delete orchestrator; + } + + return nullptr; // success + } + // Triton calls TRITONBACKEND_ModelInitialize when a model is loaded // to allow the backend to create any state associated with the model, // and to also examine the model configuration to determine if the @@ -64,7 +105,7 @@ extern "C" // TRITONBACKEND_Model. If anything goes wrong with initialization // of the model state then an error is returned and Triton will fail // to load the model. - const char* cname; + char const* cname; RETURN_IF_ERROR(TRITONBACKEND_ModelName(model, &cname)); const std::string name(cname); @@ -107,11 +148,43 @@ extern "C" RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vmodelstate)); ModelState* model_state = reinterpret_cast(vmodelstate); - // Create a ModelInstanceState object and associate it with the - // TRITONBACKEND_ModelInstance. - ModelInstanceState* instance_state; - RETURN_IF_ERROR(ModelInstanceState::Create(model_state, instance, &instance_state)); - RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceSetState(instance, reinterpret_cast(instance_state))); + TRITONBACKEND_Backend* backend; + RETURN_IF_ERROR(TRITONBACKEND_ModelBackend(model, &backend)); + + void* vstate; + RETURN_IF_ERROR(TRITONBACKEND_BackendState(backend, &vstate)); + + auto* orchestrator = reinterpret_cast(vstate); + + if (orchestrator) + { + auto const device_ids = model_state->GetDeviceIds(); + int const num_workers = device_ids ? device_ids.value().size() : 1; + + std::string workerPath = model_state->GetWorkerPath(); + MPI_Comm everyone; + MPI_Comm_spawn(workerPath.c_str(), MPI_ARGV_NULL, num_workers, MPI_INFO_NULL, 0, MPI_COMM_SELF, &everyone, + MPI_ERRCODES_IGNORE); + + // The output comm is an intercommunicator so it has some special rules. + // The parent must send data with bcast using root = MPI_ROOT (-4) + std::vector packed = model_state->serialize(); + int64_t n = packed.size(); + MPICHECK(MPI_Bcast(&n, 1, MPI_INT64_T, MPI_ROOT, everyone)); + MPICHECK(MPI_Bcast(packed.data(), packed.size(), MPI_INT64_T, MPI_ROOT, everyone)); + + OrchestratorCommunicator* communicator; + RETURN_IF_ERROR(orchestrator->addCommunicator(model_state, instance, everyone, &communicator)); + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceSetState(instance, reinterpret_cast(communicator))); + } + else + { + // Create a ModelInstanceState object and associate it with the + // TRITONBACKEND_ModelInstance. + ModelInstanceState* instance_state; + RETURN_IF_ERROR(ModelInstanceState::Create(model_state, instance, &instance_state)); + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceSetState(instance, reinterpret_cast(instance_state))); + } return nullptr; // success } @@ -122,10 +195,30 @@ extern "C" // TRITONSERVER_Error* TRITONBACKEND_ModelInstanceFinalize(TRITONBACKEND_ModelInstance* instance) { + TRITONBACKEND_Model* model; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceModel(instance, &model)); + + TRITONBACKEND_Backend* backend; + RETURN_IF_ERROR(TRITONBACKEND_ModelBackend(model, &backend)); + void* vstate; + RETURN_IF_ERROR(TRITONBACKEND_BackendState(backend, &vstate)); + + auto* orchestrator = reinterpret_cast(vstate); + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState(instance, &vstate)); - ModelInstanceState* instance_state = reinterpret_cast(vstate); - delete instance_state; + + if (orchestrator) + { + auto* communicator = reinterpret_cast(vstate); + communicator->shutdown(); + delete communicator; + } + else + { + ModelInstanceState* instance_state = reinterpret_cast(vstate); + delete instance_state; + } return nullptr; // success } @@ -138,10 +231,32 @@ extern "C" TRITONSERVER_Error* TRITONBACKEND_ModelInstanceExecute( TRITONBACKEND_ModelInstance* instance, TRITONBACKEND_Request** requests, const uint32_t request_count) { - ModelInstanceState* instance_state; - RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState(instance, reinterpret_cast(&instance_state))); + TRITONBACKEND_Model* model; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceModel(instance, &model)); + + TRITONBACKEND_Backend* backend; + RETURN_IF_ERROR(TRITONBACKEND_ModelBackend(model, &backend)); + + void* vstate; + RETURN_IF_ERROR(TRITONBACKEND_BackendState(backend, &vstate)); + + auto* orchestrator = reinterpret_cast(vstate); + + if (orchestrator) + { + OrchestratorCommunicator* communicator; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState(instance, reinterpret_cast(&communicator))); + + communicator->enqueue(requests, request_count); + } + else + { + ModelInstanceState* instance_state; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState(instance, reinterpret_cast(&instance_state))); + + instance_state->enqueue(requests, request_count); + } - instance_state->enqueue(requests, request_count); return nullptr; // success } diff --git a/inflight_batcher_llm/src/libtriton_tensorrtllm.ldscript b/inflight_batcher_llm/src/libtriton_tensorrtllm.ldscript index eba5e2f1..748714d1 100644 --- a/inflight_batcher_llm/src/libtriton_tensorrtllm.ldscript +++ b/inflight_batcher_llm/src/libtriton_tensorrtllm.ldscript @@ -26,6 +26,5 @@ { global: TRITONBACKEND_*; - *InferenceAnswer*; local: *; }; diff --git a/inflight_batcher_llm/src/model_instance_state.cc b/inflight_batcher_llm/src/model_instance_state.cc index c13bf9e5..7cc7e803 100644 --- a/inflight_batcher_llm/src/model_instance_state.cc +++ b/inflight_batcher_llm/src/model_instance_state.cc @@ -25,9 +25,14 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "model_instance_state.h" +#include "utils.h" + +#include "mpi_utils.h" #include "tensorrt_llm/common/mpiUtils.h" +#include + namespace mpi = tensorrt_llm::mpi; namespace triton::backend::inflight_batcher_llm @@ -38,9 +43,9 @@ TRITONSERVER_Error* ModelInstanceState::Create( { try { - *state = new ModelInstanceState(model_state, triton_model_instance); + *state = new ModelInstanceState(model_state, triton_model_instance, MPI_COMM_NULL); } - catch (const std::exception& ex) + catch (std::exception const& ex) { std::string errStr = std::string("unexpected error when creating modelInstanceState: ") + ex.what(); return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, errStr.c_str()); @@ -49,7 +54,26 @@ TRITONSERVER_Error* ModelInstanceState::Create( return nullptr; // success } -ModelInstanceState::ModelInstanceState(ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance) +bool ModelInstanceState::Create(ModelState* model_state, MPI_Comm leaderOrchComm, ModelInstanceState** state) +{ + try + { + // No need for a triton model instance, since this worker will communicate its answers + // to the orchestrator which communicates with Triton + TRITONBACKEND_ModelInstance* triton_model_instance = nullptr; + *state = new ModelInstanceState(model_state, triton_model_instance, leaderOrchComm); + } + catch (std::exception const& ex) + { + TLLM_LOG_ERROR("unexpected error when creating modelInstanceState: %s", ex.what()); + return false; + } + + return true; +} + +ModelInstanceState::ModelInstanceState( + ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance, MPI_Comm leaderOrchComm) : model_state_(model_state) , modelInstance_(triton_model_instance) , mHasActiveRequests(false) @@ -100,7 +124,7 @@ ModelInstanceState::ModelInstanceState(ModelState* model_state, TRITONBACKEND_Mo { maxBeamWidth = model_state_->GetParameter("max_beam_width"); } - catch (const std::exception& e) + catch (std::exception const& e) { // If parameter is not specified, just ignore TLLM_LOG_WARNING("max_beam_width is not specified, will use default value of 1"); @@ -111,7 +135,7 @@ ModelInstanceState::ModelInstanceState(ModelState* model_state, TRITONBACKEND_Mo { maxTokensInPagedKvCache = model_state_->GetParameter("max_tokens_in_paged_kv_cache"); } - catch (const std::exception& e) + catch (std::exception const& e) { // If parameter is not specified, just ignore TLLM_LOG_WARNING( @@ -138,7 +162,7 @@ ModelInstanceState::ModelInstanceState(ModelState* model_state, TRITONBACKEND_Mo "(must be max_utilization or guaranteed_no_evict)"); } } - catch (const std::exception& e) + catch (std::exception const& e) { TLLM_LOG_WARNING(e.what()); } @@ -154,7 +178,7 @@ ModelInstanceState::ModelInstanceState(ModelState* model_state, TRITONBACKEND_Mo "(requires building the model with use_paged_context_fmha)."); } } - catch (const std::exception& e) + catch (std::exception const& e) { // If parameter is not specified, just ignore TLLM_LOG_WARNING("enable_chunked_context is not specified, will be set to false."); @@ -179,7 +203,7 @@ ModelInstanceState::ModelInstanceState(ModelState* model_state, TRITONBACKEND_Mo { kvCacheFreeGpuMemFraction = model_state_->GetParameter("kv_cache_free_gpu_mem_fraction"); } - catch (const std::exception& e) + catch (std::exception const& e) { // If parameter is not specified, just ignore TLLM_LOG_WARNING( @@ -192,7 +216,7 @@ ModelInstanceState::ModelInstanceState(ModelState* model_state, TRITONBACKEND_Mo { enableTrtOverlap = model_state_->GetParameter("enable_trt_overlap"); } - catch (const std::exception& e) + catch (std::exception const& e) { // If parameter is not specified, just ignore TLLM_LOG_WARNING("enable_trt_overlap is not specified, will be set to false"); @@ -203,7 +227,7 @@ ModelInstanceState::ModelInstanceState(ModelState* model_state, TRITONBACKEND_Mo { normalizeLogProbs = model_state_->GetParameter("normalize_log_probs"); } - catch (const std::exception& e) + catch (std::exception const& e) { // If parameter is not specified, just ignore TLLM_LOG_WARNING("normalize_log_probs is not specified, will be set to true"); @@ -214,7 +238,7 @@ ModelInstanceState::ModelInstanceState(ModelState* model_state, TRITONBACKEND_Mo { excludeInputInOutput = model_state_->GetParameter("exclude_input_in_output"); } - catch (const std::exception& e) + catch (std::exception const& e) { // If parameter is not specified, just ignore TLLM_LOG_WARNING("exclude_input_in_output is not specified, will be set to false"); @@ -225,7 +249,7 @@ ModelInstanceState::ModelInstanceState(ModelState* model_state, TRITONBACKEND_Mo { maxAttentionWindow = model_state_->GetParameter("max_attention_window_size"); } - catch (const std::exception& e) + catch (std::exception const& e) { // If parameter is not specified, just ignore TLLM_LOG_WARNING( @@ -238,7 +262,7 @@ ModelInstanceState::ModelInstanceState(ModelState* model_state, TRITONBACKEND_Mo { enableKVCacheReuse = model_state_->GetParameter("enable_kv_cache_reuse"); } - catch (const std::exception& e) + catch (std::exception const& e) { // If parameter is not specified, just ignore TLLM_LOG_WARNING("enable_kv_cache_reuse is not specified, will be set to false"); @@ -269,7 +293,7 @@ ModelInstanceState::ModelInstanceState(ModelState* model_state, TRITONBACKEND_Mo throw std::runtime_error(""); } } - catch (const std::exception& e) + catch (std::exception const& e) { TLLM_LOG_WARNING( "decoding_mode parameter is invalid or not specified" @@ -277,6 +301,55 @@ ModelInstanceState::ModelInstanceState(ModelState* model_state, TRITONBACKEND_Mo "Using default: top_k_top_p if max_beam_width == 1, beam_search otherwise"); } + // parse LoRA / Peft cache parameters + // lora_cache_max_adapter_size + // lora_cache_optimal_adapter_size + // lora_cache_gpu_memory_fraction + // lora_cache_host_memory_bytes + + SizeType maxAdapterSize = 64; + SizeType optimalAdapterSize = 8; + std::optional hostCacheSize = std::nullopt; + std::optional deviceCachePercent = std::nullopt; + + std::string fieldName = "lora_cache_max_adapter_size"; + try + { + maxAdapterSize = model_state_->GetParameter(fieldName); + } + catch (std::exception const& e) + { + TLLM_LOG_WARNING(fieldName + " not set, defaulting to 64"); + } + + fieldName = "lora_cache_optimal_adapter_size"; + try + { + optimalAdapterSize = model_state_->GetParameter(fieldName); + } + catch (std::exception const& e) + { + TLLM_LOG_WARNING(fieldName + " not set, defaulting to 8"); + } + fieldName = "lora_cache_gpu_memory_fraction"; + try + { + deviceCachePercent = model_state_->GetParameter(fieldName); + } + catch (std::exception const& e) + { + TLLM_LOG_WARNING(fieldName + " not set, defaulting to 0.05"); + } + fieldName = "lora_cache_host_memory_bytes"; + try + { + hostCacheSize = model_state_->GetParameter(fieldName); + } + catch (std::exception const& e) + { + TLLM_LOG_WARNING(fieldName + " not set, defaulting to 1GB"); + } + auto const gpuDeviceIds = model_state_->GetDeviceIds(); TrtGptModelOptionalParams optionalParams; @@ -290,62 +363,172 @@ ModelInstanceState::ModelInstanceState(ModelState* model_state, TRITONBACKEND_Mo optionalParams.deviceIds = gpuDeviceIds; optionalParams.decodingMode = decodingMode; + optionalParams.peftCacheManagerConfig.maxAdapterSize = maxAdapterSize; + optionalParams.peftCacheManagerConfig.optimalAdapterSize = optimalAdapterSize; + optionalParams.peftCacheManagerConfig.deviceCachePercent = deviceCachePercent; + optionalParams.peftCacheManagerConfig.hostCacheSize = hostCacheSize; + + // TODO (grclark) find better defaults for these + optionalParams.peftCacheManagerConfig.numEnsureWorkers = ModelInstanceState::kPeftCacheNumEnsureWorkers; + optionalParams.peftCacheManagerConfig.numCopyStreams = ModelInstanceState::kPeftCacheNumCopyStreams; + optionalParams.peftCacheManagerConfig.numPutWorkers = ModelInstanceState::kPeftCacheNumPutWorkers; + mBatchManager = std::make_shared( mModelPath, mTrtGptModelType, maxBeamWidth, schedulerPolicy, - [this](int max_num_requests) { return get_inference_requests(max_num_requests); }, - [this](uint64_t requestId, std::list response_tensors, bool final_response, - const std::string& errMsg) { return sendResponse(requestId, response_tensors, final_response, errMsg); }, - [this]() { return pollStopSignals(); }, [this](const std::string& s) { return logStats(s); }, optionalParams, + [this](int max_num_requests) + { + return mLeaderOrchComm ? get_inference_requests_leader(max_num_requests) + : get_inference_requests(max_num_requests); + }, + [this]( + uint64_t requestId, std::list response_tensors, bool final_response, std::string const& errMsg) + { + return mLeaderOrchComm ? sendResponseLeader(requestId, response_tensors, final_response, errMsg) + : sendResponse(requestId, response_tensors, final_response, errMsg); + }, + [this]() { return pollStopSignals(); }, [this](std::string const& s) { return logStats(s); }, optionalParams, std::nullopt, std::nullopt, excludeInputInOutput); - if (COMM_SESSION.getRank() != 0) + int const rank = COMM_SESSION.getRank(); + // If orchestrator mode and leader rank, need to spawn threads to receive requests/ send responses from/to + // orchestrator + if (rank == 0 && leaderOrchComm != MPI_COMM_NULL) + { + mLeaderOrchComm = std::make_unique(leaderOrchComm, true); + mReceiverThread = std::thread([this]() { return RecvMpiThread(); }); + mSenderThread = std::thread([this]() { return AnsMpiThread(); }); + } + + if (rank != 0 || mLeaderOrchComm) { - while (true) + while (!mModelUnloadRequest.load()) + { + } + + if (mReceiverThread.joinable()) { + mReceiverThread.join(); + } + + if (mSenderThread.joinable()) + { + mSenderThread.join(); } } } -void ModelInstanceState::enqueue(TRITONBACKEND_Request** requests, const uint32_t request_count) +void ModelInstanceState::RecvMpiThread() { - std::vector requestsToPush; - uint64_t exec_start_ns = 0; - SET_TIMESTAMP(exec_start_ns); + MPI_Message msg; + MPI_Status status; + int32_t count; + MpiId mpiId; - for (uint32_t r = 0; r < request_count; ++r) + while (true) { - TRITONBACKEND_Request* request = requests[r]; - try - { - auto requestId = utils::getRequestId(request, mRequestIdStrMap); - bool stopRequest = utils::getRequestBooleanInputTensor(request, kStopInputTensorName); + // Blocking is okay: terminate message is expected to arrive here + mLeaderOrchComm->mprobe(0, kMPI_ID_TAG, &msg, &status); + MPICHECK(MPI_Get_count(&status, MPI_UINT64_T, &count)); + TLLM_CHECK(count == 1); + MPICHECK(MPI_Mrecv(&mpiId, count, MPI_UINT64_T, &msg, &status)); - if (stopRequest) + // EXIT condition from receiving TERMINATE msg + if (mpiId == MpiId::TERMINATION) + { + MpiMessage message(mpiId); { - if (requestId != 0) - { - // Check if request is in progress or in queue, if not ignore - mWorkItemsQueue->stopWorkItem(requestId); - // Send a response back to client for stop request - utils::sendEnqueueResponse(request); - } - else - { - throw std::runtime_error("Cannot send stop request without specifying a request_id"); - } + std::unique_lock lk(mSenderMutex); + mSenderQueue.push(message); } - else + + mSenderCV.notify_all(); + mModelUnloadRequest.store(true); + TLLM_LOG_INFO("Leader recv thread exiting"); + break; + } + else if (mpiId == MpiId::PENDING_REQUEST) + { + // Prepare receiving data + mLeaderOrchComm->mprobe(0, kMPI_DATA_TAG, &msg, &status); + MPICHECK(MPI_Get_count(&status, MPI_INT64_T, &count)); + std::vector data(count); + MPICHECK(MPI_Mrecv(data.data(), count, MPI_INT64_T, &msg, &status)); + + auto ir = InferenceRequest::deserialize(data.data()); { - requestsToPush.emplace_back(requestId, request); + std::lock_guard lk(mRecRequestsMutex); + mRecvRequests.push(ir); } } - catch (const std::exception& e) + else if (mpiId == MpiId::STOP_REQUEST || mpiId == MpiId::CANCEL_REQUEST) + { + // Prepare receiving data + mLeaderOrchComm->mprobe(0, kMPI_DATA_TAG, &msg, &status); + MPICHECK(MPI_Get_count(&status, MPI_UINT64_T, &count)); + std::vector data(count); + MPICHECK(MPI_Mrecv(data.data(), count, MPI_UINT64_T, &msg, &status)); + + std::unique_lock lk(mStoppedReqIdsMutex); + mStoppedReqIds.insert(data.begin(), data.end()); + } + } +} + +void ModelInstanceState::AnsMpiThread() +{ + while (true) + { + std::unique_lock lk(mSenderMutex); + mSenderCV.wait(lk, [&]() { return (!mSenderQueue.empty()); }); + + auto message = mSenderQueue.front(); + mSenderQueue.pop(); + + if (message.id == MpiId::TERMINATION) + { + mLeaderOrchComm->send(&message.id, 1, MpiType::kUINT64, 0, kMPI_ID_TAG); + TLLM_LOG_INFO("Leader answer thread exiting"); + break; + } + else if (message.id == MpiId::REQUEST_ANSWER) + { + auto& data = std::get(message.data); + auto packed = data.answer->serialize(); + + mLeaderOrchComm->send(&message.id, 1, MpiType::kUINT64, 0, kMPI_ID_TAG); + mLeaderOrchComm->send(packed.data(), packed.size(), MpiType::kINT64, 0, kMPI_DATA_TAG); + } + else if (message.id == MpiId::REQUEST_IN_PROGRESS) { - // In case of error, no work item is added to queue, so response - // callback needs to be called - utils::sendEnqueueResponse(request, e.what()); + auto& data = std::get(message.data); + + mLeaderOrchComm->send(&message.id, 1, MpiType::kUINT64, 0, kMPI_ID_TAG); + mLeaderOrchComm->send(data.ids.data(), data.ids.size(), MpiType::kUINT64, 0, kMPI_DATA_TAG); } } +} + +void ModelInstanceState::SendMessage(MpiMessage&& message) +{ + { + std::unique_lock lk(mSenderMutex); + mSenderQueue.push(std::move(message)); + } + + mSenderCV.notify_all(); +} + +void ModelInstanceState::enqueue(TRITONBACKEND_Request** requests, const uint32_t request_count) +{ + std::vector requestsToPush; + uint64_t exec_start_ns = 0; + SET_TIMESTAMP(exec_start_ns); + + for (uint32_t r = 0; r < request_count; ++r) + { + TRITONBACKEND_Request* request = requests[r]; + utils::handleTritonRequest(request, mRequestIdStrMap, requestsToPush, *mWorkItemsQueue); + } auto exceptions = mWorkItemsQueue->pushBatch(requestsToPush, exec_start_ns); @@ -363,7 +546,7 @@ void ModelInstanceState::enqueue(TRITONBACKEND_Request** requests, const uint32_ } // Return up to max_num_requests inference requests. -std::list> ModelInstanceState::get_inference_requests(const int max_num_requests) +std::list> ModelInstanceState::get_inference_requests(int const max_num_requests) { std::list> rval; if (max_num_requests <= 0) @@ -373,7 +556,6 @@ std::list> ModelInstanceState::get_inference_r auto const& commSession = COMM_SESSION; - auto world_size = commSession.getSize(); auto rank = commSession.getRank(); if (rank == 0) { @@ -394,32 +576,12 @@ std::list> ModelInstanceState::get_inference_r std::string warnStr = std::string("request Id ") + std::to_string(workItem->requestId()) + std::string(" has been stopped. Request is ignored."); TLLM_LOG_WARNING(warnStr); - sendTritonResponse(workItem, {}, true, warnStr); + sendTritonResponse(workItem, {}, true, warnStr, *mWorkItemsQueue, modelInstance_); } } } - if (world_size > 1) - { - int64_t num_new_work_items = rval.size(); - mHasActiveRequests = (num_new_work_items > 0 || mBatchManager->getNumActiveRequests() > 0); - if (mHasActiveRequests) - { - commSession.bcastValue(num_new_work_items, 0); - } - - if (num_new_work_items > 0) - { - std::vector packed; - for (auto ir : rval) - { - auto vpacked = ir->serialize(); - packed.push_back(static_cast(vpacked.size())); - packed.insert(packed.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end())); - } - commSession.bcast(packed, 0); - } - } + broadcast_inference_requests(rval); } else { @@ -444,8 +606,72 @@ std::list> ModelInstanceState::get_inference_r return rval; } +std::list> ModelInstanceState::get_inference_requests_leader( + int const max_num_requests) +{ + std::list> rval; + if (max_num_requests <= 0) + { + return rval; + } + + std::lock_guard lk(mRecRequestsMutex); + auto const num_requests_to_send = std::min(max_num_requests, (int) mRecvRequests.size()); + + std::vector requests_ids(num_requests_to_send); + + for (int i = 0; i < num_requests_to_send; ++i) + { + auto ir = mRecvRequests.front(); + mRecvRequests.pop(); + + requests_ids[i] = ir->getRequestId(); + + rval.emplace_back(ir); + } + + if (!requests_ids.empty()) + { + MpiMessage message(MpiId::REQUEST_IN_PROGRESS); + message.data = RequestIdsData{std::move(requests_ids)}; + + SendMessage(std::move(message)); + } + + broadcast_inference_requests(rval); + + return rval; +} + +void ModelInstanceState::broadcast_inference_requests(std::list>& rval) +{ + auto const& commSession = COMM_SESSION; + auto world_size = commSession.getSize(); + if (world_size > 1) + { + int64_t num_new_work_items = rval.size(); + mHasActiveRequests = (num_new_work_items > 0 || mBatchManager->getNumActiveRequests() > 0); + if (mHasActiveRequests) + { + commSession.bcastValue(num_new_work_items, 0); + } + + if (num_new_work_items > 0) + { + std::vector packed; + for (auto ir : rval) + { + auto vpacked = ir->serialize(); + packed.push_back(static_cast(vpacked.size())); + packed.insert(packed.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end())); + } + commSession.bcast(packed, 0); + } + } +} + void ModelInstanceState::sendResponse( - uint64_t requestId, std::list const& response_tensors, bool final_response, const std::string& errMsg) + uint64_t requestId, std::list const& response_tensors, bool final_response, std::string const& errMsg) { if (COMM_SESSION.getRank() == 0) { @@ -458,23 +684,45 @@ void ModelInstanceState::sendResponse( try { auto workItem = mWorkItemsQueue->getInProgressWorkItem(requestId); - auto tritonErr = sendTritonResponse(workItem, response_tensors, final_response, errMsg); + auto tritonErr = sendTritonResponse( + workItem, response_tensors, final_response, errMsg, *mWorkItemsQueue, modelInstance_); LOG_IF_ERROR(tritonErr, errStr); } - catch (const std::exception& e) + catch (std::exception const& e) { TLLM_LOG_ERROR(errStr); } } } +void ModelInstanceState::sendResponseLeader( + uint64_t requestId, std::list const& response_tensors, bool final_response, std::string const& errMsg) +{ + // send answer to orchestator + MpiMessage message(MpiId::REQUEST_ANSWER); + + auto answer = std::make_shared(requestId, response_tensors, final_response, errMsg); + message.data = RequestAnswerData{std::move(answer)}; + + SendMessage(std::move(message)); +} + std::unordered_set ModelInstanceState::pollStopSignals() { - auto stoppedReqIds = mWorkItemsQueue->getStoppedReqIds(); + std::unordered_set stoppedReqIds; + if (mLeaderOrchComm) + { + std::unique_lock lk(mStoppedReqIdsMutex); + stoppedReqIds = mStoppedReqIds; + } + else + { + stoppedReqIds = mWorkItemsQueue->getStoppedReqIds(); - // Merge cancelled requests into stopped requests Ids - auto cancelledReqIds = mWorkItemsQueue->getCancelledInProgressReqIds(); - stoppedReqIds.insert(cancelledReqIds.begin(), cancelledReqIds.end()); + // Merge cancelled requests into stopped requests Ids + auto cancelledReqIds = mWorkItemsQueue->getCancelledInProgressReqIds(); + stoppedReqIds.insert(cancelledReqIds.begin(), cancelledReqIds.end()); + } int64_t nStoppedReqIds = static_cast(stoppedReqIds.size()); @@ -509,7 +757,7 @@ std::unordered_set ModelInstanceState::pollStopSignals() return stoppedReqIds; } -void ModelInstanceState::logStats(const std::string& s) +void ModelInstanceState::logStats(std::string const& s) { LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, s.c_str()); #ifdef TRITON_ENABLE_METRICS @@ -518,7 +766,8 @@ void ModelInstanceState::logStats(const std::string& s) } TRITONSERVER_Error* ModelInstanceState::sendTritonResponse(std::shared_ptr workItem, - std::list const& response_tensors, bool final_response, const std::string& errMsg) + std::list const& response_tensors, bool final_response, std::string const& errMsg, + WorkItemsQueue& workItemsQueue, TRITONBACKEND_ModelInstance* model_instance) { TRITONBACKEND_ResponseFactory* response_factory; response_factory = workItem->response_factory(); @@ -530,7 +779,7 @@ TRITONSERVER_Error* ModelInstanceState::sendTritonResponse(std::shared_ptrgetTimestamps().compute_end_ns); - mWorkItemsQueue->markFinished(requestId); + workItemsQueue.markFinished(requestId); } // Check if error @@ -585,7 +834,7 @@ TRITONSERVER_Error* ModelInstanceState::sendTritonResponse(std::shared_ptrreportBaseMetrics(modelInstance_, err), "Error reporting base metrics"); + LOG_IF_ERROR(workItem->reportBaseMetrics(model_instance, err), "Error reporting base metrics"); // Reporting Triton core metrics requires the use of the original TRITONBACKEND_Request. // Therefore we hold off releasing the request until this point. TRITONBACKEND_RequestRelease(workItem->getTritonInferenceRequest(), TRITONSERVER_REQUEST_RELEASE_ALL); diff --git a/inflight_batcher_llm/src/model_instance_state.h b/inflight_batcher_llm/src/model_instance_state.h index 6b9a3611..131c8aff 100644 --- a/inflight_batcher_llm/src/model_instance_state.h +++ b/inflight_batcher_llm/src/model_instance_state.h @@ -26,13 +26,15 @@ #pragma once -#include +#include #include #include "triton/backend/backend_common.h" #include "triton/core/tritonbackend.h" #include "triton/core/tritonserver.h" +#include "tensorrt_llm/common/mpiUtils.h" + #include "tensorrt_llm/batch_manager/BatchManager.h" #include "tensorrt_llm/batch_manager/GptManager.h" #include "tensorrt_llm/batch_manager/callbacks.h" @@ -42,7 +44,9 @@ #include "tensorrt_llm/batch_manager/trtGptModelOptionalParams.h" #include "tensorrt_llm/runtime/decodingMode.h" +#include "inference_answer.h" #include "model_state.h" +#include "mpi_utils.h" #include "work_item.h" #include "work_items_queue.h" @@ -50,6 +54,7 @@ #include "custom_metrics_reporter/custom_metrics_reporter.h" #endif +using namespace tensorrt_llm::mpi; using namespace tensorrt_llm::batch_manager; using namespace tensorrt_llm::batch_manager::batch_scheduler; @@ -71,9 +76,22 @@ class ModelInstanceState using TrtGptModelType = tensorrt_llm::batch_manager::TrtGptModelType; public: + // number of cpu workers used to move weights host cache to gpu cache + static constexpr SizeType kPeftCacheNumEnsureWorkers = 4; + // number of cuda streams used for H2D copies of peft cache pages + static constexpr SizeType kPeftCacheNumCopyStreams = 4; + // number of cpu workers used to load weight into host cache + static constexpr SizeType kPeftCacheNumPutWorkers = 4; + + /// @brief Create a ModelInstanceObject when running in non-orchestrator mode static TRITONSERVER_Error* Create( ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance, ModelInstanceState** state); + /// @brief Create a ModelInstanceObject for workers when running in orchestrator mode + /// @param leaderOrchComm MPI inter-communicator containing MPI_COMM_WORLD and the parent process. + /// Is only used for the leader rank (0) and is ignored for other ranks. + static bool Create(ModelState* model_state, MPI_Comm leaderOrchComm, ModelInstanceState** state); + virtual ~ModelInstanceState() { // terminate decoupled execution loop @@ -103,24 +121,35 @@ class ModelInstanceState /// @brief Callback passed to GptManager to get new inference requests /// @return Up to max_num_requests inference requests. - std::list> get_inference_requests(const int max_num_requests); + std::list> get_inference_requests(int const max_num_requests); + std::list> get_inference_requests_leader(int const max_num_requests); /// @brief Callback passed to GptManager to send responses back to client void sendResponse(uint64_t requestId, std::list const& response_tensors, bool final_response, - const std::string& errMsg); + std::string const& errMsg); + void sendResponseLeader(uint64_t requestId, std::list const& response_tensors, bool final_response, + std::string const& errMsg); /// @brief Callback passed to GptManager to get ids of stopped requests std::unordered_set pollStopSignals(); /// @brief Callback passed to GptManager to print stats - void logStats(const std::string& s); + void logStats(std::string const& s); /// @brief Method that sends Triton response back to client - TRITONSERVER_Error* sendTritonResponse(std::shared_ptr workItem, - std::list const& response_tensors, bool final_response, const std::string& errMsg); + static TRITONSERVER_Error* sendTritonResponse(std::shared_ptr workItem, + std::list const& response_tensors, bool final_response, std::string const& errMsg, + WorkItemsQueue& workItemsQueue, TRITONBACKEND_ModelInstance* model_instance); private: /// @brief Constructor - ModelInstanceState(ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance); + ModelInstanceState( + ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance, MPI_Comm model_comm); + + void RecvMpiThread(); + void AnsMpiThread(); + void SendMessage(MpiMessage&& message); + + void broadcast_inference_requests(std::list>& rval); ModelState* model_state_; TRITONBACKEND_ModelInstance* modelInstance_; @@ -128,6 +157,19 @@ class ModelInstanceState TrtGptModelType mTrtGptModelType; std::string mModelPath; + // Only valid for leader-worker ranks + std::unique_ptr mLeaderOrchComm; + std::thread mReceiverThread; + std::queue> mRecvRequests; + std::mutex mRecRequestsMutex; + std::thread mSenderThread; + std::queue mSenderQueue; + std::mutex mSenderMutex; + std::condition_variable mSenderCV; + std::unordered_set mStoppedReqIds; + std::mutex mStoppedReqIdsMutex; + std::atomic mModelUnloadRequest = false; + std::shared_ptr mBatchManager; std::unique_ptr mWorkItemsQueue; diff --git a/inflight_batcher_llm/src/model_state.cc b/inflight_batcher_llm/src/model_state.cc index c0b29763..ba2016f2 100644 --- a/inflight_batcher_llm/src/model_state.cc +++ b/inflight_batcher_llm/src/model_state.cc @@ -26,6 +26,10 @@ #include "model_state.h" +#include "tensorrt_llm/common/mpiUtils.h" + +#include + namespace triton::backend::inflight_batcher_llm { @@ -44,7 +48,7 @@ std::vector csvStrToVecInt(std::string const& str) } TRITONSERVER_Error* ModelState::Create( - TRITONBACKEND_Model* triton_model, const std::string& name, const uint64_t version, ModelState** state) + TRITONBACKEND_Model* triton_model, std::string const& name, const uint64_t version, ModelState** state) { TRITONSERVER_Message* config_message; RETURN_IF_ERROR(TRITONBACKEND_ModelConfig(triton_model, 1 /* config_version */, &config_message)); @@ -55,7 +59,7 @@ TRITONSERVER_Error* ModelState::Create( // nice errors (currently the underlying implementation is // rapidjson... but others could be added). You can use any json // parser you prefer. - const char* buffer; + char const* buffer; size_t byte_size; RETURN_IF_ERROR(TRITONSERVER_MessageSerializeToJson(config_message, &buffer, &byte_size)); @@ -68,7 +72,7 @@ TRITONSERVER_Error* ModelState::Create( { *state = new ModelState(triton_model, name, version, std::move(model_config)); } - catch (const std::exception& ex) + catch (std::exception const& ex) { std::string errStr = std::string("unexpected error when creating modelState: ") + ex.what(); return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, errStr.c_str()); @@ -98,7 +102,7 @@ void ModelState::LoadParameters() TLLM_LOG_INFO(deviceIdInfo); } } - catch (const std::exception& e) + catch (std::exception const& e) { // If parameter is not specified, just ignore TLLM_LOG_WARNING("gpu_device_ids is not specified, will be automatically set"); @@ -110,7 +114,7 @@ common::TritonJson::Value& ModelState::GetModelConfig() return model_config_; } -const std::string& ModelState::GetModelName() const +std::string const& ModelState::GetModelName() const { return model_name_; } @@ -120,8 +124,77 @@ uint64_t ModelState::GetModelVersion() const return model_version_; } +const std::string ModelState::GetWorkerPath() +{ + std::string workerPath = "/opt/tritonserver/backends/tensorrtllm/triton_tensorrtllm_worker"; + try + { + workerPath = GetParameter("worker_path"); + } + catch (std::exception const& e) + { + TLLM_LOG_WARNING("worker_path is not specified, will use default value"); + } + + return workerPath; +} + +std::vector ModelState::serialize() const +{ + // model name + // model version + // model config + size_t totalSize = 3; + + int nameSize = (model_name_.size() + sizeof(int64_t)) / sizeof(int64_t); + totalSize += nameSize; + + TritonJson::WriteBuffer buffer; + model_config_.Write(&buffer); + + totalSize += buffer.Size(); + + std::vector packed(totalSize); + int64_t* ptr = packed.data(); + + *ptr++ = model_name_.size(); + std::memcpy(ptr, model_name_.c_str(), model_name_.size()); + ptr += nameSize; + + *ptr++ = model_version_; + *ptr++ = buffer.Size(); + std::memcpy(ptr, buffer.Base(), buffer.Size()); + + return packed; +} + +ModelState ModelState::deserialize(int64_t const* packed_ptr) +{ + auto const nameSize = *packed_ptr++; + char const* cname = reinterpret_cast(packed_ptr); + packed_ptr += (nameSize + sizeof(int64_t)) / sizeof(int64_t); + + const uint64_t version = *packed_ptr++; + + auto const jsonSize = *packed_ptr++; + char const* jsonBuffer = reinterpret_cast(packed_ptr); + common::TritonJson::Value model_config; + TRITONSERVER_Error* err = model_config.Parse(jsonBuffer, jsonSize); + if (err) + { + throw std::runtime_error("Failed to parse model config"); + } + + return ModelState{nullptr, cname, version, std::move(model_config)}; +} + +ModelState ModelState::deserialize(std::vector const& packed) +{ + return ModelState::deserialize(packed.data()); +} + template <> -std::string ModelState::GetParameter(const std::string& name) +std::string ModelState::GetParameter(std::string const& name) { TritonJson::Value parameters; TRITONSERVER_Error* err = model_config_.MemberAsObject("parameters", ¶meters); @@ -144,13 +217,13 @@ std::string ModelState::GetParameter(const std::string& name) } template <> -int32_t ModelState::GetParameter(const std::string& name) +int32_t ModelState::GetParameter(std::string const& name) { return std::stoi(GetParameter(name)); } template <> -std::vector ModelState::GetParameter>(const std::string& name) +std::vector ModelState::GetParameter>(std::string const& name) { auto deviceIdsStr = GetParameter(name); // Parse as comma delimited string @@ -158,31 +231,31 @@ std::vector ModelState::GetParameter>(const std::s } template <> -uint32_t ModelState::GetParameter(const std::string& name) +uint32_t ModelState::GetParameter(std::string const& name) { return (uint32_t) std::stoul(GetParameter(name)); } template <> -int64_t ModelState::GetParameter(const std::string& name) +int64_t ModelState::GetParameter(std::string const& name) { return std::stoll(GetParameter(name)); } template <> -uint64_t ModelState::GetParameter(const std::string& name) +uint64_t ModelState::GetParameter(std::string const& name) { return std::stoull(GetParameter(name)); } template <> -float ModelState::GetParameter(const std::string& name) +float ModelState::GetParameter(std::string const& name) { return std::stof(GetParameter(name)); } template <> -bool ModelState::GetParameter(const std::string& name) +bool ModelState::GetParameter(std::string const& name) { auto val = GetParameter(name); if (val == "True" || val == "true" || val == "TRUE" || val == "1") diff --git a/inflight_batcher_llm/src/model_state.h b/inflight_batcher_llm/src/model_state.h index df7568b3..7bacccc3 100644 --- a/inflight_batcher_llm/src/model_state.h +++ b/inflight_batcher_llm/src/model_state.h @@ -51,10 +51,10 @@ class ModelState { public: static TRITONSERVER_Error* Create( - TRITONBACKEND_Model* triton_model, const std::string& name, const uint64_t version, ModelState** state); + TRITONBACKEND_Model* triton_model, std::string const& name, const uint64_t version, ModelState** state); template - T GetParameter(const std::string& name) + T GetParameter(std::string const& name) { assert(false); auto dummy = T(); @@ -64,8 +64,9 @@ class ModelState virtual ~ModelState() = default; common::TritonJson::Value& GetModelConfig(); - const std::string& GetModelName() const; + std::string const& GetModelName() const; uint64_t GetModelVersion() const; + const std::string GetWorkerPath(); std::optional> GetDeviceIds() { @@ -77,6 +78,12 @@ class ModelState return is_decoupled_; } + [[nodiscard]] std::vector serialize() const; + + static ModelState deserialize(int64_t const* packed_ptr); + + static ModelState deserialize(std::vector const& packed); + private: const std::string model_name_; uint64_t model_version_; @@ -87,8 +94,11 @@ class ModelState std::optional> gpu_device_ids_; bool is_decoupled_ = false; + void LoadParameters(); + +public: ModelState( - TRITONBACKEND_Model* triton_model, const std::string& name, uint64_t version, TritonJson::Value&& model_config) + TRITONBACKEND_Model* triton_model, std::string const& name, uint64_t version, TritonJson::Value&& model_config) : model_name_(name) , model_version_(version) , model_config_(std::move(model_config)) @@ -98,32 +108,30 @@ class ModelState LoadParameters(); } - - void LoadParameters(); }; template <> -std::string ModelState::GetParameter(const std::string& name); +std::string ModelState::GetParameter(std::string const& name); template <> -int32_t ModelState::GetParameter(const std::string& name); +int32_t ModelState::GetParameter(std::string const& name); template <> -uint32_t ModelState::GetParameter(const std::string& name); +uint32_t ModelState::GetParameter(std::string const& name); template <> -int64_t ModelState::GetParameter(const std::string& name); +int64_t ModelState::GetParameter(std::string const& name); template <> -uint64_t ModelState::GetParameter(const std::string& name); +uint64_t ModelState::GetParameter(std::string const& name); template <> -float ModelState::GetParameter(const std::string& name); +float ModelState::GetParameter(std::string const& name); template <> -bool ModelState::GetParameter(const std::string& name); +bool ModelState::GetParameter(std::string const& name); template <> -std::vector ModelState::GetParameter>(const std::string& name); +std::vector ModelState::GetParameter>(std::string const& name); } // namespace triton::backend::inflight_batcher_llm diff --git a/inflight_batcher_llm/src/mpi_utils.h b/inflight_batcher_llm/src/mpi_utils.h new file mode 100644 index 00000000..2b5be45f --- /dev/null +++ b/inflight_batcher_llm/src/mpi_utils.h @@ -0,0 +1,86 @@ +// Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#include +#include +#include +#include + +// fwd declarations +struct TRITONBACKEND_Request; + +namespace triton::backend::inflight_batcher_llm +{ + +// fwd declarations +class InferenceAnswer; + +constexpr int32_t kMPI_ID_TAG{127}; +constexpr int32_t kMPI_DATA_TAG{1023}; + +enum class MpiId : uint64_t +{ + PENDING_REQUEST = 1, + REQUEST_IN_PROGRESS = 2, + REQUEST_ANSWER = 3, + STOP_REQUEST = 4, + CANCEL_REQUEST = 5, + TERMINATION = 6, +}; + +struct PendingRequestData +{ + std::vector requests; +}; + +// Used by REQUEST_IN_PROGRESS and CANCEL_REQUEST +struct RequestIdsData +{ + std::vector ids; +}; + +struct RequestAnswerData +{ + std::shared_ptr answer; +}; + +using MpiMessageData = std::variant; + +struct MpiMessage +{ + MpiMessage(MpiId _id) + : id(_id) + { + } + + MpiId id; + + MpiMessageData data; +}; + +} // namespace triton::backend::inflight_batcher_llm diff --git a/inflight_batcher_llm/src/orchestrator.cc b/inflight_batcher_llm/src/orchestrator.cc new file mode 100644 index 00000000..1c643611 --- /dev/null +++ b/inflight_batcher_llm/src/orchestrator.cc @@ -0,0 +1,274 @@ +// Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "orchestrator.h" + +#include "tensorrt_llm/common/mpiUtils.h" + +#include "inference_answer.h" +#include "model_instance_state.h" +#include "utils.h" +#include "work_item.h" + +namespace triton::backend::inflight_batcher_llm +{ + +OrchestratorCommunicator::OrchestratorCommunicator( + ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance, MPI_Comm mpiComm) + : model_state_(model_state) + , modelInstance_(triton_model_instance) +{ + mWorkItemsQueue = std::make_unique(isDecoupled()); + + mMpiComm = std::make_unique(mpiComm, true); + + mSenderThread = std::thread([this]() { SenderThread(); }); + mAnswerThread = std::thread([this]() { AnswerThread(); }); + mPollStopSignalThread = std::thread([this]() { PollStopSignalThread(); }); +} + +void OrchestratorCommunicator::SenderThread() +{ + while (true) + { + std::unique_lock lk(mSenderMutex); + mSenderCV.wait(lk, [&]() { return (!mSenderQueue.empty()); }); + + auto message = mSenderQueue.front(); + mSenderQueue.pop(); + + if (message.id == MpiId::TERMINATION) + { + mMpiComm->send(&message.id, 1, MpiType::kUINT64, 0, kMPI_ID_TAG); + TLLM_LOG_INFO("Orchestrator sender thread exiting"); + break; + } + else if (message.id == MpiId::PENDING_REQUEST) + { + auto& data = std::get(message.data); + + std::vector requestsToPush; + std::vector stopRequestIds; + uint64_t exec_start_ns = 0; + SET_TIMESTAMP(exec_start_ns); + + for (auto request : data.requests) + { + bool const isStopRequest + = utils::handleTritonRequest(request, mRequestIdStrMap, requestsToPush, *mWorkItemsQueue); + + if (isStopRequest) + { + stopRequestIds.push_back(utils::getRequestId(request, mRequestIdStrMap)); + } + } + + auto const workItemCb = [this, id = message.id](std::shared_ptr wi) + { + auto packed = wi->getInferenceRequest()->serialize(); + + mMpiComm->send(&id, 1, MpiType::kUINT64, 0, kMPI_ID_TAG); + mMpiComm->send(packed.data(), packed.size(), MpiType::kINT64, 0, kMPI_DATA_TAG); + }; + + auto exceptions = mWorkItemsQueue->pushBatch(requestsToPush, exec_start_ns, workItemCb); + + if (!stopRequestIds.empty()) + { + constexpr MpiId id = MpiId::STOP_REQUEST; + mMpiComm->send(&id, 1, MpiType::kUINT64, 0, kMPI_ID_TAG); + mMpiComm->send(stopRequestIds.data(), stopRequestIds.size(), MpiType::kUINT64, 0, kMPI_DATA_TAG); + } + } + else if (message.id == MpiId::CANCEL_REQUEST) + { + auto& data = std::get(message.data); + + mMpiComm->send(&message.id, 1, MpiType::kUINT64, 0, kMPI_ID_TAG); + mMpiComm->send(data.ids.data(), data.ids.size(), MpiType::kUINT64, 0, kMPI_DATA_TAG); + } + } +} + +void OrchestratorCommunicator::AnswerThread() +{ + MPI_Message msg; + MPI_Status status; + int32_t count; + MpiId mpiId; + + while (true) + { + mMpiComm->mprobe(0, kMPI_ID_TAG, &msg, &status); + MPICHECK(MPI_Get_count(&status, MPI_UINT64_T, &count)); + TLLM_CHECK(count == 1); + MPICHECK(MPI_Mrecv(&mpiId, count, MPI_UINT64_T, &msg, &status)); + + if (mpiId == MpiId::TERMINATION) + { + TLLM_LOG_INFO("Orchestrator answer thread exiting"); + break; + } + else if (mpiId == MpiId::REQUEST_IN_PROGRESS) + { + mMpiComm->mprobe(0, kMPI_DATA_TAG, &msg, &status); + MPICHECK(MPI_Get_count(&status, MPI_UINT64_T, &count)); + + std::vector request_ids(count); + MPICHECK(MPI_Mrecv(request_ids.data(), count, MPI_UINT64_T, &msg, &status)); + + for (auto id : request_ids) + { + mWorkItemsQueue->markInProgress(id); + } + + continue; + } + + mMpiComm->mprobe(0, kMPI_DATA_TAG, &msg, &status); + MPICHECK(MPI_Get_count(&status, MPI_INT64_T, &count)); + std::vector data(count); + MPICHECK(MPI_Mrecv(data.data(), count, MPI_INT64_T, &msg, &status)); + + auto answer = InferenceAnswer::deserialize(data.data()); + auto const requestId = answer->GetRequestId(); + + std::string errStr = std::string("Failed to send Triton response for requestId: ") + + utils::getRequestIdStr(requestId, mRequestIdStrMap); + + if (answer->IsFinalResponse()) + { + mRequestIdStrMap.erase(requestId); + } + + try + { + auto workItem = mWorkItemsQueue->getInProgressWorkItem(requestId); + auto tritonErr = ModelInstanceState::sendTritonResponse(workItem, answer->GetTensors(), + answer->IsFinalResponse(), answer->GetErrorMessage(), *mWorkItemsQueue, modelInstance_); + LOG_IF_ERROR(tritonErr, errStr); + } + catch (std::exception const& e) + { + TLLM_LOG_ERROR(errStr); + } + } +} + +void OrchestratorCommunicator::PollStopSignalThread(int const intervalInMs) +{ + while (true) + { + std::this_thread::sleep_for(std::chrono::milliseconds(intervalInMs)); + + if (mShutdownRequest.load()) + { + break; + } + + // Merge cancelled requests into stopped requests Ids + auto cancelledReqIds = mWorkItemsQueue->getCancelledInProgressReqIds(); + + if (cancelledReqIds.empty()) + { + continue; + } + + std::vector cancelledReqIdsVec(cancelledReqIds.begin(), cancelledReqIds.end()); + + MpiMessage message(MpiId::CANCEL_REQUEST); + message.data = RequestIdsData{std::move(cancelledReqIdsVec)}; + + SendMessage(std::move(message)); + } +} + +void OrchestratorCommunicator::SendMessage(MpiMessage&& message) +{ + { + std::unique_lock lk(mSenderMutex); + mSenderQueue.push(std::move(message)); + } + + mSenderCV.notify_all(); +} + +void OrchestratorCommunicator::enqueue(TRITONBACKEND_Request** requests, const uint32_t request_count) +{ + MpiMessage message(MpiId::PENDING_REQUEST); + + std::vector data(requests, requests + request_count); + message.data = PendingRequestData{std::move(data)}; + + SendMessage(std::move(message)); +} + +void OrchestratorCommunicator::shutdown() +{ + MpiMessage message(MpiId::TERMINATION); + + { + std::unique_lock lk(mSenderMutex); + mSenderQueue.push(message); + } + + mSenderCV.notify_all(); + mShutdownRequest.store(true); + + if (mSenderThread.joinable()) + { + mSenderThread.join(); + } + if (mAnswerThread.joinable()) + { + mAnswerThread.join(); + } + if (mPollStopSignalThread.joinable()) + { + mPollStopSignalThread.join(); + } +} + +TRITONSERVER_Error* Orchestrator::addCommunicator(ModelState* model_state, + TRITONBACKEND_ModelInstance* triton_model_instance, MPI_Comm mpiComm, OrchestratorCommunicator** communicator) +{ + *communicator = new OrchestratorCommunicator(model_state, triton_model_instance, mpiComm); + + { + std::lock_guard lk(mCommunicatorsMutex); + mCommunicators.insert(*communicator); + } + + return nullptr; // success +} + +void Orchestrator::removeCommunicator(OrchestratorCommunicator* communicator) +{ + std::lock_guard lk(mCommunicatorsMutex); + mCommunicators.erase(communicator); +} + +} // namespace triton::backend::inflight_batcher_llm diff --git a/inflight_batcher_llm/src/orchestrator.h b/inflight_batcher_llm/src/orchestrator.h new file mode 100644 index 00000000..56edc7cb --- /dev/null +++ b/inflight_batcher_llm/src/orchestrator.h @@ -0,0 +1,115 @@ +// Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#include "model_state.h" +#include "mpi_utils.h" +#include "work_items_queue.h" + +#include "triton/core/tritonbackend.h" +#include "triton/core/tritonserver.h" + +#include "tensorrt_llm/common/mpiUtils.h" + +#include +#include +#include +#include +#include + +using namespace tensorrt_llm::mpi; + +namespace triton::backend::inflight_batcher_llm +{ + +class OrchestratorCommunicator +{ +public: + OrchestratorCommunicator( + ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance, MPI_Comm mpiComm); + + void enqueue(TRITONBACKEND_Request** requests, const uint32_t request_count); + void shutdown(); + + bool isDecoupled() const + { + return model_state_->IsDecoupled(); + } + +private: + /// @brief Send work items to leader-worker ranks + void SenderThread(); + /// @brief Receive inference answers from leader-worker ranks + void AnswerThread(); + /// @brief Polls at a given interval for stop signals + void PollStopSignalThread(int const invervalInMs = 10); + + void SendMessage(MpiMessage&& message); + +private: + ModelState* model_state_; + TRITONBACKEND_ModelInstance* modelInstance_; + + std::unique_ptr mMpiComm; + + std::unique_ptr mWorkItemsQueue; + + std::thread mSenderThread; + std::queue mSenderQueue; + std::mutex mSenderMutex; + std::condition_variable mSenderCV; + + std::thread mAnswerThread; + std::thread mPollStopSignalThread; + std::atomic mShutdownRequest = false; + + std::unordered_map mRequestIdStrMap; +}; + +// +// Orchestrator +// Singleton class to track communicators +// + +class Orchestrator +{ +public: + Orchestrator() {} + + virtual ~Orchestrator() {} + + TRITONSERVER_Error* addCommunicator(ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance, + MPI_Comm mpiComm, OrchestratorCommunicator** communicator); + void removeCommunicator(OrchestratorCommunicator* communicator); + +private: + std::unordered_set mCommunicators; + + mutable std::mutex mCommunicatorsMutex; +}; + +} // namespace triton::backend::inflight_batcher_llm diff --git a/inflight_batcher_llm/src/utils.cc b/inflight_batcher_llm/src/utils.cc index 7504f7dd..1271ad57 100644 --- a/inflight_batcher_llm/src/utils.cc +++ b/inflight_batcher_llm/src/utils.cc @@ -141,7 +141,7 @@ TRITONSERVER_DataType to_triton_datatype(nvinfer1::DataType data_type) uint64_t getRequestId(TRITONBACKEND_Request* request, std::unordered_map& requestIdStrMap) { - const char* charRequestId; + char const* charRequestId; TRITONBACKEND_RequestId(request, &charRequestId); uint64_t requestId = 0; if (charRequestId != nullptr) @@ -153,7 +153,7 @@ uint64_t getRequestId(TRITONBACKEND_Request* request, std::unordered_map hasher; requestId = hasher(strRequestId); @@ -190,7 +190,7 @@ std::unordered_set getRequestOutputNames(TRITONBACKEND_Request* req LOG_IF_ERROR(TRITONBACKEND_RequestOutputCount(request, &outputCount), "Error getting request output count"); for (size_t i = 0; i < outputCount; ++i) { - const char* name; + char const* name; LOG_IF_ERROR(TRITONBACKEND_RequestOutputName(request, i, &name), "Error getting request output name"); std::string name_s(name); outputNames.insert(std::move(name_s)); @@ -198,7 +198,7 @@ std::unordered_set getRequestOutputNames(TRITONBACKEND_Request* req return outputNames; } -bool getRequestBooleanInputTensor(TRITONBACKEND_Request* request, const std::string& inputTensorName) +bool getRequestBooleanInputTensor(TRITONBACKEND_Request* request, std::string const& inputTensorName) { // Get stop signal from the request TRITONBACKEND_Input* input; @@ -223,7 +223,7 @@ bool getRequestBooleanInputTensor(TRITONBACKEND_Request* request, const std::str LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, ("ModelInstanceState::getRequestStopSignal: buffer_count = " + std::to_string(buffer_count)).c_str()); - const void* buffer = 0L; + void const* buffer = 0L; uint64_t buffer_byte_size = 0; TRITONSERVER_MemoryType memory_type = TRITONSERVER_MEMORY_CPU; int64_t memory_type_id = 0; @@ -231,12 +231,12 @@ bool getRequestBooleanInputTensor(TRITONBACKEND_Request* request, const std::str assert((memory_type == TRITONSERVER_MEMORY_CPU) || (memory_type == TRITONSERVER_MEMORY_CPU_PINNED)); - bool boolean = *reinterpret_cast(buffer); + bool boolean = *reinterpret_cast(buffer); return boolean; } -void sendEnqueueResponse(TRITONBACKEND_Request* request, const std::string& errMsg) +void sendEnqueueResponse(TRITONBACKEND_Request* request, std::string const& errMsg) { TRITONBACKEND_ResponseFactory* factory_ptr; // Create response factory for this request @@ -255,4 +255,43 @@ void sendEnqueueResponse(TRITONBACKEND_Request* request, const std::string& errM LOG_IF_ERROR(TRITONBACKEND_ResponseFactoryDelete(factory_ptr), "Cannot delete response factory"); } +bool handleTritonRequest(TRITONBACKEND_Request* request, std::unordered_map& requestIdStrMap, + std::vector& requestsToPush, WorkItemsQueue& workItemsQueue) +{ + try + { + auto requestId = utils::getRequestId(request, requestIdStrMap); + bool stopRequest = utils::getRequestBooleanInputTensor(request, kStopInputTensorName); + + if (stopRequest) + { + if (requestId != 0) + { + // Check if request is in progress or in queue, if not ignore + workItemsQueue.stopWorkItem(requestId); + // Send a response back to client for stop request + utils::sendEnqueueResponse(request); + } + else + { + throw std::runtime_error("Cannot send stop request without specifying a request_id"); + } + } + else + { + requestsToPush.emplace_back(requestId, request); + } + + return stopRequest; + } + catch (std::exception const& e) + { + // In case of error, no work item is added to queue, so response + // callback needs to be called + utils::sendEnqueueResponse(request, e.what()); + } + + return false; +} + } // namespace triton::backend::inflight_batcher_llm::utils diff --git a/inflight_batcher_llm/src/utils.h b/inflight_batcher_llm/src/utils.h index 7f52cfb8..2e6dbd31 100644 --- a/inflight_batcher_llm/src/utils.h +++ b/inflight_batcher_llm/src/utils.h @@ -26,6 +26,9 @@ #pragma once +#include "work_item.h" +#include "work_items_queue.h" + #include "NvInfer.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/runtime/tllmLogger.h" @@ -61,11 +64,16 @@ std::string getRequestIdStr(uint64_t requestId, std::unordered_map getRequestOutputNames(TRITONBACKEND_Request* request); /// @brief Get the value of a boolean tensor -bool getRequestBooleanInputTensor(TRITONBACKEND_Request* request, const std::string& inputTensorName); +bool getRequestBooleanInputTensor(TRITONBACKEND_Request* request, std::string const& inputTensorName); /// @brief For stop requests, or in case of error during enqueue, we need to send a /// response to the client -void sendEnqueueResponse(TRITONBACKEND_Request* request, const std::string& errMsg = ""); +void sendEnqueueResponse(TRITONBACKEND_Request* request, std::string const& errMsg = ""); + +/// @brief Handle a Triton request and add it to the requests to push if applicable +/// @return Is the request a stop request +bool handleTritonRequest(TRITONBACKEND_Request* request, std::unordered_map& requestIdStrMap, + std::vector& requestsToPush, WorkItemsQueue& workItemsQueue); } // namespace utils } // namespace triton::backend::inflight_batcher_llm diff --git a/inflight_batcher_llm/src/work_item.cc b/inflight_batcher_llm/src/work_item.cc index 40da847c..432f7347 100644 --- a/inflight_batcher_llm/src/work_item.cc +++ b/inflight_batcher_llm/src/work_item.cc @@ -25,6 +25,9 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "work_item.h" + +#include "utils.h" + #include namespace triton::backend::inflight_batcher_llm @@ -72,7 +75,7 @@ std::shared_ptr WorkItem::getInfe return mInferenceRequest; } -bool WorkItem::hasOutputName(const std::string& outputName) +bool WorkItem::hasOutputName(std::string const& outputName) { return (mRequestOutputNames.find(outputName) != mRequestOutputNames.end()); } @@ -91,9 +94,9 @@ std::shared_ptr WorkItem::createI TRITONBACKEND_Input* input = 0L; TRITONBACKEND_RequestInputByIndex(request, idx, &input); - const char* input_name = 0L; + char const* input_name = 0L; TRITONSERVER_DataType data_type = TRITONSERVER_TYPE_INVALID; - const int64_t* shape = 0L; + int64_t const* shape = 0L; uint32_t dims_count = 0; uint64_t byte_size = 0; uint32_t buffer_count = 0; @@ -116,7 +119,7 @@ std::shared_ptr WorkItem::createI uint64_t buffer_offset = 0; for (int64_t buffer_id = 0; buffer_id < buffer_count; ++buffer_id) { - const void* buffer = 0L; + void const* buffer = 0L; uint64_t buffer_byte_size = 0; TRITONSERVER_MemoryType memory_type = TRITONSERVER_MEMORY_CPU; int64_t memory_type_id = 0; diff --git a/inflight_batcher_llm/src/work_item.h b/inflight_batcher_llm/src/work_item.h index 1c6dea0c..92983d45 100644 --- a/inflight_batcher_llm/src/work_item.h +++ b/inflight_batcher_llm/src/work_item.h @@ -31,7 +31,6 @@ #include "triton/core/tritonbackend.h" #include "triton/core/tritonserver.h" #include -#include namespace triton::backend::inflight_batcher_llm { @@ -56,7 +55,7 @@ class WorkItem std::shared_ptr getInferenceRequest() const; - bool hasOutputName(const std::string& outputName); + bool hasOutputName(std::string const& outputName); /// timestamp storage for Triton base metrics struct Timestamps diff --git a/inflight_batcher_llm/src/work_items_queue.cc b/inflight_batcher_llm/src/work_items_queue.cc index 7bd2bb1a..b25dce11 100644 --- a/inflight_batcher_llm/src/work_items_queue.cc +++ b/inflight_batcher_llm/src/work_items_queue.cc @@ -46,8 +46,8 @@ void WorkItemsQueue::clear() /// @brief Add a batch of new work item to the queue /// Throws an error if requestId already exists -std::vector> WorkItemsQueue::pushBatch( - std::vector& requestsToPush, uint64_t exec_start_ns) +std::vector> WorkItemsQueue::pushBatch(std::vector& requestsToPush, + uint64_t exec_start_ns, std::function)> const& workItemCb) { std::lock_guard lk(mMutex); std::vector> reqExceptions; @@ -69,8 +69,13 @@ std::vector> WorkItemsQueue::pushBatch( mPendingWorkItemsReqIds.insert(workItem->requestId()); workItem->getTimestamps().exec_start_ns = exec_start_ns; reqExceptions.push_back(nullptr); + + if (workItemCb) + { + workItemCb(workItem); + } } - catch (const std::exception& e) + catch (std::exception const& e) { reqExceptions.emplace_back(std::make_shared(e.what())); } @@ -114,6 +119,29 @@ std::tuple, bool> WorkItemsQueue::pop() return {workItem, stoppedRequest}; } +void WorkItemsQueue::markInProgress(const uint64_t requestId) +{ + std::lock_guard lk(mMutex); + + if (mPendingWorkItemsReqIds.find(requestId) == mPendingWorkItemsReqIds.end()) + { + std::string warnStr + = "Received in-progress notification for unknown request ID " + std::to_string(requestId) + ", ignoring"; + TLLM_LOG_WARNING(warnStr); + return; + } + + auto it = std::find_if(mPendingWorkItems.begin(), mPendingWorkItems.end(), + [requestId](std::shared_ptr const& wi) { return wi->requestId() == requestId; }); + + auto workItem = *it; + mPendingWorkItems.erase(it); + mPendingWorkItemsReqIds.erase(requestId); + SET_TIMESTAMP(workItem->getTimestamps().compute_start_ns); + + mInProgressWorkItems.emplace(std::make_pair(workItem->requestId(), workItem)); +} + void WorkItemsQueue::markFinished(const uint64_t requestId) { std::lock_guard lk(mMutex); @@ -131,7 +159,6 @@ void WorkItemsQueue::markFinished(const uint64_t requestId) void WorkItemsQueue::stopWorkItem(const uint64_t requestId) { std::lock_guard lk(mMutex); - TLLM_LOG_DEBUG("Stopping request"); if (hasInProgressReqId(requestId) || hasPendingReqId(requestId)) { mStoppedReqIds.emplace(requestId); @@ -149,7 +176,7 @@ std::unordered_set WorkItemsQueue::getCancelledInProgressReqIds() cons std::unordered_set cancelledInProgressReqIds; { std::lock_guard lk(mMutex); - for (const auto& pair : mInProgressWorkItems) + for (auto const& pair : mInProgressWorkItems) { bool is_cancelled = false; TRITONBACKEND_ResponseFactoryIsCancelled(pair.second->response_factory(), &is_cancelled); diff --git a/inflight_batcher_llm/src/work_items_queue.h b/inflight_batcher_llm/src/work_items_queue.h index 3ee06909..ecf35879 100644 --- a/inflight_batcher_llm/src/work_items_queue.h +++ b/inflight_batcher_llm/src/work_items_queue.h @@ -29,7 +29,6 @@ #include "tensorrt_llm/common/logger.h" #include "triton/backend/backend_common.h" #include "triton/core/tritonbackend.h" -#include "utils.h" #include "work_item.h" #include @@ -73,8 +72,8 @@ class WorkItemsQueue /// @brief Add a batch of new work item to the queue /// Throws an error if requestId already exists - std::vector> pushBatch( - std::vector& requestsToPush, uint64_t exec_start_ns); + std::vector> pushBatch(std::vector& requestsToPush, + uint64_t exec_start_ns, std::function)> const& workItemCb = nullptr); /// @brief Get a new work item from the queue, and move it to the list of /// in progress work items if it hasn't been stopped @@ -95,6 +94,10 @@ class WorkItemsQueue return mInProgressWorkItems.at(requestId); } + /// @brief Mark a request as being in progress + /// @param requestId + void markInProgress(const uint64_t requestId); + /// @brief Mark a request as being finished /// @param requestId void markFinished(const uint64_t requestId); diff --git a/inflight_batcher_llm/src/worker.cc b/inflight_batcher_llm/src/worker.cc new file mode 100644 index 00000000..ab898f17 --- /dev/null +++ b/inflight_batcher_llm/src/worker.cc @@ -0,0 +1,59 @@ +#include "model_instance_state.h" + +#include "tensorrt_llm/common/logger.h" + +#include + +using namespace triton::backend::inflight_batcher_llm; + +// This worker is launched from the TRT-LLM Triton backend when using the orchestrator mode +// It is intended to be a shim layer that instantiates a ModelInstanceObject that will +// communicate inference results back to the orchestrator in the Triton backend. +// In this application: +// - MPI_COMM_WORLD contains all workers participating in the model (one per GPU) +// - parentComm is an intercommunicator containing both MPI_COMM_WORLD and the process parent +// (i.e. a TRT-LLM Triton backend). +// See https://www.mpi-forum.org/docs/mpi-4.1/mpi41-report/node198.htm#Node198 for +// more information on MPI inter-communicators +int main(int argc, char* argv[]) +{ + MPI_Init(&argc, &argv); + + MPI_Comm parentComm; + MPI_Comm_get_parent(&parentComm); + if (parentComm == MPI_COMM_NULL) + { + TLLM_LOG_ERROR("TRT-LLM worker has no parent!"); + return -1; + } + + int size; + MPI_Comm_remote_size(parentComm, &size); + if (size != 1) + { + TLLM_LOG_ERROR("Parent size is %d, must be 1", size); + return -1; + } + + // Since parentComm is an intercommunicator, input root + // is the rank of the parent process in his group + // (always 0 as the parent size is checked before) + int64_t packedSize; + MPICHECK(MPI_Bcast(&packedSize, 1, MPI_INT64_T, 0, parentComm)); + std::vector packed(packedSize); + MPICHECK(MPI_Bcast(packed.data(), packedSize, MPI_INT64_T, 0, parentComm)); + ModelState modelState = ModelState::deserialize(packed); + + TLLM_LOG_INFO("Worker loading model %s", modelState.GetModelName().c_str()); + + ModelInstanceState* state; + if (!ModelInstanceState::Create(&modelState, parentComm, &state)) + { + return -1; + } + + delete state; + + MPI_Finalize(); + return 0; +} diff --git a/scripts/launch_triton_server.py b/scripts/launch_triton_server.py index 5adac765..e0dcc2ef 100644 --- a/scripts/launch_triton_server.py +++ b/scripts/launch_triton_server.py @@ -1,4 +1,5 @@ import argparse +import os import subprocess import sys from pathlib import Path @@ -55,10 +56,18 @@ def parse_arguments(): parser.add_argument( '--tensorrt_llm_model_name', type=str, - help='Name of the tensorrt_llm Triton model in the repo', + help= + 'Name(s) of the tensorrt_llm Triton model in the repo. Use comma to separate if multiple model names', default='tensorrt_llm', ) + parser.add_argument( + '--multi-model', + action='store_true', + help= + 'Enable support for multiple TRT-LLM models in the Triton model repository' + ) + return parser.parse_args() @@ -71,10 +80,10 @@ def get_cmd(world_size, tritonserver, grpc_port, http_port, metrics_port, cmd += ['--log-verbose=3', f'--log-file={log_file}'] # If rank is not 0, skip loading of models other than `tensorrt_llm_model_name` if (i != 0): - cmd += [ - '--model-control-mode=explicit', - f'--load-model={tensorrt_llm_model_name}' - ] + cmd += ['--model-control-mode=explicit'] + model_names = tensorrt_llm_model_name.split(',') + for name in model_names: + cmd += [f'--load-model={name}'] cmd += [ f'--grpc-port={grpc_port}', f'--http-port={http_port}', f'--metrics-port={metrics_port}', '--disable-auto-complete-config', @@ -98,4 +107,8 @@ def get_cmd(world_size, tritonserver, grpc_port, http_port, metrics_port, cmd = get_cmd(int(args.world_size), args.tritonserver, args.grpc_port, args.http_port, args.metrics_port, args.model_repo, args.log, args.log_file, args.tensorrt_llm_model_name) - subprocess.Popen(cmd) + env = os.environ.copy() + if args.multi_model: + assert args.world_size == 1, 'World size must be 1 when using multi-model. Processes will be spawned automatically to run the multi-GPU models' + env['TRTLLM_ORCHESTRATOR'] = '1' + subprocess.Popen(cmd, env=env) diff --git a/tensorrt_llm b/tensorrt_llm index 4bb65f21..66ca3378 160000 --- a/tensorrt_llm +++ b/tensorrt_llm @@ -1 +1 @@ -Subproject commit 4bb65f216f2f1dbddb022f4fdd8925c2856baa58 +Subproject commit 66ca3378c61efa3154ed34a48cfc362351405eef diff --git a/tools/version.txt b/tools/version.txt index 1cf56884..f18e28f0 100644 --- a/tools/version.txt +++ b/tools/version.txt @@ -1 +1 @@ -2cde91a86a99cc30e1de1450bf8d59a295da5cc6 +83ab69738090c7166630fac320d0dfc58182f6cc