diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..a7116d55 --- /dev/null +++ b/.gitignore @@ -0,0 +1,12 @@ +__pycache__/ +.vscode +*.cache +*.nsys-rep +.VSCodeCounter +build/ +*.so +*.egg-info/ +.coverage +*.csv +*.onnx +tmp/ diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..e69de29b diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..a20459c8 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,30 @@ +repos: +- repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort +- repo: https://github.com/Lucas-C/pre-commit-hooks.git + rev: v1.1.13 + hooks: + - id: remove-crlf + files: (?!.*third_party)^.*$ | (?!.*book)^.*$ +- repo: https://github.com/pre-commit/mirrors-yapf + rev: v0.32.0 + hooks: + - id: yapf +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.1.0 + hooks: + - id: check-added-large-files + - id: check-merge-conflict + - id: check-symlinks + - id: detect-private-key + files: (?!.*third_party)^.*$ | (?!.*book)^.*$ + - id: end-of-file-fixer + - id: check-yaml + - id: trailing-whitespace +- repo: https://github.com/PyCQA/autoflake + rev: v1.6.1 + hooks: + - id: autoflake + args: ['--in-place', '--remove-all-unused-imports', '--remove-unused-variables'] diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..d6456956 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 00000000..57be24dc --- /dev/null +++ b/README.md @@ -0,0 +1,117 @@ +# TensorRT-LLM Backend +The Triton backend for TensorRT-LLM. + +## Usage + +### Launch the backend *within Docker* + +```bash +# 1. Pull the docker image +nvidia-docker run -it --rm -e LOCAL_USER_ID=`id -u ${USER}` --shm-size=2g -v : bash + +# 2. Modify parameters: +1. all_models//tensorrt_llm/config.pbtxt +2. all_models//preprocessing/config.pbtxt +3. all_models//postprocessing/config.pbtxt + +# 3. Launch triton server +python3 scripts/launch_triton_server.py --world_size=1 \ + --model_repo=all_models/ +``` + +### Launch the backend *within Slurm based clusters* +1. Prepare some scripts + +`tensorrt_llm_triton.sub` +```bash +#!/bin/bash +#SBATCH -o logs/tensorrt_llm.out +#SBATCH -e logs/tensorrt_llm.error +#SBATCH -J gpu-comparch-ftp:mgmn +#SBATCH -A gpu-comparch +#SBATCH -p luna +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=8 +#SBATCH --time=00:30:00 + +sudo nvidia-smi -lgc 1410,1410 + +srun --mpi=pmix --container-image \ + --container-mounts : \ + --container-workdir \ + --output logs/tensorrt_llm_%t.out \ + bash /tensorrt_llm_triton.sh +``` + +`tensorrt_llm_triton.sh` +``` +TRITONSERVER="/opt/tritonserver/bin/tritonserver" +MODEL_REPO="/triton_backend/" + +${TRITONSERVER} --model-repository=${MODEL_REPO} --disable-auto-complete-config --backend-config=python,shm-region-prefix-name=prefix${SLURM_PROCID}_ +``` + +2. Submit a Slurm job +``` +sbatch tensorrt_llm_triton.sub +``` + +### Kill the backend + +```bash +pgrep tritonserver | xargs kill -9 +``` + +## Examples + +### GPT/OPT/LLaMA/GPT-J... +```bash +cd tools/gpt/ + +# Download vocab and merge table for HF models +# Take GPT as an example: +rm -rf gpt2 && git clone https://huggingface.co/gpt2 +pushd gpt2 && rm pytorch_model.bin model.safetensors && \ + wget -q https://huggingface.co/gpt2/resolve/main/pytorch_model.bin && popd + +python3 client.py \ + --text="Born in north-east France, Soyer trained as a" \ + --output_len=10 \ + --tokenizer_dir gpt2 \ + --tokenizer_type auto + +# Exmaple output: +# [INFO] Latency: 92.278 ms +# Input: Born in north-east France, Soyer trained as a +# Output: chef and a cook at the local restaurant, La +``` +*Please note that the example outputs are only for reference, specific performance numbers depend on the GPU you're using.* + +## Test + +```bash +cd tools/gpt/ + +# Identity test +python3 identity_test.py \ + --batch_size=8 --start_len=128 --output_len=20 +# Results: +# [INFO] Batch size: 8, Start len: 8, Output len: 10 +# [INFO] Latency: 70.782 ms +# [INFO] Throughput: 113.023 sentences / sec + +# Benchmark using Perf Analyzer +python3 gen_input_data.py +perf_analyzer -m tensorrt_llm \ + -b 8 --input-data input_data.json \ + --concurrency-range 1:10:2 \ + -u 'localhost:8000' + +# Results: +# Concurrency: 1, throughput: 99.9875 infer/sec, latency 79797 usec +# Concurrency: 3, throughput: 197.308 infer/sec, latency 121342 usec +# Concurrency: 5, throughput: 259.077 infer/sec, latency 153693 usec +# Concurrency: 7, throughput: 286.18 infer/sec, latency 195011 usec +# Concurrency: 9, throughput: 307.067 infer/sec, latency 233354 usec +``` +*Please note that the example outputs are only for reference, specific performance numbers depend on the GPU you're using.* diff --git a/all_models/gpt/ensemble/1/.tmp b/all_models/gpt/ensemble/1/.tmp new file mode 100644 index 00000000..e69de29b diff --git a/all_models/gpt/ensemble/config.pbtxt b/all_models/gpt/ensemble/config.pbtxt new file mode 100755 index 00000000..1702cf44 --- /dev/null +++ b/all_models/gpt/ensemble/config.pbtxt @@ -0,0 +1,220 @@ +name: "ensemble" +platform: "ensemble" +max_batch_size: 1024 +input [ + { + name: "INPUT_0" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "INPUT_1" + data_type: TYPE_UINT32 + dims: [ -1 ] + }, + { + name: "INPUT_2" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "INPUT_3" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "end_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "pad_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "runtime_top_k" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "runtime_top_p" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "temperature" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "len_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "repetition_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "min_length" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "presence_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "random_seed" + data_type: TYPE_UINT64 + dims: [ 1 ] + optional: true + }, + { + name: "beam_width" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "output_log_probs" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + } +] +output [ + { + name: "OUTPUT_0" + data_type: TYPE_STRING + dims: [ -1, -1 ] + } +] +ensemble_scheduling { + step [ + { + model_name: "preprocessing" + model_version: -1 + input_map { + key: "QUERY" + value: "INPUT_0" + } + input_map { + key: "REQUEST_OUTPUT_LEN" + value: "INPUT_1" + } + input_map { + key: "BAD_WORDS_DICT" + value: "INPUT_2" + } + input_map { + key: "STOP_WORDS_DICT" + value: "INPUT_3" + } + output_map { + key: "REQUEST_INPUT_LEN" + value: "_REQUEST_INPUT_LEN" + } + output_map { + key: "INPUT_ID" + value: "_INPUT_ID" + } + output_map { + key: "REQUEST_OUTPUT_LEN" + value: "_REQUEST_OUTPUT_LEN" + } + }, + { + model_name: "tensorrt_llm" + model_version: -1 + input_map { + key: "input_ids" + value: "_INPUT_ID" + } + input_map { + key: "input_lengths" + value: "_REQUEST_INPUT_LEN" + } + input_map { + key: "request_output_len" + value: "_REQUEST_OUTPUT_LEN" + } + input_map { + key: "end_id" + value: "end_id" + } + input_map { + key: "pad_id" + value: "pad_id" + } + input_map { + key: "runtime_top_k" + value: "runtime_top_k" + } + input_map { + key: "runtime_top_p" + value: "runtime_top_p" + } + input_map { + key: "temperature" + value: "temperature" + } + input_map { + key: "len_penalty" + value: "len_penalty" + } + input_map { + key: "repetition_penalty" + value: "repetition_penalty" + } + input_map { + key: "min_length" + value: "min_length" + } + input_map { + key: "presence_penalty" + value: "presence_penalty" + } + input_map { + key: "random_seed" + value: "random_seed" + } + input_map { + key: "beam_width" + value: "beam_width" + } + input_map { + key: "output_log_probs" + value: "output_log_probs" + } + output_map { + key: "output_ids" + value: "_TOKENS_BATCH" + } + }, + { + model_name: "postprocessing" + model_version: -1 + input_map { + key: "TOKENS_BATCH" + value: "_TOKENS_BATCH" + } + output_map { + key: "OUTPUT" + value: "OUTPUT_0" + } + } + ] +} diff --git a/all_models/gpt/postprocessing/1/model.py b/all_models/gpt/postprocessing/1/model.py new file mode 100644 index 00000000..032c12ae --- /dev/null +++ b/all_models/gpt/postprocessing/1/model.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- +import json + +import numpy as np +import triton_python_backend_utils as pb_utils +from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to initialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + # Parse model configs + model_config = json.loads(args['model_config']) + tokenizer_dir = model_config['parameters']['tokenizer_dir'][ + 'string_value'] + tokenizer_type = model_config['parameters']['tokenizer_type'][ + 'string_value'] + + if tokenizer_type == 't5': + self.tokenizer = T5Tokenizer(vocab_file=tokenizer_dir, + padding_side='left') + elif tokenizer_type == 'auto': + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, + padding_side='left') + elif tokenizer_type == 'llama': + self.tokenizer = LlamaTokenizer.from_pretrained( + tokenizer_dir, legacy=False, padding_side='left') + else: + raise AttributeError( + f'Unexpected tokenizer type: {tokenizer_type}') + self.tokenizer.pad_token = self.tokenizer.eos_token + + # Parse model output configs + output_config = pb_utils.get_output_config_by_name( + model_config, "OUTPUT") + + # Convert Triton types to numpy types + self.output_dtype = pb_utils.triton_string_to_numpy( + output_config['data_type']) + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + + responses = [] + + # Every Python backend must iterate over everyone of the requests + # and create a pb_utils.InferenceResponse for each of them. + for idx, request in enumerate(requests): + # Get input tensors + tokens_batch = pb_utils.get_input_tensor_by_name( + request, 'TOKENS_BATCH').as_numpy() + + # Reshape Input + # tokens_batch = tokens_batch.reshape([-1, tokens_batch.shape[0]]) + # tokens_batch = tokens_batch.T + + # Postprocessing output data. + outputs = self._postprocessing(tokens_batch) + + # Create output tensors. You need pb_utils.Tensor + # objects to create pb_utils.InferenceResponse. + output_tensor = pb_utils.Tensor( + 'OUTPUT', + np.array(outputs).astype(self.output_dtype)) + + # Create InferenceResponse. You can set an error here in case + # there was a problem with handling this inference request. + # Below is an example of how you can set errors in inference + # response: + # + # pb_utils.InferenceResponse( + # output_tensors=..., TritonError("An error occurred")) + inference_response = pb_utils.InferenceResponse( + output_tensors=[output_tensor]) + responses.append(inference_response) + + # You should return a list of pb_utils.InferenceResponse. Length + # of this list must match the length of `requests` list. + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + print('Cleaning up...') + + def _postprocessing(self, tokens_batch): + outputs = [] + for beam_tokens in tokens_batch: + for tokens in beam_tokens: + output = self.tokenizer.decode(tokens) + outputs.append(output.encode('utf8')) + return outputs diff --git a/all_models/gpt/postprocessing/config.pbtxt b/all_models/gpt/postprocessing/config.pbtxt new file mode 100755 index 00000000..64908b5f --- /dev/null +++ b/all_models/gpt/postprocessing/config.pbtxt @@ -0,0 +1,38 @@ +name: "postprocessing" +backend: "python" +max_batch_size: 1024 +input [ + { + name: "TOKENS_BATCH" + data_type: TYPE_INT32 + dims: [ -1, -1 ] + } +] +output [ + { + name: "OUTPUT" + data_type: TYPE_STRING + dims: [ -1, -1 ] + } +] + +parameters { + key: "tokenizer_dir" + value: { + string_value: "${tokenizer_dir}" + } +} + +parameters { + key: "tokenizer_type" + value: { + string_value: "${tokenizer_type}" + } +} + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] diff --git a/all_models/gpt/preprocessing/1/model.py b/all_models/gpt/preprocessing/1/model.py new file mode 100644 index 00000000..d9fd740e --- /dev/null +++ b/all_models/gpt/preprocessing/1/model.py @@ -0,0 +1,202 @@ +# -*- coding: utf-8 -*- +import csv +import json + +import numpy as np +import torch +import triton_python_backend_utils as pb_utils +from torch.nn.utils.rnn import pad_sequence +from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer + +from tensorrt_llm.runtime import to_word_list_format + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to initialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + # Parse model configs + model_config = json.loads(args['model_config']) + tokenizer_dir = model_config['parameters']['tokenizer_dir'][ + 'string_value'] + tokenizer_type = model_config['parameters']['tokenizer_type'][ + 'string_value'] + + if tokenizer_type == 't5': + self.tokenizer = T5Tokenizer(vocab_file=tokenizer_dir, + padding_side='left') + elif tokenizer_type == 'auto': + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, + padding_side='left') + elif tokenizer_type == 'llama': + self.tokenizer = LlamaTokenizer.from_pretrained( + tokenizer_dir, legacy=False, padding_side='left') + else: + raise AttributeError( + f'Unexpected tokenizer type: {tokenizer_type}') + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.pad_id = self.tokenizer.encode(self.tokenizer.pad_token, + add_special_tokens=False)[0] + + # Parse model output configs and convert Triton types to numpy types + input_names = [ + "INPUT_ID", "REQUEST_INPUT_LEN", "BAD_WORDS_IDS", "STOP_WORDS_IDS" + ] + for input_name in input_names: + setattr( + self, + input_name.lower() + "_dtype", + pb_utils.triton_string_to_numpy( + pb_utils.get_output_config_by_name( + model_config, input_name)['data_type'])) + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + + responses = [] + + # Every Python backend must iterate over everyone of the requests + # and create a pb_utils.InferenceResponse for each of them. + for idx, request in enumerate(requests): + # Get input tensors + query = pb_utils.get_input_tensor_by_name(request, + 'QUERY').as_numpy() + request_output_len = pb_utils.get_input_tensor_by_name( + request, 'REQUEST_OUTPUT_LEN').as_numpy() + + bad_words_dict = pb_utils.get_input_tensor_by_name( + request, 'BAD_WORDS_DICT').as_numpy() + stop_words_dict = pb_utils.get_input_tensor_by_name( + request, 'STOP_WORDS_DICT').as_numpy() + + # Preprocessing input data. + input_id, request_input_len = self._create_request(query) + bad_words = to_word_list_format(bad_words_dict, self.tokenizer) + stop_words = to_word_list_format(stop_words_dict, self.tokenizer) + + # Create output tensors. You need pb_utils.Tensor + # objects to create pb_utils.InferenceResponse. + input_id_tensor = pb_utils.Tensor( + 'INPUT_ID', + np.array(input_id).astype(self.input_id_dtype)) + request_input_len_tensor = pb_utils.Tensor( + 'REQUEST_INPUT_LEN', + np.array(request_input_len).astype( + self.request_input_len_dtype)) + request_output_len_tensor = pb_utils.Tensor( + 'REQUEST_OUTPUT_LEN', request_output_len) + bad_words_ids_tensor = pb_utils.Tensor('BAD_WORDS_IDS', bad_words) + stop_words_ids_tensor = pb_utils.Tensor('STOP_WORDS_IDS', + stop_words) + + # Create InferenceResponse. You can set an error here in case + # there was a problem with handling this inference request. + # Below is an example of how you can set errors in inference + # response: + # + # pb_utils.InferenceResponse( + # output_tensors=..., TritonError("An error occurred")) + inference_response = pb_utils.InferenceResponse(output_tensors=[ + input_id_tensor, bad_words_ids_tensor, stop_words_ids_tensor, + request_input_len_tensor, request_output_len_tensor + ]) + responses.append(inference_response) + + # You should return a list of pb_utils.InferenceResponse. Length + # of this list must match the length of `requests` list. + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + print('Cleaning up...') + + def _create_request(self, query): + """ + query : batch string (2D numpy array) + """ + start_ids = [ + torch.IntTensor(self.tokenizer.encode(s[0].decode())) + for s in query + ] + start_lengths = torch.IntTensor([[len(ids)] for ids in start_ids]) + + start_ids = pad_sequence(start_ids, + batch_first=True, + padding_value=self.pad_id) + # input_len = min(start_lengths) + #attn_mask = torch.ones((batch_size, input_len, input_len)).tril() + + return start_ids, start_lengths + + def _create_word_list(self, word_dict): + flat_ids = [] + offsets = [] + for word_dict_item in word_dict: + item_flat_ids = [] + item_offsets = [] + + words = list(csv.reader([word_dict_item[0].decode()]))[0] + for word in words: + ids = self._encode(word) + + if len(ids) == 0: + continue + + item_flat_ids += ids + item_offsets.append(len(ids)) + + flat_ids.append(np.array(item_flat_ids)) + offsets.append(np.cumsum(np.array(item_offsets))) + + pad_to = max(1, max(len(ids) for ids in flat_ids)) + + for i, (ids, offs) in enumerate(zip(flat_ids, offsets)): + flat_ids[i] = np.pad(ids, (0, pad_to - len(ids)), + constant_values=0) + offsets[i] = np.pad(offs, (0, pad_to - len(offs)), + constant_values=-1) + + return np.array([flat_ids, offsets], dtype="int32").transpose( + (1, 0, 2)) + + def _encode(self, sentence): + sentence = sentence.decode() if isinstance(sentence, + bytes) else sentence + return self.tokenizer.encode(sentence) diff --git a/all_models/gpt/preprocessing/config.pbtxt b/all_models/gpt/preprocessing/config.pbtxt new file mode 100644 index 00000000..99f14c31 --- /dev/null +++ b/all_models/gpt/preprocessing/config.pbtxt @@ -0,0 +1,78 @@ +name: "preprocessing" +backend: "python" +max_batch_size: 1024 +input [ + { + name: "QUERY" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "BAD_WORDS_DICT" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "STOP_WORDS_DICT" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "REQUEST_OUTPUT_LEN" + data_type: TYPE_UINT32 + dims: [ -1 ] + } +] +output [ + { + name: "INPUT_ID" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "REQUEST_INPUT_LEN" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "BAD_WORDS_IDS" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + }, + { + name: "STOP_WORDS_IDS" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + }, + { + name: "REQUEST_OUTPUT_LEN" + data_type: TYPE_UINT32 + dims: [ -1 ] + }, + { + name: "PROMPT_LEARNING_TASK_NAME_IDS" + data_type: TYPE_UINT32 + dims: [ 1 ] + } +] + +parameters { + key: "tokenizer_dir" + value: { + string_value: "${tokenizer_dir}" + } +} + +parameters { + key: "tokenizer_type" + value: { + string_value: "${tokenizer_type}" + } +} + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] diff --git a/all_models/gpt/tensorrt_llm/1/model.py b/all_models/gpt/tensorrt_llm/1/model.py new file mode 100644 index 00000000..06c9b4bf --- /dev/null +++ b/all_models/gpt/tensorrt_llm/1/model.py @@ -0,0 +1,249 @@ +import json +import os + +import torch +import triton_python_backend_utils as pb_utils +from torch import from_numpy + +import tensorrt_llm +from tensorrt_llm.runtime import GenerationSession, ModelConfig, SamplingConfig + + +def mpi_comm(): + from mpi4py import MPI + return MPI.COMM_WORLD + + +def mpi_rank(): + return mpi_comm().Get_rank() + + +def get_engine_name(model, dtype, tp_size, rank): + return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank) + + +def get_input_tensor_by_name(request, name): + tensor = pb_utils.get_input_tensor_by_name(request, name) + if tensor is not None: + # Triton tensor -> numpy tensor -> PyTorch tensor + return from_numpy(tensor.as_numpy()) + else: + return tensor + + +def get_input_scalar_by_name(request, name): + tensor = pb_utils.get_input_tensor_by_name(request, name) + if tensor is not None: + # Triton tensor -> numpy tensor -> first scalar + tensor = tensor.as_numpy() + return tensor.reshape((tensor.size, ))[0] + else: + return tensor + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to initialize any state associated with this model. + + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + model_config = json.loads(args['model_config']) + engine_dir = model_config['parameters']['engine_dir']['string_value'] + config_path = os.path.join(engine_dir, 'config.json') + with open(config_path, 'r') as f: + config = json.load(f) + use_gpt_attention_plugin = config['plugin_config'][ + 'gpt_attention_plugin'] + self.remove_input_padding = config['plugin_config'][ + 'remove_input_padding'] + model = config['builder_config']['name'] + dtype = config['builder_config']['precision'] + world_size = config['builder_config']['tensor_parallel'] + assert world_size == tensorrt_llm.mpi_world_size(), \ + f'Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})' + num_heads = config['builder_config']['num_heads'] // world_size + hidden_size = config['builder_config']['hidden_size'] // world_size + vocab_size = config['builder_config']['vocab_size'] + num_layers = config['builder_config']['num_layers'] + num_kv_heads = num_heads + if "num_kv_heads" in config['builder_config'].keys(): + num_kv_heads = config['builder_config']['num_kv_heads'] + + self.comm = mpi_comm() + self.rank = mpi_rank() + + model_config = ModelConfig( + num_heads=num_heads, + num_kv_heads=num_kv_heads, + hidden_size=hidden_size, + vocab_size=vocab_size, + num_layers=num_layers, + gpt_attention_plugin=use_gpt_attention_plugin, + remove_input_padding=self.remove_input_padding) + engine_name = get_engine_name(model, dtype, world_size, self.rank) + serialize_path = os.path.join(engine_dir, engine_name) + with open(serialize_path, 'rb') as f: + engine_buffer = f.read() + runtime_mapping = tensorrt_llm.Mapping(world_size, self.rank) + torch.cuda.set_device(self.rank % runtime_mapping.gpus_per_node) + self.decoder = GenerationSession(model_config, engine_buffer, + runtime_mapping) + + if self.rank != 0: + while (True): + self.execute([None]) + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. + + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + responses = [] + + # Every Python backend must iterate through list of requests and create + # an instance of pb_utils.InferenceResponse class for each of them. You + # should avoid storing any of the input Tensors in the class attributes + # as they will be overridden in subsequent inference requests. You can + # make a copy of the underlying NumPy array and store it if it is + # required. + for request in requests: + # Perform inference on the request and append it to responses list... + inputs = {} + if self.rank == 0: + inputs['input_ids'] = get_input_tensor_by_name( + request, 'input_ids') + inputs['input_lengths'] = get_input_tensor_by_name( + request, 'input_lengths') + inputs['request_output_len'] = get_input_scalar_by_name( + request, 'request_output_len') + inputs['end_id'] = get_input_scalar_by_name(request, 'end_id') + inputs['pad_id'] = get_input_scalar_by_name(request, 'pad_id') + inputs['beam_width'] = get_input_scalar_by_name( + request, 'beam_width') + inputs['temperature'] = get_input_scalar_by_name( + request, 'temperature') + inputs['runtime_top_k'] = get_input_scalar_by_name( + request, 'runtime_top_k') + inputs['runtime_top_p'] = get_input_scalar_by_name( + request, 'runtime_top_p') + inputs['len_penalty'] = get_input_scalar_by_name( + request, 'len_penalty') + inputs['repetition_penalty'] = get_input_scalar_by_name( + request, 'repetition_penalty') + inputs['min_length'] = get_input_scalar_by_name( + request, 'min_length') + inputs['presence_penalty'] = get_input_scalar_by_name( + request, 'presence_penalty') + inputs['random_seed'] = get_input_scalar_by_name( + request, 'random_seed') + inputs['output_log_probs'] = get_input_scalar_by_name( + request, 'output_log_probs') + + # Broadcast requests to other clients + inputs = self.comm.bcast(inputs, root=0) + input_ids = inputs['input_ids'].cuda() + input_lengths = inputs['input_lengths'].cuda() + end_id = inputs['end_id'] + pad_id = inputs['pad_id'] + + sampling_config = SamplingConfig(end_id=end_id, pad_id=pad_id) + if inputs['beam_width'] is not None: + sampling_config.num_beams = inputs['beam_width'] + if inputs['temperature'] is not None: + sampling_config.temperature = inputs['temperature'] + if inputs['runtime_top_k'] is not None: + sampling_config.top_k = inputs['runtime_top_k'] + if inputs['runtime_top_p'] is not None: + sampling_config.top_p = inputs['runtime_top_p'] + if inputs['len_penalty'] is not None: + sampling_config.length_penalty = inputs['len_penalty'] + if inputs['repetition_penalty'] is not None: + sampling_config.repetition_penalty = inputs[ + 'repetition_penalty'] + if inputs['min_length'] is not None: + sampling_config.min_length = inputs['min_length'] + if inputs['presence_penalty'] is not None: + sampling_config.presence_penalty = inputs['presence_penalty'] + sampling_config.random_seed = inputs['random_seed'] + sampling_config.output_log_probs = inputs['output_log_probs'] + if self.remove_input_padding: + self.decoder.setup( + batch_size=input_ids.size(0), + max_context_length=torch.max(input_lengths).item(), + max_new_tokens=inputs['request_output_len']) + else: + self.decoder.setup(input_ids.size(0), input_ids.size(1), + inputs['request_output_len']) + if self.remove_input_padding: + output_ids = self.decoder.decode_batch(input_ids, + sampling_config) + else: + output_ids = self.decoder.decode(input_ids, input_lengths, + sampling_config) + + if self.rank == 0: + # Create output tensors. You need pb_utils.Tensor + # objects to create pb_utils.InferenceResponse. + torch.cuda.synchronize() + output_tensors = [ + pb_utils.Tensor("output_ids", + output_ids.cpu().numpy()) + ] + + if sampling_config.output_log_probs: + # [max_new_tokens, batch_size, num_beams] -> [batch_size, max_new_tokens, num_beams] + log_probs = self.decoder.log_probs.transpose( + 0, 1).cpu().numpy() + output_tensors.append( + pb_utils.Tensor("log_probs", log_probs)) + + # Create InferenceResponse. You can set an error here in case + # there was a problem with handling this inference request. + # Below is an example of how you can set errors in inference + # response: + # + # pb_utils.InferenceResponse( + # output_tensors=..., TritonError("An error occured")) + + inference_response = pb_utils.InferenceResponse(output_tensors) + else: + inference_response = pb_utils.InferenceResponse([]) + responses.append(inference_response) + + # You must return a list of pb_utils.InferenceResponse. Length + # of this list must match the length of `requests` list. + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + return diff --git a/all_models/gpt/tensorrt_llm/config.pbtxt b/all_models/gpt/tensorrt_llm/config.pbtxt new file mode 100644 index 00000000..15ab5100 --- /dev/null +++ b/all_models/gpt/tensorrt_llm/config.pbtxt @@ -0,0 +1,139 @@ +name: "tensorrt_llm" +backend: "python" +max_batch_size: 1024 + +# # Uncomment this for dynamic_batching +# dynamic_batching { +# max_queue_delay_microseconds: 50000 +# } + +input [ + { + name: "input_ids" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "input_lengths" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + }, + { + name: "request_output_len" + data_type: TYPE_UINT32 + dims: [ -1 ] + }, + { + name: "end_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + }, + { + name: "pad_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + }, + { + name: "beam_width" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "temperature" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_k" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_p" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "len_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "repetition_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "min_length" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "presence_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "random_seed" + data_type: TYPE_UINT64 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "output_log_probs" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + } +] +output [ + { + name: "output_ids" + data_type: TYPE_INT32 + dims: [ -1, -1 ] + }, + { + name: "log_probs" + data_type: TYPE_FP32 + dims: [ -1, -1 ] + } +] +instance_group [ + { + count: 1 + kind : KIND_CPU + } +] +parameters { + key: "engine_dir" + value: { + string_value: "${engine_dir}" + } +} +parameters: { + key: "FORCE_CPU_ONLY_INPUT_TENSORS" + value: { + string_value: "no" + } +} diff --git a/all_models/inflight_batcher_llm/ensemble/1/.tmp b/all_models/inflight_batcher_llm/ensemble/1/.tmp new file mode 100644 index 00000000..e69de29b diff --git a/all_models/inflight_batcher_llm/ensemble/config.pbtxt b/all_models/inflight_batcher_llm/ensemble/config.pbtxt new file mode 100755 index 00000000..1b883e7d --- /dev/null +++ b/all_models/inflight_batcher_llm/ensemble/config.pbtxt @@ -0,0 +1,220 @@ +name: "ensemble" +platform: "ensemble" +max_batch_size: 128 +input [ + { + name: "INPUT_0" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "INPUT_1" + data_type: TYPE_UINT32 + dims: [ -1 ] + }, + { + name: "INPUT_2" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "INPUT_3" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "end_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "pad_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "runtime_top_k" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "runtime_top_p" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "temperature" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "len_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "repetition_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "min_length" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "presence_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "random_seed" + data_type: TYPE_UINT64 + dims: [ 1 ] + optional: true + }, + { + name: "beam_width" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "streaming" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + } +] +output [ + { + name: "OUTPUT_0" + data_type: TYPE_STRING + dims: [ -1, -1 ] + } +] +ensemble_scheduling { + step [ + { + model_name: "preprocessing" + model_version: -1 + input_map { + key: "QUERY" + value: "INPUT_0" + } + input_map { + key: "REQUEST_OUTPUT_LEN" + value: "INPUT_1" + } + input_map { + key: "BAD_WORDS_DICT" + value: "INPUT_2" + } + input_map { + key: "STOP_WORDS_DICT" + value: "INPUT_3" + } + output_map { + key: "REQUEST_INPUT_LEN" + value: "_REQUEST_INPUT_LEN" + } + output_map { + key: "INPUT_ID" + value: "_INPUT_ID" + } + output_map { + key: "REQUEST_OUTPUT_LEN" + value: "_REQUEST_OUTPUT_LEN" + } + }, + { + model_name: "tensorrt_llm" + model_version: -1 + input_map { + key: "input_ids" + value: "_INPUT_ID" + } + input_map { + key: "input_lengths" + value: "_REQUEST_INPUT_LEN" + } + input_map { + key: "request_output_len" + value: "_REQUEST_OUTPUT_LEN" + } + input_map { + key: "end_id" + value: "end_id" + } + input_map { + key: "pad_id" + value: "pad_id" + } + input_map { + key: "runtime_top_k" + value: "runtime_top_k" + } + input_map { + key: "runtime_top_p" + value: "runtime_top_p" + } + input_map { + key: "temperature" + value: "temperature" + } + input_map { + key: "len_penalty" + value: "len_penalty" + } + input_map { + key: "repetition_penalty" + value: "repetition_penalty" + } + input_map { + key: "min_length" + value: "min_length" + } + input_map { + key: "presence_penalty" + value: "presence_penalty" + } + input_map { + key: "random_seed" + value: "random_seed" + } + input_map { + key: "beam_width" + value: "beam_width" + } + input_map { + key: "streaming" + value: "streaming" + } + output_map { + key: "output_ids" + value: "_TOKENS_BATCH" + } + }, + { + model_name: "postprocessing" + model_version: -1 + input_map { + key: "TOKENS_BATCH" + value: "_TOKENS_BATCH" + } + output_map { + key: "OUTPUT" + value: "OUTPUT_0" + } + } + ] +} diff --git a/all_models/inflight_batcher_llm/postprocessing/1/model.py b/all_models/inflight_batcher_llm/postprocessing/1/model.py new file mode 100644 index 00000000..032c12ae --- /dev/null +++ b/all_models/inflight_batcher_llm/postprocessing/1/model.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- +import json + +import numpy as np +import triton_python_backend_utils as pb_utils +from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to initialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + # Parse model configs + model_config = json.loads(args['model_config']) + tokenizer_dir = model_config['parameters']['tokenizer_dir'][ + 'string_value'] + tokenizer_type = model_config['parameters']['tokenizer_type'][ + 'string_value'] + + if tokenizer_type == 't5': + self.tokenizer = T5Tokenizer(vocab_file=tokenizer_dir, + padding_side='left') + elif tokenizer_type == 'auto': + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, + padding_side='left') + elif tokenizer_type == 'llama': + self.tokenizer = LlamaTokenizer.from_pretrained( + tokenizer_dir, legacy=False, padding_side='left') + else: + raise AttributeError( + f'Unexpected tokenizer type: {tokenizer_type}') + self.tokenizer.pad_token = self.tokenizer.eos_token + + # Parse model output configs + output_config = pb_utils.get_output_config_by_name( + model_config, "OUTPUT") + + # Convert Triton types to numpy types + self.output_dtype = pb_utils.triton_string_to_numpy( + output_config['data_type']) + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + + responses = [] + + # Every Python backend must iterate over everyone of the requests + # and create a pb_utils.InferenceResponse for each of them. + for idx, request in enumerate(requests): + # Get input tensors + tokens_batch = pb_utils.get_input_tensor_by_name( + request, 'TOKENS_BATCH').as_numpy() + + # Reshape Input + # tokens_batch = tokens_batch.reshape([-1, tokens_batch.shape[0]]) + # tokens_batch = tokens_batch.T + + # Postprocessing output data. + outputs = self._postprocessing(tokens_batch) + + # Create output tensors. You need pb_utils.Tensor + # objects to create pb_utils.InferenceResponse. + output_tensor = pb_utils.Tensor( + 'OUTPUT', + np.array(outputs).astype(self.output_dtype)) + + # Create InferenceResponse. You can set an error here in case + # there was a problem with handling this inference request. + # Below is an example of how you can set errors in inference + # response: + # + # pb_utils.InferenceResponse( + # output_tensors=..., TritonError("An error occurred")) + inference_response = pb_utils.InferenceResponse( + output_tensors=[output_tensor]) + responses.append(inference_response) + + # You should return a list of pb_utils.InferenceResponse. Length + # of this list must match the length of `requests` list. + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + print('Cleaning up...') + + def _postprocessing(self, tokens_batch): + outputs = [] + for beam_tokens in tokens_batch: + for tokens in beam_tokens: + output = self.tokenizer.decode(tokens) + outputs.append(output.encode('utf8')) + return outputs diff --git a/all_models/inflight_batcher_llm/postprocessing/config.pbtxt b/all_models/inflight_batcher_llm/postprocessing/config.pbtxt new file mode 100755 index 00000000..ab9e531e --- /dev/null +++ b/all_models/inflight_batcher_llm/postprocessing/config.pbtxt @@ -0,0 +1,38 @@ +name: "postprocessing" +backend: "python" +max_batch_size: 128 +input [ + { + name: "TOKENS_BATCH" + data_type: TYPE_INT32 + dims: [ -1, -1 ] + } +] +output [ + { + name: "OUTPUT" + data_type: TYPE_STRING + dims: [ -1, -1 ] + } +] + +parameters { + key: "tokenizer_dir" + value: { + string_value: "${tokenizer_dir}" + } +} + +parameters { + key: "tokenizer_type" + value: { + string_value: "${tokenizer_type}" + } +} + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] diff --git a/all_models/inflight_batcher_llm/preprocessing/1/model.py b/all_models/inflight_batcher_llm/preprocessing/1/model.py new file mode 100644 index 00000000..d9fd740e --- /dev/null +++ b/all_models/inflight_batcher_llm/preprocessing/1/model.py @@ -0,0 +1,202 @@ +# -*- coding: utf-8 -*- +import csv +import json + +import numpy as np +import torch +import triton_python_backend_utils as pb_utils +from torch.nn.utils.rnn import pad_sequence +from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer + +from tensorrt_llm.runtime import to_word_list_format + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to initialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + # Parse model configs + model_config = json.loads(args['model_config']) + tokenizer_dir = model_config['parameters']['tokenizer_dir'][ + 'string_value'] + tokenizer_type = model_config['parameters']['tokenizer_type'][ + 'string_value'] + + if tokenizer_type == 't5': + self.tokenizer = T5Tokenizer(vocab_file=tokenizer_dir, + padding_side='left') + elif tokenizer_type == 'auto': + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, + padding_side='left') + elif tokenizer_type == 'llama': + self.tokenizer = LlamaTokenizer.from_pretrained( + tokenizer_dir, legacy=False, padding_side='left') + else: + raise AttributeError( + f'Unexpected tokenizer type: {tokenizer_type}') + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.pad_id = self.tokenizer.encode(self.tokenizer.pad_token, + add_special_tokens=False)[0] + + # Parse model output configs and convert Triton types to numpy types + input_names = [ + "INPUT_ID", "REQUEST_INPUT_LEN", "BAD_WORDS_IDS", "STOP_WORDS_IDS" + ] + for input_name in input_names: + setattr( + self, + input_name.lower() + "_dtype", + pb_utils.triton_string_to_numpy( + pb_utils.get_output_config_by_name( + model_config, input_name)['data_type'])) + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + + responses = [] + + # Every Python backend must iterate over everyone of the requests + # and create a pb_utils.InferenceResponse for each of them. + for idx, request in enumerate(requests): + # Get input tensors + query = pb_utils.get_input_tensor_by_name(request, + 'QUERY').as_numpy() + request_output_len = pb_utils.get_input_tensor_by_name( + request, 'REQUEST_OUTPUT_LEN').as_numpy() + + bad_words_dict = pb_utils.get_input_tensor_by_name( + request, 'BAD_WORDS_DICT').as_numpy() + stop_words_dict = pb_utils.get_input_tensor_by_name( + request, 'STOP_WORDS_DICT').as_numpy() + + # Preprocessing input data. + input_id, request_input_len = self._create_request(query) + bad_words = to_word_list_format(bad_words_dict, self.tokenizer) + stop_words = to_word_list_format(stop_words_dict, self.tokenizer) + + # Create output tensors. You need pb_utils.Tensor + # objects to create pb_utils.InferenceResponse. + input_id_tensor = pb_utils.Tensor( + 'INPUT_ID', + np.array(input_id).astype(self.input_id_dtype)) + request_input_len_tensor = pb_utils.Tensor( + 'REQUEST_INPUT_LEN', + np.array(request_input_len).astype( + self.request_input_len_dtype)) + request_output_len_tensor = pb_utils.Tensor( + 'REQUEST_OUTPUT_LEN', request_output_len) + bad_words_ids_tensor = pb_utils.Tensor('BAD_WORDS_IDS', bad_words) + stop_words_ids_tensor = pb_utils.Tensor('STOP_WORDS_IDS', + stop_words) + + # Create InferenceResponse. You can set an error here in case + # there was a problem with handling this inference request. + # Below is an example of how you can set errors in inference + # response: + # + # pb_utils.InferenceResponse( + # output_tensors=..., TritonError("An error occurred")) + inference_response = pb_utils.InferenceResponse(output_tensors=[ + input_id_tensor, bad_words_ids_tensor, stop_words_ids_tensor, + request_input_len_tensor, request_output_len_tensor + ]) + responses.append(inference_response) + + # You should return a list of pb_utils.InferenceResponse. Length + # of this list must match the length of `requests` list. + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + print('Cleaning up...') + + def _create_request(self, query): + """ + query : batch string (2D numpy array) + """ + start_ids = [ + torch.IntTensor(self.tokenizer.encode(s[0].decode())) + for s in query + ] + start_lengths = torch.IntTensor([[len(ids)] for ids in start_ids]) + + start_ids = pad_sequence(start_ids, + batch_first=True, + padding_value=self.pad_id) + # input_len = min(start_lengths) + #attn_mask = torch.ones((batch_size, input_len, input_len)).tril() + + return start_ids, start_lengths + + def _create_word_list(self, word_dict): + flat_ids = [] + offsets = [] + for word_dict_item in word_dict: + item_flat_ids = [] + item_offsets = [] + + words = list(csv.reader([word_dict_item[0].decode()]))[0] + for word in words: + ids = self._encode(word) + + if len(ids) == 0: + continue + + item_flat_ids += ids + item_offsets.append(len(ids)) + + flat_ids.append(np.array(item_flat_ids)) + offsets.append(np.cumsum(np.array(item_offsets))) + + pad_to = max(1, max(len(ids) for ids in flat_ids)) + + for i, (ids, offs) in enumerate(zip(flat_ids, offsets)): + flat_ids[i] = np.pad(ids, (0, pad_to - len(ids)), + constant_values=0) + offsets[i] = np.pad(offs, (0, pad_to - len(offs)), + constant_values=-1) + + return np.array([flat_ids, offsets], dtype="int32").transpose( + (1, 0, 2)) + + def _encode(self, sentence): + sentence = sentence.decode() if isinstance(sentence, + bytes) else sentence + return self.tokenizer.encode(sentence) diff --git a/all_models/inflight_batcher_llm/preprocessing/config.pbtxt b/all_models/inflight_batcher_llm/preprocessing/config.pbtxt new file mode 100644 index 00000000..c5ac0c55 --- /dev/null +++ b/all_models/inflight_batcher_llm/preprocessing/config.pbtxt @@ -0,0 +1,73 @@ +name: "preprocessing" +backend: "python" +max_batch_size: 128 +input [ + { + name: "QUERY" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "BAD_WORDS_DICT" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "STOP_WORDS_DICT" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "REQUEST_OUTPUT_LEN" + data_type: TYPE_UINT32 + dims: [ -1 ] + } +] +output [ + { + name: "INPUT_ID" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "REQUEST_INPUT_LEN" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "BAD_WORDS_IDS" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + }, + { + name: "STOP_WORDS_IDS" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + }, + { + name: "REQUEST_OUTPUT_LEN" + data_type: TYPE_UINT32 + dims: [ -1 ] + } +] + +parameters { + key: "tokenizer_dir" + value: { + string_value: "${tokenizer_dir}" + } +} + +parameters { + key: "tokenizer_type" + value: { + string_value: "${tokenizer_type}" + } +} + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] diff --git a/all_models/inflight_batcher_llm/tensorrt_llm/1/.gitkeep b/all_models/inflight_batcher_llm/tensorrt_llm/1/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt b/all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt new file mode 100644 index 00000000..36309f2f --- /dev/null +++ b/all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt @@ -0,0 +1,164 @@ +name: "tensorrt_llm" +backend: "inflight_batcher_llm" +max_batch_size: 128 + +model_transaction_policy { + decoupled: ${decoupled_mode} +} + +input [ + { + name: "input_ids" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "input_lengths" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + }, + { + name: "request_output_len" + data_type: TYPE_UINT32 + dims: [ 1 ] + }, + { + name: "end_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "pad_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "beam_width" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "temperature" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_k" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_p" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "len_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "repetition_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "min_length" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "presence_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "random_seed" + data_type: TYPE_UINT64 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "stop" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + }, + { + name: "streaming" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + } +# { +# name: "output_log_probs" +# data_type: TYPE_BOOL +# dims: [ 1 ] +# reshape: { shape: [ ] } +# optional: true +# } +] +output [ + { + name: "output_ids" + data_type: TYPE_INT32 + dims: [ -1, -1 ] + } +# { +# name: "log_probs" +# data_type: TYPE_FP32 +# dims: [ -1, -1 ] +# } +] +instance_group [ + { + count: 1 + kind : KIND_CPU + } +] +parameters: { + key: "max_beam_width" + value: { + string_value: "1" + } +} +parameters: { + key: "FORCE_CPU_ONLY_INPUT_TENSORS" + value: { + string_value: "no" + } +} +parameters: { + key: "gpt_model_type" + value: { + string_value: "inflight_fused_batching" + } +} +parameters: { + key: "gpt_model_path" + value: { + string_value: "${engine_dir}" + } +} diff --git a/dockerfile/Dockerfile.trt_llm_backend b/dockerfile/Dockerfile.trt_llm_backend new file mode 100644 index 00000000..22945540 --- /dev/null +++ b/dockerfile/Dockerfile.trt_llm_backend @@ -0,0 +1,51 @@ +ARG BASE_IMAGE=nvcr.io/nvidia/tritonserver:23.07-py3 + +FROM ${BASE_IMAGE} as base + +COPY requirements.txt /tmp/ +RUN pip3 install -r /tmp/requirements.txt --extra-index-url https://pypi.ngc.nvidia.com + +# Remove prevous TRT installation +# We didn't remove libnvinfer* here because tritonserver depends on the pre-installed libraries. +RUN apt-get remove --purge -y tensorrt* +RUN pip uninstall -y tensorrt + +# Download and install TensorRT +RUN wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/secure/9.0.1/tars/TensorRT-9.0.1.4.Linux.x86_64-gnu.cuda-12.2.tar.gz -P /workspace +RUN tar -xvf /workspace/TensorRT-9.0.1.4.Linux.x86_64-gnu.cuda-12.2.tar.gz -C /usr/local/ && mv /usr/local/TensorRT-9.0.1.4 /usr/local/tensorrt +RUN pip install /usr/local/tensorrt/python/tensorrt-9.0.1*cp310-none-linux_x86_64.whl && rm -fr /workspace/TensorRT-9.0.1.4.Linux.x86_64-gnu.cuda-12.2.tar.gz +ENV LD_LIBRARY_PATH=/usr/local/tensorrt/lib/:$LD_LIBRARY_PATH +ENV TRT_ROOT=/usr/local/tensorrt + +FROM base as dev + +# Download and install polygraphy, only required if you need to run TRT-LLM python tests +RUN pip install https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/secure/9.0.1/tars/polygraphy-0.48.1-py2.py3-none-any.whl + +# CMake +RUN wget https://github.com/Kitware/CMake/releases/download/v3.18.1/cmake-3.18.1-Linux-x86_64.sh +RUN bash cmake-3.18.1-Linux-x86_64.sh --prefix=/usr/local --exclude-subdir +ENV PATH="/usr/local/bin:${PATH}" + +COPY tensorrt_llm/requirements-dev.txt /tmp/ +RUN pip install -r /tmp/requirements-dev.txt --extra-index-url https://pypi.ngc.nvidia.com + +FROM dev as trt_llm_builder + +WORKDIR /app +COPY scripts scripts +COPY tensorrt_llm tensorrt_llm +RUN cd tensorrt_llm; python3 scripts/build_wheel.py --trt_root="${TRT_ROOT}" -i; cd .. + +FROM trt_llm_builder as trt_llm_backend_builder + +WORKDIR /app/ +COPY inflight_batcher_llm inflight_batcher_llm +RUN cd inflight_batcher_llm; bash scripts/build.sh; cd .. + +FROM trt_llm_backend_builder as final + +#Install inflight batcher backend +RUN mkdir /opt/tritonserver/backends/inflight_batcher_llm +RUN mkdir -p /opt/tensorrt_llm/lib +COPY --from=trt_llm_backend_builder /app/inflight_batcher_llm/build/libtriton_inflight_batcher_llm.so /opt/tritonserver/backends/inflight_batcher_llm diff --git a/inflight_batcher_llm/CMakeLists.txt b/inflight_batcher_llm/CMakeLists.txt new file mode 100644 index 00000000..acd68a33 --- /dev/null +++ b/inflight_batcher_llm/CMakeLists.txt @@ -0,0 +1,306 @@ +# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cmake_minimum_required(VERSION 3.17) +include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules/set_ifndef.cmake) + +set_ifndef(TRTLLM_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../tensorrt_llm) +include(${TRTLLM_DIR}/cpp/cmake/modules/find_library_create_target.cmake) + +project(tutorialinflight_batcher_llmbackend LANGUAGES C CXX) + +# +# Options +# +# Must include options required for this project as well as any +# projects included in this one by FetchContent. +# +# GPU support is disabled by default because inflight_batcher_llm backend doesn't +# use GPUs. +# +option(TRITON_ENABLE_GPU "Enable GPU support in backend" OFF) +option(TRITON_ENABLE_STATS "Include statistics collections in backend" ON) + +# The tags here should match the tag of base image in Dockerfile +set(TRITON_COMMON_REPO_TAG "r23.07" CACHE STRING "Tag for triton-inference-server/common repo") +set(TRITON_CORE_REPO_TAG "r23.07" CACHE STRING "Tag for triton-inference-server/core repo") +set(TRITON_BACKEND_REPO_TAG "r23.07" CACHE STRING "Tag for triton-inference-server/backend repo") + +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) +endif() + +set(COMMON_HEADER_DIRS ${PROJECT_SOURCE_DIR} ${CUDA_PATH}/include) +message(STATUS "COMMON_HEADER_DIRS: ${COMMON_HEADER_DIRS}") + +# +# Dependencies +# +# FetchContent requires us to include the transitive closure of all +# repos that we depend on so that we can override the tags. +# +include(FetchContent) + +FetchContent_Declare( + repo-common + GIT_REPOSITORY https://github.com/triton-inference-server/common.git + GIT_TAG ${TRITON_COMMON_REPO_TAG} + GIT_SHALLOW ON +) +FetchContent_Declare( + repo-core + GIT_REPOSITORY https://github.com/triton-inference-server/core.git + GIT_TAG ${TRITON_CORE_REPO_TAG} + GIT_SHALLOW ON +) +FetchContent_Declare( + repo-backend + GIT_REPOSITORY https://github.com/triton-inference-server/backend.git + GIT_TAG ${TRITON_BACKEND_REPO_TAG} + GIT_SHALLOW ON +) +FetchContent_MakeAvailable(repo-common repo-core repo-backend) + +# +# The backend must be built into a shared library. Use an ldscript to +# hide all symbols except for the TRITONBACKEND API. +# +configure_file(src/libtriton_inflight_batcher_llm.ldscript libtriton_inflight_batcher_llm.ldscript COPYONLY) + +add_library( + triton-inflight_batcher_llm-backend SHARED + src/inflight_batcher_llm.cc +) + +add_library( + InflightBatcherLLM::triton-inflight_batcher_llm-backend ALIAS triton-inflight_batcher_llm-backend +) + +enable_language(CUDA) + +find_package(CUDA ${CUDA_REQUIRED_VERSION} REQUIRED) + +find_library( + CUDNN_LIB cudnn + HINTS ${CUDA_TOOLKIT_ROOT_DIR} ${CUDNN_ROOT_DIR} + PATH_SUFFIXES lib64 lib) +find_library( + CUBLAS_LIB cublas + HINTS ${CUDA_TOOLKIT_ROOT_DIR} + PATH_SUFFIXES lib64 lib lib/stubs) +find_library( + CUBLASLT_LIB cublasLt + HINTS ${CUDA_TOOLKIT_ROOT_DIR} + PATH_SUFFIXES lib64 lib lib/stubs) +find_library( + CUDART_LIB cudart + HINTS ${CUDA_TOOLKIT_ROOT_DIR} + PATH_SUFFIXES lib lib64) +find_library( + CUDA_DRV_LIB cuda + HINTS ${CUDA_TOOLKIT_ROOT_DIR} + PATH_SUFFIXES lib lib64 lib/stubs lib64/stubs) +set(CUDA_LIBRARIES ${CUDART_LIB}) + +find_package(MPI REQUIRED) +message(STATUS "Using MPI_INCLUDE_PATH: ${MPI_INCLUDE_PATH}") +message(STATUS "Using MPI_LIBRARIES: ${MPI_LIBRARIES}") + +# NCCL dependencies +set_ifndef(NCCL_LIB_DIR /usr/lib/x86_64-linux-gnu/) +set_ifndef(NCCL_INCLUDE_DIR /usr/include/) +find_library(NCCL_LIB nccl HINTS ${NCCL_LIB_DIR}) + +set_ifndef(TRT_LIB_DIR ${CMAKE_BINARY_DIR}) +set_ifndef(TRT_INCLUDE_DIR /usr/include/x86_64-linux-gnu) +set(TRT_LIB nvinfer) +find_library_create_target(${TRT_LIB} nvinfer SHARED ${TRT_LIB_DIR}) +find_library_create_target(nvuffparser nvparsers SHARED ${TRT_LIB_DIR}) + + +file(STRINGS "${TRT_INCLUDE_DIR}/NvInferVersion.h" VERSION_STRINGS + REGEX "#define NV_TENSORRT_.*") +foreach(TYPE MAJOR MINOR PATCH BUILD) + string(REGEX MATCH "NV_TENSORRT_${TYPE} [0-9]" TRT_TYPE_STRING + ${VERSION_STRINGS}) + string(REGEX MATCH "[0-9]" TRT_${TYPE} ${TRT_TYPE_STRING}) +endforeach(TYPE) + +foreach(TYPE MAJOR MINOR PATCH) + string(REGEX MATCH "NV_TENSORRT_SONAME_${TYPE} [0-9]" TRT_TYPE_STRING + ${VERSION_STRINGS}) + string(REGEX MATCH "[0-9]" TRT_SO_${TYPE} ${TRT_TYPE_STRING}) +endforeach(TYPE) + +set(TRT_VERSION + "${TRT_MAJOR}.${TRT_MINOR}.${TRT_PATCH}" + CACHE STRING "TensorRT project version") +set(TRT_SOVERSION + "${TRT_SO_MAJOR}" + CACHE STRING "TensorRT library so version") +message( + STATUS + "Building for TensorRT version: ${TRT_VERSION}, library version: ${TRT_SOVERSION}" +) + +list(APPEND COMMON_HEADER_DIRS ${TORCH_INCLUDE_DIRS} ${TRT_INCLUDE_DIR}) +include_directories(${COMMON_HEADER_DIRS}) + + +target_include_directories( + triton-inflight_batcher_llm-backend + PRIVATE + ${TRTLLM_DIR}/cpp + ${TRTLLM_DIR}/cpp/include + ${CMAKE_CURRENT_SOURCE_DIR}/src + ${CUDA_INCLUDE_DIRS} + ${CUDNN_ROOT_DIR}/include + ${NCCL_INCLUDE_DIR} + ${3RDPARTY_DIR}/cutlass/include + ${MPI_INCLUDE_PATH} + ${COMMON_HEADER_DIR} +) + +#include_directories(${CUDA_INCLUDE_DIRS} ${CUDNN_ROOT_DIR}/include +# ${NCCL_INCLUDE_DIR} ${3RDPARTY_DIR}/cutlass/include) + +target_compile_features(triton-inflight_batcher_llm-backend PRIVATE cxx_std_17) +target_compile_options( + triton-inflight_batcher_llm-backend PRIVATE + $<$,$,$>: + -Wall -Wextra -Wno-unused-parameter -Wno-type-limits> + $<$:/Wall /D_WIN32_WINNT=0x0A00 /EHsc> +) + +add_library(tensorrt_llm STATIC IMPORTED) +set_property(TARGET tensorrt_llm PROPERTY IMPORTED_LOCATION "${TRTLLM_DIR}/cpp/build/tensorrt_llm/libtensorrt_llm_static.a") + +add_library(tensorrt_llm_batch_manager STATIC IMPORTED) +execute_process( + COMMAND + ${Python3_EXECUTABLE} "-c" + "import torch; print(torch.compiled_with_cxx11_abi(),end='');" + RESULT_VARIABLE _PYTHON_SUCCESS + OUTPUT_VARIABLE USE_CXX11_ABI) +message(STATUS "USE_CXX11_ABI: ${USE_CXX11_ABI}") +if(USE_CXX11_ABI) +set_property(TARGET tensorrt_llm_batch_manager PROPERTY IMPORTED_LOCATION "${TRTLLM_DIR}/cpp/tensorrt_llm/batch_manager/libtensorrt_llm_batch_manager_static.a") +else() +set_property(TARGET tensorrt_llm_batch_manager PROPERTY IMPORTED_LOCATION "${TRTLLM_DIR}/cpp/tensorrt_llm/batch_manager/libtensorrt_llm_batch_manager_static.pre_cxx11.a") +endif() + + +add_library(nvinfer_plugin_tensorrt_llm SHARED IMPORTED) +set_property(TARGET nvinfer_plugin_tensorrt_llm PROPERTY IMPORTED_LOCATION "${TRTLLM_DIR}/cpp/build/tensorrt_llm/plugins/libnvinfer_plugin_tensorrt_llm.so") + +target_link_libraries( + triton-inflight_batcher_llm-backend + PRIVATE + tensorrt_llm_batch_manager + tensorrt_llm + triton-core-serverapi # from repo-core + triton-core-backendapi # from repo-core + triton-core-serverstub # from repo-core + triton-backend-utils # from repo-backend + ${MPI_LIBRARIES} + nvinfer + nvinfer_plugin_tensorrt_llm +) + +FetchContent_Declare( + json + GIT_REPOSITORY https://github.com/nlohmann/json.git + GIT_TAG v3.11.2) + +FetchContent_MakeAvailable(json) + +target_link_libraries(triton-inflight_batcher_llm-backend PRIVATE nlohmann_json::nlohmann_json) + +if(WIN32) + set_target_properties( + triton-inflight_batcher_llm-backend PROPERTIES + POSITION_INDEPENDENT_CODE ON + OUTPUT_NAME triton_inflight_batcher_llm + ) +else() + set_target_properties( + triton-inflight_batcher_llm-backend PROPERTIES + POSITION_INDEPENDENT_CODE ON + OUTPUT_NAME triton_inflight_batcher_llm + LINK_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libtriton_inflight_batcher_llm.ldscript + LINK_FLAGS "-Wl,--version-script libtriton_inflight_batcher_llm.ldscript" + ) +endif() + +# +# Install +# +include(GNUInstallDirs) +set(INSTALL_CONFIGDIR ${CMAKE_INSTALL_LIBDIR}/cmake/InflightBatcherLLMBackend) + +install( + TARGETS + triton-inflight_batcher_llm-backend + EXPORT + triton-inflight_batcher_llm-backend-targets + LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/inflight_batcher_llm + RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/inflight_batcher_llm +) + +install( + EXPORT + triton-inflight_batcher_llm-backend-targets + FILE + InflightBatcherLLMBackendTargets.cmake + NAMESPACE + InflightBatcherLLMBackend:: + DESTINATION + ${INSTALL_CONFIGDIR} +) + +include(CMakePackageConfigHelpers) +configure_package_config_file( + ${CMAKE_CURRENT_LIST_DIR}/cmake/InflightBatcherLLMBackendConfig.cmake.in + ${CMAKE_CURRENT_BINARY_DIR}/InflightBatcherLLMBackendConfig.cmake + INSTALL_DESTINATION ${INSTALL_CONFIGDIR} +) + +install( + FILES + ${CMAKE_CURRENT_BINARY_DIR}/InflightBatcherLLMBackendConfig.cmake + DESTINATION ${INSTALL_CONFIGDIR} +) + +# +# Export from build tree +# +export( + EXPORT triton-inflight_batcher_llm-backend-targets + FILE ${CMAKE_CURRENT_BINARY_DIR}/InflightBatcherLLMBackendTargets.cmake + NAMESPACE InflightBatcherLLMBackend:: +) + +export(PACKAGE InflightBatcherLLMBackend) diff --git a/inflight_batcher_llm/README.md b/inflight_batcher_llm/README.md new file mode 100644 index 00000000..4f4a06e2 --- /dev/null +++ b/inflight_batcher_llm/README.md @@ -0,0 +1,160 @@ +# Instructions to run TRT-LLM in-flight batching Triton backend: + +## Build TensorRT-LLM engine for inflight batching + +To configure a Triton server that runs a model using TensorRT-LLM, it is needed to compile a TensorRT-LLM engine for that model. + +For example, for LLaMA 7B, change to the `tensorrt_llm/examples/llama` directory: + +``` +cd tensorrt_llm/examples/llama +``` +Prepare the checkpoint of the model by following the instructions [here](https://huggingface.co/docs/transformers/main/en/model_doc/llama) and store it in a model directory. Then, create the engine: + +``` +python build.py --model_dir ${model_directory} \ + --dtype bfloat16 \ + --use_gpt_attention_plugin bfloat16 \ + --use_inflight_batching \ + --paged_kv_cache \ + --remove_input_padding \ + --use_gemm_plugin bfloat16 \ + --output_dir engines/bf16/1-gpu/ +``` + +To disable the support for in-flight batching (i.e. use the V1 batching mode), remove `--use_inflight_batching`. + +Similarly, for a GPT model, change to `tensorrt_llm/examples/gpt` directory: +``` +cd tensorrt_llm/examples/gpt + +``` +Prepare the model checkpoint following the instructions in the README file, store it in a model directory and build the TRT engine with: + +``` +python3 build.py --model_dir=${model_directory} \ + --dtype float16 \ + --use_inflight_batching \ + --use_gpt_attention_plugin float16 \ + --paged_kv_cache \ + --use_gemm_plugin float16 \ + --remove_input_padding \ + --use_layernorm_plugin float16 \ + --hidden_act gelu \ + --output_dir=engines/fp16/1-gpu +``` + +## Build the Triton server image that includes the TRT-LLM in-flight batching backend: + +From `tensorrt_llm_backend` root folder: + +``` +docker build -f dockerfile/Dockerfile.trt_llm_backend -t tritonserver:w_trt_llm_backend . +``` + +## Create a model repository folder + +First run: +``` +rm -rf triton_model_repo +mkdir triton_model_repo +cp -R all_models/inflight_batcher_llm/ triton_model_repo +``` + +Then copy the TRT engine to `triton_model_repo/tensorrt_llm/1/`. For example for the LLaMA 7B example above, run: + +``` +cp -R tensorrt_llm/examples/llama/engines/bf16/1-gpu/ triton_model_repo/tensorrt_llm/1 +``` + +For the GPT example above, run: +``` +cp -R tensorrt_llm/examples/gpt/engines/fp16/1-gpu/ triton_model_repo/tensorrt_llm/1 +``` + + +Edit the `triton_model_repo/tensorrt_llm/config.pbtxt` file and replace `${decoupled_mode}` with `True` or `False`, and `${engine_dir}` with `/triton_model_repo/tensorrt_llm/1/` since the `triton_model_repo` folder created above will be mounted to `/triton_model_repo` in the Docker container. Decoupled mode must be set to true if using the streaming option from the client. + + +To use V1 batching, the `config.pbtxt` should have: +``` +parameters: { + key: "gpt_model_type" + value: { + string_value: "V1" + } +} +``` + +For in-flight batching, use: +``` +parameters: { + key: "gpt_model_type" + value: { + string_value: "inflight_fused_batching" + } +} +``` + +## Launch the Triton server container using the model_repository you just created + +``` +docker run --rm -it --net host --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 --gpus='"'device=0'"' -v $(pwd)/triton_model_repo:/triton_model_repo tritonserver:w_trt_llm_backend /bin/bash -c "tritonserver --model-repository=/triton_model_repo" +``` + +## Run the provided client to send a request + +You can test the inflight batcher server with the provided reference python client as following: +``` +python3 inflight_batcher_llm/client/inflight_batcher_llm_client.py --request-output-len 200 +``` + +You can also stop the generation process early by using the `--stop-after-ms` option to send a stop request after a few milliseconds: + +``` +python inflight_batcher_llm_client.py --stop-after-ms 200 --request-output-len 200 +``` + +You will find that the generation process is stopped early and therefore the number of generated tokens is lower than 200. + +You can have a look at the client code to see how early stopping is achieved. + +## Run the e2e/identity test to benchmark + +### End to end test +End to end test script sends requests to deployed ensemble model. + +Ensemble model is ensembled by three models: preprocessing, tensorrt_llm and postprocessing. +* preprocessing: Tokenizing, meaning the conversion from prompts(string) to input_ids(list of ints). +* tensorrt_llm: Inferencing. +* postprocessing: De-tokenizing, meaning the conversion from output_ids(list of ints) to outputs(string). + +The end to end latency includes the total latency of the three parts of an ensemble model. + +``` +cd tools/inflight_batcher_llm +python3 end_to_end_test.py --dataset +``` +Expected outputs +``` +[INFO] Functionality test succeed. +[INFO] Warm up for benchmarking. +[INFO] Start benchmarking on 125 prompts. +[INFO] Total Latency: 11099.243 ms +``` + +### Identity test + +Identity test script sends requests directly to deployed tensorrt_llm model, the identity test latency indicates the inference latency of TensorRT-LLM, not including the pre/post-processing latency which is usually handled by a third-party library such as HuggingFace. + +``` +cd tools/inflight_batcher_llm +python3 identity_test.py --dataset +``` +Expected outputs +``` +[INFO] Warm up for benchmarking. +[INFO] Start benchmarking on 125 prompts. +[INFO] Total Latency: 10213.462 ms +``` +*Please note that the expected outputs in that document are only for reference, specific performance numbers depend on the GPU you're using.* diff --git a/inflight_batcher_llm/client/inflight_batcher_llm_client.py b/inflight_batcher_llm/client/inflight_batcher_llm_client.py new file mode 100755 index 00000000..54c99124 --- /dev/null +++ b/inflight_batcher_llm/client/inflight_batcher_llm_client.py @@ -0,0 +1,429 @@ +#!/usr/bin/env python +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import queue +import sys +import time +from functools import partial + +import numpy as np +import tritonclient.grpc as grpcclient +import tritonclient.http as httpclient +from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer +from tritonclient.utils import InferenceServerException, np_to_triton_dtype + +# +# Simple streaming client for TRT-LLM inflight bacthing backend +# +# In order for this code to work properly, config.pbtxt must contain these values: +# +# model_transaction_policy { +# decoupled: True +# } +# +# parameters: { +# key: "gpt_model_type" +# value: { +# string_value: "inflight_batching" +# } +# } +# +# In order for gpt_model_type 'inflight_batching' to work, you must copy engine from +# +# tensorrt_llm/cpp/tests/resources/models/rt_engine/gpt2/fp16-inflight-batching-plugin/1-gpu/ +# + + +class UserData: + + def __init__(self): + self._completed_requests = queue.Queue() + + +def prepare_tensor(name, input, protocol): + client_util = httpclient if protocol == "http" else grpcclient + t = client_util.InferInput(name, input.shape, + np_to_triton_dtype(input.dtype)) + t.set_data_from_numpy(input) + return t + + +def prepare_inputs(input_ids_data, input_lengths_data, request_output_len_data, + beam_width_data, temperature_data, streaming_data): + protocol = 'grpc' + inputs = [ + prepare_tensor("input_ids", input_ids_data, protocol), + prepare_tensor("input_lengths", input_lengths_data, protocol), + prepare_tensor("request_output_len", request_output_len_data, + protocol), + prepare_tensor("beam_width", beam_width_data, protocol), + prepare_tensor("temperature", temperature_data, protocol), + prepare_tensor("streaming", streaming_data, protocol), + ] + + return inputs + + +def prepare_stop_signals(): + + inputs = [ + grpcclient.InferInput('input_ids', [1, 1], "INT32"), + grpcclient.InferInput('input_lengths', [1, 1], "INT32"), + grpcclient.InferInput('request_output_len', [1, 1], "UINT32"), + grpcclient.InferInput('stop', [1, 1], "BOOL"), + ] + + inputs[0].set_data_from_numpy(np.empty([1, 1], dtype=np.int32)) + inputs[1].set_data_from_numpy(np.zeros([1, 1], dtype=np.int32)) + inputs[2].set_data_from_numpy(np.array([[0]], dtype=np.uint32)) + inputs[3].set_data_from_numpy(np.array([[True]], dtype='bool')) + + return inputs + + +# Define the callback function. Note the last two parameters should be +# result and error. InferenceServerClient would povide the results of an +# inference as grpcclient.InferResult in result. For successful +# inference, error will be None, otherwise it will be an object of +# tritonclientutils.InferenceServerException holding the error details +def callback(user_data, result, error): + if error: + user_data._completed_requests.put(error) + else: + user_data._completed_requests.put(result) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-v", + "--verbose", + action="store_true", + required=False, + default=False, + help="Enable verbose output", + ) + parser.add_argument( + "-u", + "--url", + type=str, + required=False, + default="localhost:8001", + help="Inference server URL. Default is localhost:8001.", + ) + parser.add_argument( + '--text', + type=str, + required=False, + default='Born in north-east France, Soyer trained as a', + help='Input text') + parser.add_argument( + "-s", + "--ssl", + action="store_true", + required=False, + default=False, + help="Enable SSL encrypted channel to the server", + ) + parser.add_argument( + "-t", + "--stream-timeout", + type=float, + required=False, + default=None, + help="Stream timeout in seconds. Default is None.", + ) + parser.add_argument( + "-r", + "--root-certificates", + type=str, + required=False, + default=None, + help="File holding PEM-encoded root certificates. Default is None.", + ) + parser.add_argument( + "-p", + "--private-key", + type=str, + required=False, + default=None, + help="File holding PEM-encoded private key. Default is None.", + ) + parser.add_argument( + "-x", + "--certificate-chain", + type=str, + required=False, + default=None, + help="File holding PEM-encoded certificate chain. Default is None.", + ) + parser.add_argument( + "-C", + "--grpc-compression-algorithm", + type=str, + required=False, + default=None, + help= + "The compression algorithm to be used when sending request to server. Default is None.", + ) + parser.add_argument( + "-S", + "--streaming", + action="store_true", + required=False, + default=False, + help="Enable streaming mode. Default is False.", + ) + parser.add_argument( + "-c", + "--check-output", + action="store_true", + required=False, + default=False, + help="Enable check of output ids for CI", + ) + + parser.add_argument( + "-b", + "--beam-width", + required=False, + type=int, + default=1, + help="Beam width value", + ) + parser.add_argument( + "--temperature", + type=float, + required=False, + default=1.0, + help="temperature value", + ) + parser.add_argument( + "--request-output-len", + type=int, + required=False, + default=16, + help="temperature value", + ) + parser.add_argument( + '--stop-after-ms', + type=int, + required=False, + default=0, + help='Early stop the generation after a few milliseconds') + parser.add_argument('--tokenizer_dir', + type=str, + required=True, + help='Specify tokenizer directory') + parser.add_argument('--tokenizer_type', + type=str, + default='auto', + required=False, + choices=['auto', 't5', 'llama'], + help='Specify tokenizer type') + + FLAGS = parser.parse_args() + + print('=========') + if FLAGS.tokenizer_type == 't5': + tokenizer = T5Tokenizer(vocab_file=FLAGS.tokenizer_dir, + padding_side='left') + elif FLAGS.tokenizer_type == 'auto': + tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer_dir, + padding_side='left') + elif FLAGS.tokenizer_type == 'llama': + tokenizer = LlamaTokenizer.from_pretrained(FLAGS.tokenizer_dir, + legacy=False, + padding_side='left') + else: + raise AttributeError( + f'Unexpected tokenizer type: {FLAGS.tokenizer_type}') + tokenizer.pad_token = tokenizer.eos_token + pad_id = tokenizer.encode(tokenizer.pad_token, add_special_tokens=False)[0] + end_id = tokenizer.encode(tokenizer.eos_token, add_special_tokens=False)[0] + + input_ids = [tokenizer.encode(FLAGS.text)] + input_ids_data = np.array(input_ids, dtype=np.int32) + input_lengths = [[len(ii)] for ii in input_ids] + input_lengths_data = np.array(input_lengths, dtype=np.int32) + request_output_len = [[FLAGS.request_output_len]] + request_output_len_data = np.array(request_output_len, dtype=np.uint32) + beam_width = [[FLAGS.beam_width]] + beam_width_data = np.array(beam_width, dtype=np.uint32) + temperature = [[FLAGS.temperature]] + temperature_data = np.array(temperature, dtype=np.float32) + streaming = [[FLAGS.streaming]] + streaming_data = np.array(streaming, dtype=bool) + + inputs = prepare_inputs(input_ids_data, input_lengths_data, + request_output_len_data, beam_width_data, + temperature_data, streaming_data) + + if FLAGS.stop_after_ms > 0: + stop_inputs = prepare_stop_signals() + else: + stop_inputs = None + + request_id = "12345" + + expected_output_ids = [ + input_ids[0] + [ + 21221, 290, 257, 4255, 379, 262, 1957, 7072, 11, 4689, 347, 2852, + 2564, 494, 13, 679 + ] + ] + if FLAGS.streaming: + actual_output_ids = [input_ids[0]] + else: + actual_output_ids = [] + + user_data = UserData() + with grpcclient.InferenceServerClient( + url=FLAGS.url, + verbose=FLAGS.verbose, + ssl=FLAGS.ssl, + root_certificates=FLAGS.root_certificates, + private_key=FLAGS.private_key, + certificate_chain=FLAGS.certificate_chain, + ) as triton_client: + try: + + if FLAGS.streaming: + + # Establish stream + triton_client.start_stream( + callback=partial(callback, user_data), + stream_timeout=FLAGS.stream_timeout, + ) + # Send request + triton_client.async_stream_infer( + 'tensorrt_llm', + inputs, + request_id=request_id, + ) + + if stop_inputs is not None: + + time.sleep(FLAGS.stop_after_ms / 1000.0) + + triton_client.async_stream_infer( + 'tensorrt_llm', + stop_inputs, + request_id=request_id, + parameters={'Streaming': FLAGS.streaming}) + + #Wait for server to close the stream + triton_client.stop_stream() + + # Parse the responses + while True: + try: + result = user_data._completed_requests.get(block=False) + except Exception: + break + + if type(result) == InferenceServerException: + print("Received an error from server:") + print(result) + else: + output_ids = result.as_numpy('output_ids') + + if output_ids is not None: + if (FLAGS.streaming): + # Only one beam is supported + tokens = list(output_ids[0][0]) + actual_output_ids[ + 0] = actual_output_ids[0] + tokens + else: + for beam_output_ids in output_ids[0]: + tokens = list(beam_output_ids) + actual_output_ids.append(tokens) + else: + print("Got cancellation response from server") + else: + # Send request + triton_client.async_infer( + 'tensorrt_llm', + inputs, + request_id=request_id, + callback=partial(callback, user_data), + parameters={'Streaming': FLAGS.streaming}) + + if stop_inputs is not None: + + time.sleep(FLAGS.stop_after_ms / 1000.0) + + triton_client.async_infer( + 'tensorrt_llm', + stop_inputs, + request_id=request_id, + callback=partial(callback, user_data), + parameters={'Streaming': FLAGS.streaming}) + + processed_count = 0 + expected_responses = 1 + (1 if stop_inputs is not None else 0) + while processed_count < expected_responses: + try: + result = user_data._completed_requests.get() + print("Got completed request", flush=True) + except Exception: + break + + if type(result) == InferenceServerException: + print("Received an error from server:") + print(result) + else: + output_ids = result.as_numpy('output_ids') + if output_ids is not None: + for beam_output_ids in output_ids[0]: + tokens = list(beam_output_ids) + actual_output_ids.append(tokens) + else: + print("Got response for cancellation request") + + processed_count = processed_count + 1 + except Exception as e: + print("channel creation failed: " + str(e)) + sys.exit() + + passed = True + + print("output_ids = ", actual_output_ids) + output_ids = output_ids.reshape( + (output_ids.size, )).tolist()[input_ids_data.shape[1]:] + output_text = tokenizer.decode(output_ids) + print(f'Input: {FLAGS.text}') + print(f'Output: {output_text}') + if (FLAGS.check_output): + passed = (actual_output_ids == expected_output_ids) + print("expected_output_ids = ", expected_output_ids) + print("\n=====") + print("PASS!" if passed else "FAIL!") + print("=====") + + sys.exit(not passed) diff --git a/inflight_batcher_llm/cmake/InflightBatcherLLMBackendConfig.cmake.in b/inflight_batcher_llm/cmake/InflightBatcherLLMBackendConfig.cmake.in new file mode 100644 index 00000000..eed42d0d --- /dev/null +++ b/inflight_batcher_llm/cmake/InflightBatcherLLMBackendConfig.cmake.in @@ -0,0 +1,39 @@ +# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +include(CMakeFindDependencyMacro) + +get_filename_component( + TUTORIALINFLIGHTBATCHERLLMBACKEND_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH +) + +list(APPEND CMAKE_MODULE_PATH ${TUTORIALINFLIGHTBATCHERLLMBACKEND_CMAKE_DIR}) + +if(NOT TARGET TutorialInflightBatcherLLMBackend::triton-minimal-backend) + include("${TUTORIALINFLIGHTBATCHERLLMBACKEND_CMAKE_DIR}/TutorialInflightBatcherLLMBackendTargets.cmake") +endif() + +set(TUTORIALINFLIGHTBATCHERLLMBACKEND_LIBRARIES TutorialInflightBatcherLLMBackend::triton-minimal-backend) diff --git a/inflight_batcher_llm/cmake/modules/set_ifndef.cmake b/inflight_batcher_llm/cmake/modules/set_ifndef.cmake new file mode 100644 index 00000000..bd8f0a3e --- /dev/null +++ b/inflight_batcher_llm/cmake/modules/set_ifndef.cmake @@ -0,0 +1,24 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & +# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# + +function(set_ifndef variable value) + if(NOT DEFINED ${variable}) + set(${variable} + ${value} + PARENT_SCOPE) + endif() +endfunction() diff --git a/inflight_batcher_llm/scripts/build.sh b/inflight_batcher_llm/scripts/build.sh new file mode 100644 index 00000000..f7dc016c --- /dev/null +++ b/inflight_batcher_llm/scripts/build.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +TRT_ROOT=${1:-'/usr/local/tensorrt'} + +set -x +apt-get update +apt-get install -y --no-install-recommends rapidjson-dev + +BUILD_DIR=$(dirname $0)/../build +mkdir $BUILD_DIR +BUILD_DIR=$(cd -- "$BUILD_DIR" && pwd) +cd $BUILD_DIR + +cmake -DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install \ + -DTRT_LIB_DIR=${TRT_ROOT}/targets/x86_64-linux-gnu/lib \ + -DTRT_INCLUDE_DIR=${TRT_ROOT}/include .. +make install diff --git a/inflight_batcher_llm/src/inflight_batcher_llm.cc b/inflight_batcher_llm/src/inflight_batcher_llm.cc new file mode 100644 index 00000000..181e9526 --- /dev/null +++ b/inflight_batcher_llm/src/inflight_batcher_llm.cc @@ -0,0 +1,1080 @@ +// Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#define _GLIBCXX_USE_CXX11_ABI 0 +#include +#include +#include +#include +#include +#include +#include + +#include "triton/backend/backend_common.h" +#include "triton/backend/backend_input_collector.h" +#include "triton/backend/backend_model.h" +#include "triton/backend/backend_model_instance.h" +#include "triton/backend/backend_output_responder.h" +#include "triton/core/tritonbackend.h" +#include "triton/core/tritonserver.h" + +#include "tensorrt_llm/batch_manager/NamedTensor.h" +#include "tensorrt_llm/batch_manager/callbacks.h" +#include "tensorrt_llm/batch_manager/inferenceRequest.h" +#include "tensorrt_llm/batch_manager/GptManager.h" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/runtime/tllmLogger.h" + +#include +#include + +#include "mpiUtils.h" + +using namespace ::triton::common; // TritonJson + +// +// Mockup of LLM inflight batcher based on triton 'minimal' backend example +// + +using namespace tensorrt_llm::batch_manager; +using namespace tensorrt_llm::runtime; +using namespace std::placeholders; // for _1, _2 etc. +//template class inflight_batcher::batch_manager::GPTManager; + +namespace triton { namespace backend { namespace inflight_batcher_llm { + +inline static const std::string kStopInputTensorName = "stop"; +inline static const std::string kStreamingInputTensorName = "streaming"; + +bool +getRequestBooleanInputTensor(TRITONBACKEND_Request* request, const std::string& inputTensorName) +{ + // Get stop signal from the request + TRITONBACKEND_Input* input; + TRITONSERVER_Error * error = TRITONBACKEND_RequestInput(request, inputTensorName.c_str(), &input); + if (error) + { + // If the user does not provide input "stop", then regard the request as unstopped + std::string msg = "ModelInstanceState::getRequestBooleanInputTensor: user did not not provide " + inputTensorName + " input for the request"; + LOG_MESSAGE( + TRITONSERVER_LOG_VERBOSE, msg.c_str()); + return false; + } + + uint64_t input_byte_size = 0; + uint32_t buffer_count = 0; + TRITONBACKEND_InputProperties( + input, nullptr, nullptr, nullptr, nullptr, + &input_byte_size, &buffer_count); + + LOG_MESSAGE( + TRITONSERVER_LOG_VERBOSE, + ("ModelInstanceState::getRequestStopSignal: buffer_count = " + + std::to_string(buffer_count)).c_str()); + + const void* buffer = 0L; + uint64_t buffer_byte_size = 0; + TRITONSERVER_MemoryType memory_type; + int64_t memory_type_id = 0; + TRITONBACKEND_InputBuffer(input, 0, &buffer, &buffer_byte_size, &memory_type, &memory_type_id); + + assert((memory_type == TRITONSERVER_MEMORY_CPU) || (memory_type == TRITONSERVER_MEMORY_CPU_PINNED)); + + bool boolean = *reinterpret_cast(buffer); + + return boolean; +} + +nvinfer1::DataType to_trt_datatype(TRITONSERVER_DataType data_type) +{ + if (data_type == TRITONSERVER_TYPE_INVALID) { + assert(false); + } else if (data_type == TRITONSERVER_TYPE_BOOL) { + return nvinfer1::DataType::kBOOL; + } else if (data_type == TRITONSERVER_TYPE_UINT8) { + return nvinfer1::DataType::kUINT8; + } else if (data_type == TRITONSERVER_TYPE_UINT16) { + assert(false); + } else if (data_type == TRITONSERVER_TYPE_UINT32) { + return nvinfer1::DataType::kINT32; + } else if (data_type == TRITONSERVER_TYPE_UINT64) { + return nvinfer1::DataType::kINT64; + } else if (data_type == TRITONSERVER_TYPE_INT8) { + return nvinfer1::DataType::kINT8; + } else if (data_type == TRITONSERVER_TYPE_INT16) { + assert(false); + } else if (data_type == TRITONSERVER_TYPE_INT32) { + return nvinfer1::DataType::kINT32; + } else if (data_type == TRITONSERVER_TYPE_INT64) { + return nvinfer1::DataType::kINT64; + } else if (data_type == TRITONSERVER_TYPE_FP16) { + return nvinfer1::DataType::kBF16; + } else if (data_type == TRITONSERVER_TYPE_FP32) { + return nvinfer1::DataType::kFLOAT; + } else if (data_type == TRITONSERVER_TYPE_FP64) { + assert(false); + } else if (data_type == TRITONSERVER_TYPE_BYTES) { + return nvinfer1::DataType::kINT8; + } else if (data_type == TRITONSERVER_TYPE_BF16) { + return nvinfer1::DataType::kBF16; + } else { + assert(false); + } + return nvinfer1::DataType(0); +} + +TRITONSERVER_DataType to_triton_datatype(nvinfer1::DataType data_type) +{ + if (data_type == nvinfer1::DataType::kBOOL) { + return TRITONSERVER_TYPE_BOOL; + } else if (data_type == nvinfer1::DataType::kUINT8) { + return TRITONSERVER_TYPE_UINT8; + } else if (data_type == nvinfer1::DataType::kHALF) { + return TRITONSERVER_TYPE_BF16; + } else if (data_type == nvinfer1::DataType::kINT8) { + return TRITONSERVER_TYPE_INT8; + } else if (data_type == nvinfer1::DataType::kINT32) { + return TRITONSERVER_TYPE_INT32; + } else if (data_type == nvinfer1::DataType::kINT64) { + return TRITONSERVER_TYPE_INT64; + } else if (data_type == nvinfer1::DataType::kFLOAT) { + return TRITONSERVER_TYPE_FP32; + } else if (data_type == nvinfer1::DataType::kBF16) { + return TRITONSERVER_TYPE_BF16; + } else { + return TRITONSERVER_TYPE_INVALID; + } +} + + +///////////// + +// +// ModelState +// +// State associated with a model that is using this backend. An object +// of this class is created and associated with each +// TRITONBACKEND_Model. ModelState is derived from BackendModel class +// provided in the backend utilities that provides many common +// functions. +// +class ModelState : public BackendModel { + public: + static TRITONSERVER_Error* Create( + TRITONBACKEND_Model* triton_model, ModelState** state); + + template + T GetParameter(const std::string& name) { + assert(false); + } + + virtual ~ModelState() = default; + + common::TritonJson::Value& GetModelConfig(); + + private: + + common::TritonJson::Value model_config_; + std::shared_ptr mTrtLogger{}; + + ModelState(TRITONBACKEND_Model* triton_model, TritonJson::Value&& model_config) : BackendModel(triton_model, true), model_config_(std::move(model_config)) + { + mTrtLogger = std::make_shared(); + initLibNvInferPlugins(mTrtLogger.get(), "tensorrt_llm"); + } +}; + + +TRITONSERVER_Error* +ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state) +{ + TRITONSERVER_Message* config_message; + RETURN_IF_ERROR(TRITONBACKEND_ModelConfig( + triton_model, 1 /* config_version */, &config_message)); + + + // We can get the model configuration as a json string from + // config_message, parse it with our favorite json parser to create + // DOM that we can access when we need to example the + // configuration. We use TritonJson, which is a wrapper that returns + // nice errors (currently the underlying implementation is + // rapidjson... but others could be added). You can use any json + // parser you prefer. + const char* buffer; + size_t byte_size; + RETURN_IF_ERROR( + TRITONSERVER_MessageSerializeToJson(config_message, &buffer, &byte_size)); + + common::TritonJson::Value model_config; + TRITONSERVER_Error* err = model_config.Parse(buffer, byte_size); + RETURN_IF_ERROR(TRITONSERVER_MessageDelete(config_message)); + RETURN_IF_ERROR(err); + + + try { + *state = new ModelState(triton_model, std::move(model_config)); + } + catch (const BackendModelException& ex) { + RETURN_ERROR_IF_TRUE( + ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL, + std::string("unexpected nullptr in BackendModelException")); + RETURN_IF_ERROR(ex.err_); + } + + return nullptr; // success +} + +common::TritonJson::Value& +ModelState::GetModelConfig() { return model_config_; } + +template <> +std::string +ModelState::GetParameter( + const std::string& name) +{ + //TODO: Error handling + TritonJson::Value parameters; + model_config_.MemberAsObject("parameters", ¶meters); + TritonJson::Value value; + std::string str_value; + parameters.MemberAsObject(name.c_str(), &value); + value.MemberAsString("string_value", &str_value); + return str_value; +} + +template <> +int32_t +ModelState::GetParameter(const std::string& name) +{ + return std::stoi(GetParameter(name)); +} + +template <> +uint32_t +ModelState::GetParameter(const std::string& name) +{ + return (uint32_t)std::stoul(GetParameter( name)); +} + +template <> +int64_t +ModelState::GetParameter(const std::string& name) +{ + return std::stoll(GetParameter(name)); +} + +template <> +uint64_t +ModelState::GetParameter(const std::string& name) +{ + return std::stoull(GetParameter(name)); +} + +extern "C" { + +// Triton calls TRITONBACKEND_ModelInitialize when a model is loaded +// to allow the backend to create any state associated with the model, +// and to also examine the model configuration to determine if the +// configuration is suitable for the backend. Any errors reported by +// this function will prevent the model from loading. +// +TRITONSERVER_Error* +TRITONBACKEND_ModelInitialize(TRITONBACKEND_Model* model) +{ + // Create a ModelState object and associate it with the + // TRITONBACKEND_Model. If anything goes wrong with initialization + // of the model state then an error is returned and Triton will fail + // to load the model. + ModelState* model_state; + RETURN_IF_ERROR(ModelState::Create(model, &model_state)); + RETURN_IF_ERROR( + TRITONBACKEND_ModelSetState(model, reinterpret_cast(model_state))); + + return nullptr; // success +} + +// Triton calls TRITONBACKEND_ModelFinalize when a model is no longer +// needed. The backend should cleanup any state associated with the +// model. This function will not be called until all model instances +// of the model have been finalized. +// +TRITONSERVER_Error* +TRITONBACKEND_ModelFinalize(TRITONBACKEND_Model* model) +{ + void* vstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vstate)); + ModelState* model_state = reinterpret_cast(vstate); + delete model_state; + + return nullptr; // success +} + +} // extern "C" + +///////////// + +// Class holding all infos regarding a single work item. +// This includes the original request, associated response factor +// and state. +class WorkItem +{ + public: + WorkItem(TRITONBACKEND_Request* request, bool isDecoupled) + { + mRequestId = (rand() % INT64_MAX) + 1; + mInferenceRequest = createInferenceRequest(request, mRequestId, isDecoupled); + + // Create response factory for this request + TRITONBACKEND_ResponseFactoryNew(&factory_ptr_, request); + } + + WorkItem(TRITONBACKEND_Request* request, uint64_t request_id, bool isDecoupled) + : mRequestId(request_id) + { + mInferenceRequest = createInferenceRequest(request, mRequestId, isDecoupled); + + // Create response factory for this request + TRITONBACKEND_ResponseFactoryNew(&factory_ptr_, request); + } + + WorkItem(std::shared_ptr ir, + uint64_t RequestId) + : mInferenceRequest(ir), + mRequestId(RequestId) + { + factory_ptr_ = nullptr; + } + + ~WorkItem() + { + if (factory_ptr_ != nullptr) + { + TRITONBACKEND_ResponseFactoryDelete(factory_ptr_); + } + } + + TRITONBACKEND_ResponseFactory* response_factory() + { + assert(factory_ptr_ != nullptr); + return factory_ptr_; + } + + uint64_t requestId() const + { + return mRequestId; + } + + std::shared_ptr getInferenceRequest() const + { + return mInferenceRequest; + } + + private: + + // Convert info from original backend request to data structures defined in common/common.h + std::shared_ptr createInferenceRequest(TRITONBACKEND_Request* request, uint64_t requestId, bool isDecoupled) + { + auto inferenceRequest = std::make_shared(requestId); + + // Extract input tensors + std::map input_tensors; + uint32_t num_inputs; + LOG_IF_ERROR(TRITONBACKEND_RequestInputCount(request, &num_inputs), "Error getting input count"); + for (uint32_t idx = 0; idx < num_inputs; ++idx) + { + TRITONBACKEND_Input* input = 0L; + TRITONBACKEND_RequestInputByIndex(request, idx, &input); + + const char* input_name = 0L; + TRITONSERVER_DataType data_type = TRITONSERVER_TYPE_INVALID; + const int64_t* shape = 0L; + uint32_t dims_count = 0; + uint64_t byte_size = 0; + uint32_t buffer_count = 0; + TRITONBACKEND_InputProperties(input, &input_name, &data_type, &shape, &dims_count, &byte_size, &buffer_count); + + if (std::string(input_name) == "START" || std::string(input_name) == "CORRID" || std::string(input_name) == "END" || + std::string(input_name) == kStopInputTensorName || std::string(input_name) == kStreamingInputTensorName) + { + continue; + } + + std::vector shapev; + for (uint32_t i = 0; i < dims_count; ++i) { + shapev.push_back(shape[i]); + } + + NamedTensor t(to_trt_datatype(data_type), shapev, input_name); + uint64_t buffer_offset = 0; + for (int64_t buffer_id=0; buffer_id < buffer_count; ++buffer_id) + { + const void* buffer = 0L; + uint64_t buffer_byte_size = 0; + TRITONSERVER_MemoryType memory_type; + int64_t memory_type_id = 0; + TRITONBACKEND_InputBuffer(input, buffer_id, &buffer, &buffer_byte_size, &memory_type, &memory_type_id); + assert((memory_type == TRITONSERVER_MEMORY_CPU) || (memory_type == TRITONSERVER_MEMORY_CPU_PINNED)); + // TODO: Do we need to handle GPU mem input buffers?? + std::memcpy(static_cast(t.tensor->data()) + buffer_offset, buffer, buffer_byte_size); + buffer_offset += buffer_byte_size; + } + + inferenceRequest->emplaceInputTensor(t.name, std::move(t.tensor)); + } + + bool streamingFlag = getRequestBooleanInputTensor(request, kStreamingInputTensorName); + inferenceRequest->setIsStreaming(streamingFlag); + + if (streamingFlag && !isDecoupled) { + throw std::runtime_error("Streaming is only supported if model is deployed using decoupled mode."); + } + + return inferenceRequest; + } + + std::shared_ptr mInferenceRequest; + TRITONBACKEND_ResponseFactory* factory_ptr_; + uint64_t mRequestId; +}; + +/// @brief Thread-safe queue of work items + +class WorkItemsQueue +{ + public: + void clear() + { + std::lock_guard lk(mMutex); + mPendingWorkItems.clear(); + mPendingWorkItemsReqIds.clear(); + mInProgressWorkItems.clear(); + mStoppedReqIds.clear(); + + } + + // Note: this function only be called under a lock + bool hasInProgressReqId(const uint64_t reqId) const { + return (mInProgressWorkItems.find(reqId) != mInProgressWorkItems.end()); + } + + // Note: this function only be called under a lock + bool hasPendingReqId(const uint64_t reqId) const { + return (mPendingWorkItemsReqIds.find(reqId) != mPendingWorkItemsReqIds.end()); + } + + /// @brief Add a new work item to the queue + /// Throws an error if requestId already exists + + void push(TRITONBACKEND_Request* request, uint64_t requestId, bool isDecoupled) + { + std::lock_guard lk(mMutex); + if (hasInProgressReqId(requestId) || hasPendingReqId(requestId)) + { + std::string errStr = "requestId " + std::to_string(requestId) + " is already in progress, request is ignored."; + throw std::runtime_error(errStr); + } + else + { + auto workItem = std::make_shared(request, requestId, isDecoupled); + mPendingWorkItems.push_back(workItem); + mPendingWorkItemsReqIds.insert(workItem->requestId()); + } + } + + void push(TRITONBACKEND_Request* request, bool isDecoupled) + { + std::lock_guard lk(mMutex); + auto workItem = std::make_shared(request, isDecoupled); + mPendingWorkItems.push_back(workItem); + mPendingWorkItemsReqIds.insert(workItem->requestId()); + } + + /// @brief Get a new work item from the queue, and move it to the list of + /// in progress work items if it hasn't been stopped + /// @return A tuple of the workItem and a boolean flag indicating if the work item + /// has been marked in progress + std::tuple, bool> pop() + { + std::lock_guard lk(mMutex); + + auto workItem = mPendingWorkItems.front(); + mPendingWorkItems.pop_front(); + mPendingWorkItemsReqIds.erase(workItem->requestId()); + + bool markedInProgress; + // Check if work item has been stopped + if (mStoppedReqIds.find(workItem->requestId()) == mStoppedReqIds.end()) + { + mInProgressWorkItems.emplace(std::make_pair(workItem->requestId(), workItem)); + markedInProgress = true; + } else { + mStoppedReqIds.erase(workItem->requestId()); + markedInProgress = false; + } + + return {workItem, markedInProgress}; + } + + size_t numPendingWorkItems() const + { + std::lock_guard lk(mMutex); + return mPendingWorkItems.size(); + } + + std::shared_ptr getInProgressWorkItem(uint64_t requestId) + { + std::lock_guard lk(mMutex); + return mInProgressWorkItems.at(requestId); + } + + /// @brief Mark a request as being finished + /// @param requestId + void markFinished(const uint64_t requestId) + { + std::lock_guard lk(mMutex); + if (hasInProgressReqId(requestId)) + { + mInProgressWorkItems.erase(requestId); + } + + if (mStoppedReqIds.find(requestId) != mStoppedReqIds.end()) + { + mStoppedReqIds.erase(requestId); + } + } + + // Stop a request by adding the request Id to a set + // The set of stopped request id is used by the poll callback + // and the pop function + void stopWorkItem(const uint64_t requestId) + { + std::lock_guard lk(mMutex); + TLLM_LOG_DEBUG("Stopping request"); + if (hasInProgressReqId(requestId) || hasPendingReqId(requestId)) + { + mStoppedReqIds.emplace(requestId); + } + else + { + std::string errStr = std::string("Received stop request for requestId ") + std::to_string(requestId) + std::string(" but it's not active (might be completed already)."); + throw std::runtime_error(errStr); + } + } + + std::unordered_set getStoppedReqIds() const + { + std::lock_guard lk(mMutex); + return mStoppedReqIds; + } + + private: + /// Queue of work items + std::list> mPendingWorkItems; + /// requestIds of work items in the queue + std::set mPendingWorkItemsReqIds; + + /// work items currently in progress + std::unordered_map> mInProgressWorkItems; + + /// ids of the work items that have been stopped + std::unordered_set mStoppedReqIds; + + mutable std::mutex mMutex; +}; + +// +// ModelInstanceState +// +// State associated with a model instance. An object of this class is +// created and associated with each +// TRITONBACKEND_ModelInstance. ModelInstanceState is derived from +// BackendModelInstance class provided in the backend utilities that +// provides many common functions. +// +class ModelInstanceState : public BackendModelInstance { + public: + static TRITONSERVER_Error* Create( + ModelState* model_state, + TRITONBACKEND_ModelInstance* triton_model_instance, + ModelInstanceState** state); + virtual ~ModelInstanceState() + { + // terminate decoupled execution loop + { + mWorkItemsQueue.clear(); + } + } + + // Get the state of the model that corresponds to this instance. + ModelState* StateForModel() const { return model_state_; } + + bool isDecoupled() const + { + return mIsDecoupled; + } + + uint64_t getRequestId(TRITONBACKEND_Request* request) + { + const char* charRequestId; + TRITONBACKEND_RequestId(request, &charRequestId); + uint64_t requestId = 0; + if (charRequestId != nullptr) + { + std::string strRequestId(charRequestId); + if (!strRequestId.empty()) { + try { + requestId = stoul(strRequestId); + } catch (const std::exception& e) { + std::string err = std::string("Invalid requestId, must be uint64_t. Got ") + strRequestId; + throw std::runtime_error(err); + } + } + } + + return requestId; + } + + // For stop requests, or in case of error during enqueue, we need to send a response to the client + void sendEnqueueResponse(TRITONBACKEND_Request* request, const std::string& errMsg = "") + { + TRITONBACKEND_ResponseFactory* factory_ptr; + // Create response factory for this request + LOG_IF_ERROR(TRITONBACKEND_ResponseFactoryNew(&factory_ptr, request), "Cannot create response factory"); + + TRITONSERVER_Error* err = nullptr; + if (!errMsg.empty()) + { + TLLM_LOG_ERROR(errMsg); + err = TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, errMsg.c_str()); + } + TRITONBACKEND_Response* response; + LOG_IF_ERROR(TRITONBACKEND_ResponseNewFromFactory(&response, factory_ptr), "Cannot create response"); + LOG_IF_ERROR(TRITONBACKEND_ResponseSend(response, TRITONSERVER_RESPONSE_COMPLETE_FINAL, err), "Cannot send response"); + LOG_IF_ERROR(TRITONBACKEND_ResponseFactoryDelete(factory_ptr), "Cannot delete response factory"); + } + + void enqueue(TRITONBACKEND_Request** requests, const uint32_t request_count, bool isDecoupled) + { + for (uint32_t r = 0; r < request_count; ++r) { + + TRITONBACKEND_Request* request = requests[r]; + try { + auto requestId = getRequestId(request); + bool stopRequest = getRequestBooleanInputTensor(request, kStopInputTensorName); + + if (requestId != 0) + { + if (stopRequest) + { + // Check if request is in progress or in queue, if not ignore + mWorkItemsQueue.stopWorkItem(requestId); + // Send a response back to client for stop request + sendEnqueueResponse(request); + } + else + { + mWorkItemsQueue.push(request, requestId, isDecoupled); + } + + } + else if (!stopRequest) + { + mWorkItemsQueue.push(request, isDecoupled); + } + else + { + throw std::runtime_error("Cannot send stop request without specifying a request_id"); + } + + } catch (const std::exception& e) { + // In case of error, no work item is added to queue, so response callback needs to be called + sendEnqueueResponse(request, e.what()); + } + } + return; + } + + // Return up to max_num_requests inference requests. + std::list> get_inference_requests(const int max_num_requests) + { + std::list> rval; + if (max_num_requests > 0) + { + auto world_size = getCommWorldSize(); + auto rank = getCommWorldRank(); + if (rank == 0) + { + int64_t num_new_work_items = std::min(static_cast(mWorkItemsQueue.numPendingWorkItems()), static_cast(max_num_requests)); + if (world_size > 1) + { + bcast(&num_new_work_items, 1, MPI_TYPE_INT64_T, 0); + } + + if (num_new_work_items > 0) + { + int count = 0; + while (count < num_new_work_items) + { + auto [workItem, markedInProgress] = mWorkItemsQueue.pop(); + + if (markedInProgress) { + rval.emplace_back(workItem->getInferenceRequest()); + count++; + } else { + std::string warnStr = + std::string("request Id ") + std::to_string(workItem->requestId()) + + std::string(" has been stopped. Request is ignored."); + TLLM_LOG_WARNING(warnStr); + sendTritonResponse(workItem, {}, true, warnStr); + } + } + if (world_size > 1) + { + std::vector packed; + for (auto ir : rval) + { + auto vpacked = ir->serialize(); + packed.push_back(static_cast(vpacked.size())); + packed.insert(packed.end(), + std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end())); + } + int64_t nWords1 = static_cast(packed.size()); + bcast(&nWords1, 1, MPI_TYPE_INT64_T, 0); + bcast(packed, 0); + } + } + } + else + { + // subordinate ranks hang until master rank sends work + int64_t num_new_work_items; + bcast(&num_new_work_items, 1, MPI_TYPE_INT64_T, 0); + if (num_new_work_items > 0) + { + int nWords1; + bcast(&nWords1, 1, MPI_TYPE_INT64_T, 0); + std::vector packed(nWords1); + bcast(packed, 0); + int64_t* packed_ptr = packed.data(); + for (int64_t count = 0; count < num_new_work_items; ++count) + { + int64_t n = *(packed_ptr++); + auto ir = InferenceRequest::deserialize(packed_ptr); + packed_ptr += n; + rval.emplace_back(ir); + } + } + } + } + return rval; + } + + TRITONSERVER_Error* sendTritonResponse( + std::shared_ptr workItem, + std::list const& response_tensors, + bool final_response, const std::string& errMsg) + { + TRITONBACKEND_ResponseFactory* response_factory; + response_factory = workItem->response_factory(); + + TRITONBACKEND_Response* response; + RETURN_IF_ERROR(TRITONBACKEND_ResponseNewFromFactory(&response, response_factory)); + + auto requestId = workItem->requestId(); + if (final_response) + { + mWorkItemsQueue.markFinished(requestId); + } + + // Check if error + TRITONSERVER_Error* err = nullptr; + if (!errMsg.empty()) { + std::string errStr = "Encountered error for requestId " + std::to_string(requestId) + ": " + errMsg; + TLLM_LOG_ERROR(errStr); + + err = TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, errStr.c_str()); + final_response = true; + } else { + for (auto it = response_tensors.begin(); it != response_tensors.end(); ++it) + { + auto tensor = *it; + auto shape = tensor.tensor->getShape(); // returns std::vectorint64_t> + std::vector vshape(shape.nbDims); + for (int i = 0; i < vshape.size(); ++i) { + vshape[i] = shape.d[i]; + } + + TRITONBACKEND_Output* output; + RETURN_IF_ERROR(TRITONBACKEND_ResponseOutput( + response, &output, tensor.name.c_str(), to_triton_datatype(tensor.tensor->getDataType()), + vshape.data(), shape.nbDims)); + + uint64_t buffersize = tensor.tensor->getSizeInBytes(); + void* buffer = 0L; + TRITONSERVER_MemoryType memory_type = TRITONSERVER_MEMORY_CPU; + int64_t memory_type_id = 0; + RETURN_IF_ERROR(TRITONBACKEND_OutputBuffer(output, &buffer, buffersize, &memory_type, &memory_type_id)); + if (memory_type != TRITONSERVER_MEMORY_CPU && memory_type != TRITONSERVER_MEMORY_CPU_PINNED) + { + std::string errStr = "Triton failed to allocate output buffer on CPU"; + err = TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, errStr.c_str()); + break; + } + std::memcpy(buffer, tensor.tensor->data(), buffersize); + } + } + + RETURN_IF_ERROR( + TRITONBACKEND_ResponseSend( + response, final_response ? TRITONSERVER_RESPONSE_COMPLETE_FINAL : 0, err)); + + return nullptr; + } + + void sendResponse(uint64_t requestId, std::list const& response_tensors, bool final_response, const std::string& errMsg) + { + if (getCommWorldRank() == 0) + { + std::string errStr = std::string("Failed to send Triton response for requestId: ") + std::to_string(requestId); + try { + auto workItem = mWorkItemsQueue.getInProgressWorkItem(requestId); + auto tritonErr = sendTritonResponse(workItem, response_tensors, final_response, errMsg); + LOG_IF_ERROR(tritonErr, errStr); + } catch (const std::exception& e) { + TLLM_LOG_ERROR(errStr); + } + } + } + + std::unordered_set pollStopSignals() + { + auto stoppedReqIds = mWorkItemsQueue.getStoppedReqIds(); + int64_t nStoppedReqIds = static_cast(stoppedReqIds.size()); + + if (getCommWorldSize() > 1) + { + // Broadcast number of stopped requests + bcast(&nStoppedReqIds, 1, MPI_TYPE_INT64_T, 0); + + if (nStoppedReqIds > 0) + { + // Broadcast stopped requests Ids + if (getCommWorldRank() == 0) + { + // Store the requestIds in a contiguous vector + std::vector stoppedReqIdsVec(stoppedReqIds.begin(), stoppedReqIds.end()); + bcast(stoppedReqIdsVec.data(), stoppedReqIdsVec.size(), MPI_TYPE_UINT64_T, 0); + } else { + std::vector stoppedReqIdsVec(nStoppedReqIds); + bcast(stoppedReqIdsVec.data(), stoppedReqIdsVec.size(), MPI_TYPE_UINT64_T, 0); + // Store the requestIds in the set + stoppedReqIds.clear(); + std::copy(stoppedReqIdsVec.begin(), stoppedReqIdsVec.end(), std::inserter(stoppedReqIds, stoppedReqIds.end())); + } + } + } + return stoppedReqIds; + } + + private: + ModelInstanceState( + ModelState* model_state, + TRITONBACKEND_ModelInstance* triton_model_instance) + : BackendModelInstance(model_state, triton_model_instance), + model_state_(model_state), mIsDecoupled(false) + { + // Note: std::string::compare fails this test (always return non-zero value). + // Using old school strcmp instead. + if (model_state_->GetParameter("gpt_model_type") == "V1" || + model_state_->GetParameter("gpt_model_type") == "v1") + { + mTrtGptModelType = TrtGptModelType::V1; + } + else if (model_state_->GetParameter("gpt_model_type") == "inflight_batching") + { + mTrtGptModelType = TrtGptModelType::InflightBatching; + } + else if (model_state_->GetParameter("gpt_model_type") == "inflight_fused_batching") + { + mTrtGptModelType = TrtGptModelType::InflightFusedBatching; + } + else + { + throw std::runtime_error("Invalid gpt_model_type. Must be v1/inflight_batching/inflight_fused_batching."); + } + + // Check if model is in decoupled mode: + triton::common::TritonJson::Value transaction_policy; + model_state_->GetModelConfig().MemberAsObject("model_transaction_policy", &transaction_policy); + transaction_policy.MemberAsBool("decoupled", &mIsDecoupled); + + // Note: std::string::compare fails this test (always return non-zero value). + // Using old school strcmp instead. + mModelPath = model_state_->GetParameter("gpt_model_path"); + auto configPath = mModelPath + "/config.json"; + std::ifstream jsonStream(configPath); + + auto constexpr allowExceptions = true; + auto constexpr ingoreComments = true; + auto json = nlohmann::json::parse(jsonStream, nullptr, allowExceptions, ingoreComments); + + auto const& builderConfig = json.at("builder_config"); + int maxInputLen = builderConfig.at("max_input_len"); + int maxOutputLen = builderConfig.at("max_output_len"); + int maxSeqLen = maxInputLen + maxOutputLen; + int maxNumRequests = builderConfig.at("max_batch_size"); + int32_t maxBeamWidth = model_state_->GetParameter("max_beam_width"); + + mBatchManager = std::make_shared(mModelPath, mTrtGptModelType, maxSeqLen, maxNumRequests, maxBeamWidth, + [this](int max_num_requests){return get_inference_requests(max_num_requests);}, + [this](uint64_t requestId, std::list const& response_tensors, bool final_response, const std::string& errMsg){return sendResponse(requestId, response_tensors, final_response, errMsg);}, + [this](){return pollStopSignals();}); + + if (getCommWorldRank() != 0) + { + while (true) {} + } + } + + ModelState* model_state_; + + // + // inflight batcher is a decoupled design. + // It uses response factory objects to decouple responses from requests. + // + // New requests are added to mWorkItems list. This list is processed + // in an infinite loop run by a worker thread. Requests take multiple + // iterations to complete, and number of iterations is not known in + // advance. To facilitate this, we use response factory objects to + // decouple requests and responses. + // + TrtGptModelType mTrtGptModelType; + std::string mModelPath; + bool mIsDecoupled; + + std::shared_ptr mBatchManager; + + WorkItemsQueue mWorkItemsQueue; +}; + +TRITONSERVER_Error* +ModelInstanceState::Create( + ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance, + ModelInstanceState** state) +{ + try { + *state = new ModelInstanceState(model_state, triton_model_instance); + } + catch (const BackendModelInstanceException& ex) { + RETURN_ERROR_IF_TRUE( + ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL, + std::string("unexpected nullptr in BackendModelInstanceException")); + RETURN_IF_ERROR(ex.err_); + } + + return nullptr; // success +} + +extern "C" { + +// Triton calls TRITONBACKEND_ModelInstanceInitialize when a model +// instance is created to allow the backend to initialize any state +// associated with the instance. +// +TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceInitialize(TRITONBACKEND_ModelInstance* instance) +{ + // Get the model state associated with this instance's model. + TRITONBACKEND_Model* model; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceModel(instance, &model)); + + void* vmodelstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vmodelstate)); + ModelState* model_state = reinterpret_cast(vmodelstate); + + // Create a ModelInstanceState object and associate it with the + // TRITONBACKEND_ModelInstance. + ModelInstanceState* instance_state; + RETURN_IF_ERROR( + ModelInstanceState::Create(model_state, instance, &instance_state)); + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceSetState( + instance, reinterpret_cast(instance_state))); + + return nullptr; // success +} + +// Triton calls TRITONBACKEND_ModelInstanceFinalize when a model +// instance is no longer needed. The backend should cleanup any state +// associated with the model instance. +// +TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceFinalize(TRITONBACKEND_ModelInstance* instance) +{ + void* vstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState(instance, &vstate)); + ModelInstanceState* instance_state = + reinterpret_cast(vstate); + delete instance_state; + + return nullptr; // success +} + +} // extern "C" + +///////////// + +extern "C" { + +// When Triton calls TRITONBACKEND_ModelInstanceExecute it is required +// that a backend create a response for each request in the batch. A +// response may be the output tensors required for that request or may +// be an error that is returned in the response. +// +TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceExecute( + TRITONBACKEND_ModelInstance* instance, TRITONBACKEND_Request** requests, + const uint32_t request_count) +{ + ModelInstanceState* instance_state; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState( + instance, reinterpret_cast(&instance_state))); + + auto isDecoupled = instance_state->isDecoupled(); + + instance_state->enqueue(requests, request_count, isDecoupled); + + for (uint32_t r = 0; r < request_count; ++r) { + TRITONBACKEND_Request* request = requests[r]; + TRITONBACKEND_RequestRelease(request, TRITONSERVER_REQUEST_RELEASE_ALL); + } + + return nullptr; // success +} + +} // extern "C" + + +}}} // namespace triton::backend::minimal diff --git a/inflight_batcher_llm/src/libtriton_inflight_batcher_llm.ldscript b/inflight_batcher_llm/src/libtriton_inflight_batcher_llm.ldscript new file mode 100644 index 00000000..748714d1 --- /dev/null +++ b/inflight_batcher_llm/src/libtriton_inflight_batcher_llm.ldscript @@ -0,0 +1,30 @@ +# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +{ + global: + TRITONBACKEND_*; + local: *; +}; diff --git a/inflight_batcher_llm/src/mpiUtils.h b/inflight_batcher_llm/src/mpiUtils.h new file mode 100644 index 00000000..41dc0a6a --- /dev/null +++ b/inflight_batcher_llm/src/mpiUtils.h @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#define MPICHECK(cmd) \ + do \ + { \ + int e = cmd; \ + if (e != MPI_SUCCESS) \ + { \ + printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +enum MpiType +{ + MPI_TYPE_BYTE, + MPI_TYPE_CHAR, + MPI_TYPE_INT, + MPI_TYPE_INT64_T, + MPI_TYPE_UINT32_T, + MPI_TYPE_UINT64_T, + MPI_TYPE_UNSIGNED_LONG_LONG, +}; + +inline MPI_Datatype getMpiDtype(MpiType dtype) +{ + static const std::unordered_map dtype_map{ + {MPI_TYPE_BYTE, MPI_BYTE}, + {MPI_TYPE_CHAR, MPI_CHAR}, + {MPI_TYPE_INT, MPI_INT}, + {MPI_TYPE_INT64_T, MPI_INT64_T}, + {MPI_TYPE_UINT32_T, MPI_UINT32_T}, + {MPI_TYPE_UINT64_T, MPI_UINT64_T}, + {MPI_TYPE_UNSIGNED_LONG_LONG, MPI_UNSIGNED_LONG_LONG}, + }; + return dtype_map.at(dtype); +} + +inline int getCommWorldSize() +{ + int size; + MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &size)); + return size; +} + +inline int getCommWorldRank() +{ + int rank; + MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &rank)); + return rank; +} + +inline void barrier() +{ + MPICHECK(MPI_Barrier(MPI_COMM_WORLD)); +} + +inline void bcast(void* buffer, size_t size, MpiType dtype, int root) +{ + MPICHECK(MPI_Bcast(buffer, size, getMpiDtype(dtype), root, MPI_COMM_WORLD)); +} + +inline void bcast(std::vector& packed, int root) +{ + MPICHECK(MPI_Bcast(packed.data(), packed.size(), MPI_INT64_T, root, MPI_COMM_WORLD)); +} diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..a9f4082c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +regex +fire +tritonclient[all] +transformers==4.31.0 diff --git a/scripts/launch_triton_server.py b/scripts/launch_triton_server.py new file mode 100644 index 00000000..1009a81f --- /dev/null +++ b/scripts/launch_triton_server.py @@ -0,0 +1,32 @@ +import argparse +import subprocess +from pathlib import Path + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', + type=int, + default=1, + help='world size, only support tensor parallelism now') + parser.add_argument('--tritonserver', + type=str, + default='/opt/tritonserver/bin/tritonserver') + path = str(Path(__file__).parent.absolute()) + '/../all_models/gpt' + parser.add_argument('--model_repo', type=str, default=path) + return parser.parse_args() + + +def get_cmd(world_size, tritonserver, model_repo): + cmd = 'mpirun --allow-run-as-root ' + for i in range(world_size): + cmd += ' -n 1 {} --model-repository={} --disable-auto-complete-config --backend-config=python,shm-region-prefix-name=prefix{}_ : '.format( + tritonserver, model_repo, i) + cmd += '&' + return cmd + + +if __name__ == '__main__': + args = parse_arguments() + cmd = get_cmd(int(args.world_size), args.tritonserver, args.model_repo) + subprocess.call(cmd, shell=True) diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tools/fill_template.py b/tools/fill_template.py new file mode 100644 index 00000000..cb298b31 --- /dev/null +++ b/tools/fill_template.py @@ -0,0 +1,38 @@ +#! /usr/bin/env python3 +from argparse import ArgumentParser +from string import Template + + +def main(file_path, substitutions, in_place): + with open(file_path) as f: + pbtxt = Template(f.read()) + + sub_dict = {} + for sub in substitutions.split(","): + key, value = sub.split(":") + sub_dict[key] = value + + pbtxt = pbtxt.safe_substitute(sub_dict) + + if in_place: + with open(file_path, "w") as f: + f.write(pbtxt) + else: + print(pbtxt) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("file_path", help="path of the .pbtxt to modify") + parser.add_argument( + "substitutions", + help= + "substitions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..." + ) + parser.add_argument("--in_place", + "-i", + action="store_true", + help="do the operation in-place") + args = parser.parse_args() + + main(**vars(args)) diff --git a/tools/gpt/client.py b/tools/gpt/client.py new file mode 100644 index 00000000..7bd44af0 --- /dev/null +++ b/tools/gpt/client.py @@ -0,0 +1,137 @@ +#!/usr/bin/python + +import os +import sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) +import argparse +from datetime import datetime + +import numpy as np +from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer +from utils import utils + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-v', + '--verbose', + action="store_true", + required=False, + default=False, + help='Enable verbose output') + parser.add_argument('-u', + '--url', + type=str, + required=False, + help='Inference server URL.') + parser.add_argument( + '-i', + '--protocol', + type=str, + required=False, + default='http', + help='Protocol ("http"/"grpc") used to ' + + 'communicate with inference service. Default is "http".') + parser.add_argument( + '-t', + '--text', + type=str, + required=False, + default='Born in north-east France, Soyer trained as a', + help='Input text') + parser.add_argument('-c', + '--concurrency', + type=int, + default=1, + required=False, + help='Specify concurrency') + parser.add_argument('-beam', + '--beam_width', + type=int, + default=1, + required=False, + help='Specify beam width') + parser.add_argument('-topk', + '--topk', + type=int, + default=1, + required=False, + help='topk for sampling') + parser.add_argument('-topp', + '--topp', + type=float, + default=0.0, + required=False, + help='topp for sampling') + parser.add_argument('-o', + '--output_len', + type=int, + default=10, + required=False, + help='Specify output length') + parser.add_argument('--tokenizer_dir', + type=str, + required=True, + help='Specify tokenizer directory') + parser.add_argument('--tokenizer_type', + type=str, + default='auto', + required=False, + choices=['auto', 't5', 'llama'], + help='Specify tokenizer type') + + FLAGS = parser.parse_args() + if (FLAGS.protocol != "http") and (FLAGS.protocol != "grpc"): + print( + "unexpected protocol \"{}\", expects \"http\" or \"grpc\"".format( + FLAGS.protocol)) + exit(1) + + if FLAGS.url is None: + FLAGS.url = "localhost:8000" if FLAGS.protocol == "http" else "localhost:8001" + + if FLAGS.tokenizer_type == 't5': + tokenizer = T5Tokenizer(vocab_file=FLAGS.tokenizer_dir, + padding_side='left') + elif FLAGS.tokenizer_type == 'auto': + tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer_dir, + padding_side='left') + elif FLAGS.tokenizer_type == 'llama': + tokenizer = LlamaTokenizer.from_pretrained(FLAGS.tokenizer_dir, + legacy=False, + padding_side='left') + else: + raise AttributeError( + f'Unexpected tokenizer type: {FLAGS.tokenizer_type}') + tokenizer.pad_token = tokenizer.eos_token + pad_id = tokenizer.encode(tokenizer.pad_token, add_special_tokens=False)[0] + end_id = tokenizer.encode(tokenizer.eos_token, add_special_tokens=False)[0] + + line = tokenizer.encode(FLAGS.text) + input_start_ids = np.array([line], np.int32) + input_len = np.array([[len(line)]], np.int32) + inputs = utils.prepare_inputs(input_start_ids, input_len, pad_id, end_id, + FLAGS) + + start_time = datetime.now() + + with utils.create_inference_server_client(FLAGS.protocol, + FLAGS.url, + concurrency=FLAGS.concurrency, + verbose=FLAGS.verbose) as client: + results = utils.send_requests('tensorrt_llm', + inputs, + client, + request_parallelism=1) + output_ids = results[0].as_numpy("output_ids") + + stop_time = datetime.now() + latency = (stop_time - start_time).total_seconds() * 1000.0 + latency = round(latency, 3) + print(f"[INFO] Latency: {latency} ms") + + output_ids = output_ids.reshape( + (output_ids.size, )).tolist()[input_start_ids.shape[1]:] + output_text = tokenizer.decode(output_ids) + print(f'Input: {FLAGS.text}') + print(f'Output: {output_text}') diff --git a/tools/gpt/client_async.py b/tools/gpt/client_async.py new file mode 100644 index 00000000..f20c1458 --- /dev/null +++ b/tools/gpt/client_async.py @@ -0,0 +1,150 @@ +#!/usr/bin/python + +import os +import sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) +import argparse +from datetime import datetime + +import numpy as np +import tritonclient.grpc as grpcclient +import tritonclient.http as httpclient +from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer +from utils import utils + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-v', + '--verbose', + action="store_true", + required=False, + default=False, + help='Enable verbose output') + parser.add_argument('-u', + '--url', + type=str, + required=False, + help='Inference server URL.') + parser.add_argument( + '-i', + '--protocol', + type=str, + required=False, + default='http', + help='Protocol ("http"/"grpc") used to ' + + 'communicate with inference service. Default is "http".') + parser.add_argument( + '-t', + '--text', + type=str, + required=False, + default='Born in north-east France, Soyer trained as a', + help='Input text') + parser.add_argument('-c', + '--concurrency', + type=int, + default=1, + required=False, + help='Specify concurrency') + parser.add_argument('-beam', + '--beam_width', + type=int, + default=1, + required=False, + help='Specify beam width') + parser.add_argument('-topk', + '--topk', + type=int, + default=1, + required=False, + help='topk for sampling') + parser.add_argument('-topp', + '--topp', + type=float, + default=0.0, + required=False, + help='topp for sampling') + parser.add_argument('-o', + '--output_len', + type=int, + default=10, + required=False, + help='Specify output length') + parser.add_argument('--tokenizer_dir', + type=str, + required=True, + help='Specify tokenizer directory') + parser.add_argument('--tokenizer_type', + type=str, + default='auto', + required=False, + choices=['auto', 't5', 'llama'], + help='Specify tokenizer type') + + FLAGS = parser.parse_args() + if (FLAGS.protocol != "http") and (FLAGS.protocol != "grpc"): + print( + "unexpected protocol \"{}\", expects \"http\" or \"grpc\"".format( + FLAGS.protocol)) + exit(1) + + client_util = httpclient if FLAGS.protocol == "http" else grpcclient + if FLAGS.url is None: + FLAGS.url = "localhost:8000" if FLAGS.protocol == "http" else "localhost:8001" + + if FLAGS.tokenizer_type == 't5': + tokenizer = T5Tokenizer(vocab_file=FLAGS.tokenizer_dir, + padding_side='left') + elif FLAGS.tokenizer_type == 'auto': + tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer_dir, + padding_side='left') + elif FLAGS.tokenizer_type == 'llama': + tokenizer = LlamaTokenizer.from_pretrained(FLAGS.tokenizer_dir, + legacy=False, + padding_side='left') + else: + raise AttributeError( + f'Unexpected tokenizer type: {FLAGS.tokenizer_type}') + tokenizer.pad_token = tokenizer.eos_token + pad_id = tokenizer.encode(tokenizer.pad_token, add_special_tokens=False)[0] + end_id = tokenizer.encode(tokenizer.eos_token, add_special_tokens=False)[0] + + line = tokenizer.encode(FLAGS.text) + input_start_ids = np.array([line], np.int32) + input_len = np.array([[len(line)]], np.int32) + inputs = utils.prepare_inputs(input_start_ids, input_len, pad_id, end_id, + FLAGS) + + start_time = datetime.now() + + with utils.create_inference_server_client(FLAGS.protocol, + FLAGS.url, + concurrency=FLAGS.concurrency, + verbose=FLAGS.verbose) as client: + if FLAGS.protocol == "http": + async_requests = utils.send_requests_async('tensorrt_llm', + inputs, + client, + FLAGS, + request_parallelism=1) + results = utils.get_http_results(async_requests) + else: + user_data = utils.send_requests_async('tensorrt_llm', + inputs, + client, + FLAGS, + request_parallelism=1) + results = utils.get_grpc_results(user_data, request_parallelism=1) + output_ids = results[0].as_numpy("output_ids") + + stop_time = datetime.now() + latency = (stop_time - start_time).total_seconds() * 1000.0 + latency = round(latency, 3) + print(f"[INFO] Latency: {latency} ms") + + output_ids = output_ids.reshape( + (output_ids.size, )).tolist()[input_start_ids.shape[1]:] + output_text = tokenizer.decode(output_ids) + print(f'Input: {FLAGS.text}') + print(f'Output: {output_text}') diff --git a/tools/gpt/end_to_end_test.py b/tools/gpt/end_to_end_test.py new file mode 100644 index 00000000..e55ccc34 --- /dev/null +++ b/tools/gpt/end_to_end_test.py @@ -0,0 +1,264 @@ +#!/usr/bin/python + +import os +import sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) +import argparse + +import numpy as np +from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer +from utils import utils + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-v', + '--verbose', + action="store_true", + required=False, + default=False, + help='Enable verbose output') + parser.add_argument('-u', + '--url', + type=str, + required=False, + help='Inference server URL.') + parser.add_argument( + '-i', + '--protocol', + type=str, + required=False, + default='http', + help='Protocol ("http"/"grpc") used to ' + + 'communicate with inference service. Default is "http".') + parser.add_argument('-c', + '--concurrency', + type=int, + default=1, + required=False, + help='Specify concurrency') + parser.add_argument('-beam', + '--beam_width', + type=int, + default=1, + required=False, + help='Specify beam width') + parser.add_argument('-topk', + '--topk', + type=int, + default=1, + required=False, + help='topk for sampling') + parser.add_argument('-topp', + '--topp', + type=float, + default=0.0, + required=False, + help='topp for sampling') + parser.add_argument('-o', + '--output_len', + type=int, + default=10, + required=False, + help='Specify output length') + parser.add_argument('--tokenizer_dir', + type=str, + required=True, + help='Specify tokenizer directory') + parser.add_argument('--tokenizer_type', + type=str, + default='auto', + required=False, + choices=['auto', 't5', 'llama'], + help='Specify tokenizer type') + + FLAGS = parser.parse_args() + if (FLAGS.protocol != "http") and (FLAGS.protocol != "grpc"): + print( + "unexpected protocol \"{}\", expects \"http\" or \"grpc\"".format( + FLAGS.protocol)) + exit(1) + + if FLAGS.url is None: + FLAGS.url = "localhost:8000" if FLAGS.protocol == "http" else "localhost:8001" + + if FLAGS.tokenizer_type == 't5': + tokenizer = T5Tokenizer(vocab_file=FLAGS.tokenizer_dir, + padding_side='left') + elif FLAGS.tokenizer_type == 'auto': + tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer_dir, + padding_side='left') + elif FLAGS.tokenizer_type == 'llama': + tokenizer = LlamaTokenizer.from_pretrained(FLAGS.tokenizer_dir, + legacy=False, + padding_side='left') + else: + raise AttributeError( + f'Unexpected tokenizer type: {FLAGS.tokenizer_type}') + tokenizer.pad_token = tokenizer.eos_token + pad_id = tokenizer.encode(tokenizer.pad_token, add_special_tokens=False)[0] + end_id = tokenizer.encode(tokenizer.eos_token, add_special_tokens=False)[0] + + model_name = 'preprocessing' + with utils.create_inference_server_client(FLAGS.protocol, + FLAGS.url, + concurrency=FLAGS.concurrency, + verbose=FLAGS.verbose) as client: + input0 = [["Blackhawks\n The 2015 Hilltoppers"], + ["Data sources you can use to make a decision:"], + ["\n if(angle = 0) { if(angle"], + ["GMs typically get 78% female enrollment, but the "], + ["Previous Chapter | Index | Next Chapter"], + ["Michael, an American Jew, called Jews"], + ["Born in north-east France, Soyer trained as a"], + ["Data sources you can use to make a comparison:"]] + input0_data = np.array(input0).astype(object) + output0_len = np.ones_like(input0).astype(np.uint32) * FLAGS.output_len + bad_words_list = np.array( + [["Hawks, Hawks"], [""], [""], [""], [""], [""], [""], [""]], + dtype=object) + stop_words_list = np.array( + [[""], [""], [""], [""], [""], [""], [""], ["month, month"]], + dtype=object) + inputs = [ + utils.prepare_tensor("QUERY", input0_data, FLAGS.protocol), + utils.prepare_tensor("BAD_WORDS_DICT", bad_words_list, + FLAGS.protocol), + utils.prepare_tensor("STOP_WORDS_DICT", stop_words_list, + FLAGS.protocol), + utils.prepare_tensor("REQUEST_OUTPUT_LEN", output0_len, + FLAGS.protocol), + ] + + try: + result = client.infer(model_name, inputs) + output0 = result.as_numpy("INPUT_ID") + output1 = result.as_numpy("REQUEST_INPUT_LEN") + output2 = result.as_numpy("REQUEST_OUTPUT_LEN") + output3 = result.as_numpy("BAD_WORDS_IDS") + output4 = result.as_numpy("STOP_WORDS_IDS") + except Exception as e: + print(e) + + model_name = "tensorrt_llm" + with utils.create_inference_server_client(FLAGS.protocol, + FLAGS.url, + concurrency=1, + verbose=FLAGS.verbose) as client: + inputs = utils.prepare_inputs(output0, output1, pad_id, end_id, FLAGS) + + try: + result = client.infer(model_name, inputs) + output0 = result.as_numpy("output_ids") + except Exception as e: + print(e) + + model_name = "postprocessing" + with utils.create_inference_server_client(FLAGS.protocol, + FLAGS.url, + concurrency=FLAGS.concurrency, + verbose=FLAGS.verbose) as client: + inputs = [ + utils.prepare_tensor("TOKENS_BATCH", output0, FLAGS.protocol) + ] + inputs[0].set_data_from_numpy(output0) + + try: + result = client.infer(model_name, inputs) + output0 = result.as_numpy("OUTPUT") + print("============After postprocessing============") + batch_size = len(input0) + output0 = output0.reshape([-1, batch_size]).T.tolist() + output0 = [[char.decode('UTF-8') for char in line] + for line in output0] + output0 = [''.join(line) for line in output0] + for line in output0: + print(f"{line}") + print("===========================================\n\n\n") + except Exception as e: + print(e) + + model_name = "ensemble" + with utils.create_inference_server_client(FLAGS.protocol, + FLAGS.url, + concurrency=FLAGS.concurrency, + verbose=FLAGS.verbose) as client: + input0 = [["Blackhawks\n The 2015 Hilltoppers"], + ["Data sources you can use to make a decision:"], + ["\n if(angle = 0) { if(angle"], + ["GMs typically get 78% female enrollment, but the "], + ["Previous Chapter | Index | Next Chapter"], + ["Michael, an American Jew, called Jews"], + ["Born in north-east France, Soyer trained as a"], + ["Data sources you can use to make a comparison:"]] + bad_words_list = np.array( + [["Hawks, Hawks"], [""], [""], [""], [""], [""], [""], [""]], + dtype=object) + stop_words_list = np.array( + [[""], [""], [""], [""], [""], [""], [""], ["month, month"]], + dtype=object) + input0_data = np.array(input0).astype(object) + output0_len = np.ones_like(input0).astype(np.uint32) * FLAGS.output_len + runtime_top_k = (FLAGS.topk * + np.ones([input0_data.shape[0], 1])).astype(np.uint32) + runtime_top_p = FLAGS.topp * np.ones([input0_data.shape[0], 1]).astype( + np.float32) + temperature = 1.0 * np.ones([input0_data.shape[0], 1]).astype( + np.float32) + len_penalty = 1.0 * np.ones([input0_data.shape[0], 1]).astype( + np.float32) + repetition_penalty = 1.0 * np.ones([input0_data.shape[0], 1]).astype( + np.float32) + random_seed = 0 * np.ones([input0_data.shape[0], 1]).astype(np.uint64) + output_log_probs = True * np.ones([input0_data.shape[0], 1 + ]).astype(bool) + beam_width = (FLAGS.beam_width * + np.ones([input0_data.shape[0], 1])).astype(np.uint32) + pad_ids = pad_id * \ + np.ones([input0_data.shape[0], 1]).astype(np.uint32) + end_ids = end_id * \ + np.ones([input0_data.shape[0], 1]).astype(np.uint32) + min_length = 1 * \ + np.ones([input0_data.shape[0], 1]).astype(np.uint32) + presence_penalty = 0.0 * \ + np.ones([input0_data.shape[0], 1]).astype(np.float32) + inputs = [ + utils.prepare_tensor("INPUT_0", input0_data, FLAGS.protocol), + utils.prepare_tensor("INPUT_1", output0_len, FLAGS.protocol), + utils.prepare_tensor("INPUT_2", bad_words_list, FLAGS.protocol), + utils.prepare_tensor("INPUT_3", stop_words_list, FLAGS.protocol), + utils.prepare_tensor("pad_id", pad_ids, FLAGS.protocol), + utils.prepare_tensor("end_id", end_ids, FLAGS.protocol), + utils.prepare_tensor("beam_width", beam_width, FLAGS.protocol), + utils.prepare_tensor("runtime_top_k", runtime_top_k, + FLAGS.protocol), + utils.prepare_tensor("runtime_top_p", runtime_top_p, + FLAGS.protocol), + utils.prepare_tensor("temperature", temperature, FLAGS.protocol), + utils.prepare_tensor("len_penalty", len_penalty, FLAGS.protocol), + utils.prepare_tensor("repetition_penalty", repetition_penalty, + FLAGS.protocol), + utils.prepare_tensor("min_length", min_length, FLAGS.protocol), + utils.prepare_tensor("presence_penalty", presence_penalty, + FLAGS.protocol), + utils.prepare_tensor("random_seed", random_seed, FLAGS.protocol), + utils.prepare_tensor("output_log_probs", output_log_probs, + FLAGS.protocol), + ] + + try: + result = client.infer(model_name, inputs) + ensemble_output0 = result.as_numpy("OUTPUT_0") + print("============After ensemble============") + batch_size = len(input0) + ensemble_output0 = ensemble_output0.reshape([-1, batch_size + ]).T.tolist() + ensemble_output0 = [[char.decode('UTF-8') for char in line] + for line in ensemble_output0] + ensemble_output0 = [''.join(line) for line in ensemble_output0] + for line in ensemble_output0: + print(f"{line}") + except Exception as e: + print(e) + + assert output0 == ensemble_output0 diff --git a/tools/gpt/gen_input_data.py b/tools/gpt/gen_input_data.py new file mode 100644 index 00000000..00a29dcb --- /dev/null +++ b/tools/gpt/gen_input_data.py @@ -0,0 +1,108 @@ +import argparse +import json + +import numpy as np + + +def add_sample(sample, name, array): + sample[name] = {'content': array.flatten().tolist(), 'shape': array.shape} + + +def main(args): + data = {'data': []} + input_start_ids = np.random.randint(0, + 50255, + size=(args.start_len), + dtype=np.int32) + input_len = np.array([input_start_ids.shape[0]], np.int32) + output_len = np.ones([1]).astype(np.uint32) * args.output_len + runtime_top_k = (args.topk * np.ones([1])).astype(np.uint32) + runtime_top_p = args.topp * np.ones([1]).astype(np.float32) + beam_search_diversity_rate = 0.0 * np.ones([1]).astype(np.float32) + temperature = 1.0 * np.ones([1]).astype(np.float32) + len_penalty = 1.0 * np.ones([1]).astype(np.float32) + repetition_penalty = 1.0 * np.ones([1]).astype(np.float32) + random_seed = 0 * np.ones([1]).astype(np.uint64) + # is_return_log_probs = True * np.ones([1]).astype(bool) + beam_width = (args.beam_width * np.ones([1])).astype(np.uint32) + # start_ids = 50256 * np.ones([1]).astype(np.uint32) + # end_ids = 50256 * np.ones([1]).astype(np.uint32) + # bad_words_list = np.concatenate([ + # np.zeros([1, 1]).astype(np.int32), + # (-1 * np.ones([1, 1])).astype(np.int32) + # ], + # axis=1) + # stop_word_list = np.concatenate([ + # np.zeros([1, 1]).astype(np.int32), + # (-1 * np.ones([1, 1])).astype(np.int32) + # ], + # axis=1) + + for _ in range(args.num_samples): + sample = {} + add_sample(sample, 'input_ids', input_start_ids) + add_sample(sample, 'input_lengths', input_len) + add_sample(sample, 'request_output_len', output_len) + add_sample(sample, 'runtime_top_k', runtime_top_k) + add_sample(sample, 'runtime_top_p', runtime_top_p) + add_sample(sample, 'beam_search_diversity_rate', + beam_search_diversity_rate) + add_sample(sample, 'temperature', temperature) + add_sample(sample, 'len_penalty', len_penalty) + add_sample(sample, 'repetition_penalty', repetition_penalty) + add_sample(sample, 'random_seed', random_seed) + add_sample(sample, 'beam_width', beam_width) + # add_sample(sample, 'top_p_decay', top_p_decay) + # add_sample(sample, 'top_p_min', top_p_min) + # add_sample(sample, 'top_p_reset_ids', top_p_reset_ids) + data['data'].append(sample) + + with open('input_data.json', 'w') as f: + json.dump(data, f, indent=4) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-b', + '--batch_size', + type=int, + default=8, + required=False, + help='Specify batch size') + parser.add_argument('-beam', + '--beam_width', + type=int, + default=1, + required=False, + help='Specify beam width') + parser.add_argument('-topk', + '--topk', + type=int, + default=1, + required=False, + help='topk for sampling') + parser.add_argument('-topp', + '--topp', + type=float, + default=0.0, + required=False, + help='topp for sampling') + parser.add_argument('-s', + '--start_len', + type=int, + default=8, + required=False, + help='Specify input length') + parser.add_argument('-o', + '--output_len', + type=int, + default=10, + required=False, + help='Specify output length') + parser.add_argument('--num_samples', + type=int, + default=10000, + required=False, + help='Specify number of samples to generate') + args = parser.parse_args() + main(args) diff --git a/tools/gpt/identity_test.py b/tools/gpt/identity_test.py new file mode 100644 index 00000000..fdb4e93a --- /dev/null +++ b/tools/gpt/identity_test.py @@ -0,0 +1,178 @@ +#!/usr/bin/python + +import os +import sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) +import argparse +import statistics as s +from builtins import range +from datetime import datetime + +import numpy as np +from utils import utils + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-v', + '--verbose', + action="store_true", + required=False, + default=False, + help='Enable verbose output') + parser.add_argument('-u', + '--url', + type=str, + required=False, + help='Inference server URL.') + parser.add_argument( + '-i', + '--protocol', + type=str, + required=False, + default='http', + help='Protocol ("http"/"grpc") used to ' + + 'communicate with inference service. Default is "http".') + parser.add_argument('-w', + '--warm_up', + action="store_true", + required=False, + default=False, + help='Enable warm_up before benchmark') + parser.add_argument('-c', + '--concurrency', + type=int, + default=1, + required=False, + help='Specify concurrency') + parser.add_argument('-p', + '--request_parallelism', + type=int, + default=10, + required=False, + help='Specify request parallelism') + parser.add_argument('-m', + '--mode', + type=str, + required=False, + default='sync', + help='Mode ("sync"/"async").') + parser.add_argument('-b', + '--batch_size', + type=int, + default=8, + required=False, + help='Specify batch size') + parser.add_argument('-beam', + '--beam_width', + type=int, + default=1, + required=False, + help='Specify beam width') + parser.add_argument('-topk', + '--topk', + type=int, + default=1, + required=False, + help='topk for sampling') + parser.add_argument('-topp', + '--topp', + type=float, + default=0.0, + required=False, + help='topp for sampling') + parser.add_argument('-s', + '--start_len', + type=int, + default=8, + required=False, + help='Specify input length') + parser.add_argument('-o', + '--output_len', + type=int, + default=10, + required=False, + help='Specify output length') + parser.add_argument( + '-n', + '--num_runs', + type=int, + default=1, + required=False, + help="Spedifty number of runs to get the average latency") + + FLAGS = parser.parse_args() + if (FLAGS.protocol != "http") and (FLAGS.protocol != "grpc"): + print( + "unexpected protocol \"{}\", expects \"http\" or \"grpc\"".format( + FLAGS.protocol)) + exit(1) + + if FLAGS.url is None: + FLAGS.url = "localhost:8000" if FLAGS.protocol == "http" else "localhost:8001" + input_start_ids = np.random.randint(0, + 50255, + size=(FLAGS.batch_size, + FLAGS.start_len), + dtype=np.int32) + input_len = np.array([[input_start_ids.shape[1]] + for _ in range(input_start_ids.shape[0])], np.int32) + inputs = utils.prepare_inputs(input_start_ids, + input_len, + pad_id=0, + end_id=2, + flags=FLAGS) + + # warm up + if FLAGS.warm_up: + print("[INFO] sending requests to warm up") + with utils.create_inference_server_client( + FLAGS.protocol, + FLAGS.url, + concurrency=FLAGS.concurrency, + verbose=FLAGS.verbose) as client: + utils.send_requests('tensorrt_llm', + inputs, + client, + request_parallelism=2) + + latencies = [] + for i in range(FLAGS.num_runs): + start_time = datetime.now() + + with utils.create_inference_server_client( + FLAGS.protocol, + FLAGS.url, + concurrency=FLAGS.concurrency, + verbose=FLAGS.verbose) as client: + if FLAGS.mode == 'sync': + utils.send_requests('tensorrt_llm', inputs, client, + FLAGS.request_parallelism) + else: + if FLAGS.protocol == "http": + async_requests = utils.send_requests_async( + 'tensorrt_llm', inputs, client, FLAGS, + FLAGS.request_parallelism) + results = utils.get_http_results(async_requests) + else: + user_data = utils.send_requests_async( + 'tensorrt_llm', inputs, client, FLAGS, + FLAGS.request_parallelism) + results = utils.get_grpc_results(user_data, + FLAGS.request_parallelism) + + stop_time = datetime.now() + latencies.append((stop_time - start_time).total_seconds() * 1000.0 / + FLAGS.request_parallelism) + + if FLAGS.num_runs > 1: + latency = s.mean(latencies) + else: + latency = latencies[0] + latency = round(latency, 3) + throughput = round(1000 / latency * FLAGS.batch_size, 3) + print( + f"[INFO] Batch size: {FLAGS.batch_size}, Start len: {FLAGS.start_len}, Output len: {FLAGS.output_len}" + ) + print(f"[INFO] Latency: {latency} ms") + print(f"[INFO] Throughput: {throughput} sentences / sec") diff --git a/tools/inflight_batcher_llm/end_to_end_streaming_client.py b/tools/inflight_batcher_llm/end_to_end_streaming_client.py new file mode 100644 index 00000000..b508eaa5 --- /dev/null +++ b/tools/inflight_batcher_llm/end_to_end_streaming_client.py @@ -0,0 +1,131 @@ +#!/usr/bin/python + +import os +import sys +from functools import partial + +sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) + +import argparse +import queue +import sys + +import numpy as np +import tritonclient.grpc as grpcclient +from tritonclient.utils import InferenceServerException +from utils import utils + + +class UserData: + + def __init__(self): + self._completed_requests = queue.Queue() + + +def callback(user_data, result, error): + if error: + user_data._completed_requests.put(error) + else: + user_data._completed_requests.put(result) + output = result.as_numpy('OUTPUT_0') + print(output[0], flush=True) + + +def test(triton_client, prompt): + model_name = "ensemble" + + input0 = [[prompt]] + input0_data = np.array(input0).astype(object) + output0_len = np.ones_like(input0).astype(np.uint32) * FLAGS.output_len + bad_words_list = np.array([[""]], dtype=object) + stop_words_list = np.array([[""]], dtype=object) + streaming = [[FLAGS.streaming]] + streaming_data = np.array(streaming, dtype=bool) + + inputs = [ + utils.prepare_tensor("INPUT_0", input0_data, FLAGS.protocol), + utils.prepare_tensor("INPUT_1", output0_len, FLAGS.protocol), + utils.prepare_tensor("INPUT_2", bad_words_list, FLAGS.protocol), + utils.prepare_tensor("INPUT_3", stop_words_list, FLAGS.protocol), + utils.prepare_tensor("streaming", streaming_data, FLAGS.protocol), + ] + + user_data = UserData() + # Establish stream + triton_client.start_stream(callback=partial(callback, user_data)) + # Send request + triton_client.async_stream_infer(model_name, inputs) + + #Wait for server to close the stream + triton_client.stop_stream() + + # Parse the responses + while True: + try: + result = user_data._completed_requests.get(block=False) + except Exception: + break + + if type(result) == InferenceServerException: + print("Received an error from server:") + print(result) + else: + result.as_numpy('OUTPUT_0') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-v', + '--verbose', + action="store_true", + required=False, + default=False, + help='Enable verbose output') + parser.add_argument('-u', + '--url', + type=str, + required=False, + help='Inference server URL.') + + parser.add_argument('-p', + '--prompt', + type=str, + required=True, + help='Input prompt.') + parser.add_argument( + "-S", + "--streaming", + action="store_true", + required=False, + default=False, + help="Enable streaming mode. Default is False.", + ) + + parser.add_argument( + '-i', + '--protocol', + type=str, + required=False, + default='grpc', + choices=['grpc'], + help='Protocol ("http"/"grpc") used to ' + + 'communicate with inference service. Default is "http".') + + parser.add_argument('-o', + '--output_len', + type=int, + default=100, + required=False, + help='Specify output length') + + FLAGS = parser.parse_args() + if FLAGS.url is None: + FLAGS.url = "localhost:8000" if FLAGS.protocol == "http" else "localhost:8001" + + try: + client = grpcclient.InferenceServerClient(url=FLAGS.url) + except Exception as e: + print("client creation failed: " + str(e)) + sys.exit(1) + + test(client, FLAGS.prompt) diff --git a/tools/inflight_batcher_llm/end_to_end_test.py b/tools/inflight_batcher_llm/end_to_end_test.py new file mode 100644 index 00000000..70bc8f07 --- /dev/null +++ b/tools/inflight_batcher_llm/end_to_end_test.py @@ -0,0 +1,226 @@ +#!/usr/bin/python + +import os +import sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) + +import argparse +import json +import sys +from datetime import datetime +from functools import partial + +import numpy as np +from utils import utils + + +def callback(user_data, start_time, result, error): + user_data._completed_requests.put((result, error)) + stop_time = datetime.now() + latency = (stop_time - start_time).total_seconds() * 1000.0 + latency = round(latency, 3) + user_data._latencies.append(latency) + + +def test_functionality(client, prompts, output_lens): + print(f"[INFO] Start testing on {len(prompts)} prompts.") + for i, prompt in enumerate(prompts): + + # 1. Ensemble models manually: preprocessing -> tensorrt_llm -> postprocessing + model_name = 'preprocessing' + input0 = [[prompt]] + input0_data = np.array(input0).astype(object) + output0_len = np.ones_like(input0).astype(np.uint32) * output_lens[i] + bad_words_list = np.array([[""]], dtype=object) + stop_words_list = np.array([[""]], dtype=object) + + inputs = [ + utils.prepare_tensor("QUERY", input0_data, FLAGS.protocol), + utils.prepare_tensor("BAD_WORDS_DICT", bad_words_list, + FLAGS.protocol), + utils.prepare_tensor("STOP_WORDS_DICT", stop_words_list, + FLAGS.protocol), + utils.prepare_tensor("REQUEST_OUTPUT_LEN", output0_len, + FLAGS.protocol), + ] + result = client.infer(model_name, inputs, request_id=str(i)) + output0 = result.as_numpy("INPUT_ID") + output1 = result.as_numpy("REQUEST_INPUT_LEN") + output2 = result.as_numpy("REQUEST_OUTPUT_LEN") + + model_name = "tensorrt_llm" + inputs = [ + utils.prepare_tensor("input_ids", output0, FLAGS.protocol), + utils.prepare_tensor("input_lengths", output1, FLAGS.protocol), + utils.prepare_tensor("request_output_len", output2, + FLAGS.protocol), + ] + result = client.infer(model_name, inputs, request_id=str(i)) + output0 = result.as_numpy("output_ids") + + model_name = "postprocessing" + inputs = [ + utils.prepare_tensor("TOKENS_BATCH", output0, FLAGS.protocol) + ] + inputs[0].set_data_from_numpy(output0) + + result = client.infer(model_name, inputs, request_id=str(i)) + output0 = result.as_numpy("OUTPUT") + + # 2. Use ensemble model + model_name = "ensemble" + input0 = [[prompt]] + input0_data = np.array(input0).astype(object) + output0_len = np.ones_like(input0).astype(np.uint32) * output_lens[i] + bad_words_list = np.array([[""]], dtype=object) + stop_words_list = np.array([[""]], dtype=object) + + inputs = [ + utils.prepare_tensor("INPUT_0", input0_data, FLAGS.protocol), + utils.prepare_tensor("INPUT_1", output0_len, FLAGS.protocol), + utils.prepare_tensor("INPUT_2", bad_words_list, FLAGS.protocol), + utils.prepare_tensor("INPUT_3", stop_words_list, FLAGS.protocol), + ] + + result = client.infer(model_name, inputs, request_id=str(i)) + + # 3. Check the results between manually ensembled models and the ensemble model + ensemble_output = result.as_numpy('OUTPUT_0') + assert output0 == ensemble_output + print('Response: {}'.format(result.get_response())) + print('Output: {}'.format(ensemble_output)) + print(f"[INFO] Functionality test succeed.") + + +def test_performance(client, prompts, output_lens): + model_name = "ensemble" + + print(f"[INFO] Warm up for benchmarking.") + for i in range(10): + input0 = [[prompts[0]]] + input0_data = np.array(input0).astype(object) + output0_len = np.ones_like(input0).astype(np.uint32) * output_lens[i] + bad_words_list = np.array([[""]], dtype=object) + stop_words_list = np.array([[""]], dtype=object) + + inputs = [ + utils.prepare_tensor("INPUT_0", input0_data, FLAGS.protocol), + utils.prepare_tensor("INPUT_1", output0_len, FLAGS.protocol), + utils.prepare_tensor("INPUT_2", bad_words_list, FLAGS.protocol), + utils.prepare_tensor("INPUT_3", stop_words_list, FLAGS.protocol), + ] + + client.infer(model_name, inputs, request_id=str(i)) + + print(f"[INFO] Start benchmarking on {len(prompts)} prompts.") + latency = 0 + async_requests = [] + start_time = datetime.now() + user_data = utils.UserData() + for i, prompt in enumerate(prompts): + input0 = [[prompt]] + input0_data = np.array(input0).astype(object) + output0_len = np.ones_like(input0).astype(np.uint32) * output_lens[i] + bad_words_list = np.array([[""]], dtype=object) + stop_words_list = np.array([[""]], dtype=object) + + inputs = [ + utils.prepare_tensor("INPUT_0", input0_data, FLAGS.protocol), + utils.prepare_tensor("INPUT_1", output0_len, FLAGS.protocol), + utils.prepare_tensor("INPUT_2", bad_words_list, FLAGS.protocol), + utils.prepare_tensor("INPUT_3", stop_words_list, FLAGS.protocol), + ] + + if FLAGS.protocol == "http": + async_requests.append( + client.async_infer(model_name, inputs, request_id=str(i))) + elif FLAGS.protocol == "grpc": + async_requests.append( + client.async_infer(model_name, + inputs, + callback=partial(callback, user_data, + datetime.now()), + request_id=str(i))) + + if FLAGS.protocol == "http": + utils.get_http_results(async_requests) + elif FLAGS.protocol == "grpc": + utils.get_grpc_results(user_data, len(prompts)) + else: + raise RuntimeError("Invalid protocol") + + stop_time = datetime.now() + latency = (stop_time - start_time).total_seconds() * 1000.0 + latency = round(latency, 3) + print(f"[INFO] Total Latency: {latency} ms") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-v', + '--verbose', + action="store_true", + required=False, + default=False, + help='Enable verbose output') + parser.add_argument('-u', + '--url', + type=str, + required=False, + help='Inference server URL.') + parser.add_argument( + '-i', + '--protocol', + type=str, + required=False, + default='http', + choices=['http', 'grpc'], + help='Protocol ("http"/"grpc") used to ' + + 'communicate with inference service. Default is "http".') + parser.add_argument('-c', + '--concurrency', + type=int, + default=128, + required=False, + help='Specify concurrency') + parser.add_argument('--max_input_len', + type=int, + required=True, + help='Specify max input length') + + parser.add_argument('--dataset', + type=str, + required=True, + help='Dataset path used for the test.') + + FLAGS = parser.parse_args() + if FLAGS.url is None: + FLAGS.url = "localhost:8000" if FLAGS.protocol == "http" else "localhost:8001" + + try: + client = utils.create_inference_server_client( + FLAGS.protocol, + FLAGS.url, + concurrency=FLAGS.concurrency, + verbose=FLAGS.verbose) + except Exception as e: + print("channel creation failed: " + str(e)) + sys.exit(1) + + prompts = [] + output_lens = [] + with open(FLAGS.dataset, 'r') as f: + data_dict = json.load(f) + for req in data_dict: + prompt = req['input'] + ' ' + req['instruction'] + output = req['output'] + # 1.3 is a magic number that converts number of words to number of tokens + if int(len(prompt.split(' ')) / 1.3) > FLAGS.max_input_len: + continue + prompts.append(prompt) + # 1.3 is a magic number that converts number of words to number of tokens + output_lens.append(int(len(output.split(' ')) * 1.3)) + + test_functionality(client, prompts, output_lens) + test_performance(client, prompts, output_lens) diff --git a/tools/inflight_batcher_llm/identity_test.py b/tools/inflight_batcher_llm/identity_test.py new file mode 100644 index 00000000..20ae0d89 --- /dev/null +++ b/tools/inflight_batcher_llm/identity_test.py @@ -0,0 +1,187 @@ +#!/usr/bin/python + +import os +import sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) + +import argparse +import json +import sys +from datetime import datetime +from functools import partial + +import numpy as np +from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer +from utils import utils + + +def callback(user_data, start_time, result, error): + user_data._completed_requests.put((result, error)) + stop_time = datetime.now() + latency = (stop_time - start_time).total_seconds() * 1000.0 + latency = round(latency, 3) + user_data._latencies.append(latency) + + +def test_performance(client, input_start_ids, input_lens, output_lens): + model_name = "tensorrt_llm" + + print(f"[INFO] Warm up for benchmarking.") + for i in range(10): + output0_len = np.ones_like([[1]]).astype(np.uint32) * 100 + inputs = [ + utils.prepare_tensor("input_ids", input_start_ids[0], + FLAGS.protocol), + utils.prepare_tensor("input_lengths", input_lens[i], + FLAGS.protocol), + utils.prepare_tensor("request_output_len", output0_len, + FLAGS.protocol), + ] + client.infer(model_name, inputs, request_id=str(i)) + + print(f"[INFO] Start benchmarking on {len(input_start_ids)} prompts.") + latency = 0 + async_requests = [] + start_time = datetime.now() + user_data = utils.UserData() + for i, ids in enumerate(input_start_ids): + output0_len = np.ones_like([[1]]).astype(np.uint32) * output_lens[i] + inputs = [ + utils.prepare_tensor("input_ids", ids, FLAGS.protocol), + utils.prepare_tensor("input_lengths", input_lens[i], + FLAGS.protocol), + utils.prepare_tensor("request_output_len", output0_len, + FLAGS.protocol), + ] + + if FLAGS.protocol == "http": + async_requests.append( + client.async_infer(model_name, inputs, request_id=str(i))) + elif FLAGS.protocol == "grpc": + async_requests.append( + client.async_infer(model_name, + inputs, + callback=partial(callback, user_data, + datetime.now()), + request_id=str(i))) + + try: + if FLAGS.protocol == "http": + utils.get_http_results(async_requests) + elif FLAGS.protocol == "grpc": + utils.get_grpc_results(user_data, len(input_start_ids)) + else: + raise RuntimeError("Invalid protocol") + + stop_time = datetime.now() + latency = (stop_time - start_time).total_seconds() * 1000.0 + latency = round(latency, 3) + print(f"[INFO] Total Latency: {latency} ms") + + # Get latencies per request + if FLAGS.protocol == "grpc": + request_latencies = 0.0 + for latency in user_data._latencies: + request_latencies += latency + print(f"[INFO] Total request latencies: {request_latencies} ms") + + except Exception as e: + print("Failed receiving responses: " + str(e)) + sys.exit(1) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-v', + '--verbose', + action="store_true", + required=False, + default=False, + help='Enable verbose output') + parser.add_argument('-u', + '--url', + type=str, + required=False, + help='Inference server URL.') + parser.add_argument( + '-i', + '--protocol', + type=str, + required=False, + default='http', + choices=['http', 'grpc'], + help='Protocol ("http"/"grpc") used to ' + + 'communicate with inference service. Default is "http".') + parser.add_argument('-c', + '--concurrency', + type=int, + default=128, + required=False, + help='Specify concurrency') + parser.add_argument('--max_input_len', + type=int, + required=True, + help='Specify max input length') + + parser.add_argument('--dataset', + type=str, + required=True, + help='Dataset path used for the test.') + parser.add_argument('--tokenizer_dir', + type=str, + required=True, + help='Specify tokenizer directory') + parser.add_argument('--tokenizer_type', + type=str, + default='auto', + required=False, + choices=['auto', 't5', 'llama'], + help='Specify tokenizer type') + + FLAGS = parser.parse_args() + if FLAGS.url is None: + FLAGS.url = "localhost:8000" if FLAGS.protocol == "http" else "localhost:8001" + + try: + client = utils.create_inference_server_client( + FLAGS.protocol, + FLAGS.url, + concurrency=FLAGS.concurrency, + verbose=FLAGS.verbose) + except Exception as e: + print("channel creation failed: " + str(e)) + sys.exit(1) + + if FLAGS.tokenizer_type == 't5': + tokenizer = T5Tokenizer(vocab_file=FLAGS.tokenizer_dir, + padding_side='left') + elif FLAGS.tokenizer_type == 'auto': + tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer_dir, + padding_side='left') + elif FLAGS.tokenizer_type == 'llama': + tokenizer = LlamaTokenizer.from_pretrained(FLAGS.tokenizer_dir, + legacy=False, + padding_side='left') + else: + raise AttributeError( + f'Unexpected tokenizer type: {FLAGS.tokenizer_type}') + tokenizer.pad_token = tokenizer.eos_token + + input_start_ids = [] + input_lens = [] + output_lens = [] + with open(FLAGS.dataset, 'r') as f: + data_dict = json.load(f) + for req in data_dict: + prompt = req['input'] + ' ' + req['instruction'] + output = req['output'] + line = tokenizer.encode(prompt) + if len(line) > FLAGS.max_input_len: + continue + input_start_ids.append(np.array([line], np.int32)) + input_lens.append(np.array([[len(line)]], np.int32)) + # 1.3 is a magic number that converts number of words to number of tokens + output_lens.append(int(len(output.split(' ')) * 1.3)) + + test_performance(client, input_start_ids, input_lens, output_lens) diff --git a/tools/utils.sh b/tools/utils.sh new file mode 100644 index 00000000..042f1e5a --- /dev/null +++ b/tools/utils.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +# Wait until server health endpoint shows ready. Sets WAIT_RET to 0 on +# success, 1 on failure +function wait_for_server_ready() { + local spid="$1"; shift + local wait_time_secs="${1:-30}"; shift + + WAIT_RET=0 + + local wait_secs=$wait_time_secs + until test $wait_secs -eq 0 ; do + if ! kill -0 $spid; then + echo "=== Server not running." + WAIT_RET=1 + return + fi + + sleep 1; + + set +e + code=`curl -s -w %{http_code} localhost:8000/v2/health/ready` + set -e + if [ "$code" == "200" ]; then + return + fi + + ((wait_secs--)); + done + + echo "=== Timeout $wait_time_secs secs. Server not ready." + WAIT_RET=1 +} diff --git a/tools/utils/__init__.py b/tools/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tools/utils/utils.py b/tools/utils/utils.py new file mode 100644 index 00000000..33e17926 --- /dev/null +++ b/tools/utils/utils.py @@ -0,0 +1,155 @@ +import queue +from functools import partial + +import numpy as np +import tritonclient.grpc as grpcclient +import tritonclient.http as httpclient +from tritonclient.utils import np_to_triton_dtype + + +class UserData: + + def __init__(self): + self._completed_requests = queue.Queue() + self._latencies = [] + + +# Callback function used for async_stream_infer() +def completion_callback(user_data, result, error): + # passing error raise and handling out + user_data._completed_requests.put((result, error)) + + +def prepare_tensor(name, input, protocol): + client_util = httpclient if protocol == "http" else grpcclient + t = client_util.InferInput(name, input.shape, + np_to_triton_dtype(input.dtype)) + t.set_data_from_numpy(input) + return t + + +def prepare_inputs(input_start_ids, input_len, pad_id, end_id, flags): + output_len = np.ones([input_start_ids.shape[0], 1]).astype( + np.uint32) * flags.output_len + runtime_top_k = (flags.topk * + np.ones([input_start_ids.shape[0], 1])).astype(np.uint32) + runtime_top_p = flags.topp * \ + np.ones([input_start_ids.shape[0], 1]).astype(np.float32) + beam_search_diversity_rate = 0.0 * \ + np.ones([input_start_ids.shape[0], 1]).astype(np.float32) + temperature = 1.0 * \ + np.ones([input_start_ids.shape[0], 1]).astype(np.float32) + len_penalty = 1.0 * \ + np.ones([input_start_ids.shape[0], 1]).astype(np.float32) + repetition_penalty = 1.0 * \ + np.ones([input_start_ids.shape[0], 1]).astype(np.float32) + random_seed = 0 * \ + np.ones([input_start_ids.shape[0], 1]).astype(np.uint64) + output_log_probs = True * \ + np.ones([input_start_ids.shape[0], 1]).astype(bool) + beam_width = (flags.beam_width * + np.ones([input_start_ids.shape[0], 1])).astype(np.uint32) + pad_ids = pad_id * \ + np.ones([input_start_ids.shape[0], 1]).astype(np.uint32) + end_ids = end_id * \ + np.ones([input_start_ids.shape[0], 1]).astype(np.uint32) + min_length = 1 * \ + np.ones([input_start_ids.shape[0], 1]).astype(np.uint32) + presence_penalty = 0.0 * \ + np.ones([input_start_ids.shape[0], 1]).astype(np.float32) + bad_words_list = np.concatenate([ + np.zeros([input_start_ids.shape[0], 1, 1]).astype(np.int32), + (-1 * np.ones([input_start_ids.shape[0], 1, 1])).astype(np.int32) + ], + axis=1) + stop_word_list = np.concatenate([ + np.zeros([input_start_ids.shape[0], 1, 1]).astype(np.int32), + (-1 * np.ones([input_start_ids.shape[0], 1, 1])).astype(np.int32) + ], + axis=1) + inputs = [ + prepare_tensor("input_ids", input_start_ids, flags.protocol), + prepare_tensor("input_lengths", input_len, flags.protocol), + prepare_tensor("request_output_len", output_len, flags.protocol), + prepare_tensor("pad_id", pad_ids, flags.protocol), + prepare_tensor("end_id", end_ids, flags.protocol), + prepare_tensor("beam_width", beam_width, flags.protocol), + prepare_tensor("temperature", temperature, flags.protocol), + prepare_tensor("runtime_top_k", runtime_top_k, flags.protocol), + prepare_tensor("runtime_top_p", runtime_top_p, flags.protocol), + prepare_tensor("len_penalty", len_penalty, flags.protocol), + prepare_tensor("repetition_penalty", repetition_penalty, + flags.protocol), + prepare_tensor("min_length", min_length, flags.protocol), + prepare_tensor("presence_penalty", presence_penalty, flags.protocol), + prepare_tensor("random_seed", random_seed, flags.protocol), + prepare_tensor("output_log_probs", output_log_probs, flags.protocol), + # prepare_tensor("bad_words_list", bad_words_list, flags.protocol), + # prepare_tensor("stop_words_list", stop_word_list, flags.protocol), + ] + return inputs + + +def create_inference_server_client(protocol, url, concurrency, verbose): + client_util = httpclient if protocol == "http" else grpcclient + if protocol == "http": + return client_util.InferenceServerClient(url, + concurrency=concurrency, + verbose=verbose) + elif protocol == "grpc": + return client_util.InferenceServerClient(url, verbose=verbose) + + +def send_requests(model_name, inputs, client, request_parallelism): + results = [] + for _ in range(request_parallelism): + result = client.infer(model_name, inputs) + results.append(result) + return results + + +def send_requests_async(model_name, inputs, client, flags, + request_parallelism): + if flags.protocol == "http": + async_requests = [] + for _ in range(request_parallelism): + async_requests.append(client.async_infer(model_name, inputs)) + return async_requests + else: + user_data = UserData() + for _ in range(request_parallelism): + client.async_infer(model_name, inputs, + partial(completion_callback, user_data)) + return user_data + + +def get_http_results(async_requests): + results = [] + for async_request in async_requests: + results.append(async_request.get_result()) + return results + + +def get_grpc_results(user_data, request_parallelism): + results = [] + processed_count = 0 + while processed_count < request_parallelism: + (result, error) = user_data._completed_requests.get() + processed_count += 1 + if error is not None: + raise RuntimeError(error) + results.append(result) + return results + + +def append_start_and_end_ids(inputs, + batch_size, + flags, + start_id=None, + end_id=None): + if start_id is not None: + start_ids = start_id * np.ones([batch_size, 1]).astype(np.uint32) + inputs.append(prepare_tensor("start_id", start_ids, flags.protocol)) + if end_id is not None: + end_ids = end_id * np.ones([batch_size, 1]).astype(np.uint32) + inputs.append(prepare_tensor("end_id", end_ids, flags.protocol))