Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
* Update

* Update doc for lfs usage

* Update TensorRT-LLM submodule
  • Loading branch information
kaiyux authored Oct 27, 2023
1 parent 99de6ed commit 06f63fe
Show file tree
Hide file tree
Showing 21 changed files with 721 additions and 167 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,8 @@ don't need by removing the corresponding flags.
```bash
# Update the submodules
cd tensorrtllm_backend
git submodule update --init --recursive
git lfs install
git lfs pull
git submodule update --init --recursive

# Use the Dockerfile to build the backend in a container
# For x86_64
Expand Down
10 changes: 9 additions & 1 deletion all_models/inflight_batcher_llm/ensemble/config.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ output [
{
name: "text_output"
data_type: TYPE_STRING
dims: [ -1, -1 ]
dims: [ -1 ]
}
]
ensemble_scheduling {
Expand Down Expand Up @@ -229,6 +229,10 @@ ensemble_scheduling {
key: "output_ids"
value: "_TOKENS_BATCH"
}
output_map {
key: "sequence_length"
value: "_SEQUENCE_LENGTH"
}
},
{
model_name: "postprocessing"
Expand All @@ -237,6 +241,10 @@ ensemble_scheduling {
key: "TOKENS_BATCH"
value: "_TOKENS_BATCH"
}
input_map {
key: "SEQUENCE_LENGTH"
value: "_SEQUENCE_LENGTH"
}
output_map {
key: "OUTPUT"
value: "text_output"
Expand Down
15 changes: 10 additions & 5 deletions all_models/inflight_batcher_llm/postprocessing/1/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,16 @@ def execute(self, requests):
tokens_batch = pb_utils.get_input_tensor_by_name(
request, 'TOKENS_BATCH').as_numpy()

# Get sequence length
sequence_lengths = pb_utils.get_input_tensor_by_name(
request, 'SEQUENCE_LENGTH').as_numpy()

# Reshape Input
# tokens_batch = tokens_batch.reshape([-1, tokens_batch.shape[0]])
# tokens_batch = tokens_batch.T

# Postprocessing output data.
outputs = self._postprocessing(tokens_batch)
outputs = self._postprocessing(tokens_batch, sequence_lengths)

# Create output tensors. You need pb_utils.Tensor
# objects to create pb_utils.InferenceResponse.
Expand Down Expand Up @@ -144,10 +148,11 @@ def finalize(self):
"""
print('Cleaning up...')

def _postprocessing(self, tokens_batch):
def _postprocessing(self, tokens_batch, sequence_lengths):
outputs = []
for beam_tokens in tokens_batch:
for tokens in beam_tokens:
output = self.tokenizer.decode(tokens)
for batch_idx, beam_tokens in enumerate(tokens_batch):
for beam_idx, tokens in enumerate(beam_tokens):
seq_len = sequence_lengths[batch_idx][beam_idx]
output = self.tokenizer.decode(tokens[:seq_len])
outputs.append(output.encode('utf8'))
return outputs
7 changes: 6 additions & 1 deletion all_models/inflight_batcher_llm/postprocessing/config.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,18 @@ input [
name: "TOKENS_BATCH"
data_type: TYPE_INT32
dims: [ -1, -1 ]
},
{
name: "SEQUENCE_LENGTH"
data_type: TYPE_INT32
dims: [ -1 ]
}
]
output [
{
name: "OUTPUT"
data_type: TYPE_STRING
dims: [ -1, -1 ]
dims: [ -1 ]
}
]

Expand Down
27 changes: 13 additions & 14 deletions all_models/inflight_batcher_llm/preprocessing/1/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@
from typing import List

import numpy as np
import torch
import triton_python_backend_utils as pb_utils
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer


Expand Down Expand Up @@ -135,12 +133,10 @@ def execute(self, requests):
# Create output tensors. You need pb_utils.Tensor
# objects to create pb_utils.InferenceResponse.
input_id_tensor = pb_utils.Tensor(
'INPUT_ID',
np.array(input_id).astype(self.input_id_dtype))
'INPUT_ID', input_id.astype(self.input_id_dtype))
request_input_len_tensor = pb_utils.Tensor(
'REQUEST_INPUT_LEN',
np.array(request_input_len).astype(
self.request_input_len_dtype))
request_input_len.astype(self.request_input_len_dtype))
request_output_len_tensor = pb_utils.Tensor(
'REQUEST_OUTPUT_LEN', request_output_len)
bad_words_ids_tensor = pb_utils.Tensor('BAD_WORDS_IDS', bad_words)
Expand Down Expand Up @@ -176,16 +172,19 @@ def _create_request(self, query):
query : batch string (2D numpy array)
"""
start_ids = [
torch.IntTensor(self.tokenizer.encode(s[0].decode()))
np.array(self.tokenizer.encode(s[0].decode())).astype(int)
for s in query
]
start_lengths = torch.IntTensor([[len(ids)] for ids in start_ids])

start_ids = pad_sequence(start_ids,
batch_first=True,
padding_value=self.pad_id)
# input_len = min(start_lengths)
#attn_mask = torch.ones((batch_size, input_len, input_len)).tril()
start_lengths = np.array([[len(ids)] for ids in start_ids]).astype(int)

max_len = 0
for seq in start_ids:
max_len = max(max_len, seq.shape[0])
start_ids = np.stack([
np.pad(seq, (0, max_len - seq.shape[0]),
'constant',
constant_values=(0, self.pad_id)) for seq in start_ids
])

return start_ids, start_lengths

Expand Down
7 changes: 6 additions & 1 deletion all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@ output [
name: "output_ids"
data_type: TYPE_INT32
dims: [ -1, -1 ]
},
{
name: "sequence_length"
data_type: TYPE_INT32
dims: [ -1 ]
}
]
instance_group [
Expand All @@ -167,7 +172,7 @@ parameters: {
parameters: {
key: "gpt_model_type"
value: {
string_value: "inflight_fused_batching"
string_value: "${batching_strategy}"
}
}
parameters: {
Expand Down
2 changes: 2 additions & 0 deletions dockerfile/Dockerfile.trt_llm_backend
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ RUN pip uninstall -y tensorrt

FROM base as dev

ENV SHINIT_FILE=${BASH_ENV}

# Download & install internal TRT release
COPY tensorrt_llm/docker/common/install_tensorrt.sh /tmp/
RUN bash /tmp/install_tensorrt.sh && rm /tmp/install_tensorrt.sh
Expand Down
2 changes: 1 addition & 1 deletion inflight_batcher_llm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ set_ifndef(TRT_INCLUDE_DIR /usr/include/${CMAKE_SYSTEM_PROCESSOR}-linux-gnu)

set(TRT_LIB nvinfer)
find_library_create_target(${TRT_LIB} nvinfer SHARED ${TRT_LIB_DIR})
find_library_create_target(nvuffparser nvparsers SHARED ${TRT_LIB_DIR})

file(STRINGS "${TRT_INCLUDE_DIR}/NvInferVersion.h" VERSION_STRINGS
REGEX "#define NV_TENSORRT_.*")
Expand Down Expand Up @@ -311,6 +310,7 @@ target_link_libraries(
triton-core-serverstub # from repo-core
triton-backend-utils # from repo-backend
${MPI_LIBRARIES}
${CUDA_LIBRARIES}
nvinfer
nvinfer_plugin_tensorrt_llm)

Expand Down
73 changes: 40 additions & 33 deletions inflight_batcher_llm/client/inflight_batcher_llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def prepare_tensor(name, input, protocol):


def prepare_inputs(input_ids_data, input_lengths_data, request_output_len_data,
beam_width_data, temperature_data, streaming_data):
beam_width_data, temperature_data, streaming_data, end_id,
pad_id):
protocol = 'grpc'
inputs = [
prepare_tensor("input_ids", input_ids_data, protocol),
Expand All @@ -84,6 +85,8 @@ def prepare_inputs(input_ids_data, input_lengths_data, request_output_len_data,
prepare_tensor("beam_width", beam_width_data, protocol),
prepare_tensor("temperature", temperature_data, protocol),
prepare_tensor("streaming", streaming_data, protocol),
prepare_tensor("end_id", end_id, protocol),
prepare_tensor("pad_id", pad_id, protocol),
]

return inputs
Expand Down Expand Up @@ -118,8 +121,9 @@ def callback(user_data, result, error):
user_data._completed_requests.put(result)
if (FLAGS.streaming):
output_ids = result.as_numpy('output_ids')
tokens = list(output_ids[0][0])
print(tokens, flush=True)
if output_ids != None:
tokens = list(output_ids[0][0])
print(tokens, flush=True)


if __name__ == "__main__":
Expand Down Expand Up @@ -275,6 +279,8 @@ def callback(user_data, result, error):
tokenizer.pad_token = tokenizer.eos_token
pad_id = tokenizer.encode(tokenizer.pad_token, add_special_tokens=False)[0]
end_id = tokenizer.encode(tokenizer.eos_token, add_special_tokens=False)[0]
end_id_data = np.array([[end_id]], dtype=np.uint32)
pad_id_data = np.array([[pad_id]], dtype=np.uint32)

input_ids = [tokenizer.encode(FLAGS.text)]
input_ids_data = np.array(input_ids, dtype=np.int32)
Expand All @@ -291,7 +297,8 @@ def callback(user_data, result, error):

inputs = prepare_inputs(input_ids_data, input_lengths_data,
request_output_len_data, beam_width_data,
temperature_data, streaming_data)
temperature_data, streaming_data, end_id_data,
pad_id_data)

if FLAGS.stop_after_ms > 0:
stop_inputs = prepare_stop_signals()
Expand All @@ -300,17 +307,18 @@ def callback(user_data, result, error):

request_id = FLAGS.request_id

expected_output_ids = [
input_ids[0] + [
21221, 290, 257, 4255, 379, 262, 1957, 7072, 11, 4689, 347, 2852,
2564, 494, 13, 679
]
expected_output_ids = input_ids[0] + [
21221, 290, 257, 4255, 379, 262, 1957, 7072, 11, 4689, 347, 2852, 2564,
494, 13, 679
]

if FLAGS.streaming:
actual_output_ids = [input_ids[0]]
else:
actual_output_ids = []

sequence_lengths = []

user_data = UserData()
with grpcclient.InferenceServerClient(
url=FLAGS.url,
Expand Down Expand Up @@ -361,17 +369,12 @@ def callback(user_data, result, error):
print(result)
else:
output_ids = result.as_numpy('output_ids')

sequence_lengths = result.as_numpy('sequence_length')
if output_ids is not None:
if (FLAGS.streaming):
# Only one beam is supported
tokens = list(output_ids[0][0])
actual_output_ids[
0] = actual_output_ids[0] + tokens
else:
for beam_output_ids in output_ids[0]:
tokens = list(beam_output_ids)
actual_output_ids.append(tokens)
# Only one beam is supported
tokens = list(output_ids[0][0])
actual_output_ids[
0] = actual_output_ids[0] + tokens
else:
print("Got cancellation response from server")
else:
Expand Down Expand Up @@ -408,12 +411,13 @@ def callback(user_data, result, error):
print(result)
else:
output_ids = result.as_numpy('output_ids')
sequence_lengths = result.as_numpy('sequence_length')
if output_ids is not None:
for beam_output_ids in output_ids[0]:
tokens = list(beam_output_ids)
actual_output_ids.append(tokens)
else:
print("Got response for cancellation request")
print("Got cancellation response from server")

processed_count = processed_count + 1
except Exception as e:
Expand All @@ -422,18 +426,21 @@ def callback(user_data, result, error):

passed = True

print("output_ids = ", actual_output_ids)
output_ids = np.array(actual_output_ids)
output_ids = output_ids.reshape(
(output_ids.size, )).tolist()[input_ids_data.shape[1]:]
output_text = tokenizer.decode(output_ids)
print(f'Input: {FLAGS.text}')
print(f'Output: {output_text}')
if (FLAGS.check_output):
passed = (actual_output_ids == expected_output_ids)
print("expected_output_ids = ", expected_output_ids)
print("\n=====")
print("PASS!" if passed else "FAIL!")
print("=====")
for beam in range(FLAGS.beam_width):
seq_len = sequence_lengths[0][
beam] if not FLAGS.streaming else len(actual_output_ids[beam])
output_ids_w_prompt = actual_output_ids[beam][:seq_len]
output_ids_wo_prompt = output_ids_w_prompt[input_ids_data.
shape[1]:]
output_text = tokenizer.decode(output_ids_wo_prompt)
print(f'Input: {FLAGS.text}')
print(f'Output beam {beam}: {output_text}')
if (FLAGS.check_output and beam == 0):
passed = (output_ids_w_prompt == expected_output_ids)
print("output_ids = ", output_ids_w_prompt)
print("expected_output_ids = ", expected_output_ids)
print("\n=====")
print("PASS!" if passed else "FAIL!")
print("=====")

sys.exit(not passed)
Loading

0 comments on commit 06f63fe

Please sign in to comment.