From 171ed05830ff7e0a8057134f08802c7e3044b7b3 Mon Sep 17 00:00:00 2001 From: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Date: Fri, 1 Dec 2023 22:33:58 +0800 Subject: [PATCH] Update TensorRT-LLM backend (#180) --- .gitignore | 1 - .pre-commit-config.yaml | 1 + README.md | 23 +- .../postprocessing/1/model.py | 9 +- .../postprocessing/config.pbtxt | 9 +- .../preprocessing/1/model.py | 10 +- .../preprocessing/config.pbtxt | 9 +- .../tensorrt_llm_bls/1/model.py | 369 +++++ .../tensorrt_llm_bls/config.pbtxt | 191 +++ docs/baichuan.md | 412 ++++++ docs/llama.md | 346 +++++ inflight_batcher_llm/CMakeLists.txt | 7 +- .../client/end_to_end_grpc_client.py | 36 +- .../client/inflight_batcher_llm_client.py | 16 +- inflight_batcher_llm/src/libtensorrtllm.cc | 1301 +---------------- .../src/model_instance_state.cc | 514 +++++++ .../src/model_instance_state.h | 130 ++ inflight_batcher_llm/src/model_state.cc | 158 ++ inflight_batcher_llm/src/model_state.h | 113 ++ inflight_batcher_llm/src/utils.cc | 219 +++ inflight_batcher_llm/src/utils.h | 64 + inflight_batcher_llm/src/work_item.cc | 155 ++ inflight_batcher_llm/src/work_item.h | 75 + inflight_batcher_llm/src/work_items_queue.cc | 150 ++ inflight_batcher_llm/src/work_items_queue.h | 115 ++ tensorrt_llm | 2 +- 26 files changed, 3115 insertions(+), 1320 deletions(-) create mode 100644 all_models/inflight_batcher_llm/tensorrt_llm_bls/1/model.py create mode 100755 all_models/inflight_batcher_llm/tensorrt_llm_bls/config.pbtxt create mode 100644 docs/baichuan.md create mode 100644 docs/llama.md create mode 100644 inflight_batcher_llm/src/model_instance_state.cc create mode 100644 inflight_batcher_llm/src/model_instance_state.h create mode 100644 inflight_batcher_llm/src/model_state.cc create mode 100644 inflight_batcher_llm/src/model_state.h create mode 100644 inflight_batcher_llm/src/utils.cc create mode 100644 inflight_batcher_llm/src/utils.h create mode 100644 inflight_batcher_llm/src/work_item.cc create mode 100644 inflight_batcher_llm/src/work_item.h create mode 100644 inflight_batcher_llm/src/work_items_queue.cc create mode 100644 inflight_batcher_llm/src/work_items_queue.h diff --git a/.gitignore b/.gitignore index a7116d55..f4c2f069 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,5 @@ build/ *.so *.egg-info/ .coverage -*.csv *.onnx tmp/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4fffede7..caca92b3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,5 +44,6 @@ repos: rev: v2.2.4 hooks: - id: codespell + exclude: tools/dataset/ args: - --skip=".git,tensorrt_llm" diff --git a/README.md b/README.md index 9fb34fd0..7ded5cf8 100644 --- a/README.md +++ b/README.md @@ -162,16 +162,16 @@ python3 build.py --model_dir=./c-model/gpt2/4-gpu/ \ ### Create the model repository -There are four models in the [`all_models/inflight_batcher_llm`](./all_models/inflight_batcher_llm/) +There are five models in the [`all_models/inflight_batcher_llm`](./all_models/inflight_batcher_llm/) directory that will be used in this example: - "preprocessing": This model is used for tokenizing, meaning the conversion from prompts(string) to input_ids(list of ints). - "tensorrt_llm": This model is a wrapper of your TensorRT-LLM model and is used for inferencing - "postprocessing": This model is used for de-tokenizing, meaning the conversion from output_ids(list of ints) to outputs(string). -- "ensemble": This model is used to chain the three models above together: -preprocessing -> tensorrt_llm -> postprocessing +- "ensemble": This model can be used to chain the preprocessing, tensorrt_llm and postprocessing models together. +- "tensorrt_llm_bls": This model can also be used to chain the preprocessing, tensorrt_llm and postprocessing models together. The BLS model has an optional parameter `accumulate_tokens` which can be used in streaming mode to call the preprocessing model with all accumulated tokens, instead of only one token. This might be necessary for certain tokenizers. -To learn more about ensemble model, please see -[here](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/architecture.md#ensemble-models). +To learn more about ensemble and BLS models, please see the +[Ensemble Models](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/architecture.md#ensemble-models) and [Business Logic Scripting](https://github.com/triton-inference-server/python_backend#business-logic-scripting) sections of the Triton Inference Server documentation. ```bash # Create the model repository that will be used by the Triton server @@ -258,8 +258,8 @@ environment/container: curl -X POST localhost:8000/v2/models/${MODEL_NAME}/generate -d '{"{PARAM1_KEY}": "{PARAM1_VALUE}", ... }' ``` -In the case of the models used in this example, you can replace MODEL_NAME with `ensemble`. Examining the -ensemble model's config.pbtxt file, you can see that 4 parameters are required to generate a response +In the case of the models used in this example, you can replace MODEL_NAME with `ensemble` or `tensorrt_llm_bls`. Examining the +`ensemble` and `tensorrt_llm_bls` model's config.pbtxt file, you can see that 4 parameters are required to generate a response for this model: - "text_input": Input text to generate a response from @@ -272,6 +272,11 @@ Therefore, we can query the server in the following way: ```bash curl -X POST localhost:8000/v2/models/ensemble/generate -d '{"text_input": "What is machine learning?", "max_tokens": 20, "bad_words": "", "stop_words": ""}' ``` +if using the `ensemble` model or +``` +curl -X POST localhost:8000/v2/models/tensorrt_llm_bls/generate -d '{"text_input": "What is machine learning?", "max_tokens": 20, "bad_words": "", "stop_words": ""}' +``` +if using the `tensorrt_llm_bls` model. Which should return a result similar to (formatted for readability): ```json @@ -292,7 +297,7 @@ You can send requests to the "tensorrt_llm" model with the provided as following: ```bash -python3 inflight_batcher_llm/client/inflight_batcher_llm_client.py --request-output-len 200 --tokenizer_dir /workspace/tensorrtllm_backend/tensorrt_llm/examples/gpt/gpt2 +python3 inflight_batcher_llm/client/inflight_batcher_llm_client.py --request-output-len 200 --tokenizer-dir /workspace/tensorrtllm_backend/tensorrt_llm/examples/gpt/gpt2 ``` The result should be similar to the following: @@ -323,7 +328,7 @@ Soyer was a member of the French Academy of Sciences and You can also stop the generation process early by using the `--stop-after-ms` option to send a stop request after a few milliseconds: ```bash -python inflight_batcher_llm/client/inflight_batcher_llm_client.py --stop-after-ms 200 --request-output-len 200 --tokenizer_dir /workspace/tensorrtllm_backend/tensorrt_llm/examples/gpt/gpt2 +python inflight_batcher_llm/client/inflight_batcher_llm_client.py --stop-after-ms 200 --request-output-len 200 --tokenizer-dir /workspace/tensorrtllm_backend/tensorrt_llm/examples/gpt/gpt2 ``` You will find that the generation process is stopped early and therefore the number of generated tokens is lower than 200. diff --git a/all_models/inflight_batcher_llm/postprocessing/1/model.py b/all_models/inflight_batcher_llm/postprocessing/1/model.py index dc6d14eb..dd2ea43f 100644 --- a/all_models/inflight_batcher_llm/postprocessing/1/model.py +++ b/all_models/inflight_batcher_llm/postprocessing/1/model.py @@ -57,6 +57,11 @@ def initialize(self, args): 'string_value'] tokenizer_type = model_config['parameters']['tokenizer_type'][ 'string_value'] + self.skip_special_tokens = model_config['parameters'].get( + 'skip_special_tokens', + {'string_value': "true"})['string_value'].lower() in [ + 'true', '1', 't', 'y', 'yes' + ] if tokenizer_type == 't5': self.tokenizer = T5Tokenizer(vocab_file=tokenizer_dir, @@ -168,6 +173,8 @@ def _postprocessing(self, tokens_batch, sequence_lengths): for batch_idx, beam_tokens in enumerate(tokens_batch): for beam_idx, tokens in enumerate(beam_tokens): seq_len = sequence_lengths[batch_idx][beam_idx] - output = self.tokenizer.decode(tokens[:seq_len]) + output = self.tokenizer.decode( + tokens[:seq_len], + skip_special_tokens=self.skip_special_tokens) outputs.append(output.encode('utf8')) return outputs diff --git a/all_models/inflight_batcher_llm/postprocessing/config.pbtxt b/all_models/inflight_batcher_llm/postprocessing/config.pbtxt index e6d6894e..4b11f75b 100755 --- a/all_models/inflight_batcher_llm/postprocessing/config.pbtxt +++ b/all_models/inflight_batcher_llm/postprocessing/config.pbtxt @@ -81,9 +81,16 @@ parameters { } } +parameters { + key: "skip_special_tokens" + value: { + string_value: "True" + } +} + instance_group [ { - count: 1 + count: ${postprocessing_instance_count} kind: KIND_CPU } ] diff --git a/all_models/inflight_batcher_llm/preprocessing/1/model.py b/all_models/inflight_batcher_llm/preprocessing/1/model.py index 9150b4c3..edc5d319 100644 --- a/all_models/inflight_batcher_llm/preprocessing/1/model.py +++ b/all_models/inflight_batcher_llm/preprocessing/1/model.py @@ -58,6 +58,11 @@ def initialize(self, args): 'string_value'] tokenizer_type = model_config['parameters']['tokenizer_type'][ 'string_value'] + self.add_special_tokens = model_config['parameters'].get( + 'add_special_tokens', + {'string_value': "false"})['string_value'].lower() in [ + 'true', '1', 't', 'y', 'yes' + ] if tokenizer_type == 't5': self.tokenizer = T5Tokenizer(vocab_file=tokenizer_dir, @@ -207,7 +212,10 @@ def _create_request(self, query): query : batch string (2D numpy array) """ start_ids = [ - np.array(self.tokenizer.encode(s[0].decode())).astype(int) + np.array( + self.tokenizer.encode( + s[0].decode(), + add_special_tokens=self.add_special_tokens)).astype(int) for s in query ] start_lengths = np.array([[len(ids)] for ids in start_ids]).astype(int) diff --git a/all_models/inflight_batcher_llm/preprocessing/config.pbtxt b/all_models/inflight_batcher_llm/preprocessing/config.pbtxt index 134e9b81..0d863e6a 100644 --- a/all_models/inflight_batcher_llm/preprocessing/config.pbtxt +++ b/all_models/inflight_batcher_llm/preprocessing/config.pbtxt @@ -110,9 +110,16 @@ parameters { } } +parameters { + key: "add_special_tokens" + value: { + string_value: "False" + } +} + instance_group [ { - count: 1 + count: ${preprocessing_instance_count} kind: KIND_CPU } ] diff --git a/all_models/inflight_batcher_llm/tensorrt_llm_bls/1/model.py b/all_models/inflight_batcher_llm/tensorrt_llm_bls/1/model.py new file mode 100644 index 00000000..77c1f76e --- /dev/null +++ b/all_models/inflight_batcher_llm/tensorrt_llm_bls/1/model.py @@ -0,0 +1,369 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import traceback + +import numpy as np +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + + def initialize(self, args): + + # Parse model configs + model_config = json.loads(args['model_config']) + + params = model_config['parameters'] + + accumulate_tokens_str = '' + if 'accumulate_tokens' in params: + accumulate_tokens_str = params['accumulate_tokens']['string_value'] + + self.accumulate_tokens = accumulate_tokens_str.lower() in [ + 'true', 'yes', '1', 't' + ] + + self.decoupled = pb_utils.using_decoupled_model_transaction_policy( + model_config) + + self.logger = pb_utils.Logger + + self.bls_input_tensor_names = [ + "text_input", "max_tokens", "bad_words", "stop_words", "end_id", + "pad_id", "top_k", "top_p", "temperature", "length_penalty", + "repetition_penalty", "min_length", "presence_penalty", + "random_seed", "return_log_probs", "beam_width", "stream", + "prompt_embedding_table", "prompt_vocab_size", + "embedding_bias_words", "embedding_bias_weights" + ] + + self.preproc_input_to_bls_input_map = { + "QUERY": "text_input", + "REQUEST_OUTPUT_LEN": "max_tokens", + "BAD_WORDS_DICT": "bad_words", + "STOP_WORDS_DICT": "stop_words", + "EMBEDDING_BIAS_WORDS": "embedding_bias_words", + "EMBEDDING_BIAS_WEIGHTS": "embedding_bias_weights" + } + + self.preproc_output_to_trtllm_input_map = { + "INPUT_ID": "input_ids", + "REQUEST_INPUT_LEN": "input_lengths", + "REQUEST_OUTPUT_LEN": "request_output_len", + "BAD_WORDS_IDS": "bad_words_list", + "STOP_WORDS_IDS": "stop_words_list", + "EMBEDDING_BIAS": "embedding_bias", + } + + self.trtllm_input_to_bls_input_map = { + "end_id": "end_id", + "pad_id": "pad_id", + "beam_width": "beam_width", + "runtime_top_k": "top_k", + "runtime_top_p": "top_p", + "len_penalty": "length_penalty", + "repetition_penalty": "repetition_penalty", + "min_length": "min_length", + "presence_penalty": "presence_penalty", + "random_seed": "random_seed", + "return_log_probs": "return_log_probs", + "streaming": "stream", + "prompt_embedding_table": "prompt_embedding_table", + "prompt_vocab_size": "prompt_vocab_size", + } + + self.trtllm_output_to_postproc_input_map = { + "output_ids": "TOKENS_BATCH", + "sequence_length": "SEQUENCE_LENGTH", + "cum_log_probs": "CUM_LOG_PROBS", + "output_log_probs": "OUTPUT_LOG_PROBS", + } + + self.postproc_output_to_bls_output_map = { + "OUTPUT": "text_output", + "OUT_CUM_LOG_PROBS": "cum_log_probs", + "OUT_OUTPUT_LOG_PROBS": "output_log_probs", + } + + def _get_bls_input_tensors_map(self, request): + + bls_input_tensors_map = {} + for input_tensor_name in self.bls_input_tensor_names: + tensor = pb_utils.get_input_tensor_by_name(request, + input_tensor_name) + if tensor != None: + bls_input_tensors_map[input_tensor_name] = tensor + + return bls_input_tensors_map + + def _get_preproc_input_tensors(self, bls_input_tensors_map): + + preproc_input_tensors = [] + + for preproc_name, bls_name in self.preproc_input_to_bls_input_map.items( + ): + + if bls_name in bls_input_tensors_map: + tensor = bls_input_tensors_map[bls_name] + # Change the name to what the preprocessor expects + preproc_input_tensors.append( + pb_utils.Tensor(preproc_name, tensor.as_numpy())) + + return preproc_input_tensors + + def _get_trtllm_input_tensors(self, bls_input_tensors_map, + preproc_output_tensors): + + trtllm_input_tensors = [] + + # Set input tensors from preprocessor outputs + for preproc_output_tensor in preproc_output_tensors: + + trtllm_tensor_name = self.preproc_output_to_trtllm_input_map[ + preproc_output_tensor.name()] + trtllm_input_tensors.append( + pb_utils.Tensor(trtllm_tensor_name, + preproc_output_tensor.as_numpy())) + + # Set input tensors from bls inputs + for trtllm_name, bls_name in self.trtllm_input_to_bls_input_map.items( + ): + + if bls_name in bls_input_tensors_map: + tensor = bls_input_tensors_map[bls_name] + # Change the name to what the preprocessor expects + trtllm_input_tensors.append( + pb_utils.Tensor(trtllm_name, tensor.as_numpy())) + + return trtllm_input_tensors + + def _get_postproc_input_tensors(self, tokens, trtllm_output_tensors): + + postproc_input_tensors = [] + + for trtllm_output_tensor in trtllm_output_tensors: + + # If in decoupled mode, option to append new tokens to existing tokens before calling postprocessor + # This might be needed for some tokenizers + # Note that in that case, the client must overwrite previously received output text + if (self.accumulate_tokens and self.decoupled + and trtllm_output_tensor.name() == "output_ids"): + + new_tokens = trtllm_output_tensor.as_numpy() + if new_tokens.ndim != 3: + raise pb_utils.TritonModelException( + "Expected output_ids tensor to have 3 dims.") + if new_tokens.shape[0] != 1: + raise pb_utils.TritonModelException( + "Expected output_ids tensor to have batch size of 1") + if new_tokens.shape[1] != 1: + raise pb_utils.TritonModelException( + "Accumulation of tokens is only implemented for beam width = 1" + ) + + tokens = new_tokens if (tokens is None) else np.concatenate( + (tokens, new_tokens), axis=2) + + # output ids + postproc_output_ids_name = self.trtllm_output_to_postproc_input_map[ + "output_ids"] + postproc_input_tensors.append( + pb_utils.Tensor(postproc_output_ids_name, tokens)) + + # sequence length + np_seq_len_tensor = np.array([[tokens.shape[2]]], + dtype=np.int32) + postproc_seq_len_name = self.trtllm_output_to_postproc_input_map[ + "sequence_length"] + postproc_input_tensors.append( + pb_utils.Tensor(postproc_seq_len_name, np_seq_len_tensor)) + + # Set input tensors from trtllm outputs + for trtllm_output_tensor in trtllm_output_tensors: + + # output_ids and sequence_length were handled earlier + if (self.accumulate_tokens and self.decoupled + and (trtllm_output_tensor.name() == "output_ids" + or trtllm_output_tensor.name() == "sequence_length")): + continue + + postproc_tensor_name = self.trtllm_output_to_postproc_input_map[ + trtllm_output_tensor.name()] + + postproc_input_tensors.append( + pb_utils.Tensor(postproc_tensor_name, + trtllm_output_tensor.as_numpy())) + + return tokens, postproc_input_tensors + + def _get_bls_output_tensors(self, postproc_output_tensors): + + bls_output_tensors = [] + + # Set input tensors from trtllm outputs + for postproc_output_tensor in postproc_output_tensors: + + bls_tensor_name = self.postproc_output_to_bls_output_map[ + postproc_output_tensor.name()] + bls_output_tensors.append( + pb_utils.Tensor(bls_tensor_name, + postproc_output_tensor.as_numpy())) + + return bls_output_tensors + + def execute(self, requests): + + responses = [] + bls_response_sender = None + + for request in requests: + + #Get the response sender for the BLS + if self.decoupled: + bls_response_sender = request.get_response_sender() + + try: + # Get the bls input tensors + bls_input_tensors_map = self._get_bls_input_tensors_map( + request) + + #Check the batch dimension + for name, tensor in bls_input_tensors_map.items(): + batch_dim = tensor.as_numpy().shape[0] + + if batch_dim != 1: + + err_str = "Inflight batching backend expects requests with batch size of 1." + self.logger.log_error(err_str) + raise pb_utils.TritonModelException(err_str) + + # Create the preprocessor input tensors + preproc_input_tensors = self._get_preproc_input_tensors( + bls_input_tensors_map) + + preproc_request = pb_utils.InferenceRequest( + model_name="preprocessing", + inputs=preproc_input_tensors, + requested_output_names=list( + self.preproc_output_to_trtllm_input_map.keys())) + + #Execute preprocessor + preproc_response = preproc_request.exec() + + if preproc_response.has_error(): + raise pb_utils.TritonModelException( + preproc_response.error().message()) + + # Create the trtllm input tensors + trtllm_input_tensors = self._get_trtllm_input_tensors( + bls_input_tensors_map, preproc_response.output_tensors()) + + trtllm_request = pb_utils.InferenceRequest( + model_name="tensorrt_llm", + inputs=trtllm_input_tensors, + requested_output_names=list( + self.trtllm_output_to_postproc_input_map.keys())) + + #Execute trtllm + trtllm_responses = trtllm_request.exec( + decoupled=self.decoupled) + + if not self.decoupled: + trtllm_responses = [trtllm_responses] + + tokens = None + + #Loop over the trtllm responses + for trtllm_response in trtllm_responses: + + if trtllm_response.has_error(): + raise pb_utils.TritonModelException( + trtllm_response.error().message()) + + trtllm_output_tensors = trtllm_response.output_tensors() + + tokens, postproc_input_tensors = self._get_postproc_input_tensors( + tokens, trtllm_output_tensors) + + postproc_request = pb_utils.InferenceRequest( + model_name="postprocessing", + inputs=postproc_input_tensors, + requested_output_names=list( + self.postproc_output_to_bls_output_map.keys())) + + #Execute postprocessor + postproc_response = postproc_request.exec() + + if postproc_response.has_error(): + raise pb_utils.TritonModelException( + postproc_response.error().message()) + + # Create the BLS response + bls_output_tensors = self._get_bls_output_tensors( + postproc_response.output_tensors()) + + bls_response = pb_utils.InferenceResponse( + output_tensors=bls_output_tensors) + + if self.decoupled: + bls_response_sender.send(bls_response) + else: + responses.append(bls_response) + + # All responses have been sent, set final flag + if self.decoupled: + bls_response_sender.send( + flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) + + except Exception: + + self.logger.log_error(traceback.format_exc()) + # If encountering an error, send a response with err msg + error_response = pb_utils.InferenceResponse( + output_tensors=[], + error=pb_utils.TritonError(traceback.format_exc())) + + if self.decoupled: + bls_response_sender.send(error_response) + bls_response_sender.send( + flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) + else: + responses.append(error_response) + + if self.decoupled: + return None + else: + assert len(responses) == len(requests) + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + print('Cleaning up...') diff --git a/all_models/inflight_batcher_llm/tensorrt_llm_bls/config.pbtxt b/all_models/inflight_batcher_llm/tensorrt_llm_bls/config.pbtxt new file mode 100755 index 00000000..b0d3934a --- /dev/null +++ b/all_models/inflight_batcher_llm/tensorrt_llm_bls/config.pbtxt @@ -0,0 +1,191 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "tensorrt_llm_bls" +backend: "python" +max_batch_size: ${triton_max_batch_size} + +model_transaction_policy { + decoupled: ${decoupled_mode} +} + +input [ + { + name: "text_input" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "max_tokens" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "bad_words" + data_type: TYPE_STRING + dims: [ -1 ] + optional: true + }, + { + name: "stop_words" + data_type: TYPE_STRING + dims: [ -1 ] + optional: true + }, + { + name: "end_id" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + }, + { + name: "pad_id" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + }, + { + name: "top_k" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + }, + { + name: "top_p" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "temperature" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "length_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "repetition_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "min_length" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + }, + { + name: "presence_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "random_seed" + data_type: TYPE_UINT64 + dims: [ 1 ] + optional: true + }, + { + name: "return_log_probs" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + }, + { + name: "beam_width" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + }, + { + name: "stream" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + }, + { + name: "prompt_embedding_table" + data_type: TYPE_FP16 + dims: [ -1, -1 ] + optional: true + }, + { + name: "prompt_vocab_size" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + }, + { + name: "embedding_bias_words" + data_type: TYPE_STRING + dims: [ -1 ] + optional: true + }, + { + name: "embedding_bias_weights" + data_type: TYPE_FP32 + dims: [ -1 ] + optional: true + } +] +output [ + { + name: "text_output" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "cum_log_probs" + data_type: TYPE_FP32 + dims: [ -1 ] + }, + { + name: "output_log_probs" + data_type: TYPE_FP32 + dims: [ -1, -1 ] + } +] + +parameters: { + key: "accumulate_tokens" + value: { + string_value: "${accumulate_tokens}" + } +} + +instance_group [ + { + count: ${bls_instance_count} + kind : KIND_CPU + } +] diff --git a/docs/baichuan.md b/docs/baichuan.md new file mode 100644 index 00000000..4cf96ed0 --- /dev/null +++ b/docs/baichuan.md @@ -0,0 +1,412 @@ + +## End to end workflow to run baichuan + +* build engine + +```bash +export HF_BAICHUAN_MODEL=Baichuan-13B-Chat/ +python build.py --model_dir ${HF_BAICHUAN_MODEL} \ + --dtype float16 \ + --remove_input_padding \ + --use_gpt_attention_plugin float16 \ + --enable_context_fmha \ + --use_gemm_plugin float16 \ + --output_dir /tmp/baichuan/13B/trt_engines/fp16/1-gpu/ \ + --paged_kv_cache \ + --max_batch_size 64 + +[11/29/2023-08:20:34] [TRT] [I] Total Host Persistent Memory: 77008 +[11/29/2023-08:20:34] [TRT] [I] Total Device Persistent Memory: 0 +[11/29/2023-08:20:34] [TRT] [I] Total Scratch Memory: 1342439424 +[11/29/2023-08:20:34] [TRT] [I] [BlockAssignment] Started assigning block shifts. This will take 690 steps to complete. +[11/29/2023-08:20:34] [TRT] [I] [BlockAssignment] Algorithm ShiftNTopDown took 25.5938ms to assign 11 blocks to 690 nodes requiring 6308236288 bytes. +[11/29/2023-08:20:34] [TRT] [I] Total Activation Memory: 6308236288 +[11/29/2023-08:20:35] [TRT] [I] Total Weights Memory: 26529804072 +[11/29/2023-08:20:35] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +64, now: CPU 56027, GPU 28529 (MiB) +[11/29/2023-08:20:35] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +0, GPU +72, now: CPU 56027, GPU 28601 (MiB) +[11/29/2023-08:20:35] [TRT] [I] [MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 1250 MiB, GPU 41088 MiB +[11/29/2023-08:20:35] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in building engine: CPU +0, GPU +25301, now: CPU 0, GPU 25301 (MiB) +[11/29/2023-08:20:44] [TRT] [I] [MemUsageStats] Peak memory usage during Engine building and serialization: CPU: 81260 MiB +[11/29/2023-08:20:44] [TRT-LLM] [I] Total time of building baichuan_float16_tp1_rank0.engine: 00:00:37 +[11/29/2023-08:20:44] [TRT-LLM] [I] Config saved to /tmp/baichuan/13B/trt_engines/fp16/1-gpu/config.json. +[11/29/2023-08:20:45] [TRT-LLM] [I] Serializing engine to /tmp/baichuan/13B/trt_engines/fp16/1-gpu/baichuan_float16_tp1_rank0.engine... +[11/29/2023-08:21:35] [TRT-LLM] [I] Engine serialized. Total time: 00:00:49 +[11/29/2023-08:21:36] [TRT-LLM] [I] Timing cache serialized to /tmp/baichuan/13B/trt_engines/fp16/1-gpu/model.cache +[11/29/2023-08:21:36] [TRT-LLM] [I] Total time of building all 1 engines: 00:05:00 +``` + +* Prepare configs + +```bash +cp all_models/inflight_batcher_llm/ baichuan_ifb -r + +python3 tools/fill_template.py -i baichuan_ifb/preprocessing/config.pbtxt tokenizer_dir:${HF_BAICHUAN_MODEL},tokenizer_type:auto,triton_max_batch_size:64,preprocessing_instance_count:1 +python3 tools/fill_template.py -i baichuan_ifb/postprocessing/config.pbtxt tokenizer_dir:${HF_BAICHUAN_MODEL},tokenizer_type:auto,triton_max_batch_size:64,postprocessing_instance_count:1 +python3 tools/fill_template.py -i baichuan_ifb/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:64,decoupled_mode:False,bls_instance_count:1,accumulate_tokens:False +python3 tools/fill_template.py -i baichuan_ifb/ensemble/config.pbtxt triton_max_batch_size:64 +python3 tools/fill_template.py -i baichuan_ifb/tensorrt_llm/config.pbtxt triton_max_batch_size:64,decoupled_mode:False,max_beam_width:1,engine_dir:/tmp/baichuan/13B/trt_engines/fp16/1-gpu/,max_tokens_in_paged_kv_cache:2560,max_kv_cache_length:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,batching_strategy:inflight_batching,max_queue_delay_microseconds:600 +``` + +* Launch server + +```bash +pip pinstall SentencePiece +python3 scripts/launch_triton_server.py --world_size 1 --model_repo=baichuan_ifb/ +``` + +this setting requires about 35GB + +```bash +nvidia-smi + +Wed Nov 29 08:33:50 2023 ++---------------------------------------------------------------------------------------+ +| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 | +|-----------------------------------------+----------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+======================+======================| +| 0 NVIDIA H100 PCIe On | 00000000:41:00.0 Off | 0 | +| N/A 43C P0 81W / 350W | 34743MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ + ++---------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=======================================================================================| ++---------------------------------------------------------------------------------------+ +``` + +If you encounter error + +```bash +I1129 08:28:33.267969 15088 model_lifecycle.cc:818] successfully loaded 'tensorrt_llm_bls' +I1129 08:28:33.928915 15088 pb_stub.cc:325] Failed to initialize Python stub: ValueError: Tokenizer class BaichuanTokenizer does not exist or is not currently imported. + +At: + /home/bhsueh/.local/lib/python3.10/site-packages/transformers/models/auto/tokenization_auto.py(748): from_pretrained + /home/scratch.bhsueh_sw_1/workspace/TensorRT-LLM/tllm_backend_nvbug/baichuan_ifb/preprocessing/1/model.py(66): initialize + +I1129 08:28:33.928991 15088 pb_stub.cc:325] Failed to initialize Python stub: ValueError: Tokenizer class BaichuanTokenizer does not exist or is not currently imported. + +At: + /home/bhsueh/.local/lib/python3.10/site-packages/transformers/models/auto/tokenization_auto.py(748): from_pretrained + /home/scratch.bhsueh_sw_1/workspace/TensorRT-LLM/tllm_backend_nvbug/baichuan_ifb/postprocessing/1/model.py(65): initialize + +E1129 08:28:34.285773 15088 backend_model.cc:634] ERROR: Failed to create instance: ValueError: Tokenizer class BaichuanTokenizer does not exist or is not currently imported. + +At: + /home/bhsueh/.local/lib/python3.10/site-packages/transformers/models/auto/tokenization_auto.py(748): from_pretrained + /home/scratch.bhsueh_sw_1/workspace/TensorRT-LLM/tllm_backend_nvbug/baichuan_ifb/postprocessing/1/model.py(65): initialize + +E1129 08:28:34.285879 15088 model_lifecycle.cc:621] failed to load 'postprocessing' version 1: Internal: ValueError: Tokenizer class BaichuanTokenizer does not exist or is not currently imported. + +At: + /home/bhsueh/.local/lib/python3.10/site-packages/transformers/models/auto/tokenization_auto.py(748): from_pretrained + /home/scratch.bhsueh_sw_1/workspace/TensorRT-LLM/tllm_backend_nvbug/baichuan_ifb/postprocessing/1/model.py(65): initialize + +I1129 08:28:34.285894 15088 model_lifecycle.cc:756] failed to load 'postprocessing' +E1129 08:28:34.304925 15088 backend_model.cc:634] ERROR: Failed to create instance: ValueError: Tokenizer class BaichuanTokenizer does not exist or is not currently imported. + +At: + /home/bhsueh/.local/lib/python3.10/site-packages/transformers/models/auto/tokenization_auto.py(748): from_pretrained + /home/scratch.bhsueh_sw_1/workspace/TensorRT-LLM/tllm_backend_nvbug/baichuan_ifb/preprocessing/1/model.py(66): initialize + +E1129 08:28:34.305028 15088 model_lifecycle.cc:621] failed to load 'preprocessing' version 1: Internal: ValueError: Tokenizer class BaichuanTokenizer does not exist or is not currently imported. + +At: + /home/bhsueh/.local/lib/python3.10/site-packages/transformers/models/auto/tokenization_auto.py(748): from_pretrained + /home/scratch.bhsueh_sw_1/workspace/TensorRT-LLM/tllm_backend_nvbug/baichuan_ifb/preprocessing/1/model.py(66): initialize + +I1129 08:28:34.305052 15088 model_lifecycle.cc:756] failed to load 'preprocessing' +``` + +please add `trust_remote_code=True` in tokenizer of preprocessing and postprocessing. Considering the security, we don't add it by default. + +* Send request + +```bash +curl -X POST localhost:8000/v2/models/ensemble/generate -d '{"text_input": "What is machine learning?", "max_tokens": 20, "bad_words": "", "stop_words": "", "pad_id": 2, "end_id": 2}' + +{"cum_log_probs":0.0,"model_name":"ensemble","model_version":"1","output_log_probs":[0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0],"sequence_end":false,"sequence_id":0,"sequence_start":false,"text_output":"\nMachine learning is a subset of artificial intelligence (AI) that focuses on the"} +``` + +* Send request with bad_words and stop_words + +```bash +curl -X POST localhost:8000/v2/models/ensemble/generate -d '{"text_input": "What is machine learning?", "max_tokens": 20, "bad_words": "intelligence", "stop_words": "focuses", "pad_id": 2, "end_id": 2}' + +{"cum_log_probs":0.0,"model_name":"ensemble","model_version":"1","output_log_probs":[0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0],"sequence_end":false,"sequence_id":0,"sequence_start":false,"text_output":"\nMachine learning is a subset of artificial intelligent (AI) that focuses"} +``` + +* Send request by `inflight_batcher_llm_client.py` (Remember to add `trust_remote_code=True` in tokenizer of `inflight_batcher_llm_client.py`) + +```bash +python3 inflight_batcher_llm/client/inflight_batcher_llm_client.py --request-output-len 200 --tokenizer-dir ${HF_BAICHUAN_MODEL} + +========= +Input sequence: [16448, 677, 5611, 31136, 21309, 4746, 31125, 694, 1033, 653, 8808, 754, 650] +Got completed request +Input: Born in north-east France, Soyer trained as a +Output beam 0: . He became the chef at the Reform Club, and later at the Vegetarian Restaurant, where he pioneered the use of vegetables in fine dining. He also wrote a number of books, including The London Art of Cookery (1858), The Modern Housekeeper (1861), and The Compleat Housekeeper (1862). +Soyer was a strong supporter of the British National Rifle Association, and was a member of the organisation's council. He was also a member of the Reform Club, the Athenaeum, and the Rifle Club. He died in London in 1904. +Soyer was born in the village of Montigny-lès-Cormeilles, in the department of Aisne, France. He was the son of a baker, and was educated in the +Output sequence: [16814, 677, 5621, 1412, 4514, 678, 2835, 677, 31106, 53, 60, 57, 59, 79, 1057, 3142, 656, 16814, 772, 656, 15824, 4305, 31125, 680, 2384, 772, 656, 9592, 1161, 8480, 13550, 807, 31125, 1238, 742, 11135, 2521, 656, 1226, 679, 8431, 3392, 677, 4816, 8946, 79, 1057, 982, 4251, 650, 1697, 679, 3594, 31125, 1516, 776, 2835, 2409, 679, 7782, 1620, 762, 53, 60, 57, 60, 1098, 776, 8753, 2542, 17655, 762, 53, 60, 58, 53, 1098, 680, 776, 1127, 1596, 658, 2542, 17655, 762, 53, 60, 58, 54, 31145, 79, 5, 31131, 1033, 653, 796, 650, 2427, 23747, 679, 656, 3681, 2024, 751, 19422, 2790, 728, 31125, 680, 796, 650, 2736, 679, 656, 1625, 4859, 31155, 31114, 7284, 79, 1057, 796, 982, 650, 2736, 679, 656, 15824, 4305, 31125, 656, 1996, 1179, 4302, 784, 31125, 680, 656, 751, 19422, 4305, 79, 1057, 4357, 677, 2835, 677, 31106, 53, 61, 52, 56, 79, 5, 31131, 1033, 653, 796, 4204, 677, 656, 6730, 679, 5136, 942, 31124, 31136, 31115, 16987, 31136, 31133, 908, 31107, 22542, 31125, 677, 656, 1664, 2049, 679, 703, 667, 1024, 31125, 4746, 79, 1057, 796, 656, 3652, 679, 650, 675, 3034, 31125, 680, 796, 18735, 677, 656] +``` + +* Run test on dataset + +``` +python3 tools/inflight_batcher_llm/end_to_end_test.py --dataset ci/L0_backend_trtllm/simple_data.json --max-input-len 500 + +[INFO] Start testing on 13 prompts. +[INFO] Functionality test succeed. +[INFO] Warm up for benchmarking. +[INFO] Start benchmarking on 13 prompts. +[INFO] Total Latency: 1598.328 ms +``` + +* run with decouple mode (streaming) + +```bash +cp all_models/inflight_batcher_llm/ baichuan_ifb -r + +python3 tools/fill_template.py -i baichuan_ifb/preprocessing/config.pbtxt tokenizer_dir:${HF_BAICHUAN_MODEL},tokenizer_type:auto,triton_max_batch_size:64,preprocessing_instance_count:1 +python3 tools/fill_template.py -i baichuan_ifb/postprocessing/config.pbtxt tokenizer_dir:${HF_BAICHUAN_MODEL},tokenizer_type:auto,triton_max_batch_size:64,postprocessing_instance_count:1 +python3 tools/fill_template.py -i baichuan_ifb/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:64,decoupled_mode:True,bls_instance_count:1,accumulate_tokens:True +python3 tools/fill_template.py -i baichuan_ifb/ensemble/config.pbtxt triton_max_batch_size:64 +python3 tools/fill_template.py -i baichuan_ifb/tensorrt_llm/config.pbtxt triton_max_batch_size:64,decoupled_mode:True,max_beam_width:1,engine_dir:/tmp/baichuan/13B/trt_engines/fp16/1-gpu/,max_tokens_in_paged_kv_cache:2560,max_kv_cache_length:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,batching_strategy:inflight_batching,max_queue_delay_microseconds:600 + +pip pinstall SentencePiece +# please add `trust_remote_code=True` in tokenizer of preprocessing and postprocessing. Considering the security, we don't add it by default. +python3 scripts/launch_triton_server.py --world_size 1 --model_repo=baichuan_ifb/ + +python3 inflight_batcher_llm/client/inflight_batcher_llm_client.py --request-output-len 200 --tokenizer-dir ${HF_BAICHUAN_MODEL} --streaming +``` + +
+ The result would be like + + +```bash +========= +Input sequence: [16448, 677, 5611, 31136, 21309, 4746, 31125, 694, 1033, 653, 8808, 754, 650] +[16814] +[677] +[5621] +[1412] +[4514] +[678] +[2835] +[677] +[31106] +[53] +[60] +[57] +[59] +[79] +[1057] +[3142] +[656] +[16814] +[772] +[656] +[15824] +[4305] +[31125] +[680] +[2384] +[772] +[656] +[9592] +[1161] +[8480] +[13550] +[807] +[31125] +[1238] +[742] +[11135] +[2521] +[656] +[1226] +[679] +[8431] +[3392] +[677] +[4816] +[8946] +[79] +[1057] +[982] +[4251] +[650] +[1697] +[679] +[3594] +[31125] +[1516] +[776] +[2835] +[2409] +[679] +[7782] +[1620] +[762] +[53] +[60] +[57] +[60] +[1098] +[776] +[8753] +[2542] +[17655] +[762] +[53] +[60] +[58] +[53] +[1098] +[680] +[776] +[1127] +[1596] +[658] +[2542] +[17655] +[762] +[53] +[60] +[58] +[54] +[31145] +[79] +[5] +[31131] +[1033] +[653] +[796] +[650] +[2427] +[23747] +[679] +[656] +[3681] +[2024] +[751] +[19422] +[2790] +[728] +[31125] +[680] +[796] +[650] +[2736] +[679] +[656] +[1625] +[4859] +[31155] +[31114] +[7284] +[79] +[1057] +[796] +[982] +[650] +[2736] +[679] +[656] +[15824] +[4305] +[31125] +[656] +[1996] +[1179] +[4302] +[784] +[31125] +[680] +[656] +[751] +[19422] +[4305] +[79] +[1057] +[4357] +[677] +[2835] +[677] +[31106] +[53] +[61] +[52] +[56] +[79] +[5] +[31131] +[1033] +[653] +[796] +[4204] +[677] +[656] +[6730] +[679] +[5136] +[942] +[31124] +[31136] +[31115] +[16987] +[31136] +[31133] +[908] +[31107] +[22542] +[31125] +[677] +[656] +[1664] +[2049] +[679] +[703] +[667] +[1024] +[31125] +[4746] +[79] +[1057] +[796] +[656] +[3652] +[679] +[650] +[675] +[3034] +[31125] +[680] +[796] +[18735] +[677] +[656] +Input: Born in north-east France, Soyer trained as a +Output beam 0: chef in Paris before moving to London in 1857. He became the chef at the Reform Club, and later at the Vegetarian Restaurant, where he pioneered the use of vegetables in fine dining. He also wrote a number of books, including The London Art of Cookery (1858), The Modern Housekeeper (1861), and The Compleat Housekeeper (1862). +Soyer was a strong supporter of the British National Rifle Association, and was a member of the organisation's council. He was also a member of the Reform Club, the Athenaeum, and the Rifle Club. He died in London in 1904. +Soyer was born in the village of Montigny-lès-Cormeilles, in the department of Aisne, France. He was the son of a baker, and was educated in the +Output sequence: [16448, 677, 5611, 31136, 21309, 4746, 31125, 694, 1033, 653, 8808, 754, 650, 16814, 677, 5621, 1412, 4514, 678, 2835, 677, 31106, 53, 60, 57, 59, 79, 1057, 3142, 656, 16814, 772, 656, 15824, 4305, 31125, 680, 2384, 772, 656, 9592, 1161, 8480, 13550, 807, 31125, 1238, 742, 11135, 2521, 656, 1226, 679, 8431, 3392, 677, 4816, 8946, 79, 1057, 982, 4251, 650, 1697, 679, 3594, 31125, 1516, 776, 2835, 2409, 679, 7782, 1620, 762, 53, 60, 57, 60, 1098, 776, 8753, 2542, 17655, 762, 53, 60, 58, 53, 1098, 680, 776, 1127, 1596, 658, 2542, 17655, 762, 53, 60, 58, 54, 31145, 79, 5, 31131, 1033, 653, 796, 650, 2427, 23747, 679, 656, 3681, 2024, 751, 19422, 2790, 728, 31125, 680, 796, 650, 2736, 679, 656, 1625, 4859, 31155, 31114, 7284, 79, 1057, 796, 982, 650, 2736, 679, 656, 15824, 4305, 31125, 656, 1996, 1179, 4302, 784, 31125, 680, 656, 751, 19422, 4305, 79, 1057, 4357, 677, 2835, 677, 31106, 53, 61, 52, 56, 79, 5, 31131, 1033, 653, 796, 4204, 677, 656, 6730, 679, 5136, 942, 31124, 31136, 31115, 16987, 31136, 31133, 908, 31107, 22542, 31125, 677, 656, 1664, 2049, 679, 703, 667, 1024, 31125, 4746, 79, 1057, 796, 656, 3652, 679, 650, 675, 3034, 31125, 680, 796, 18735, 677, 656] +``` + +
+ + +* Run several requests at the same time + +```bash +echo '{"text_input": "What is machine learning?", "max_tokens": 20, "bad_words": "", "stop_words": "", "pad_id": 2, "end_id": 2}' > tmp.txt +printf '%s\n' {1..20} | xargs -I % -P 20 curl -X POST localhost:8000/v2/models/ensemble/generate -d @tmp.txt +``` diff --git a/docs/llama.md b/docs/llama.md new file mode 100644 index 00000000..434da499 --- /dev/null +++ b/docs/llama.md @@ -0,0 +1,346 @@ + +## End to end workflow to run llama + +* build engine + +```bash +export HF_LLAMA_MODEL=llama-7b-hf/ +python build.py --model_dir ${HF_LLAMA_MODEL} \ + --dtype float16 \ + --remove_input_padding \ + --use_gpt_attention_plugin float16 \ + --enable_context_fmha \ + --use_gemm_plugin float16 \ + --output_dir /tmp/llama/7B/trt_engines/fp16/1-gpu/ \ + --paged_kv_cache \ + --max_batch_size 64 +``` + +* Prepare configs + +```bash +cp all_models/inflight_batcher_llm/ llama_ifb -r + +python3 tools/fill_template.py -i llama_ifb/preprocessing/config.pbtxt tokenizer_dir:${HF_LLAMA_MODEL},tokenizer_type:llama,triton_max_batch_size:64,preprocessing_instance_count:1 +python3 tools/fill_template.py -i llama_ifb/postprocessing/config.pbtxt tokenizer_dir:${HF_LLAMA_MODEL},tokenizer_type:llama,triton_max_batch_size:64,postprocessing_instance_count:1 +python3 tools/fill_template.py -i llama_ifb/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:64,decoupled_mode:False,bls_instance_count:1,accumulate_tokens:False +python3 tools/fill_template.py -i llama_ifb/ensemble/config.pbtxt triton_max_batch_size:64 +python3 tools/fill_template.py -i llama_ifb/tensorrt_llm/config.pbtxt triton_max_batch_size:64,decoupled_mode:False,max_beam_width:1,engine_dir:/tmp/llama/7B/trt_engines/fp16/1-gpu/,max_tokens_in_paged_kv_cache:2560,max_kv_cache_length:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,batching_strategy:inflight_batching,max_queue_delay_microseconds:600 +``` + +* Launch server + +```bash +pip pinstall SentencePiece +python3 scripts/launch_triton_server.py --world_size 1 --model_repo=llama_ifb/ +``` + +this setting requires about 25GB + +```bash +nvidia-smi + +Wed Nov 29 08:51:30 2023 ++---------------------------------------------------------------------------------------+ +| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 | +|-----------------------------------------+----------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+======================+======================| +| 0 NVIDIA H100 PCIe On | 00000000:41:00.0 Off | 0 | +| N/A 40C P0 79W / 350W | 25169MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ + ++---------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=======================================================================================| ++---------------------------------------------------------------------------------------+ +``` + +* Send request + +```bash +curl -X POST localhost:8000/v2/models/ensemble/generate -d '{"text_input": "What is machine learning?", "max_tokens": 20, "bad_words": "", "stop_words": "", "pad_id": 2, "end_id": 2}' + +{"cum_log_probs":0.0,"model_name":"ensemble","model_version":"1","output_log_probs":[0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0],"sequence_end":false,"sequence_id":0,"sequence_start":false,"text_output":"Machine learning is a subfield of artificial intelligence that focuses on the development of algorithms that can learn"} +``` + +* Send request with bad_words and stop_words + +```bash +curl -X POST localhost:8000/v2/models/ensemble/generate -d '{"text_input": "What is machine learning?", "max_tokens": 20, "bad_words": "intelligence", "stop_words": "focuses", "pad_id": 2, "end_id": 2}' + +{"cum_log_probs":0.0,"model_name":"ensemble","model_version":"1","output_log_probs":[0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0],"sequence_end":false,"sequence_id":0,"sequence_start":false,"text_output":"Machine learning is a subfield of artificial Intelligence (AI) that focuses"} +``` + +* Send request by `inflight_batcher_llm_client.py` + +```bash +python3 inflight_batcher_llm/client/inflight_batcher_llm_client.py --request-output-len 200 --tokenizer-dir ${HF_LLAMA_MODEL} + +========= +[[1, 19298, 297, 6641, 29899, 23027, 3444, 29892, 1105, 7598, 16370, 408, 263]] +Got completed request +Input: Born in north-east France, Soyer trained as a +Output beam 0: 850. He was the first chef to be hired by the newly opened Delmonico’s restaurant, where he worked for 10 years. He then opened his own restaurant, which was a huge success. +Soyer was a prolific writer and his books include The Gastronomic Regenerator (1854), The Gastronomic Regenerator and Cookery for the People (1855), The Cuisine of To-day (1859), The Cuisine of To-morrow (1864), The Cuisine of the Future (1867), The Cuisine of the Future (1873), The Cuisine of the Future (1874), The Cuisine of the Future (1875), The Cuisine of the Future (1876), The +output_ids = [14547, 297, 3681, 322, 4517, 1434, 8401, 304, 1570, 3088, 297, 29871, 29896, 29947, 29945, 29900, 29889, 940, 471, 278, 937, 14547, 304, 367, 298, 2859, 491, 278, 15141, 6496, 5556, 3712, 1417, 30010, 29879, 27144, 29892, 988, 540, 3796, 363, 29871, 29896, 29900, 2440, 29889, 940, 769, 6496, 670, 1914, 27144, 29892, 607, 471, 263, 12176, 2551, 29889, 13, 6295, 7598, 471, 263, 410, 29880, 928, 9227, 322, 670, 8277, 3160, 450, 402, 7614, 4917, 293, 2169, 759, 1061, 313, 29896, 29947, 29945, 29946, 511, 450, 402, 7614, 4917, 293, 2169, 759, 1061, 322, 17278, 708, 363, 278, 11647, 313, 29896, 29947, 29945, 29945, 511, 450, 315, 4664, 457, 310, 1763, 29899, 3250, 313, 29896, 29947, 29945, 29929, 511, 450, 315, 4664, 457, 310, 1763, 29899, 26122, 313, 29896, 29947, 29953, 29946, 511, 450, 315, 4664, 457, 310, 278, 16367, 313, 29896, 29947, 29953, 29955, 511, 450, 315, 4664, 457, 310, 278, 16367, 313, 29896, 29947, 29955, 29941, 511, 450, 315, 4664, 457, 310, 278, 16367, 313, 29896, 29947, 29955, 29946, 511, 450, 315, 4664, 457, 310, 278, 16367, 313, 29896, 29947, 29955, 29945, 511, 450, 315, 4664, 457, 310, 278, 16367, 313, 29896, 29947, 29955, 29953, 511, 450] +``` + +* Run test on dataset + +``` +python3 tools/inflight_batcher_llm/end_to_end_test.py --dataset ci/L0_backend_trtllm/simple_data.json --max-input-len 500 + +[INFO] Start testing on 13 prompts. +[INFO] Functionality test succeed. +[INFO] Warm up for benchmarking. +[INFO] Start benchmarking on 13 prompts. +[INFO] Total Latency: 962.179 ms +``` + + + +* run with decouple mode (streaming) + +```bash +cp all_models/inflight_batcher_llm/ llama_ifb -r + +python3 tools/fill_template.py -i llama_ifb/preprocessing/config.pbtxt tokenizer_dir:${HF_LLAMA_MODEL},tokenizer_type:llama,triton_max_batch_size:64,preprocessing_instance_count:1 +python3 tools/fill_template.py -i llama_ifb/postprocessing/config.pbtxt tokenizer_dir:${HF_LLAMA_MODEL},tokenizer_type:llama,triton_max_batch_size:64,postprocessing_instance_count:1 +python3 tools/fill_template.py -i llama_ifb/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:64,decoupled_mode:True,bls_instance_count:1,accumulate_tokens:True +python3 tools/fill_template.py -i llama_ifb/ensemble/config.pbtxt triton_max_batch_size:64 +python3 tools/fill_template.py -i llama_ifb/tensorrt_llm/config.pbtxt triton_max_batch_size:64,decoupled_mode:True,max_beam_width:1,engine_dir:/tmp/llama/7B/trt_engines/fp16/1-gpu/,max_tokens_in_paged_kv_cache:2560,max_kv_cache_length:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,batching_strategy:inflight_batching,max_queue_delay_microseconds:600 + +pip pinstall SentencePiece +python3 scripts/launch_triton_server.py --world_size 1 --model_repo=llama_ifb/ + +python3 inflight_batcher_llm/client/inflight_batcher_llm_client.py --request-output-len 200 --tokenizer-dir ${HF_LLAMA_MODEL} --streaming +``` + +
+ The result would be like + + +```bash +========= +Input sequence: [1, 19298, 297, 6641, 29899, 23027, 3444, 29892, 1105, 7598, 16370, 408, 263] +[14547] +[297] +[3681] +[322] +[4517] +[1434] +[8401] +[304] +[1570] +[3088] +[297] +[29871] +[29896] +[29947] +[29945] +[29900] +[29889] +[940] +[471] +[278] +[937] +[14547] +[304] +[367] +[298] +[2859] +[491] +[278] +[15141] +[6496] +[5556] +[3712] +[1417] +[30010] +[29879] +[27144] +[29892] +[988] +[540] +[3796] +[363] +[29871] +[29896] +[29900] +[2440] +[29889] +[940] +[769] +[6496] +[670] +[1914] +[27144] +[29892] +[607] +[471] +[263] +[12176] +[2551] +[29889] +[13] +[6295] +[7598] +[471] +[263] +[410] +[29880] +[928] +[9227] +[322] +[670] +[8277] +[3160] +[450] +[402] +[7614] +[4917] +[293] +[2169] +[759] +[1061] +[313] +[29896] +[29947] +[29945] +[29946] +[511] +[450] +[402] +[7614] +[4917] +[293] +[2169] +[759] +[1061] +[322] +[17278] +[708] +[363] +[278] +[11647] +[313] +[29896] +[29947] +[29945] +[29945] +[511] +[450] +[315] +[4664] +[457] +[310] +[1763] +[29899] +[3250] +[313] +[29896] +[29947] +[29945] +[29929] +[511] +[450] +[315] +[4664] +[457] +[310] +[1763] +[29899] +[26122] +[313] +[29896] +[29947] +[29953] +[29946] +[511] +[450] +[315] +[4664] +[457] +[310] +[278] +[16367] +[313] +[29896] +[29947] +[29953] +[29955] +[511] +[450] +[315] +[4664] +[457] +[310] +[278] +[16367] +[313] +[29896] +[29947] +[29955] +[29941] +[511] +[450] +[315] +[4664] +[457] +[310] +[278] +[16367] +[313] +[29896] +[29947] +[29955] +[29946] +[511] +[450] +[315] +[4664] +[457] +[310] +[278] +[16367] +[313] +[29896] +[29947] +[29955] +[29945] +[511] +[450] +[315] +[4664] +[457] +[310] +[278] +[16367] +[313] +[29896] +[29947] +[29955] +[29953] +[511] +[450] +Input: Born in north-east France, Soyer trained as a +Output beam 0: chef in Paris and London before moving to New York in 1850. He was the first chef to be hired by the newly opened Delmonico’s restaurant, where he worked for 10 years. He then opened his own restaurant, which was a huge success. +Soyer was a prolific writer and his books include The Gastronomic Regenerator (1854), The Gastronomic Regenerator and Cookery for the People (1855), The Cuisine of To-day (1859), The Cuisine of To-morrow (1864), The Cuisine of the Future (1867), The Cuisine of the Future (1873), The Cuisine of the Future (1874), The Cuisine of the Future (1875), The Cuisine of the Future (1876), The +Output sequence: [1, 19298, 297, 6641, 29899, 23027, 3444, 29892, 1105, 7598, 16370, 408, 263, 14547, 297, 3681, 322, 4517, 1434, 8401, 304, 1570, 3088, 297, 29871, 29896, 29947, 29945, 29900, 29889, 940, 471, 278, 937, 14547, 304, 367, 298, 2859, 491, 278, 15141, 6496, 5556, 3712, 1417, 30010, 29879, 27144, 29892, 988, 540, 3796, 363, 29871, 29896, 29900, 2440, 29889, 940, 769, 6496, 670, 1914, 27144, 29892, 607, 471, 263, 12176, 2551, 29889, 13, 6295, 7598, 471, 263, 410, 29880, 928, 9227, 322, 670, 8277, 3160, 450, 402, 7614, 4917, 293, 2169, 759, 1061, 313, 29896, 29947, 29945, 29946, 511, 450, 402, 7614, 4917, 293, 2169, 759, 1061, 322, 17278, 708, 363, 278, 11647, 313, 29896, 29947, 29945, 29945, 511, 450, 315, 4664, 457, 310, 1763, 29899, 3250, 313, 29896, 29947, 29945, 29929, 511, 450, 315, 4664, 457, 310, 1763, 29899, 26122, 313, 29896, 29947, 29953, 29946, 511, 450, 315, 4664, 457, 310, 278, 16367, 313, 29896, 29947, 29953, 29955, 511, 450, 315, 4664, 457, 310, 278, 16367, 313, 29896, 29947, 29955, 29941, 511, 450, 315, 4664, 457, 310, 278, 16367, 313, 29896, 29947, 29955, 29946, 511, 450, 315, 4664, 457, 310, 278, 16367, 313, 29896, 29947, 29955, 29945, 511, 450, 315, 4664, 457, 310, 278, 16367, 313, 29896, 29947, 29955, 29953, 511, 450] +``` + +
+ + +* Run several requests at the same time + +```bash +echo '{"text_input": "What is machine learning?", "max_tokens": 20, "bad_words": "", "stop_words": "", "pad_id": 2, "end_id": 2}' > tmp.txt +printf '%s\n' {1..20} | xargs -I % -P 20 curl -X POST localhost:8000/v2/models/ensemble/generate -d @tmp.txt +``` diff --git a/inflight_batcher_llm/CMakeLists.txt b/inflight_batcher_llm/CMakeLists.txt index 73b63253..390d7dbf 100644 --- a/inflight_batcher_llm/CMakeLists.txt +++ b/inflight_batcher_llm/CMakeLists.txt @@ -147,7 +147,11 @@ endif() # configure_file(src/libtriton_tensorrtllm.ldscript libtriton_tensorrtllm.ldscript COPYONLY) -add_library(triton-tensorrt-llm-backend SHARED src/libtensorrtllm.cc) + +set(SRCS src/libtensorrtllm.cc src/work_item.cc src/work_items_queue.cc + src/model_instance_state.cc src/model_state.cc src/utils.cc) + +add_library(triton-tensorrt-llm-backend SHARED ${SRCS}) if(TRITON_BUILD) add_custom_target(trtllm_target DEPENDS tensorrt_llm_build) @@ -251,6 +255,7 @@ target_compile_options( -Wall -Wextra -Wno-unused-parameter + -Wno-deprecated-declarations -Wno-type-limits> $<$:/Wall /D_WIN32_WINNT=0x0A00 diff --git a/inflight_batcher_llm/client/end_to_end_grpc_client.py b/inflight_batcher_llm/client/end_to_end_grpc_client.py index 1f77b839..5d543ad3 100644 --- a/inflight_batcher_llm/client/end_to_end_grpc_client.py +++ b/inflight_batcher_llm/client/end_to_end_grpc_client.py @@ -35,14 +35,12 @@ def callback(user_data, result, error): user_data._completed_requests.put(error) else: user_data._completed_requests.put(result) - output = result.as_numpy('text_output') - print(output, flush=True) def test(triton_client, prompt, request_id, repetition_penalty, presence_penalty, temperatuure, stop_words, bad_words, embedding_bias_words, embedding_bias_weights): - model_name = "ensemble" + model_name = FLAGS.model_name input0 = [[prompt]] input0_data = np.array(input0).astype(object) @@ -120,6 +118,7 @@ def test(triton_client, prompt, request_id, repetition_penalty, triton_client.stop_stream() # Parse the responses + output_text = "" while True: try: result = user_data._completed_requests.get(block=False) @@ -130,7 +129,18 @@ def test(triton_client, prompt, request_id, repetition_penalty, print("Received an error from server:") print(result) else: - result.as_numpy('text_output') + output = result.as_numpy('text_output') + if FLAGS.streaming and FLAGS.beam_width == 1: + new_output = output[0].decode("utf8") + if FLAGS.overwrite_output_text: + output_text = new_output + else: + output_text += new_output + else: + print(output, flush=True) + + if FLAGS.streaming and FLAGS.beam_width == 1: + print(output_text) if __name__ == '__main__': @@ -152,6 +162,13 @@ def test(triton_client, prompt, request_id, repetition_penalty, type=str, required=True, help='Input prompt.') + + parser.add_argument('--model-name', + type=str, + required=False, + default="ensemble", + help='Name of the Triton model to send request to') + parser.add_argument( "-S", "--streaming", @@ -213,7 +230,7 @@ def test(triton_client, prompt, request_id, repetition_penalty, parser.add_argument('--request-id', type=str, - default='1', + default='', required=False, help='The request_id for the stop request') @@ -237,6 +254,15 @@ def test(triton_client, prompt, request_id, repetition_penalty, default=[], help='The biased words weights') + parser.add_argument( + '--overwrite-output-text', + action="store_true", + required=False, + default=False, + help= + 'In streaming mode, overwrite previously received output text instead of appending to it' + ) + FLAGS = parser.parse_args() if FLAGS.url is None: FLAGS.url = "localhost:8000" if FLAGS.protocol == "http" else "localhost:8001" diff --git a/inflight_batcher_llm/client/inflight_batcher_llm_client.py b/inflight_batcher_llm/client/inflight_batcher_llm_client.py index c4900eeb..ddeb44ea 100755 --- a/inflight_batcher_llm/client/inflight_batcher_llm_client.py +++ b/inflight_batcher_llm/client/inflight_batcher_llm_client.py @@ -180,9 +180,11 @@ def callback(user_data, result, error): user_data._completed_requests.put(result) if (FLAGS.streaming): output_ids = result.as_numpy('output_ids') + seq_lens = result.as_numpy('sequence_length') if output_ids != None: - tokens = list(output_ids[0][0]) - print(tokens, flush=True) + if seq_lens == None or seq_lens[0][0] > 0: + tokens = list(output_ids[0][0]) + print(tokens, flush=True) if __name__ == "__main__": @@ -367,7 +369,7 @@ def callback(user_data, result, error): help='Specify tokenizer type') parser.add_argument('--request-id', type=str, - default='1', + default='', required=False, help='The request_id for the stop request') @@ -625,9 +627,11 @@ def callback(user_data, result, error): sequence_lengths = result.as_numpy('sequence_length') if output_ids is not None: # Only one beam is supported - tokens = list(output_ids[0][0]) - actual_output_ids[ - 0] = actual_output_ids[0] + tokens + if sequence_lengths == None or sequence_lengths[0][ + 0] > 0: + tokens = list(output_ids[0][0]) + actual_output_ids[ + 0] = actual_output_ids[0] + tokens else: print("Got cancellation response from server") else: diff --git a/inflight_batcher_llm/src/libtensorrtllm.cc b/inflight_batcher_llm/src/libtensorrtllm.cc index aa206cdb..7081b269 100644 --- a/inflight_batcher_llm/src/libtensorrtllm.cc +++ b/inflight_batcher_llm/src/libtensorrtllm.cc @@ -33,373 +33,26 @@ #include #include +// Triton headers #include "triton/backend/backend_common.h" -#include "triton/backend/backend_input_collector.h" -#include "triton/backend/backend_model.h" -#include "triton/backend/backend_model_instance.h" -#include "triton/backend/backend_output_responder.h" #include "triton/core/tritonbackend.h" #include "triton/core/tritonserver.h" -#include "tensorrt_llm/batch_manager/GptManager.h" -#include "tensorrt_llm/batch_manager/NamedTensor.h" -#include "tensorrt_llm/batch_manager/callbacks.h" -#include "tensorrt_llm/batch_manager/inferenceRequest.h" -#include "tensorrt_llm/batch_manager/kvCacheConfig.h" -#include "tensorrt_llm/batch_manager/trtGptModelOptionalParams.h" -#include "tensorrt_llm/common/logger.h" -#include "tensorrt_llm/common/mpiUtils.h" -#include "tensorrt_llm/plugins/api/tllmPlugin.h" -#include "tensorrt_llm/runtime/tllmLogger.h" - -#include -using namespace ::triton::common; // TritonJson +// trtllm backend headers +#include "model_instance_state.h" +#include "model_state.h" +#include "work_item.h" +#include "work_items_queue.h" #ifdef TRITON_ENABLE_METRICS #include "metrics/triton_metrics.h" #endif -// -// Mockup of LLM inflight batcher based on triton 'minimal' backend example -// - -using namespace tensorrt_llm::batch_manager; -using namespace tensorrt_llm::runtime; -using namespace tensorrt_llm::mpi; -using namespace std::placeholders; // for _1, _2 etc. - -// template class inflight_batcher::batch_manager::GPTManager; - -namespace triton -{ -namespace backend -{ -namespace inflight_batcher_llm -{ - -inline static const std::string kStopInputTensorName = "stop"; -inline static const std::string kStreamingInputTensorName = "streaming"; - -bool getRequestBooleanInputTensor(TRITONBACKEND_Request* request, const std::string& inputTensorName) -{ - // Get stop signal from the request - TRITONBACKEND_Input* input; - TRITONSERVER_Error* error = TRITONBACKEND_RequestInput(request, inputTensorName.c_str(), &input); - if (error) - { - // If the user does not provide input "stop", then regard the request as - // unstopped - std::string msg - = "ModelInstanceState::getRequestBooleanInputTensor: user " - "did not not provide " - + inputTensorName + " input for the request"; - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, msg.c_str()); - return false; - } - - uint64_t input_byte_size = 0; - uint32_t buffer_count = 0; - TRITONBACKEND_InputProperties(input, nullptr, nullptr, nullptr, nullptr, &input_byte_size, &buffer_count); - - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, - ("ModelInstanceState::getRequestStopSignal: buffer_count = " + std::to_string(buffer_count)).c_str()); - - const void* buffer = 0L; - uint64_t buffer_byte_size = 0; - TRITONSERVER_MemoryType memory_type = TRITONSERVER_MEMORY_CPU; - int64_t memory_type_id = 0; - TRITONBACKEND_InputBuffer(input, 0, &buffer, &buffer_byte_size, &memory_type, &memory_type_id); - - assert((memory_type == TRITONSERVER_MEMORY_CPU) || (memory_type == TRITONSERVER_MEMORY_CPU_PINNED)); - - bool boolean = *reinterpret_cast(buffer); - - return boolean; -} - -nvinfer1::DataType to_trt_datatype(TRITONSERVER_DataType data_type) -{ - if (data_type == TRITONSERVER_TYPE_INVALID) - { - assert(false); - } - else if (data_type == TRITONSERVER_TYPE_BOOL) - { - return nvinfer1::DataType::kBOOL; - } - else if (data_type == TRITONSERVER_TYPE_UINT8) - { - return nvinfer1::DataType::kUINT8; - } - else if (data_type == TRITONSERVER_TYPE_UINT16) - { - assert(false); - } - else if (data_type == TRITONSERVER_TYPE_UINT32) - { - return nvinfer1::DataType::kINT32; - } - else if (data_type == TRITONSERVER_TYPE_UINT64) - { - return nvinfer1::DataType::kINT64; - } - else if (data_type == TRITONSERVER_TYPE_INT8) - { - return nvinfer1::DataType::kINT8; - } - else if (data_type == TRITONSERVER_TYPE_INT16) - { - assert(false); - } - else if (data_type == TRITONSERVER_TYPE_INT32) - { - return nvinfer1::DataType::kINT32; - } - else if (data_type == TRITONSERVER_TYPE_INT64) - { - return nvinfer1::DataType::kINT64; - } - else if (data_type == TRITONSERVER_TYPE_FP16) - { - return nvinfer1::DataType::kHALF; - } - else if (data_type == TRITONSERVER_TYPE_FP32) - { - return nvinfer1::DataType::kFLOAT; - } - else if (data_type == TRITONSERVER_TYPE_FP64) - { - assert(false); - } - else if (data_type == TRITONSERVER_TYPE_BYTES) - { - return nvinfer1::DataType::kINT8; - } - else if (data_type == TRITONSERVER_TYPE_BF16) - { - return nvinfer1::DataType::kBF16; - } - else - { - assert(false); - } - return nvinfer1::DataType(0); -} - -TRITONSERVER_DataType to_triton_datatype(nvinfer1::DataType data_type) -{ - if (data_type == nvinfer1::DataType::kBOOL) - { - return TRITONSERVER_TYPE_BOOL; - } - else if (data_type == nvinfer1::DataType::kUINT8) - { - return TRITONSERVER_TYPE_UINT8; - } - else if (data_type == nvinfer1::DataType::kHALF) - { - return TRITONSERVER_TYPE_BF16; - } - else if (data_type == nvinfer1::DataType::kINT8) - { - return TRITONSERVER_TYPE_INT8; - } - else if (data_type == nvinfer1::DataType::kINT32) - { - return TRITONSERVER_TYPE_INT32; - } - else if (data_type == nvinfer1::DataType::kINT64) - { - return TRITONSERVER_TYPE_INT64; - } - else if (data_type == nvinfer1::DataType::kFLOAT) - { - return TRITONSERVER_TYPE_FP32; - } - else if (data_type == nvinfer1::DataType::kBF16) - { - return TRITONSERVER_TYPE_BF16; - } - else - { - return TRITONSERVER_TYPE_INVALID; - } -} - -///////////// - -// -// ModelState -// -// State associated with a model that is using this backend. An object -// of this class is created and associated with each -// TRITONBACKEND_Model. -// -class ModelState -{ -public: - static TRITONSERVER_Error* Create(TRITONBACKEND_Model* triton_model, ModelState** state); - - template - T GetParameter(const std::string& name) - { - assert(false); - } - - virtual ~ModelState() = default; - -#ifdef TRITON_ENABLE_METRICS - TRITONSERVER_Error* InitMetrics(const std::string& model_name, const uint64_t version, const bool is_v1_model); - TRITONSERVER_Error* UpdateMetrics(const std::string& statistics); -#endif - common::TritonJson::Value& GetModelConfig(); - -private: -#ifdef TRITON_ENABLE_METRICS - std::unique_ptr triton_metrics_; -#endif - common::TritonJson::Value model_config_; - std::shared_ptr mTrtLogger{}; - - ModelState(TRITONBACKEND_Model* triton_model, TritonJson::Value&& model_config) - : model_config_(std::move(model_config)) - { - mTrtLogger = std::make_shared(); - initTrtLlmPlugins(mTrtLogger.get()); -#ifdef TRITON_ENABLE_METRICS - triton_metrics_ = std::make_unique(); -#endif - } -}; - -TRITONSERVER_Error* ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state) -{ - TRITONSERVER_Message* config_message; - RETURN_IF_ERROR(TRITONBACKEND_ModelConfig(triton_model, 1 /* config_version */, &config_message)); - - // We can get the model configuration as a json string from - // config_message, parse it with our favorite json parser to create - // DOM that we can access when we need to example the - // configuration. We use TritonJson, which is a wrapper that returns - // nice errors (currently the underlying implementation is - // rapidjson... but others could be added). You can use any json - // parser you prefer. - const char* buffer; - size_t byte_size; - RETURN_IF_ERROR(TRITONSERVER_MessageSerializeToJson(config_message, &buffer, &byte_size)); - - common::TritonJson::Value model_config; - TRITONSERVER_Error* err = model_config.Parse(buffer, byte_size); - RETURN_IF_ERROR(TRITONSERVER_MessageDelete(config_message)); - RETURN_IF_ERROR(err); - - try - { - *state = new ModelState(triton_model, std::move(model_config)); - } - catch (const std::exception& ex) - { - std::string errStr = std::string("unexpected error when creating modelState: ") + ex.what(); - return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, errStr.c_str()); - } - - return nullptr; // success -} - -common::TritonJson::Value& ModelState::GetModelConfig() -{ - return model_config_; -} - -template <> -std::string ModelState::GetParameter(const std::string& name) -{ - TritonJson::Value parameters; - TRITONSERVER_Error* err = model_config_.MemberAsObject("parameters", ¶meters); - if (err != nullptr) - { - throw std::runtime_error("Model config doesn't have a parameters section"); - TRITONSERVER_ErrorDelete(err); - } - TritonJson::Value value; - std::string str_value; - err = parameters.MemberAsObject(name.c_str(), &value); - if (err != nullptr) - { - std::string errStr = "Cannot find parameter with name: " + name; - throw std::runtime_error(errStr); - TRITONSERVER_ErrorDelete(err); - } - value.MemberAsString("string_value", &str_value); - return str_value; -} - -template <> -int32_t ModelState::GetParameter(const std::string& name) -{ - return std::stoi(GetParameter(name)); -} - -template <> -uint32_t ModelState::GetParameter(const std::string& name) -{ - return (uint32_t) std::stoul(GetParameter(name)); -} - -template <> -int64_t ModelState::GetParameter(const std::string& name) -{ - return std::stoll(GetParameter(name)); -} - -template <> -uint64_t ModelState::GetParameter(const std::string& name) -{ - return std::stoull(GetParameter(name)); -} - -template <> -float ModelState::GetParameter(const std::string& name) -{ - return std::stof(GetParameter(name)); -} - -template <> -bool ModelState::GetParameter(const std::string& name) -{ - auto val = GetParameter(name); - if (val == "True" || val == "true" || val == "TRUE" || val == "1") - { - return true; - } - else if (val == "False" || val == "false" || val == "FALSE" || val == "0") - { - return false; - } - else - { - std::string err = "Cannot convert " + val + " to a boolean."; - throw std::runtime_error(err); - } -} - -#ifdef TRITON_ENABLE_METRICS -TRITONSERVER_Error* ModelState::InitMetrics( - const std::string& model_name, const uint64_t version, const bool is_v1_model) -{ - RETURN_IF_ERROR(triton_metrics_->InitMetrics(model_name, version, is_v1_model)); - return nullptr; // success -} -TRITONSERVER_Error* ModelState::UpdateMetrics(const std::string& statistics) +namespace triton::backend::inflight_batcher_llm { - RETURN_IF_ERROR(triton_metrics_->UpdateMetrics(statistics)); - return nullptr; // success -} -#endif extern "C" { - // Triton calls TRITONBACKEND_ModelInitialize when a model is loaded // to allow the backend to create any state associated with the model, // and to also examine the model configuration to determine if the @@ -446,935 +99,6 @@ extern "C" return nullptr; // success } -} // extern "C" - -///////////// - -// Class holding all infos regarding a single work item. -// This includes the original request, associated response factor -// and state. -class WorkItem -{ -public: - WorkItem(TRITONBACKEND_Request* request, bool isDecoupled) - { - uint64_t requestId = (rand() % INT64_MAX) + 1; - Initialize(request, requestId, isDecoupled); - } - - WorkItem(TRITONBACKEND_Request* request, uint64_t requestId, bool isDecoupled) - { - Initialize(request, requestId, isDecoupled); - } - - WorkItem(std::shared_ptr ir, uint64_t RequestId) - : mInferenceRequest(ir) - , mRequestId(RequestId) - { - factory_ptr_ = nullptr; - } - - ~WorkItem() - { - if (factory_ptr_ != nullptr) - { - TRITONBACKEND_ResponseFactoryDelete(factory_ptr_); - } - } - - TRITONBACKEND_ResponseFactory* response_factory() - { - assert(factory_ptr_ != nullptr); - return factory_ptr_; - } - - uint64_t requestId() const - { - return mRequestId; - } - - std::shared_ptr getInferenceRequest() const - { - return mInferenceRequest; - } - - bool hasOutputName(const std::string& outputName) - { - return (mRequestOutputNames.find(outputName) != mRequestOutputNames.end()); - } - -private: - void Initialize(TRITONBACKEND_Request* request, uint64_t requestId, bool isDecoupled) - { - mRequestId = requestId; - mInferenceRequest = createInferenceRequest(request, requestId, isDecoupled); - mRequestOutputNames = getRequestOutputNames(request); - - // Create response factory for this request - TRITONBACKEND_ResponseFactoryNew(&factory_ptr_, request); - } - - // Convert info from original backend request to data structures defined in - // common/common.h - std::shared_ptr createInferenceRequest( - TRITONBACKEND_Request* request, uint64_t requestId, bool isDecoupled) - { - auto inferenceRequest = std::make_shared(requestId); - - // Extract input tensors - std::map input_tensors; - uint32_t num_inputs; - LOG_IF_ERROR(TRITONBACKEND_RequestInputCount(request, &num_inputs), "Error getting input count"); - for (uint32_t idx = 0; idx < num_inputs; ++idx) - { - TRITONBACKEND_Input* input = 0L; - TRITONBACKEND_RequestInputByIndex(request, idx, &input); - - const char* input_name = 0L; - TRITONSERVER_DataType data_type = TRITONSERVER_TYPE_INVALID; - const int64_t* shape = 0L; - uint32_t dims_count = 0; - uint64_t byte_size = 0; - uint32_t buffer_count = 0; - TRITONBACKEND_InputProperties( - input, &input_name, &data_type, &shape, &dims_count, &byte_size, &buffer_count); - - if (std::string(input_name) == "START" || std::string(input_name) == "CORRID" - || std::string(input_name) == "END" || std::string(input_name) == kStopInputTensorName - || std::string(input_name) == kStreamingInputTensorName) - { - continue; - } - - std::vector shapev; - for (uint32_t i = 0; i < dims_count; ++i) - { - shapev.push_back(shape[i]); - } - - NamedTensor t(to_trt_datatype(data_type), shapev, input_name); - uint64_t buffer_offset = 0; - for (int64_t buffer_id = 0; buffer_id < buffer_count; ++buffer_id) - { - const void* buffer = 0L; - uint64_t buffer_byte_size = 0; - TRITONSERVER_MemoryType memory_type = TRITONSERVER_MEMORY_CPU; - int64_t memory_type_id = 0; - TRITONBACKEND_InputBuffer(input, buffer_id, &buffer, &buffer_byte_size, &memory_type, &memory_type_id); - assert((memory_type == TRITONSERVER_MEMORY_CPU) || (memory_type == TRITONSERVER_MEMORY_CPU_PINNED)); - // TODO: Do we need to handle GPU mem input buffers?? - std::memcpy(static_cast(t.tensor->data()) + buffer_offset, buffer, buffer_byte_size); - buffer_offset += buffer_byte_size; - } - - inferenceRequest->emplaceInputTensor(t.name, std::move(t.tensor)); - } - - bool streamingFlag = getRequestBooleanInputTensor(request, kStreamingInputTensorName); - inferenceRequest->setIsStreaming(streamingFlag); - - if (streamingFlag && !isDecoupled) - { - throw std::runtime_error( - "Streaming is only supported if model is " - "deployed using decoupled mode."); - } - - return inferenceRequest; - } - - std::unordered_set getRequestOutputNames(TRITONBACKEND_Request* request) - { - std::unordered_set outputNames; - uint32_t outputCount; - LOG_IF_ERROR(TRITONBACKEND_RequestOutputCount(request, &outputCount), "Error getting request output count"); - for (size_t i = 0; i < outputCount; ++i) - { - const char* name; - LOG_IF_ERROR(TRITONBACKEND_RequestOutputName(request, i, &name), "Error getting request output name"); - std::string name_s(name); - outputNames.insert(std::move(name_s)); - } - return outputNames; - } - - std::shared_ptr mInferenceRequest; - TRITONBACKEND_ResponseFactory* factory_ptr_; - uint64_t mRequestId; - std::unordered_set mRequestOutputNames; -}; - -/// @brief Thread-safe queue of work items - -class WorkItemsQueue -{ -public: - void clear() - { - std::lock_guard lk(mMutex); - mPendingWorkItems.clear(); - mPendingWorkItemsReqIds.clear(); - mInProgressWorkItems.clear(); - mStoppedReqIds.clear(); - } - - // Note: this function only be called under a lock - bool hasInProgressReqId(const uint64_t reqId) const - { - return (mInProgressWorkItems.find(reqId) != mInProgressWorkItems.end()); - } - - // Note: this function only be called under a lock - bool hasPendingReqId(const uint64_t reqId) const - { - return (mPendingWorkItemsReqIds.find(reqId) != mPendingWorkItemsReqIds.end()); - } - - /// @brief Add a batch of new work item to the queue - /// Throws an error if requestId already exists - std::vector> pushBatch( - std::vector>& requestsToPush, bool isDecoupled) - { - std::lock_guard lk(mMutex); - std::vector> reqExceptions; - for (auto& [requestId, request] : requestsToPush) - { - if (requestId != 0 && (hasInProgressReqId(requestId) || hasPendingReqId(requestId))) - { - std::string errStr - = "requestId " + std::to_string(requestId) + " is already in progress, request is ignored."; - reqExceptions.emplace_back(std::make_shared(errStr)); - } - else - { - auto workItem = requestId != 0 ? std::make_shared(request, requestId, isDecoupled) - : std::make_shared(request, isDecoupled); - mPendingWorkItems.push_back(workItem); - mPendingWorkItemsReqIds.insert(workItem->requestId()); - reqExceptions.push_back(nullptr); - } - } - return reqExceptions; - } - - /// @brief Add a new work item to the queue - /// Throws an error if requestId already exists - void push(TRITONBACKEND_Request* request, uint64_t requestId, bool isDecoupled) - { - std::lock_guard lk(mMutex); - if (hasInProgressReqId(requestId) || hasPendingReqId(requestId)) - { - std::string errStr - = "requestId " + std::to_string(requestId) + " is already in progress, request is ignored."; - throw std::runtime_error(errStr); - } - else - { - auto workItem = std::make_shared(request, requestId, isDecoupled); - mPendingWorkItems.push_back(workItem); - mPendingWorkItemsReqIds.insert(workItem->requestId()); - } - } - - void push(TRITONBACKEND_Request* request, bool isDecoupled) - { - std::lock_guard lk(mMutex); - auto workItem = std::make_shared(request, isDecoupled); - mPendingWorkItems.push_back(workItem); - mPendingWorkItemsReqIds.insert(workItem->requestId()); - } - - /// @brief Get a new work item from the queue, and move it to the list of - /// in progress work items if it hasn't been stopped - /// @return A tuple of the workItem and a boolean flag indicating if the work - /// item has been marked in progress - /// In case the queue is empty, return nullptr - - std::tuple, bool> pop() - { - std::lock_guard lk(mMutex); - if (mPendingWorkItems.empty()) - { - return {nullptr, false}; - } - - auto workItem = mPendingWorkItems.front(); - mPendingWorkItems.pop_front(); - mPendingWorkItemsReqIds.erase(workItem->requestId()); - - // Check if work item has been stopped - bool is_stopped = mStoppedReqIds.count(workItem->requestId()); - - // Check if the Triton request has been cancelled - bool is_cancelled = false; - TRITONBACKEND_ResponseFactoryIsCancelled(workItem->response_factory(), &is_cancelled); - - bool stoppedRequest = false; - if (!is_stopped && !is_cancelled) - { - mInProgressWorkItems.emplace(std::make_pair(workItem->requestId(), workItem)); - } - else - { - mStoppedReqIds.erase(workItem->requestId()); - stoppedRequest = true; - } - - return {workItem, stoppedRequest}; - } - - size_t numPendingWorkItems() const - { - std::lock_guard lk(mMutex); - return mPendingWorkItems.size(); - } - - std::shared_ptr getInProgressWorkItem(uint64_t requestId) - { - std::lock_guard lk(mMutex); - return mInProgressWorkItems.at(requestId); - } - - /// @brief Mark a request as being finished - /// @param requestId - void markFinished(const uint64_t requestId) - { - std::lock_guard lk(mMutex); - if (hasInProgressReqId(requestId)) - { - mInProgressWorkItems.erase(requestId); - } - - if (mStoppedReqIds.find(requestId) != mStoppedReqIds.end()) - { - mStoppedReqIds.erase(requestId); - } - } - - // Stop a request by adding the request Id to a set - // The set of stopped request id is used by the poll callback - // and the pop function - void stopWorkItem(const uint64_t requestId) - { - std::lock_guard lk(mMutex); - TLLM_LOG_DEBUG("Stopping request"); - if (hasInProgressReqId(requestId) || hasPendingReqId(requestId)) - { - mStoppedReqIds.emplace(requestId); - } - else - { - std::string errStr = std::string("Received stop request for requestId ") + std::to_string(requestId) - + std::string(" but it's not active (might be completed already)."); - throw std::runtime_error(errStr); - } - } - - std::unordered_set getStoppedReqIds() const - { - std::lock_guard lk(mMutex); - return mStoppedReqIds; - } - - std::unordered_set getCancelledInProgressReqIds() const - { - std::unordered_set cancelledInProgressReqIds; - { - std::lock_guard lk(mMutex); - for (const auto& pair : mInProgressWorkItems) - { - bool is_cancelled = false; - TRITONBACKEND_ResponseFactoryIsCancelled(pair.second->response_factory(), &is_cancelled); - if (is_cancelled) - { - cancelledInProgressReqIds.emplace(pair.first); - } - } - } - return cancelledInProgressReqIds; - } - -private: - /// Queue of work items - std::list> mPendingWorkItems; - /// requestIds of work items in the queue - std::set mPendingWorkItemsReqIds; - - /// work items currently in progress - std::unordered_map> mInProgressWorkItems; - - /// ids of the work items that have been stopped - std::unordered_set mStoppedReqIds; - - mutable std::mutex mMutex; -}; - -// -// ModelInstanceState -// -// State associated with a model instance. An object of this class is -// created and associated with each -// TRITONBACKEND_ModelInstance. ModelInstanceState is derived from -// -class ModelInstanceState -{ -public: - static TRITONSERVER_Error* Create( - ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance, ModelInstanceState** state); - - virtual ~ModelInstanceState() - { - // terminate decoupled execution loop - { - mWorkItemsQueue.clear(); - } - } - - // Get the state of the model that corresponds to this instance. - ModelState* StateForModel() const - { - return model_state_; - } - - bool isDecoupled() const - { - return mIsDecoupled; - } - - uint64_t getRequestId(TRITONBACKEND_Request* request) - { - const char* charRequestId; - TRITONBACKEND_RequestId(request, &charRequestId); - uint64_t requestId = 0; - if (charRequestId != nullptr) - { - std::string strRequestId(charRequestId); - if (!strRequestId.empty()) - { - try - { - requestId = stoul(strRequestId); - } - catch (const std::exception& e) - { - std::string err = std::string("Invalid requestId, must be uint64_t. Got ") + strRequestId; - throw std::runtime_error(err); - } - } - } - - return requestId; - } - - // For stop requests, or in case of error during enqueue, we need to send a - // response to the client - void sendEnqueueResponse(TRITONBACKEND_Request* request, const std::string& errMsg = "") - { - TRITONBACKEND_ResponseFactory* factory_ptr; - // Create response factory for this request - LOG_IF_ERROR(TRITONBACKEND_ResponseFactoryNew(&factory_ptr, request), "Cannot create response factory"); - - TRITONSERVER_Error* err = nullptr; - if (!errMsg.empty()) - { - TLLM_LOG_ERROR(errMsg); - err = TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, errMsg.c_str()); - } - TRITONBACKEND_Response* response; - LOG_IF_ERROR(TRITONBACKEND_ResponseNewFromFactory(&response, factory_ptr), "Cannot create response"); - LOG_IF_ERROR( - TRITONBACKEND_ResponseSend(response, TRITONSERVER_RESPONSE_COMPLETE_FINAL, err), "Cannot send response"); - LOG_IF_ERROR(TRITONBACKEND_ResponseFactoryDelete(factory_ptr), "Cannot delete response factory"); - } - - void enqueue(TRITONBACKEND_Request** requests, const uint32_t request_count, bool isDecoupled) - { - std::vector> requestsToPush; - - for (uint32_t r = 0; r < request_count; ++r) - { - TRITONBACKEND_Request* request = requests[r]; - try - { - auto requestId = getRequestId(request); - bool stopRequest = getRequestBooleanInputTensor(request, kStopInputTensorName); - - if (stopRequest) - { - if (requestId != 0) - { - // Check if request is in progress or in queue, if not ignore - mWorkItemsQueue.stopWorkItem(requestId); - // Send a response back to client for stop request - sendEnqueueResponse(request); - } - else - { - throw std::runtime_error("Cannot send stop request without specifying a request_id"); - } - } - else - { - requestsToPush.emplace_back(requestId, request); - } - } - catch (const std::exception& e) - { - // In case of error, no work item is added to queue, so response - // callback needs to be called - sendEnqueueResponse(request, e.what()); - } - } - - auto exceptions = mWorkItemsQueue.pushBatch(requestsToPush, isDecoupled); - - for (uint32_t r = 0; r < requestsToPush.size(); ++r) - { - auto request = requestsToPush.at(r).second; - auto e = exceptions.at(r); - if (e) - { - sendEnqueueResponse(request, e->what()); - } - } - - return; - } - - // Return up to max_num_requests inference requests. - std::list> get_inference_requests(const int max_num_requests) - { - std::list> rval; - if (max_num_requests <= 0) - { - return rval; - } - - auto world_size = getCommWorldSize(); - auto rank = getCommWorldRank(); - if (rank == 0) - { - auto numPendingWorkItems = mWorkItemsQueue.numPendingWorkItems(); - // Loop over the pending work items and include at most `max_num_requests` - for (size_t i = 0; i < numPendingWorkItems && rval.size() < max_num_requests; ++i) - { - auto [workItem, stoppedRequest] = mWorkItemsQueue.pop(); - - if (workItem) - { - if (!stoppedRequest) - { - rval.emplace_back(workItem->getInferenceRequest()); - } - else - { - std::string warnStr = std::string("request Id ") + std::to_string(workItem->requestId()) - + std::string(" has been stopped. Request is ignored."); - TLLM_LOG_WARNING(warnStr); - sendTritonResponse(workItem, {}, true, warnStr); - } - } - } - - if (world_size > 1) - { - int64_t num_new_work_items = rval.size(); - bcast(&num_new_work_items, 1, MPI_TYPE_INT64_T, 0, COMM_WORLD); - - if (num_new_work_items > 0) - { - std::vector packed; - for (auto ir : rval) - { - auto vpacked = ir->serialize(); - packed.push_back(static_cast(vpacked.size())); - packed.insert( - packed.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end())); - } - bcast(packed, 0, COMM_WORLD); - } - } - } - else - { - // subordinate ranks hang until master rank sends work - int64_t num_new_work_items; - bcast(&num_new_work_items, 1, MPI_TYPE_INT64_T, 0, COMM_WORLD); - if (num_new_work_items > 0) - { - std::vector packed; - bcast(packed, 0, COMM_WORLD); - int64_t* packed_ptr = packed.data(); - for (int64_t count = 0; count < num_new_work_items; ++count) - { - int64_t n = *(packed_ptr++); - auto ir = InferenceRequest::deserialize(packed_ptr); - packed_ptr += n; - rval.emplace_back(ir); - } - } - } - return rval; - } - - TRITONSERVER_Error* sendTritonResponse(std::shared_ptr workItem, - std::list const& response_tensors, bool final_response, const std::string& errMsg) - { - TRITONBACKEND_ResponseFactory* response_factory; - response_factory = workItem->response_factory(); - - TRITONBACKEND_Response* response; - RETURN_IF_ERROR(TRITONBACKEND_ResponseNewFromFactory(&response, response_factory)); - - auto requestId = workItem->requestId(); - if (final_response) - { - mWorkItemsQueue.markFinished(requestId); - } - - // Check if error - TRITONSERVER_Error* err = nullptr; - if (!errMsg.empty()) - { - std::string errStr = "Encountered error for requestId " + std::to_string(requestId) + ": " + errMsg; - TLLM_LOG_ERROR(errStr); - - bool is_cancelled = false; - TRITONBACKEND_ResponseFactoryIsCancelled(response_factory, &is_cancelled); - - auto err_code = is_cancelled ? TRITONSERVER_ERROR_CANCELLED : TRITONSERVER_ERROR_INTERNAL; - - err = TRITONSERVER_ErrorNew(err_code, errStr.c_str()); - final_response = true; - } - else - { - for (auto it = response_tensors.begin(); it != response_tensors.end(); ++it) - { - auto tensor = *it; - if (!workItem->hasOutputName(tensor.name)) - { - continue; - } - auto shape = tensor.tensor->getShape(); // returns std::vectorint64_t> - std::vector vshape(shape.nbDims); - for (int i = 0; i < vshape.size(); ++i) - { - vshape[i] = shape.d[i]; - } - - TRITONBACKEND_Output* output; - RETURN_IF_ERROR(TRITONBACKEND_ResponseOutput(response, &output, tensor.name.c_str(), - to_triton_datatype(tensor.tensor->getDataType()), vshape.data(), shape.nbDims)); - - uint64_t buffersize = tensor.tensor->getSizeInBytes(); - void* buffer = 0L; - TRITONSERVER_MemoryType memory_type = TRITONSERVER_MEMORY_CPU; - int64_t memory_type_id = 0; - RETURN_IF_ERROR(TRITONBACKEND_OutputBuffer(output, &buffer, buffersize, &memory_type, &memory_type_id)); - if (memory_type != TRITONSERVER_MEMORY_CPU && memory_type != TRITONSERVER_MEMORY_CPU_PINNED) - { - std::string errStr = "Triton failed to allocate output buffer on CPU"; - err = TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, errStr.c_str()); - break; - } - std::memcpy(buffer, tensor.tensor->data(), buffersize); - } - } - - RETURN_IF_ERROR( - TRITONBACKEND_ResponseSend(response, final_response ? TRITONSERVER_RESPONSE_COMPLETE_FINAL : 0, err)); - - return nullptr; - } - - void sendResponse(uint64_t requestId, std::list const& response_tensors, bool final_response, - const std::string& errMsg) - { - if (getCommWorldRank() == 0) - { - std::string errStr - = std::string("Failed to send Triton response for requestId: ") + std::to_string(requestId); - try - { - auto workItem = mWorkItemsQueue.getInProgressWorkItem(requestId); - auto tritonErr = sendTritonResponse(workItem, response_tensors, final_response, errMsg); - LOG_IF_ERROR(tritonErr, errStr); - } - catch (const std::exception& e) - { - TLLM_LOG_ERROR(errStr); - } - } - } - - std::unordered_set pollStopSignals() - { - auto stoppedReqIds = mWorkItemsQueue.getStoppedReqIds(); - - // Merge cancelled requests into stopped requests Ids - auto cancelledReqIds = mWorkItemsQueue.getCancelledInProgressReqIds(); - stoppedReqIds.insert(cancelledReqIds.begin(), cancelledReqIds.end()); - - int64_t nStoppedReqIds = static_cast(stoppedReqIds.size()); - - if (getCommWorldSize() > 1) - { - // Broadcast number of stopped requests - bcast(&nStoppedReqIds, 1, MPI_TYPE_INT64_T, 0, COMM_WORLD); - - if (nStoppedReqIds > 0) - { - // Broadcast stopped requests Ids - if (getCommWorldRank() == 0) - { - // Store the requestIds in a contiguous vector - std::vector stoppedReqIdsVec(stoppedReqIds.begin(), stoppedReqIds.end()); - bcast(stoppedReqIdsVec.data(), stoppedReqIdsVec.size(), MPI_TYPE_UINT64_T, 0, COMM_WORLD); - } - else - { - std::vector stoppedReqIdsVec(nStoppedReqIds); - bcast(stoppedReqIdsVec.data(), stoppedReqIdsVec.size(), MPI_TYPE_UINT64_T, 0, COMM_WORLD); - // Store the requestIds in the set - stoppedReqIds.clear(); - std::copy(stoppedReqIdsVec.begin(), stoppedReqIdsVec.end(), - std::inserter(stoppedReqIds, stoppedReqIds.end())); - } - } - } - return stoppedReqIds; - } - - void logStats(const std::string& s) - { - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, s.c_str()); -#ifdef TRITON_ENABLE_METRICS - LOG_IF_ERROR(model_state_->UpdateMetrics(s), "Failed updating metrics"); -#endif - } - -private: - ModelInstanceState(ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance) - : model_state_(model_state) - , mIsDecoupled(false) - { - // Note: std::string::compare fails this test (always return non-zero - // value). Using old school strcmp instead. - if (model_state_->GetParameter("gpt_model_type") == "V1" - || model_state_->GetParameter("gpt_model_type") == "v1") - { - mTrtGptModelType = TrtGptModelType::V1; - } - else if (model_state_->GetParameter("gpt_model_type") == "inflight_batching") - { - mTrtGptModelType = TrtGptModelType::InflightBatching; - } - else if (model_state_->GetParameter("gpt_model_type") == "inflight_fused_batching") - { - mTrtGptModelType = TrtGptModelType::InflightFusedBatching; - } - else - { - throw std::runtime_error( - "Invalid gpt_model_type. Must be " - "v1/inflight_batching/inflight_fused_batching."); - } - - // Check if model is in decoupled mode: - triton::common::TritonJson::Value transaction_policy; - model_state_->GetModelConfig().MemberAsObject("model_transaction_policy", &transaction_policy); - transaction_policy.MemberAsBool("decoupled", &mIsDecoupled); - - // Note: std::string::compare fails this test (always return non-zero - // value). Using old school strcmp instead. - mModelPath = model_state_->GetParameter("gpt_model_path"); - auto configPath = mModelPath + "/config.json"; - std::ifstream jsonStream(configPath); - - auto constexpr allowExceptions = true; - auto constexpr ingoreComments = true; - auto json = nlohmann::json::parse(jsonStream, nullptr, allowExceptions, ingoreComments); - - int32_t maxBeamWidth = 1; - try - { - maxBeamWidth = model_state_->GetParameter("max_beam_width"); - } - catch (const std::exception& e) - { - // If parameter is not specified, just ignore - TLLM_LOG_WARNING("max_beam_width is not specified, will use default value of 1"); - } - - std::optional maxTokensInPagedKvCache = std::nullopt; - try - { - maxTokensInPagedKvCache = model_state_->GetParameter("max_tokens_in_paged_kv_cache"); - } - catch (const std::exception& e) - { - // If parameter is not specified, just ignore - TLLM_LOG_WARNING( - "max_tokens_in_paged_kv_cache is not specified, will " - "use default value"); - } - - auto schedulerPolicy = batch_scheduler::SchedulerPolicy::GUARANTEED_NO_EVICT; - try - { - std::string schedulerPolicyStr = model_state_->GetParameter("batch_scheduler_policy"); - if (schedulerPolicyStr == "max_utilization") - { - schedulerPolicy = batch_scheduler::SchedulerPolicy::MAX_UTILIZATION; - } - else if (schedulerPolicyStr == "guaranteed_no_evict") - { - schedulerPolicy = batch_scheduler::SchedulerPolicy::GUARANTEED_NO_EVICT; - } - else - { - throw std::runtime_error( - "batch_scheduler_policy parameter was not found or is invalid " - "(must be max_utilization or guaranteed_no_evict)"); - } - } - catch (const std::exception& e) - { - TLLM_LOG_WARNING(e.what()); - } - - if (mIsDecoupled && schedulerPolicy != batch_scheduler::SchedulerPolicy::GUARANTEED_NO_EVICT) - { - TLLM_LOG_WARNING( - "The batch scheduler policy will be set to guaranteed_no_evict" - "since the backend operates in decoupled mode"); - schedulerPolicy = batch_scheduler::SchedulerPolicy::GUARANTEED_NO_EVICT; - } - - std::optional kvCacheFreeGpuMemFraction = std::nullopt; - try - { - kvCacheFreeGpuMemFraction = model_state_->GetParameter("kv_cache_free_gpu_mem_fraction"); - } - catch (const std::exception& e) - { - // If parameter is not specified, just ignore - TLLM_LOG_WARNING( - "kv_cache_free_gpu_mem_fraction is not specified, will use default value of 0.85 or " - "max_tokens_in_paged_kv_cache"); - } - - std::optional maxNumSequences = std::nullopt; - try - { - maxNumSequences = model_state_->GetParameter("max_num_sequences"); - } - catch (const std::exception& e) - { - // If parameter is not specified, just ignore - TLLM_LOG_WARNING("max_num_sequences is not specified, will be set to the TRT engine max_batch_size"); - } - - bool enableTrtOverlap = true; - try - { - enableTrtOverlap = model_state_->GetParameter("enable_trt_overlap"); - } - catch (const std::exception& e) - { - // If parameter is not specified, just ignore - TLLM_LOG_WARNING("enable_trt_overlap is not specified, will be set to true"); - } - - bool excludeInputInOutput = false; - try - { - excludeInputInOutput = model_state_->GetParameter("exclude_input_in_output"); - } - catch (const std::exception& e) - { - // If parameter is not specified, just ignore - TLLM_LOG_WARNING("exclude_input_in_output is not specified, will be set to false"); - } - - std::optional maxKvCacheLength = std::nullopt; - try - { - maxKvCacheLength = model_state_->GetParameter("max_kv_cache_length"); - } - catch (const std::exception& e) - { - // If parameter is not specified, just ignore - TLLM_LOG_WARNING( - "max_kv_cache_length is not specified, will " - "use default value"); - } - - TrtGptModelOptionalParams optionalParams; - optionalParams.maxNumSequences = maxNumSequences; - optionalParams.kvCacheConfig.maxTokens = maxTokensInPagedKvCache; - optionalParams.kvCacheConfig.freeGpuMemoryFraction = kvCacheFreeGpuMemFraction; - optionalParams.kvCacheConfig.maxKvCacheLength = maxKvCacheLength; - optionalParams.enableTrtOverlap = enableTrtOverlap; - - mBatchManager = std::make_shared( - mModelPath, mTrtGptModelType, maxBeamWidth, schedulerPolicy, - [this](int max_num_requests) { return get_inference_requests(max_num_requests); }, - [this](uint64_t requestId, std::list response_tensors, bool final_response, - const std::string& errMsg) - { return sendResponse(requestId, response_tensors, final_response, errMsg); }, - [this]() { return pollStopSignals(); }, [this](const std::string& s) { return logStats(s); }, - optionalParams, std::nullopt, std::nullopt, excludeInputInOutput); - - if (getCommWorldRank() != 0) - { - while (true) - { - } - } - } - - ModelState* model_state_; - - // - // inflight batcher is a decoupled design. - // It uses response factory objects to decouple responses from requests. - // - // New requests are added to mWorkItems list. This list is processed - // in an infinite loop run by a worker thread. Requests take multiple - // iterations to complete, and number of iterations is not known in - // advance. To facilitate this, we use response factory objects to - // decouple requests and responses. - // - TrtGptModelType mTrtGptModelType; - std::string mModelPath; - bool mIsDecoupled; - - std::shared_ptr mBatchManager; - - WorkItemsQueue mWorkItemsQueue; -}; - -TRITONSERVER_Error* ModelInstanceState::Create( - ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance, ModelInstanceState** state) -{ - try - { - *state = new ModelInstanceState(model_state, triton_model_instance); - } - catch (const std::exception& ex) - { - std::string errStr = std::string("unexpected error when creating modelInstanceState: ") + ex.what(); - return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, errStr.c_str()); - } - - return nullptr; // success -} - -extern "C" -{ - // Triton calls TRITONBACKEND_ModelInstanceInitialize when a model // instance is created to allow the backend to initialize any state // associated with the instance. @@ -1412,13 +136,6 @@ extern "C" return nullptr; // success } -} // extern "C" - -///////////// - -extern "C" -{ - // When Triton calls TRITONBACKEND_ModelInstanceExecute it is required // that a backend create a response for each request in the batch. A // response may be the output tensors required for that request or may @@ -1445,6 +162,4 @@ extern "C" } // extern "C" -} // namespace inflight_batcher_llm -} // namespace backend -} // namespace triton +} // namespace triton::backend::inflight_batcher_llm diff --git a/inflight_batcher_llm/src/model_instance_state.cc b/inflight_batcher_llm/src/model_instance_state.cc new file mode 100644 index 00000000..1947091b --- /dev/null +++ b/inflight_batcher_llm/src/model_instance_state.cc @@ -0,0 +1,514 @@ +// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#define _GLIBCXX_USE_CXX11_ABI 0 + +#include "model_instance_state.h" + +namespace triton::backend::inflight_batcher_llm +{ + +TRITONSERVER_Error* ModelInstanceState::Create( + ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance, ModelInstanceState** state) +{ + try + { + *state = new ModelInstanceState(model_state, triton_model_instance); + } + catch (const std::exception& ex) + { + std::string errStr = std::string("unexpected error when creating modelInstanceState: ") + ex.what(); + return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, errStr.c_str()); + } + + return nullptr; // success +} + +ModelInstanceState::ModelInstanceState(ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance) + : model_state_(model_state) + , mIsDecoupled(false) +{ + // Note: std::string::compare fails this test (always return non-zero + // value). Using old school strcmp instead. + auto gpt_model_type = model_state_->GetParameter("gpt_model_type"); + if (gpt_model_type == "V1" || gpt_model_type == "v1") + { + mTrtGptModelType = TrtGptModelType::V1; + } + else if (gpt_model_type == "inflight_batching") + { + mTrtGptModelType = TrtGptModelType::InflightBatching; + } + else if (gpt_model_type == "inflight_fused_batching") + { + mTrtGptModelType = TrtGptModelType::InflightFusedBatching; + } + else + { + throw std::runtime_error( + "Invalid gpt_model_type. Must be " + "v1/inflight_batching/inflight_fused_batching."); + } + + // Check if model is in decoupled mode: + triton::common::TritonJson::Value transaction_policy; + model_state_->GetModelConfig().MemberAsObject("model_transaction_policy", &transaction_policy); + transaction_policy.MemberAsBool("decoupled", &mIsDecoupled); + + // Note: std::string::compare fails this test (always return non-zero + // value). Using old school strcmp instead. + mModelPath = model_state_->GetParameter("gpt_model_path"); + auto configPath = mModelPath + "/config.json"; + std::ifstream jsonStream(configPath); + + auto constexpr allowExceptions = true; + auto constexpr ingoreComments = true; + auto json = nlohmann::json::parse(jsonStream, nullptr, allowExceptions, ingoreComments); + + int32_t maxBeamWidth = 1; + try + { + maxBeamWidth = model_state_->GetParameter("max_beam_width"); + } + catch (const std::exception& e) + { + // If parameter is not specified, just ignore + TLLM_LOG_WARNING("max_beam_width is not specified, will use default value of 1"); + } + + std::optional maxTokensInPagedKvCache = std::nullopt; + try + { + maxTokensInPagedKvCache = model_state_->GetParameter("max_tokens_in_paged_kv_cache"); + } + catch (const std::exception& e) + { + // If parameter is not specified, just ignore + TLLM_LOG_WARNING( + "max_tokens_in_paged_kv_cache is not specified, will " + "use default value"); + } + + auto schedulerPolicy = SchedulerPolicy::GUARANTEED_NO_EVICT; + try + { + std::string schedulerPolicyStr = model_state_->GetParameter("batch_scheduler_policy"); + if (schedulerPolicyStr == "max_utilization") + { + schedulerPolicy = SchedulerPolicy::MAX_UTILIZATION; + } + else if (schedulerPolicyStr == "guaranteed_no_evict") + { + schedulerPolicy = SchedulerPolicy::GUARANTEED_NO_EVICT; + } + else + { + throw std::runtime_error( + "batch_scheduler_policy parameter was not found or is invalid " + "(must be max_utilization or guaranteed_no_evict)"); + } + } + catch (const std::exception& e) + { + TLLM_LOG_WARNING(e.what()); + } + + if (mIsDecoupled && schedulerPolicy != SchedulerPolicy::GUARANTEED_NO_EVICT) + { + TLLM_LOG_WARNING( + "The batch scheduler policy will be set to guaranteed_no_evict" + "since the backend operates in decoupled mode"); + schedulerPolicy = SchedulerPolicy::GUARANTEED_NO_EVICT; + } + + std::optional kvCacheFreeGpuMemFraction = std::nullopt; + try + { + kvCacheFreeGpuMemFraction = model_state_->GetParameter("kv_cache_free_gpu_mem_fraction"); + } + catch (const std::exception& e) + { + // If parameter is not specified, just ignore + TLLM_LOG_WARNING( + "kv_cache_free_gpu_mem_fraction is not specified, will use default value of 0.85 or " + "max_tokens_in_paged_kv_cache"); + } + + std::optional maxNumSequences = std::nullopt; + try + { + maxNumSequences = model_state_->GetParameter("max_num_sequences"); + } + catch (const std::exception& e) + { + // If parameter is not specified, just ignore + TLLM_LOG_WARNING("max_num_sequences is not specified, will be set to the TRT engine max_batch_size"); + } + + bool enableTrtOverlap = true; + try + { + enableTrtOverlap = model_state_->GetParameter("enable_trt_overlap"); + } + catch (const std::exception& e) + { + // If parameter is not specified, just ignore + TLLM_LOG_WARNING("enable_trt_overlap is not specified, will be set to true"); + } + + bool excludeInputInOutput = false; + try + { + excludeInputInOutput = model_state_->GetParameter("exclude_input_in_output"); + } + catch (const std::exception& e) + { + // If parameter is not specified, just ignore + TLLM_LOG_WARNING("exclude_input_in_output is not specified, will be set to false"); + } + + std::optional maxKvCacheLength = std::nullopt; + try + { + maxKvCacheLength = model_state_->GetParameter("max_kv_cache_length"); + } + catch (const std::exception& e) + { + // If parameter is not specified, just ignore + TLLM_LOG_WARNING( + "max_kv_cache_length is not specified, will " + "use default value"); + } + + TrtGptModelOptionalParams optionalParams; + optionalParams.maxNumSequences = maxNumSequences; + optionalParams.kvCacheConfig.maxTokens = maxTokensInPagedKvCache; + optionalParams.kvCacheConfig.freeGpuMemoryFraction = kvCacheFreeGpuMemFraction; + optionalParams.kvCacheConfig.maxKvCacheLength = maxKvCacheLength; + optionalParams.enableTrtOverlap = enableTrtOverlap; + + mBatchManager = std::make_shared( + mModelPath, mTrtGptModelType, maxBeamWidth, schedulerPolicy, + [this](int max_num_requests) { return get_inference_requests(max_num_requests); }, + [this](uint64_t requestId, std::list response_tensors, bool final_response, + const std::string& errMsg) { return sendResponse(requestId, response_tensors, final_response, errMsg); }, + [this]() { return pollStopSignals(); }, [this](const std::string& s) { return logStats(s); }, optionalParams, + std::nullopt, std::nullopt, excludeInputInOutput); + + if (getCommWorldRank() != 0) + { + while (true) + { + } + } +} + +// For stop requests, or in case of error during enqueue, we need to send a +// response to the client +void ModelInstanceState::sendEnqueueResponse(TRITONBACKEND_Request* request, const std::string& errMsg) +{ + TRITONBACKEND_ResponseFactory* factory_ptr; + // Create response factory for this request + LOG_IF_ERROR(TRITONBACKEND_ResponseFactoryNew(&factory_ptr, request), "Cannot create response factory"); + + TRITONSERVER_Error* err = nullptr; + if (!errMsg.empty()) + { + TLLM_LOG_ERROR(errMsg); + err = TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, errMsg.c_str()); + } + TRITONBACKEND_Response* response; + LOG_IF_ERROR(TRITONBACKEND_ResponseNewFromFactory(&response, factory_ptr), "Cannot create response"); + LOG_IF_ERROR( + TRITONBACKEND_ResponseSend(response, TRITONSERVER_RESPONSE_COMPLETE_FINAL, err), "Cannot send response"); + LOG_IF_ERROR(TRITONBACKEND_ResponseFactoryDelete(factory_ptr), "Cannot delete response factory"); +} + +void ModelInstanceState::enqueue(TRITONBACKEND_Request** requests, const uint32_t request_count, bool isDecoupled) +{ + std::vector> requestsToPush; + + for (uint32_t r = 0; r < request_count; ++r) + { + TRITONBACKEND_Request* request = requests[r]; + try + { + auto requestId = utils::getRequestId(request); + bool stopRequest = utils::getRequestBooleanInputTensor(request, kStopInputTensorName); + + if (stopRequest) + { + if (requestId != 0) + { + // Check if request is in progress or in queue, if not ignore + mWorkItemsQueue.stopWorkItem(requestId); + // Send a response back to client for stop request + sendEnqueueResponse(request); + } + else + { + throw std::runtime_error("Cannot send stop request without specifying a request_id"); + } + } + else + { + requestsToPush.emplace_back(requestId, request); + } + } + catch (const std::exception& e) + { + // In case of error, no work item is added to queue, so response + // callback needs to be called + sendEnqueueResponse(request, e.what()); + } + } + + auto exceptions = mWorkItemsQueue.pushBatch(requestsToPush, isDecoupled); + + for (uint32_t r = 0; r < requestsToPush.size(); ++r) + { + auto request = requestsToPush.at(r).second; + auto e = exceptions.at(r); + if (e) + { + sendEnqueueResponse(request, e->what()); + } + } + + return; +} + +// Return up to max_num_requests inference requests. +std::list> ModelInstanceState::get_inference_requests(const int max_num_requests) +{ + std::list> rval; + if (max_num_requests <= 0) + { + return rval; + } + + auto world_size = getCommWorldSize(); + auto rank = getCommWorldRank(); + if (rank == 0) + { + auto numPendingWorkItems = mWorkItemsQueue.numPendingWorkItems(); + // Loop over the pending work items and include at most `max_num_requests` + for (size_t i = 0; i < numPendingWorkItems && static_cast(rval.size()) < max_num_requests; ++i) + { + auto [workItem, stoppedRequest] = mWorkItemsQueue.pop(); + + if (workItem) + { + if (!stoppedRequest) + { + rval.emplace_back(workItem->getInferenceRequest()); + } + else + { + std::string warnStr = std::string("request Id ") + std::to_string(workItem->requestId()) + + std::string(" has been stopped. Request is ignored."); + TLLM_LOG_WARNING(warnStr); + sendTritonResponse(workItem, {}, true, warnStr); + } + } + } + + if (world_size > 1) + { + int64_t num_new_work_items = rval.size(); + bcast(&num_new_work_items, 1, MPI_TYPE_INT64_T, 0, COMM_WORLD); + + if (num_new_work_items > 0) + { + std::vector packed; + for (auto ir : rval) + { + auto vpacked = ir->serialize(); + packed.push_back(static_cast(vpacked.size())); + packed.insert(packed.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end())); + } + bcast(packed, 0, COMM_WORLD); + } + } + } + else + { + // subordinate ranks hang until master rank sends work + int64_t num_new_work_items; + bcast(&num_new_work_items, 1, MPI_TYPE_INT64_T, 0, COMM_WORLD); + if (num_new_work_items > 0) + { + std::vector packed; + bcast(packed, 0, COMM_WORLD); + int64_t* packed_ptr = packed.data(); + for (int64_t count = 0; count < num_new_work_items; ++count) + { + int64_t n = *(packed_ptr++); + auto ir = InferenceRequest::deserialize(packed_ptr); + packed_ptr += n; + rval.emplace_back(ir); + } + } + } + return rval; +} + +void ModelInstanceState::sendResponse( + uint64_t requestId, std::list const& response_tensors, bool final_response, const std::string& errMsg) +{ + if (getCommWorldRank() == 0) + { + std::string errStr = std::string("Failed to send Triton response for requestId: ") + std::to_string(requestId); + try + { + auto workItem = mWorkItemsQueue.getInProgressWorkItem(requestId); + auto tritonErr = sendTritonResponse(workItem, response_tensors, final_response, errMsg); + LOG_IF_ERROR(tritonErr, errStr); + } + catch (const std::exception& e) + { + TLLM_LOG_ERROR(errStr); + } + } +} + +std::unordered_set ModelInstanceState::pollStopSignals() +{ + auto stoppedReqIds = mWorkItemsQueue.getStoppedReqIds(); + + // Merge cancelled requests into stopped requests Ids + auto cancelledReqIds = mWorkItemsQueue.getCancelledInProgressReqIds(); + stoppedReqIds.insert(cancelledReqIds.begin(), cancelledReqIds.end()); + + int64_t nStoppedReqIds = static_cast(stoppedReqIds.size()); + + if (getCommWorldSize() > 1) + { + // Broadcast number of stopped requests + bcast(&nStoppedReqIds, 1, MPI_TYPE_INT64_T, 0, COMM_WORLD); + + if (nStoppedReqIds > 0) + { + // Broadcast stopped requests Ids + if (getCommWorldRank() == 0) + { + // Store the requestIds in a contiguous vector + std::vector stoppedReqIdsVec(stoppedReqIds.begin(), stoppedReqIds.end()); + bcast(stoppedReqIdsVec.data(), stoppedReqIdsVec.size(), MPI_TYPE_UINT64_T, 0, COMM_WORLD); + } + else + { + std::vector stoppedReqIdsVec(nStoppedReqIds); + bcast(stoppedReqIdsVec.data(), stoppedReqIdsVec.size(), MPI_TYPE_UINT64_T, 0, COMM_WORLD); + // Store the requestIds in the set + stoppedReqIds.clear(); + std::copy(stoppedReqIdsVec.begin(), stoppedReqIdsVec.end(), + std::inserter(stoppedReqIds, stoppedReqIds.end())); + } + } + } + return stoppedReqIds; +} + +void ModelInstanceState::logStats(const std::string& s) +{ + LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, s.c_str()); +#ifdef TRITON_ENABLE_METRICS + LOG_IF_ERROR(model_state_->UpdateMetrics(s), "Failed updating metrics"); +#endif +} + +TRITONSERVER_Error* ModelInstanceState::sendTritonResponse(std::shared_ptr workItem, + std::list const& response_tensors, bool final_response, const std::string& errMsg) +{ + TRITONBACKEND_ResponseFactory* response_factory; + response_factory = workItem->response_factory(); + + TRITONBACKEND_Response* response; + RETURN_IF_ERROR(TRITONBACKEND_ResponseNewFromFactory(&response, response_factory)); + + auto requestId = workItem->requestId(); + if (final_response) + { + mWorkItemsQueue.markFinished(requestId); + } + + // Check if error + TRITONSERVER_Error* err = nullptr; + if (!errMsg.empty()) + { + std::string errStr = "Encountered error for requestId " + std::to_string(requestId) + ": " + errMsg; + TLLM_LOG_ERROR(errStr); + + bool is_cancelled = false; + TRITONBACKEND_ResponseFactoryIsCancelled(response_factory, &is_cancelled); + + auto err_code = is_cancelled ? TRITONSERVER_ERROR_CANCELLED : TRITONSERVER_ERROR_INTERNAL; + + err = TRITONSERVER_ErrorNew(err_code, errStr.c_str()); + final_response = true; + } + else + { + for (auto it = response_tensors.begin(); it != response_tensors.end(); ++it) + { + auto tensor = *it; + if (!workItem->hasOutputName(tensor.name)) + { + continue; + } + auto shape = tensor.tensor->getShape(); // returns std::vectorint64_t> + std::vector vshape(shape.nbDims); + for (std::size_t i = 0; i < vshape.size(); ++i) + { + vshape[i] = shape.d[i]; + } + + TRITONBACKEND_Output* output; + RETURN_IF_ERROR(TRITONBACKEND_ResponseOutput(response, &output, tensor.name.c_str(), + utils::to_triton_datatype(tensor.tensor->getDataType()), vshape.data(), shape.nbDims)); + + uint64_t buffersize = tensor.tensor->getSizeInBytes(); + void* buffer = 0L; + TRITONSERVER_MemoryType memory_type = TRITONSERVER_MEMORY_CPU; + int64_t memory_type_id = 0; + RETURN_IF_ERROR(TRITONBACKEND_OutputBuffer(output, &buffer, buffersize, &memory_type, &memory_type_id)); + if (memory_type != TRITONSERVER_MEMORY_CPU && memory_type != TRITONSERVER_MEMORY_CPU_PINNED) + { + std::string errStr = "Triton failed to allocate output buffer on CPU"; + err = TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, errStr.c_str()); + break; + } + std::memcpy(buffer, tensor.tensor->data(), buffersize); + } + } + + RETURN_IF_ERROR( + TRITONBACKEND_ResponseSend(response, final_response ? TRITONSERVER_RESPONSE_COMPLETE_FINAL : 0, err)); + + return nullptr; +} + +} // namespace triton::backend::inflight_batcher_llm diff --git a/inflight_batcher_llm/src/model_instance_state.h b/inflight_batcher_llm/src/model_instance_state.h new file mode 100644 index 00000000..e6829c60 --- /dev/null +++ b/inflight_batcher_llm/src/model_instance_state.h @@ -0,0 +1,130 @@ +// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once +#define _GLIBCXX_USE_CXX11_ABI 0 + +#include + +#include "triton/backend/backend_common.h" +#include "triton/core/tritonbackend.h" +#include "triton/core/tritonserver.h" + +#include "tensorrt_llm/batch_manager/BatchManager.h" +#include "tensorrt_llm/batch_manager/GptManager.h" +#include "tensorrt_llm/batch_manager/batchScheduler.h" +#include "tensorrt_llm/batch_manager/callbacks.h" +#include "tensorrt_llm/batch_manager/kvCacheConfig.h" +#include "tensorrt_llm/batch_manager/namedTensor.h" +#include "tensorrt_llm/batch_manager/trtGptModelOptionalParams.h" +#include "tensorrt_llm/common/mpiUtils.h" + +#include "model_state.h" +#include "work_item.h" +#include "work_items_queue.h" + +using namespace tensorrt_llm::batch_manager; +using namespace tensorrt_llm::batch_manager::batch_scheduler; +using namespace tensorrt_llm::mpi; + +namespace triton::backend::inflight_batcher_llm +{ + +// +// ModelInstanceState +// State associated with a model instance. An object of this class is +// created and associated with each +// TRITONBACKEND_ModelInstance. ModelInstanceState is derived from +// + +class ModelInstanceState +{ + using InferenceRequest = tensorrt_llm::batch_manager::InferenceRequest; + using NamedTensor = tensorrt_llm::batch_manager::NamedTensor; + using TrtGptModelType = tensorrt_llm::batch_manager::TrtGptModelType; + +public: + static TRITONSERVER_Error* Create( + ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance, ModelInstanceState** state); + + virtual ~ModelInstanceState() + { + // terminate decoupled execution loop + { + mWorkItemsQueue.clear(); + } + } + + // Get the state of the model that corresponds to this instance. + ModelState* StateForModel() const + { + return model_state_; + } + + bool isDecoupled() const + { + return mIsDecoupled; + } + + /// @brief For stop requests, or in case of error during enqueue, we need to send a + /// response to the client + void sendEnqueueResponse(TRITONBACKEND_Request* request, const std::string& errMsg = ""); + + /// @brief Add the request to the WorkItemsQueue + void enqueue(TRITONBACKEND_Request** requests, const uint32_t request_count, bool isDecoupled); + + /// @brief Callback passed to GptManager to get new inference requests + /// @return Up to max_num_requests inference requests. + std::list> get_inference_requests(const int max_num_requests); + + /// @brief Callback passed to GptManager to send responses back to client + void sendResponse(uint64_t requestId, std::list const& response_tensors, bool final_response, + const std::string& errMsg); + /// @brief Callback passed to GptManager to get ids of stopped requests + std::unordered_set pollStopSignals(); + + /// @brief Callback passed to GptManager to print stats + void logStats(const std::string& s); + + /// @brief Method that sends Triton response back to client + TRITONSERVER_Error* sendTritonResponse(std::shared_ptr workItem, + std::list const& response_tensors, bool final_response, const std::string& errMsg); + +private: + /// @brief Constructor + ModelInstanceState(ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance); + + ModelState* model_state_; + + TrtGptModelType mTrtGptModelType; + std::string mModelPath; + bool mIsDecoupled; + + std::shared_ptr mBatchManager; + WorkItemsQueue mWorkItemsQueue; +}; + +} // namespace triton::backend::inflight_batcher_llm diff --git a/inflight_batcher_llm/src/model_state.cc b/inflight_batcher_llm/src/model_state.cc new file mode 100644 index 00000000..7b963495 --- /dev/null +++ b/inflight_batcher_llm/src/model_state.cc @@ -0,0 +1,158 @@ +// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "model_state.h" + +namespace triton::backend::inflight_batcher_llm +{ + +TRITONSERVER_Error* ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state) +{ + TRITONSERVER_Message* config_message; + RETURN_IF_ERROR(TRITONBACKEND_ModelConfig(triton_model, 1 /* config_version */, &config_message)); + + // We can get the model configuration as a json string from + // config_message, parse it with our favorite json parser to create + // DOM that we can access when we need to example the + // configuration. We use TritonJson, which is a wrapper that returns + // nice errors (currently the underlying implementation is + // rapidjson... but others could be added). You can use any json + // parser you prefer. + const char* buffer; + size_t byte_size; + RETURN_IF_ERROR(TRITONSERVER_MessageSerializeToJson(config_message, &buffer, &byte_size)); + + common::TritonJson::Value model_config; + TRITONSERVER_Error* err = model_config.Parse(buffer, byte_size); + RETURN_IF_ERROR(TRITONSERVER_MessageDelete(config_message)); + RETURN_IF_ERROR(err); + + try + { + *state = new ModelState(triton_model, std::move(model_config)); + } + catch (const std::exception& ex) + { + std::string errStr = std::string("unexpected error when creating modelState: ") + ex.what(); + return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, errStr.c_str()); + } + + return nullptr; // success +} + +common::TritonJson::Value& ModelState::GetModelConfig() +{ + return model_config_; +} + +template <> +std::string ModelState::GetParameter(const std::string& name) +{ + TritonJson::Value parameters; + TRITONSERVER_Error* err = model_config_.MemberAsObject("parameters", ¶meters); + if (err != nullptr) + { + throw std::runtime_error("Model config doesn't have a parameters section"); + TRITONSERVER_ErrorDelete(err); + } + TritonJson::Value value; + std::string str_value; + err = parameters.MemberAsObject(name.c_str(), &value); + if (err != nullptr) + { + std::string errStr = "Cannot find parameter with name: " + name; + throw std::runtime_error(errStr); + TRITONSERVER_ErrorDelete(err); + } + value.MemberAsString("string_value", &str_value); + return str_value; +} + +template <> +int32_t ModelState::GetParameter(const std::string& name) +{ + return std::stoi(GetParameter(name)); +} + +template <> +uint32_t ModelState::GetParameter(const std::string& name) +{ + return (uint32_t) std::stoul(GetParameter(name)); +} + +template <> +int64_t ModelState::GetParameter(const std::string& name) +{ + return std::stoll(GetParameter(name)); +} + +template <> +uint64_t ModelState::GetParameter(const std::string& name) +{ + return std::stoull(GetParameter(name)); +} + +template <> +float ModelState::GetParameter(const std::string& name) +{ + return std::stof(GetParameter(name)); +} + +template <> +bool ModelState::GetParameter(const std::string& name) +{ + auto val = GetParameter(name); + if (val == "True" || val == "true" || val == "TRUE" || val == "1") + { + return true; + } + else if (val == "False" || val == "false" || val == "FALSE" || val == "0") + { + return false; + } + else + { + std::string err = "Cannot convert " + val + " to a boolean."; + throw std::runtime_error(err); + } +} + +#ifdef TRITON_ENABLE_METRICS +TRITONSERVER_Error* ModelState::InitMetrics( + const std::string& model_name, const uint64_t version, const bool is_v1_model) +{ + RETURN_IF_ERROR(triton_metrics_->InitMetrics(model_name, version, is_v1_model)); + return nullptr; // success +} + +TRITONSERVER_Error* ModelState::UpdateMetrics(const std::string& statistics) +{ + RETURN_IF_ERROR(triton_metrics_->UpdateMetrics(statistics)); + return nullptr; // success +} +#endif + +} // namespace triton::backend::inflight_batcher_llm diff --git a/inflight_batcher_llm/src/model_state.h b/inflight_batcher_llm/src/model_state.h new file mode 100644 index 00000000..342744aa --- /dev/null +++ b/inflight_batcher_llm/src/model_state.h @@ -0,0 +1,113 @@ +// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once +#define _GLIBCXX_USE_CXX11_ABI 0 + +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/plugins/api/tllmPlugin.h" +#include "tensorrt_llm/runtime/tllmLogger.h" + +#include "triton/backend/backend_common.h" +#include "triton/core/tritonbackend.h" +#include "triton/core/tritonserver.h" + +#ifdef TRITON_ENABLE_METRICS +#include "metrics/triton_metrics.h" +#endif + +using namespace ::triton::common; // TritonJson + +namespace triton::backend::inflight_batcher_llm +{ + +// ModelState +// +// State associated with a model that is using this backend. An object +// of this class is created and associated with each +// TRITONBACKEND_Model. + +class ModelState +{ +public: + static TRITONSERVER_Error* Create(TRITONBACKEND_Model* triton_model, ModelState** state); + + template + T GetParameter(const std::string& name) + { + assert(false); + auto dummy = T(); + return dummy; + } + + virtual ~ModelState() = default; + +#ifdef TRITON_ENABLE_METRICS + TRITONSERVER_Error* InitMetrics(const std::string& model_name, const uint64_t version, const bool is_v1_model); + TRITONSERVER_Error* UpdateMetrics(const std::string& statistics); +#endif + common::TritonJson::Value& GetModelConfig(); + +private: +#ifdef TRITON_ENABLE_METRICS + std::unique_ptr triton_metrics_; +#endif + common::TritonJson::Value model_config_; + std::shared_ptr mTrtLogger{}; + + ModelState(TRITONBACKEND_Model* triton_model, TritonJson::Value&& model_config) + : model_config_(std::move(model_config)) + { + mTrtLogger = std::make_shared(); + initTrtLlmPlugins(mTrtLogger.get()); +#ifdef TRITON_ENABLE_METRICS + triton_metrics_ = std::make_unique(); +#endif + } +}; + +template <> +std::string ModelState::GetParameter(const std::string& name); + +template <> +int32_t ModelState::GetParameter(const std::string& name); + +template <> +uint32_t ModelState::GetParameter(const std::string& name); + +template <> +int64_t ModelState::GetParameter(const std::string& name); + +template <> +uint64_t ModelState::GetParameter(const std::string& name); + +template <> +float ModelState::GetParameter(const std::string& name); + +template <> +bool ModelState::GetParameter(const std::string& name); + +} // namespace triton::backend::inflight_batcher_llm diff --git a/inflight_batcher_llm/src/utils.cc b/inflight_batcher_llm/src/utils.cc new file mode 100644 index 00000000..902ec62f --- /dev/null +++ b/inflight_batcher_llm/src/utils.cc @@ -0,0 +1,219 @@ +// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "utils.h" +#include + +namespace triton::backend::inflight_batcher_llm::utils +{ + +nvinfer1::DataType to_trt_datatype(TRITONSERVER_DataType data_type) +{ + if (data_type == TRITONSERVER_TYPE_INVALID) + { + assert(false); + } + else if (data_type == TRITONSERVER_TYPE_BOOL) + { + return nvinfer1::DataType::kBOOL; + } + else if (data_type == TRITONSERVER_TYPE_UINT8) + { + return nvinfer1::DataType::kUINT8; + } + else if (data_type == TRITONSERVER_TYPE_UINT16) + { + assert(false); + } + else if (data_type == TRITONSERVER_TYPE_UINT32) + { + return nvinfer1::DataType::kINT32; + } + else if (data_type == TRITONSERVER_TYPE_UINT64) + { + return nvinfer1::DataType::kINT64; + } + else if (data_type == TRITONSERVER_TYPE_INT8) + { + return nvinfer1::DataType::kINT8; + } + else if (data_type == TRITONSERVER_TYPE_INT16) + { + assert(false); + } + else if (data_type == TRITONSERVER_TYPE_INT32) + { + return nvinfer1::DataType::kINT32; + } + else if (data_type == TRITONSERVER_TYPE_INT64) + { + return nvinfer1::DataType::kINT64; + } + else if (data_type == TRITONSERVER_TYPE_FP16) + { + return nvinfer1::DataType::kHALF; + } + else if (data_type == TRITONSERVER_TYPE_FP32) + { + return nvinfer1::DataType::kFLOAT; + } + else if (data_type == TRITONSERVER_TYPE_FP64) + { + assert(false); + } + else if (data_type == TRITONSERVER_TYPE_BYTES) + { + return nvinfer1::DataType::kINT8; + } + else if (data_type == TRITONSERVER_TYPE_BF16) + { + return nvinfer1::DataType::kBF16; + } + else + { + assert(false); + } + return nvinfer1::DataType(0); +} + +TRITONSERVER_DataType to_triton_datatype(nvinfer1::DataType data_type) +{ + if (data_type == nvinfer1::DataType::kBOOL) + { + return TRITONSERVER_TYPE_BOOL; + } + else if (data_type == nvinfer1::DataType::kUINT8) + { + return TRITONSERVER_TYPE_UINT8; + } + else if (data_type == nvinfer1::DataType::kHALF) + { + return TRITONSERVER_TYPE_BF16; + } + else if (data_type == nvinfer1::DataType::kINT8) + { + return TRITONSERVER_TYPE_INT8; + } + else if (data_type == nvinfer1::DataType::kINT32) + { + return TRITONSERVER_TYPE_INT32; + } + else if (data_type == nvinfer1::DataType::kINT64) + { + return TRITONSERVER_TYPE_INT64; + } + else if (data_type == nvinfer1::DataType::kFLOAT) + { + return TRITONSERVER_TYPE_FP32; + } + else if (data_type == nvinfer1::DataType::kBF16) + { + return TRITONSERVER_TYPE_BF16; + } + else + { + return TRITONSERVER_TYPE_INVALID; + } +} + +uint64_t getRequestId(TRITONBACKEND_Request* request) +{ + const char* charRequestId; + TRITONBACKEND_RequestId(request, &charRequestId); + uint64_t requestId = 0; + if (charRequestId != nullptr) + { + std::string strRequestId(charRequestId); + if (!strRequestId.empty()) + { + try + { + requestId = stoul(strRequestId); + } + catch (const std::exception& e) + { + std::string err = std::string("Invalid requestId, must be uint64_t. Got ") + strRequestId; + throw std::runtime_error(err); + } + } + } + + return requestId; +} + +std::unordered_set getRequestOutputNames(TRITONBACKEND_Request* request) +{ + std::unordered_set outputNames; + uint32_t outputCount; + LOG_IF_ERROR(TRITONBACKEND_RequestOutputCount(request, &outputCount), "Error getting request output count"); + for (size_t i = 0; i < outputCount; ++i) + { + const char* name; + LOG_IF_ERROR(TRITONBACKEND_RequestOutputName(request, i, &name), "Error getting request output name"); + std::string name_s(name); + outputNames.insert(std::move(name_s)); + } + return outputNames; +} + +bool getRequestBooleanInputTensor(TRITONBACKEND_Request* request, const std::string& inputTensorName) +{ + // Get stop signal from the request + TRITONBACKEND_Input* input; + TRITONSERVER_Error* error = TRITONBACKEND_RequestInput(request, inputTensorName.c_str(), &input); + if (error) + { + // If the user does not provide input "stop", then regard the request as + // unstopped + std::string msg + = "ModelInstanceState::getRequestBooleanInputTensor: user " + "did not not provide " + + inputTensorName + " input for the request"; + LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, msg.c_str()); + return false; + } + + uint64_t input_byte_size = 0; + uint32_t buffer_count = 0; + TRITONBACKEND_InputProperties(input, nullptr, nullptr, nullptr, nullptr, &input_byte_size, &buffer_count); + + LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, + ("ModelInstanceState::getRequestStopSignal: buffer_count = " + std::to_string(buffer_count)).c_str()); + + const void* buffer = 0L; + uint64_t buffer_byte_size = 0; + TRITONSERVER_MemoryType memory_type = TRITONSERVER_MEMORY_CPU; + int64_t memory_type_id = 0; + TRITONBACKEND_InputBuffer(input, 0, &buffer, &buffer_byte_size, &memory_type, &memory_type_id); + + assert((memory_type == TRITONSERVER_MEMORY_CPU) || (memory_type == TRITONSERVER_MEMORY_CPU_PINNED)); + + bool boolean = *reinterpret_cast(buffer); + + return boolean; +} + +} // namespace triton::backend::inflight_batcher_llm::utils diff --git a/inflight_batcher_llm/src/utils.h b/inflight_batcher_llm/src/utils.h new file mode 100644 index 00000000..9abf0ff7 --- /dev/null +++ b/inflight_batcher_llm/src/utils.h @@ -0,0 +1,64 @@ +// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once +#define _GLIBCXX_USE_CXX11_ABI 0 + +#include "NvInfer.h" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/runtime/tllmLogger.h" +#include "triton/backend/backend_common.h" +#include "triton/core/tritonbackend.h" +#include "triton/core/tritonserver.h" +#include +#include + +namespace triton::backend::inflight_batcher_llm +{ +inline static const std::string kStopInputTensorName = "stop"; +inline static const std::string kStreamingInputTensorName = "streaming"; + +namespace utils +{ + +/// @brief Convert Triton datatype to TRT datatype +nvinfer1::DataType to_trt_datatype(TRITONSERVER_DataType data_type); + +/// @brief Convert TRT datatype to Triton datatype +TRITONSERVER_DataType to_triton_datatype(nvinfer1::DataType data_type); + +/// @brief get the requestId of the request +/// @return Returns 0 if not specified. Throws an error if request_id cannot be convert to uint64_t +uint64_t getRequestId(TRITONBACKEND_Request* request); + +/// @brief Get the requested output names +std::unordered_set getRequestOutputNames(TRITONBACKEND_Request* request); + +/// @brief Get the value of a boolean tensor +bool getRequestBooleanInputTensor(TRITONBACKEND_Request* request, const std::string& inputTensorName); + +} // namespace utils +} // namespace triton::backend::inflight_batcher_llm diff --git a/inflight_batcher_llm/src/work_item.cc b/inflight_batcher_llm/src/work_item.cc new file mode 100644 index 00000000..8a80efc0 --- /dev/null +++ b/inflight_batcher_llm/src/work_item.cc @@ -0,0 +1,155 @@ +// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "work_item.h" + +namespace triton::backend::inflight_batcher_llm +{ + +WorkItem::WorkItem(TRITONBACKEND_Request* request, bool isDecoupled) +{ + uint64_t requestId = (rand() % INT64_MAX) + 1; + Initialize(request, requestId, isDecoupled); +} + +WorkItem::WorkItem(TRITONBACKEND_Request* request, uint64_t requestId, bool isDecoupled) +{ + Initialize(request, requestId, isDecoupled); +} + +WorkItem::WorkItem(std::shared_ptr ir, uint64_t RequestId) + : mInferenceRequest(ir) + , mRequestId(RequestId) +{ + factory_ptr_ = nullptr; +} + +WorkItem::~WorkItem() +{ + if (factory_ptr_ != nullptr) + { + TRITONBACKEND_ResponseFactoryDelete(factory_ptr_); + } +} + +TRITONBACKEND_ResponseFactory* WorkItem::response_factory() +{ + assert(factory_ptr_ != nullptr); + return factory_ptr_; +} + +uint64_t WorkItem::requestId() const +{ + return mRequestId; +} + +std::shared_ptr WorkItem::getInferenceRequest() const +{ + return mInferenceRequest; +} + +bool WorkItem::hasOutputName(const std::string& outputName) +{ + return (mRequestOutputNames.find(outputName) != mRequestOutputNames.end()); +} + +std::shared_ptr WorkItem::createInferenceRequest( + TRITONBACKEND_Request* request, uint64_t requestId, bool isDecoupled) +{ + auto inferenceRequest = std::make_shared(requestId); + + // Extract input tensors + std::map input_tensors; + uint32_t num_inputs; + LOG_IF_ERROR(TRITONBACKEND_RequestInputCount(request, &num_inputs), "Error getting input count"); + for (uint32_t idx = 0; idx < num_inputs; ++idx) + { + TRITONBACKEND_Input* input = 0L; + TRITONBACKEND_RequestInputByIndex(request, idx, &input); + + const char* input_name = 0L; + TRITONSERVER_DataType data_type = TRITONSERVER_TYPE_INVALID; + const int64_t* shape = 0L; + uint32_t dims_count = 0; + uint64_t byte_size = 0; + uint32_t buffer_count = 0; + TRITONBACKEND_InputProperties(input, &input_name, &data_type, &shape, &dims_count, &byte_size, &buffer_count); + + if (std::string(input_name) == "START" || std::string(input_name) == "CORRID" + || std::string(input_name) == "END" || std::string(input_name) == kStopInputTensorName + || std::string(input_name) == kStreamingInputTensorName) + { + continue; + } + + std::vector shapev; + for (uint32_t i = 0; i < dims_count; ++i) + { + shapev.push_back(shape[i]); + } + + NamedTensor t(utils::to_trt_datatype(data_type), shapev, input_name); + uint64_t buffer_offset = 0; + for (int64_t buffer_id = 0; buffer_id < buffer_count; ++buffer_id) + { + const void* buffer = 0L; + uint64_t buffer_byte_size = 0; + TRITONSERVER_MemoryType memory_type = TRITONSERVER_MEMORY_CPU; + int64_t memory_type_id = 0; + TRITONBACKEND_InputBuffer(input, buffer_id, &buffer, &buffer_byte_size, &memory_type, &memory_type_id); + assert((memory_type == TRITONSERVER_MEMORY_CPU) || (memory_type == TRITONSERVER_MEMORY_CPU_PINNED)); + // TODO: Do we need to handle GPU mem input buffers?? + std::memcpy(static_cast(t.tensor->data()) + buffer_offset, buffer, buffer_byte_size); + buffer_offset += buffer_byte_size; + } + + inferenceRequest->emplaceInputTensor(t.name, std::move(t.tensor)); + } + + bool streamingFlag = utils::getRequestBooleanInputTensor(request, kStreamingInputTensorName); + inferenceRequest->setIsStreaming(streamingFlag); + + if (streamingFlag && !isDecoupled) + { + throw std::runtime_error( + "Streaming is only supported if model is " + "deployed using decoupled mode."); + } + + return inferenceRequest; +} + +void WorkItem::Initialize(TRITONBACKEND_Request* request, uint64_t requestId, bool isDecoupled) +{ + mRequestId = requestId; + mInferenceRequest = createInferenceRequest(request, requestId, isDecoupled); + mRequestOutputNames = utils::getRequestOutputNames(request); + + // Create response factory for this request + TRITONBACKEND_ResponseFactoryNew(&factory_ptr_, request); +} + +} // namespace triton::backend::inflight_batcher_llm diff --git a/inflight_batcher_llm/src/work_item.h b/inflight_batcher_llm/src/work_item.h new file mode 100644 index 00000000..508179ce --- /dev/null +++ b/inflight_batcher_llm/src/work_item.h @@ -0,0 +1,75 @@ +// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once +#define _GLIBCXX_USE_CXX11_ABI 0 + +#include "tensorrt_llm/batch_manager/inferenceRequest.h" +#include "triton/backend/backend_common.h" +#include "triton/core/tritonbackend.h" +#include "triton/core/tritonserver.h" +#include +#include + +namespace triton::backend::inflight_batcher_llm +{ + +// Class holding all infos regarding a single work item. +// This includes the original request, associated response factor +// and state. +class WorkItem +{ + using InferenceRequest = tensorrt_llm::batch_manager::InferenceRequest; + using NamedTensor = tensorrt_llm::batch_manager::NamedTensor; + +public: + WorkItem(TRITONBACKEND_Request* request, bool isDecoupled); + WorkItem(TRITONBACKEND_Request* request, uint64_t requestId, bool isDecoupled); + WorkItem(std::shared_ptr ir, uint64_t RequestId); + ~WorkItem(); + + TRITONBACKEND_ResponseFactory* response_factory(); + + uint64_t requestId() const; + + std::shared_ptr getInferenceRequest() const; + + bool hasOutputName(const std::string& outputName); + +private: + // Convert Trition request to trtllm InferenceRequest + static std::shared_ptr createInferenceRequest( + TRITONBACKEND_Request* request, uint64_t requestId, bool isDecoupled); + + void Initialize(TRITONBACKEND_Request* request, uint64_t requestId, bool isDecoupled); + + std::shared_ptr mInferenceRequest; + TRITONBACKEND_ResponseFactory* factory_ptr_; + uint64_t mRequestId; + std::unordered_set mRequestOutputNames; +}; + +} // namespace triton::backend::inflight_batcher_llm diff --git a/inflight_batcher_llm/src/work_items_queue.cc b/inflight_batcher_llm/src/work_items_queue.cc new file mode 100644 index 00000000..68d44c36 --- /dev/null +++ b/inflight_batcher_llm/src/work_items_queue.cc @@ -0,0 +1,150 @@ +// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "work_items_queue.h" +#include "work_item.h" + +namespace triton::backend::inflight_batcher_llm +{ + +void WorkItemsQueue::clear() +{ + std::lock_guard lk(mMutex); + mPendingWorkItems.clear(); + mPendingWorkItemsReqIds.clear(); + mInProgressWorkItems.clear(); + mStoppedReqIds.clear(); +} + +/// @brief Add a batch of new work item to the queue +/// Throws an error if requestId already exists +std::vector> WorkItemsQueue::pushBatch( + std::vector>& requestsToPush, bool isDecoupled) +{ + std::lock_guard lk(mMutex); + std::vector> reqExceptions; + for (auto& [requestId, request] : requestsToPush) + { + if (requestId != 0 && (hasInProgressReqId(requestId) || hasPendingReqId(requestId))) + { + std::string errStr + = "requestId " + std::to_string(requestId) + " is already in progress, request is ignored."; + reqExceptions.emplace_back(std::make_shared(errStr)); + } + else + { + auto workItem = requestId != 0 ? std::make_shared(request, requestId, isDecoupled) + : std::make_shared(request, isDecoupled); + mPendingWorkItems.push_back(workItem); + mPendingWorkItemsReqIds.insert(workItem->requestId()); + reqExceptions.push_back(nullptr); + } + } + return reqExceptions; +} + +std::tuple, bool> WorkItemsQueue::pop() +{ + std::lock_guard lk(mMutex); + if (mPendingWorkItems.empty()) + { + return {nullptr, false}; + } + + auto workItem = mPendingWorkItems.front(); + mPendingWorkItems.pop_front(); + mPendingWorkItemsReqIds.erase(workItem->requestId()); + + // Check if work item has been stopped + bool is_stopped = mStoppedReqIds.count(workItem->requestId()); + + // Check if the Triton request has been cancelled + bool is_cancelled = false; + TRITONBACKEND_ResponseFactoryIsCancelled(workItem->response_factory(), &is_cancelled); + + bool stoppedRequest = false; + if (!is_stopped && !is_cancelled) + { + mInProgressWorkItems.emplace(std::make_pair(workItem->requestId(), workItem)); + } + else + { + mStoppedReqIds.erase(workItem->requestId()); + stoppedRequest = true; + } + + return {workItem, stoppedRequest}; +} + +void WorkItemsQueue::markFinished(const uint64_t requestId) +{ + std::lock_guard lk(mMutex); + if (hasInProgressReqId(requestId)) + { + mInProgressWorkItems.erase(requestId); + } + + if (mStoppedReqIds.find(requestId) != mStoppedReqIds.end()) + { + mStoppedReqIds.erase(requestId); + } +} + +void WorkItemsQueue::stopWorkItem(const uint64_t requestId) +{ + std::lock_guard lk(mMutex); + TLLM_LOG_DEBUG("Stopping request"); + if (hasInProgressReqId(requestId) || hasPendingReqId(requestId)) + { + mStoppedReqIds.emplace(requestId); + } + else + { + std::string errStr = std::string("Received stop request for requestId ") + std::to_string(requestId) + + std::string(" but it's not active (might be completed already)."); + throw std::runtime_error(errStr); + } +} + +std::unordered_set WorkItemsQueue::getCancelledInProgressReqIds() const +{ + std::unordered_set cancelledInProgressReqIds; + { + std::lock_guard lk(mMutex); + for (const auto& pair : mInProgressWorkItems) + { + bool is_cancelled = false; + TRITONBACKEND_ResponseFactoryIsCancelled(pair.second->response_factory(), &is_cancelled); + if (is_cancelled) + { + cancelledInProgressReqIds.emplace(pair.first); + } + } + } + return cancelledInProgressReqIds; +} + +} // namespace triton::backend::inflight_batcher_llm diff --git a/inflight_batcher_llm/src/work_items_queue.h b/inflight_batcher_llm/src/work_items_queue.h new file mode 100644 index 00000000..f0de74cd --- /dev/null +++ b/inflight_batcher_llm/src/work_items_queue.h @@ -0,0 +1,115 @@ +// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once +#define _GLIBCXX_USE_CXX11_ABI 0 + +#include "tensorrt_llm/common/logger.h" +#include "triton/backend/backend_common.h" +#include "triton/core/tritonbackend.h" +#include "utils.h" +#include "work_item.h" +#include + +namespace triton::backend::inflight_batcher_llm +{ + +/// @brief Thread-safe queue of work items +class WorkItemsQueue +{ +public: + /// @brief Clear the queue + void clear(); + + // Note: this function only be called under a lock + bool hasInProgressReqId(const uint64_t reqId) const + { + return (mInProgressWorkItems.find(reqId) != mInProgressWorkItems.end()); + } + + // Note: this function only be called under a lock + bool hasPendingReqId(const uint64_t reqId) const + { + return (mPendingWorkItemsReqIds.find(reqId) != mPendingWorkItemsReqIds.end()); + } + + /// @brief Add a batch of new work item to the queue + /// Throws an error if requestId already exists + std::vector> pushBatch( + std::vector>& requestsToPush, bool isDecoupled); + + /// @brief Get a new work item from the queue, and move it to the list of + /// in progress work items if it hasn't been stopped + /// @return A tuple of the workItem and a boolean flag indicating if the work + /// item has been marked in progress + /// In case the queue is empty, return nullptr + std::tuple, bool> pop(); + + size_t numPendingWorkItems() const + { + std::lock_guard lk(mMutex); + return mPendingWorkItems.size(); + } + + std::shared_ptr getInProgressWorkItem(uint64_t requestId) + { + std::lock_guard lk(mMutex); + return mInProgressWorkItems.at(requestId); + } + + /// @brief Mark a request as being finished + /// @param requestId + void markFinished(const uint64_t requestId); + + // Stop a request by adding the request Id to a set + // The set of stopped request id is used by the poll callback + // and the pop function + void stopWorkItem(const uint64_t requestId); + + std::unordered_set getStoppedReqIds() const + { + std::lock_guard lk(mMutex); + return mStoppedReqIds; + } + + std::unordered_set getCancelledInProgressReqIds() const; + +private: + /// Queue of work items + std::list> mPendingWorkItems; + /// requestIds of work items in the queue + std::set mPendingWorkItemsReqIds; + + /// work items currently in progress + std::unordered_map> mInProgressWorkItems; + + /// ids of the work items that have been stopped + std::unordered_set mStoppedReqIds; + + mutable std::mutex mMutex; +}; + +} // namespace triton::backend::inflight_batcher_llm diff --git a/tensorrt_llm b/tensorrt_llm index 711a28d9..71f60f6d 160000 --- a/tensorrt_llm +++ b/tensorrt_llm @@ -1 +1 @@ -Subproject commit 711a28d9bf80a0e8a8be1e597bf809b55db981c8 +Subproject commit 71f60f6df09b29ab33861f98ac8d97cbe5417e45