diff --git a/tools/gpt/end_to_end_test.py b/tools/gpt/end_to_end_test.py index ad2f677f..ac461336 100644 --- a/tools/gpt/end_to_end_test.py +++ b/tools/gpt/end_to_end_test.py @@ -10,7 +10,7 @@ from transformers import AutoTokenizer from utils import utils -if __name__ == '__main__': +def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('-v', '--verbose', @@ -29,8 +29,8 @@ type=str, required=False, default='http', - help='Protocol ("http"/"grpc") used to ' + - 'communicate with inference service. Default is "http".') + choices=['http', 'grpc'], + help='Protocol ("http"/"grpc") used to communicate with inference service. Default is "http".') parser.add_argument('-c', '--concurrency', type=int, @@ -65,14 +65,11 @@ type=str, required=True, help='Specify tokenizer directory') + return parser.parse_args() - FLAGS = parser.parse_args() - if (FLAGS.protocol != "http") and (FLAGS.protocol != "grpc"): - print( - "unexpected protocol \"{}\", expects \"http\" or \"grpc\"".format( - FLAGS.protocol)) - exit(1) - +def main(): + FLAGS = parse_args() + if FLAGS.url is None: FLAGS.url = "localhost:8000" if FLAGS.protocol == "http" else "localhost:8001" @@ -188,65 +185,4 @@ runtime_top_p = FLAGS.topp * np.ones([input0_data.shape[0], 1]).astype( np.float32) temperature = 1.0 * np.ones([input0_data.shape[0], 1]).astype( - np.float32) - len_penalty = 1.0 * np.ones([input0_data.shape[0], 1]).astype( - np.float32) - repetition_penalty = 1.0 * np.ones([input0_data.shape[0], 1]).astype( - np.float32) - random_seed = 0 * np.ones([input0_data.shape[0], 1]).astype(np.uint64) - output_log_probs = True * np.ones([input0_data.shape[0], 1 - ]).astype(bool) - beam_width = (FLAGS.beam_width * - np.ones([input0_data.shape[0], 1])).astype(np.int32) - pad_ids = pad_id * \ - np.ones([input0_data.shape[0], 1]).astype(np.int32) - end_ids = end_id * \ - np.ones([input0_data.shape[0], 1]).astype(np.int32) - min_length = 1 * \ - np.ones([input0_data.shape[0], 1]).astype(np.int32) - presence_penalty = 0.0 * \ - np.ones([input0_data.shape[0], 1]).astype(np.float32) - frequency_penalty = 0.0 * \ - np.ones([input0_data.shape[0], 1]).astype(np.float32) - inputs = [ - utils.prepare_tensor("text_input", input0_data, FLAGS.protocol), - utils.prepare_tensor("max_tokens", output0_len, FLAGS.protocol), - utils.prepare_tensor("bad_words", bad_words_list, FLAGS.protocol), - utils.prepare_tensor("stop_words", stop_words_list, - FLAGS.protocol), - utils.prepare_tensor("pad_id", pad_ids, FLAGS.protocol), - utils.prepare_tensor("end_id", end_ids, FLAGS.protocol), - utils.prepare_tensor("beam_width", beam_width, FLAGS.protocol), - utils.prepare_tensor("top_k", runtime_top_k, FLAGS.protocol), - utils.prepare_tensor("top_p", runtime_top_p, FLAGS.protocol), - utils.prepare_tensor("temperature", temperature, FLAGS.protocol), - utils.prepare_tensor("length_penalty", len_penalty, - FLAGS.protocol), - utils.prepare_tensor("repetition_penalty", repetition_penalty, - FLAGS.protocol), - utils.prepare_tensor("min_length", min_length, FLAGS.protocol), - utils.prepare_tensor("presence_penalty", presence_penalty, - FLAGS.protocol), - utils.prepare_tensor("frequency_penalty", frequency_penalty, - FLAGS.protocol), - utils.prepare_tensor("random_seed", random_seed, FLAGS.protocol), - utils.prepare_tensor("output_log_probs", output_log_probs, - FLAGS.protocol), - ] - - try: - result = client.infer(model_name, inputs) - ensemble_output0 = result.as_numpy("text_output") - print("============After ensemble============") - batch_size = len(input0) - ensemble_output0 = ensemble_output0.reshape([-1, batch_size - ]).T.tolist() - ensemble_output0 = [[char.decode('UTF-8') for char in line] - for line in ensemble_output0] - ensemble_output0 = [''.join(line) for line in ensemble_output0] - for line in ensemble_output0: - print(f"{line}") - except Exception as e: - print(e) - - assert output0 == ensemble_output0 +