diff --git a/README.md b/README.md index 9007cdaa..bac790cc 100644 --- a/README.md +++ b/README.md @@ -70,10 +70,10 @@ The below commands will build the same Triton TRT-LLM container as the one on th # Prepare the TRT-LLM base image using the dockerfile from tensorrtllm_backend. cd tensorrtllm_backend # Specify the build args for the dockerfile. -BASE_IMAGE=nvcr.io/nvidia/pytorch:24.02-py3 -TRT_VERSION=9.3.0.1 -TRT_URL_x86=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/9.3.0/tensorrt-9.3.0.1.linux.x86_64-gnu.cuda-12.2.tar.gz -TRT_URL_ARM=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/9.3.0/tensorrt-9.3.0.1.ubuntu-22.04.aarch64-gnu.cuda-12.2.tar.gz +BASE_IMAGE=nvcr.io/nvidia/pytorch:24.03-py3 +TRT_VERSION=10.0.1.6 +TRT_URL_x86=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.0.1/tars/TensorRT-10.0.1.6.Linux.x86_64-gnu.cuda-12.4.tar.gz +TRT_URL_ARM=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.0.1/tars/TensorRT-10.0.1.6.ubuntu-22.04.aarch64-gnu.cuda-12.4.tar.gz docker build -t trtllm_base \ --build-arg BASE_IMAGE="${BASE_IMAGE}" \ @@ -297,9 +297,9 @@ The following table shows the fields that may to be modified before deployment: | `max_tokens_in_paged_kv_cache` | Optional (default=unspecified). The maximum size of the KV cache in number of tokens. If unspecified, value is interpreted as 'infinite'. KV cache allocation is the min of max_tokens_in_paged_kv_cache and value derived from kv_cache_free_gpu_mem_fraction below. | | `max_attention_window_size` | Optional (default=max_sequence_length). When using techniques like sliding window attention, the maximum number of tokens that are attended to generate one token. Defaults attends to all tokens in sequence. | | `kv_cache_free_gpu_mem_fraction` | Optional (default=0.9). Set to a number between 0 and 1 to indicate the maximum fraction of GPU memory (after loading the model) that may be used for KV cache.| -| `enable_trt_overlap` | Optional (default=`false`). Set to `true` to partition available requests into 2 'microbatches' that can be run concurrently to hide exposed CPU runtime | | `exclude_input_in_output` | Optional (default=`false`). Set to `true` to only return completion tokens in a response. Set to `false` to return the prompt tokens concatenated with the generated tokens | | `cancellation_check_period_ms` | Optional (default=100). The time for cancellation check thread to sleep before doing the next check. It checks if any of the current active requests are cancelled through triton and prevent further execution of them. | +| `stats_check_period_ms` | Optional (default=100). The time for the statistics reporting thread to sleep before doing the next check. | | `iter_stats_max_iterations` | Optional (default=executor::kDefaultIterStatsMaxIterations). The numbers of iteration stats to be kept. | | `request_stats_max_iterations` | Optional (default=executor::kDefaultRequestStatsMaxIterations). The numbers of request stats to be kept. | | `normalize_log_probs` | Optional (default=`true`). Set to `false` to skip normalization of `output_log_probs` | diff --git a/all_models/inflight_batcher_llm/postprocessing/1/model.py b/all_models/inflight_batcher_llm/postprocessing/1/model.py index 02aafad7..3766efcb 100644 --- a/all_models/inflight_batcher_llm/postprocessing/1/model.py +++ b/all_models/inflight_batcher_llm/postprocessing/1/model.py @@ -55,11 +55,28 @@ def initialize(self, args): model_config = json.loads(args['model_config']) tokenizer_dir = model_config['parameters']['tokenizer_dir'][ 'string_value'] - self.skip_special_tokens = model_config['parameters'].get( - 'skip_special_tokens', - {'string_value': "true"})['string_value'].lower() in [ - 'true', '1', 't', 'y', 'yes' - ] + + skip_special_tokens = model_config['parameters'].get( + 'skip_special_tokens') + if skip_special_tokens is not None: + skip_special_tokens_str = skip_special_tokens[ + 'string_value'].lower() + if skip_special_tokens_str in [ + 'true', 'false', '1', '0', 't', 'f', 'y', 'n', 'yes', 'no' + ]: + self.skip_special_tokens = skip_special_tokens_str in [ + 'true', '1', 't', 'y', 'yes' + ] + else: + print( + f"[TensorRT-LLM][WARNING] Don't setup 'skip_special_tokens' correctly (set value is {skip_special_tokens['string_value']}). Set it as True by default." + ) + self.skip_special_tokens = True + else: + print( + f"[TensorRT-LLM][WARNING] Don't setup 'skip_special_tokens'. Set it as True by default." + ) + self.skip_special_tokens = True self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, legacy=False, diff --git a/all_models/inflight_batcher_llm/postprocessing/config.pbtxt b/all_models/inflight_batcher_llm/postprocessing/config.pbtxt index 60d0290a..aaecb134 100644 --- a/all_models/inflight_batcher_llm/postprocessing/config.pbtxt +++ b/all_models/inflight_batcher_llm/postprocessing/config.pbtxt @@ -101,7 +101,7 @@ parameters { parameters { key: "skip_special_tokens" value: { - string_value: "True" + string_value: "${skip_special_tokens}" } } diff --git a/all_models/inflight_batcher_llm/preprocessing/1/model.py b/all_models/inflight_batcher_llm/preprocessing/1/model.py index 62ab2430..1824e2f8 100644 --- a/all_models/inflight_batcher_llm/preprocessing/1/model.py +++ b/all_models/inflight_batcher_llm/preprocessing/1/model.py @@ -56,11 +56,27 @@ def initialize(self, args): model_config = json.loads(args['model_config']) tokenizer_dir = model_config['parameters']['tokenizer_dir'][ 'string_value'] - self.add_special_tokens = model_config['parameters'].get( - 'add_special_tokens', - {'string_value': "false"})['string_value'].lower() in [ - 'true', '1', 't', 'y', 'yes' - ] + + add_special_tokens = model_config['parameters'].get( + 'add_special_tokens') + if add_special_tokens is not None: + add_special_tokens_str = add_special_tokens['string_value'].lower() + if add_special_tokens_str in [ + 'true', 'false', '1', '0', 't', 'f', 'y', 'n', 'yes', 'no' + ]: + self.add_special_tokens = add_special_tokens_str in [ + 'true', '1', 't', 'y', 'yes' + ] + else: + print( + f"[TensorRT-LLM][WARNING] Don't setup 'add_special_tokens' correctly (set value is {add_special_tokens['string_value']}). Set it as True by default." + ) + self.add_special_tokens = True + else: + print( + f"[TensorRT-LLM][WARNING] Don't setup 'add_special_tokens'. Set it as True by default." + ) + self.add_special_tokens = True self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, legacy=False, diff --git a/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py b/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py new file mode 100644 index 00000000..828f7ea2 --- /dev/null +++ b/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py @@ -0,0 +1,582 @@ +import datetime +import json +import os +import time +from threading import Lock, Thread + +import numpy as np +import triton_python_backend_utils as pb_utils +from torch import from_numpy + +import tensorrt_llm.bindings.executor as trtllm + + +def get_input_tensor_by_name(request, name): + tensor = pb_utils.get_input_tensor_by_name(request, name) + if tensor is None: + return None + return tensor.as_numpy() + + +def get_input_scalar_by_name(request, name): + tensor = get_input_tensor_by_name(request, name) + if tensor is None: + return None + if tensor.size != 1: + raise pb_utils.TritonModelException( + f"Expected a single value for {name}") + return tensor.item() + + +def read_parameter_as_type(value, name, pytype=str): + if value == "": + return None + if value.startswith("${") and value.endswith("}"): + return None + if pytype is bool: + return value.lower() in ["1", "true"] + try: + result = pytype(value) + return result + except: + pb_utils.Logger.log_warning( + f"Could not read parameter '{name}' with value '{value}', will use default." + ) + return None + + +def get_parameter(model_config, name, pytype=str): + if name not in model_config['parameters']: + return None + return read_parameter_as_type( + model_config['parameters'][name]['string_value'], name, pytype) + + +def convert_word_list(word_list): + if word_list is None: + return None + word_list = word_list.tolist() + if len(word_list) == 0 or len(word_list[0]) != 2: + raise pb_utils.TritonModelException(f"Invalid format for word list.") + words, indices = word_list[0] + result = [] + current_index = 0 + for i in indices: + if i == -1: + continue + if i > len(words): + raise pb_utils.TritonModelException( + f"Invalid format for word list.") + current_word = [] + while current_index < i: + current_word.append(words[current_index]) + current_index += 1 + result.append(current_word) + return result + + +def parse_medusa_choices(medusa_choices): + if medusa_choices is None: + return None + try: + result = json.loads( + "[" + medusa_choices.replace("{", "[").replace("}", "]") + "]") + assert isinstance(result, list) and len(result) > 0 + assert all([isinstance(x, list) for x in result]) + assert all([isinstance(y, int) for x in result for y in x]) + except Exception: + raise pb_utils.TritonModelException( + "Invalid format for medusa_choices") + return result + + +def get_sampling_config_from_request(request): + kwargs = {} + kwargs['beam_width'] = get_input_scalar_by_name(request, 'beam_width') or 1 + kwargs['top_k'] = get_input_scalar_by_name(request, 'runtime_top_k') + kwargs['top_p'] = get_input_scalar_by_name(request, 'runtime_top_p') + kwargs['top_p'] = None if kwargs['top_p'] is None or kwargs[ + 'top_p'] <= 0 else kwargs['top_p'] + kwargs['random_seed'] = get_input_scalar_by_name(request, 'random_seed') + kwargs['temperature'] = get_input_scalar_by_name(request, 'temperature') + kwargs['min_length'] = get_input_scalar_by_name(request, 'min_length') + kwargs['repetition_penalty'] = get_input_scalar_by_name( + request, 'repetition_penalty') + kwargs['presence_penalty'] = get_input_scalar_by_name( + request, 'presence_penalty') + kwargs['frequency_penalty'] = get_input_scalar_by_name( + request, 'frequency_penalty') + kwargs['length_penalty'] = get_input_scalar_by_name(request, 'len_penalty') + kwargs['top_p_min'] = get_input_scalar_by_name(request, + 'runtime_top_p_min') + kwargs['top_p_reset_ids'] = get_input_scalar_by_name( + request, 'runtime_top_p_reset_ids') + kwargs['top_p_decay'] = get_input_scalar_by_name(request, + 'runtime_top_p_decay') + kwargs['beam_search_diversity_rate'] = get_input_scalar_by_name( + request, 'beam_search_diversity_rate') + kwargs['early_stopping'] = get_input_scalar_by_name( + request, 'early_stopping') + kwargs = {k: v for k, v in kwargs.items() if v is not None} + return trtllm.SamplingConfig(**kwargs) + + +def get_output_config_from_request(request, exclude_input_from_output): + kwargs = {} + kwargs["return_log_probs"] = get_input_scalar_by_name( + request, 'return_log_probs') + kwargs["return_context_logits"] = get_input_scalar_by_name( + request, 'return_context_logits') + kwargs["return_generation_logits"] = get_input_scalar_by_name( + request, 'return_generation_logits') + kwargs["exclude_input_from_output"] = exclude_input_from_output + kwargs = {k: v for k, v in kwargs.items() if v is not None} + return trtllm.OutputConfig(**kwargs) + + +def get_speculative_decoding_config_from_request(request): + kwargs = {} + draft_input_ids = get_input_tensor_by_name(request, 'draft_input_ids') + if draft_input_ids is not None: + kwargs['tokens'] = draft_input_ids.tolist() + draft_logits = get_input_tensor_by_name(request, 'draft_logits') + if draft_logits is not None: + kwargs['logits'] = from_numpy(draft_logits) + kwargs['acceptance_threshold'] = get_input_scalar_by_name( + request, 'draft_acceptance_threshold') + kwargs = {k: v for k, v in kwargs.items() if v is not None} + if len(kwargs) > 0: + return trtllm.SpeculativeDecodingConfig(**kwargs) + return None + + +def get_prompt_tuning_config_from_request(request): + # prompt_vocab_size is unused by executor. + kwargs = {} + prompt_embedding_table = get_input_tensor_by_name( + request, 'prompt_embedding_table') + if prompt_embedding_table is not None: + kwargs["embedding_table"] = from_numpy(prompt_embedding_table) + kwargs = {k: v for k, v in kwargs.items() if v is not None} + if len(kwargs) > 0: + return trtllm.PromptTuningConfig(**kwargs) + return None + + +def get_lora_config_from_request(request): + kwargs = {} + kwargs["task_id"] = get_input_scalar_by_name(request, 'lora_task_id') + lora_weights = get_input_tensor_by_name(request, 'lora_weights') + if lora_weights is not None: + kwargs["weights"] = from_numpy(lora_weights) + lora_config = get_input_tensor_by_name(request, 'lora_config') + if lora_config is not None: + kwargs["config"] = from_numpy(lora_config) + kwargs = {k: v for k, v in kwargs.items() if v is not None} + if len(kwargs) > 0: + return trtllm.LoraConfig(**kwargs) + return None + + +def convert_request(request, exclude_input_from_output, decoupled): + inputs = {} + input_token_ids = get_input_tensor_by_name(request, 'input_ids') + if input_token_ids is None: + raise pb_utils.TritonModelException( + "A value is required for input_ids") + input_token_ids = input_token_ids.tolist() + if len(input_token_ids) == 0: + raise pb_utils.TritonModelException(f"Invalid format for input_ids") + inputs['input_token_ids'] = input_token_ids[0] + # input_lengths is not not used by executor. + inputs['max_new_tokens'] = get_input_scalar_by_name( + request, 'request_output_len') + if inputs['max_new_tokens'] is None: + raise pb_utils.TritonModelException( + "A value is required for request_output_len") + inputs['streaming'] = get_input_scalar_by_name(request, 'streaming') + if inputs['streaming'] and not decoupled: + raise pb_utils.TritonModelException( + "Streaming is only supported in decoupled mode.") + inputs['end_id'] = get_input_scalar_by_name(request, 'end_id') + inputs['pad_id'] = get_input_scalar_by_name(request, 'pad_id') + inputs['stop_words'] = convert_word_list( + get_input_tensor_by_name(request, 'stop_words_list')) + inputs['bad_words'] = convert_word_list( + get_input_tensor_by_name(request, 'bad_words_list')) + embedding_bias = get_input_tensor_by_name(request, 'embedding_bias') + if embedding_bias is not None and embedding_bias.size != 0: + inputs['embedding_bias'] = from_numpy(embedding_bias).squeeze() + + sampling_config = get_sampling_config_from_request(request) + output_config = get_output_config_from_request(request, + exclude_input_from_output) + speculative_decoding_config = get_speculative_decoding_config_from_request( + request) + prompt_tuning_config = get_prompt_tuning_config_from_request(request) + lora_config = get_lora_config_from_request(request) + + return trtllm.Request( + **inputs, + sampling_config=sampling_config, + output_config=output_config, + speculative_decoding_config=speculative_decoding_config, + prompt_tuning_config=prompt_tuning_config, + lora_config=lora_config, + ) + + +def convert_response(response): + if response.has_error(): + return pb_utils.InferenceResponse(output_tensors=[], + error=pb_utils.TritonError( + response.error_msg)), True + result = response.result + beam_lengths = np.expand_dims( + np.array([len(beam) for beam in result.output_token_ids], np.int32), 0) + max_beam_length = max([len(beam) for beam in result.output_token_ids]) + output_ids = np.full((1, len(result.output_token_ids), max_beam_length), + -1, np.int32) + for idx, beam in enumerate(result.output_token_ids): + output_ids[0, idx, :len(beam)] = beam + output_tensors = [ + pb_utils.Tensor("output_ids", output_ids), + pb_utils.Tensor("sequence_length", beam_lengths), + ] + output_tensors.append( + pb_utils.Tensor( + "cum_log_probs", + np.expand_dims(np.array(result.cum_log_probs, np.float32), 0) + if result.cum_log_probs is not None else np.zeros( + (1, 1), np.float32))) + output_tensors.append( + pb_utils.Tensor( + "output_log_probs", + np.expand_dims(np.array(result.log_probs, np.float32), 0) if + result.log_probs is not None else np.zeros((1, 1, 1), np.float32))) + output_tensors.append( + pb_utils.Tensor( + "context_logits", + np.expand_dims(np.array(result.context_logits, np.float32), 0) + if result.context_logits is not None else np.zeros( + (1, 1, 1), np.float32))) + output_tensors.append( + pb_utils.Tensor( + "generation_logits", + np.expand_dims(np.array(result.generation_logits, np.float32), 0) + if result.generation_logits is not None else np.zeros( + (1, 1, 1, 1), np.float32))) + return pb_utils.InferenceResponse(output_tensors), result.is_final + + +def convert_scheduler_policy(batch_scheduler_policy: str): + if batch_scheduler_policy.lower() == "max_utilization": + return trtllm.CapacitySchedulerPolicy.MAX_UTILIZATION + elif batch_scheduler_policy.lower() == "guaranteed_no_evict": + return trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT + raise pb_utils.TritonModelException( + f"batch_scheduler_policy value of '{batch_scheduler_policy}' is not supported." + ) + + +def convert_batching_type(gpt_model_type: str): + if gpt_model_type is None: + return None + if gpt_model_type.lower( + ) == "inflight_fused_batching" or gpt_model_type.lower( + ) == "inflight_batching": + return trtllm.BatchingType.INFLIGHT + elif gpt_model_type.lower() == "v1": + return trtllm.BatchingType.STATIC + raise pb_utils.TritonModelException( + f"gpt_model_type value of '{gpt_model_type}' is not supported.") + + +def convert_decoding_mode(decoding_mode: str): + if decoding_mode is None: + return None + elif decoding_mode == "none": + return trtllm.DecodingMode.NONE + elif decoding_mode == "top_k": + return trtllm.DecodingMode.TOP_K + elif decoding_mode == "top_p": + return trtllm.DecodingMode.TOP_P + elif decoding_mode == "top_k_top_p": + return trtllm.DecodingMode.TOP_K_TOP_P + elif decoding_mode == "beam_search": + return trtllm.DecodingMode.BEAM_SEARCH + elif decoding_mode == "medusa": + return trtllm.DecodingMode.MEDUSA + raise pb_utils.TritonModelException( + f"decoding_mode value of '{decoding_mode}' is not supported.") + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def get_scheduler_config(self, model_config): + batch_scheduler_policy = get_parameter(model_config, + "batch_scheduler_policy") + if batch_scheduler_policy is None: + return trtllm.SchedulerConfig() + return trtllm.SchedulerConfig( + convert_scheduler_policy(batch_scheduler_policy)) + + def get_kv_cache_config(self, model_config): + kwargs = { + "enable_block_reuse": + get_parameter(model_config, "enable_kv_cache_reuse", bool), + "max_tokens": + get_parameter(model_config, "max_tokens_in_paged_kv_cache", int), + "sink_token_length": + get_parameter(model_config, "sink_token_length", int), + "max_attention_window": + get_parameter(model_config, "max_attention_window_size", int), + "free_gpu_memory_fraction": + get_parameter(model_config, "kv_cache_free_gpu_mem_fraction", + float), + "host_cache_size": + get_parameter(model_config, "kv_cache_host_memory_bytes", int), + "onboard_blocks": + get_parameter(model_config, "kv_cache_onboard_blocks", bool), + } + kwargs = {k: v for k, v in kwargs.items() if v is not None} + return trtllm.KvCacheConfig(**kwargs) + + def get_parallel_config(self, model_config): + kwargs = {} + gpu_device_ids = get_parameter(model_config, "gpu_device_ids") + if gpu_device_ids: + kwargs["device_ids"] = [int(x) for x in gpu_device_ids.split(",")] + self.use_orchestrator_mode = os.environ.get("TRTLLM_ORCHESTRATOR", + "0") == "1" + if self.use_orchestrator_mode: + kwargs[ + "communication_mode"] = trtllm.CommunicationMode.ORCHESTRATOR + worker_path = get_parameter(model_config, "worker_path") + if worker_path is not None: + raise pb_utils.TritonModelException( + "worker_path parameter is specified, but this is no longer supported. Please specify executor_worker_path instead to specify the location of the trtllmExecutorWorker executable." + ) + executor_worker_path = get_parameter(model_config, + "executor_worker_path") + kwargs["orchestrator_config"] = trtllm.OrchestratorConfig( + True, executor_worker_path) + if len(kwargs) > 0: + return trtllm.ParallelConfig(**kwargs) + return None + + def get_peft_cache_config(self, model_config): + kwargs = { + "optimal_adapter_size": + get_parameter(model_config, "lora_cache_optimal_adapter_size", + int), + "max_adapter_size": + get_parameter(model_config, "lora_cache_max_adapter_size", int), + "device_cache_percent": + get_parameter(model_config, "lora_cache_gpu_memory_fraction", + float), + "host_cache_size": + get_parameter(model_config, "lora_cache_host_memory_bytes", int), + } + kwargs = {k: v for k, v in kwargs.items() if v is not None} + return trtllm.PeftCacheConfig(**kwargs) + + def get_executor_config(self, model_config): + kwargs = { + "max_beam_width": + get_parameter(model_config, "max_beam_width", int), + "scheduler_config": + self.get_scheduler_config(model_config), + "kv_cache_config": + self.get_kv_cache_config(model_config), + "enable_chunked_context": + get_parameter(model_config, "enable_chunked_context", bool), + "normalize_log_probs": + get_parameter(model_config, "normalize_log_probs", bool), + "batching_type": + convert_batching_type(get_parameter(model_config, + "gpt_model_type")), + "parallel_config": + self.get_parallel_config(model_config), + "peft_cache_config": + self.get_peft_cache_config(model_config), + "medusa_choices": + parse_medusa_choices(get_parameter(model_config, + "medusa_choices")), + "decoding_mode": + convert_decoding_mode(get_parameter(model_config, + "decoding_mode")), + } + kwargs = {k: v for k, v in kwargs.items() if v is not None} + return trtllm.ExecutorConfig(**kwargs) + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to initialize any state associated with this model. + + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + model_config = json.loads(args['model_config']) + gpt_model_path = get_parameter(model_config, "gpt_model_path") + if get_parameter(model_config, "enable_trt_overlap", bool): + raise pb_utils.TritonModelException( + f"enable_trt_overlap=true is not supported.") + self.exclude_input_from_output = get_parameter( + model_config, "exclude_input_in_output", bool) + executor_config = self.get_executor_config(model_config) + self.executor = trtllm.Executor(gpt_model_path, + trtllm.ModelType.DECODER_ONLY, + executor_config) + self.decoupled = pb_utils.using_decoupled_model_transaction_policy( + model_config) + self.cancellation_check_period_ms = get_parameter( + model_config, "cancellation_check_period_ms", int) or 100 + + if not self.decoupled: + raise pb_utils.TritonModelException( + "Please enable decoupled transaction policy in the model configuration to serve this model" + ) + + self.triton_id_to_req_id = {} + self.req_id_to_response_sender = {} + self.lock = Lock() + self.running = False + self.awaiter_thread = Thread(target=self.awaiter_loop) + self.cancellation_thread = Thread(target=self.cancellation_loop) + if self.executor.can_enqueue_requests(): + self.running = True + self.awaiter_thread.start() + self.cancellation_thread.start() + else: + # In leader mode, worker ranks will wait here until leader is done. + self.executor.shutdown() + + def handle_stop_request(self, triton_id, response_sender): + if triton_id is None or triton_id == "": + response_sender.send( + pb_utils.InferenceResponse(error=pb_utils.TritonError( + "A request id must be provided for request cancellation")), + flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) + return + + if triton_id in self.triton_id_to_req_id: + req_id = self.triton_id_to_req_id[triton_id] + self.executor.cancel_request(req_id) + + response_sender.send( + pb_utils.InferenceResponse(), + flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. + + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + if not self.executor.can_enqueue_requests(): + return + + # Convert to executor requests. + triton_requests = [] + executor_requests = [] + for request in requests: + response_sender = request.get_response_sender() + if get_input_scalar_by_name(request, 'stop'): + self.handle_stop_request(request.request_id(), response_sender) + else: + try: + converted = convert_request(request, + self.exclude_input_from_output, + self.decoupled) + except Exception as e: + response_sender.send( + pb_utils.InferenceResponse(error=pb_utils.TritonError( + f"An error occurred when processing the input values for request id {request.request_id()}, the error was '{e}'" + )), + flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) + else: + triton_requests.append(request) + executor_requests.append(converted) + + with self.lock: + request_ids = self.executor.enqueue_requests(executor_requests) + for req_id, request in zip(request_ids, triton_requests): + triton_id = request.request_id() + self.req_id_to_response_sender[ + req_id] = triton_id, request.get_response_sender() + self.triton_id_to_req_id[triton_id] = req_id + return None + + def awaiter_loop(self): + """Gets responses from executor and returns the results.""" + while self.running: + for response in self.executor.await_responses( + timeout=datetime.timedelta(milliseconds=1)): + req_id = response.request_id + with self.lock: + if req_id not in self.req_id_to_response_sender: + continue + triton_id, response_sender = self.req_id_to_response_sender[ + req_id] + + triton_response, is_final = convert_response(response) + response_sender.send( + triton_response, + flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL + if is_final else 0) + + if is_final: + with self.lock: + del self.triton_id_to_req_id[triton_id] + del self.req_id_to_response_sender[req_id] + # TODO: Read stats: https://jirasw.nvidia.com/browse/TRTLLM-563 + + def cancellation_loop(self): + """Checks if any pending requests have been cancelled.""" + while self.running: + time.sleep(self.cancellation_check_period_ms / 1000.0) + with self.lock: + cancelled_ids = [] + for req_id, (triton_id, response_sender + ) in self.req_id_to_response_sender.items(): + if response_sender.is_cancelled(): + self.executor.cancel_request(req_id) + cancelled_ids.append((req_id, triton_id)) + for req_id, triton_id in cancelled_ids: + del self.triton_id_to_req_id[triton_id] + del self.req_id_to_response_sender[req_id] + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + if self.executor.can_enqueue_requests(): + self.running = False + self.awaiter_thread.join() + self.cancellation_thread.join() + self.executor.shutdown() diff --git a/all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt b/all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt index de865f6d..6e5ebf74 100644 --- a/all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt +++ b/all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt @@ -25,7 +25,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. name: "tensorrt_llm" -backend: "tensorrtllm" +backend: "${triton_backend}" max_batch_size: ${triton_max_batch_size} model_transaction_policy { @@ -410,12 +410,13 @@ parameters: { string_value: "${kv_cache_onboard_blocks}" } } -parameters: { - key: "enable_trt_overlap" - value: { - string_value: "${enable_trt_overlap}" - } -} +# enable_trt_overlap is deprecated and doesn't have any effect on the runtime +# parameters: { +# key: "enable_trt_overlap" +# value: { +# string_value: "${enable_trt_overlap}" +# } +# } parameters: { key: "exclude_input_in_output" value: { @@ -428,6 +429,12 @@ parameters: { string_value: "${cancellation_check_period_ms}" } } +parameters: { + key: "stats_check_period_ms" + value: { + string_value: "${stats_check_period_ms}" + } +} parameters: { key: "iter_stats_max_iterations" value: { diff --git a/all_models/tests/test_python_backend.py b/all_models/tests/test_python_backend.py new file mode 100644 index 00000000..0e98a913 --- /dev/null +++ b/all_models/tests/test_python_backend.py @@ -0,0 +1,576 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import sys +from dataclasses import dataclass +from typing import Dict, List +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import torch + +# Mock pb_utils +sys.modules["triton_python_backend_utils"] = MagicMock() + +# Use PYTHONPATH=../inflight_batcher_llm/tensorrt_llm/1/ +from model import * + +import tensorrt_llm.bindings.executor as trtllm + + +@dataclass +class MockTritonTensor: + _name: str + _tensor: np.ndarray + + def name(self) -> str: + return self._name + + def as_numpy(self) -> np.ndarray: + return self._tensor + + +@dataclass +class MockTritonError: + message: str + + +@dataclass +class MockTritonResponse: + tensors: Dict[str, MockTritonTensor] + error: MockTritonError + + def __init__(self, + output_tensors: List[MockTritonTensor], + error: MockTritonError = None): + self.tensors = {} + for tensor in output_tensors: + self.tensors[tensor.name()] = tensor + self.error = error + + def output_tensors(self): + return self.tensors.values() + + def has_error(self): + return self.error is not None + + +@dataclass +class MockTritonRequest: + tensors: Dict[str, MockTritonTensor] + + def get_input_tensor_by_name(self, name: str) -> MockTritonTensor: + return self.tensors[name] if name in self.tensors else None + + def get_response_sender(self): + return None + + +def mock_pb_utils_get_input_tensor_by_name_side_effect( + request: MockTritonRequest, name: str) -> MockTritonTensor: + return request.get_input_tensor_by_name(name) + + +@pytest.fixture(autouse=True) +def apply_patches(): + patch("model.pb_utils.Tensor", new=MockTritonTensor).start() + patch("model.pb_utils.InferenceResponse", new=MockTritonResponse).start() + patch("model.pb_utils.TritonError", new=MockTritonError).start() + patch("model.pb_utils.InferenceRequest", new=MockTritonRequest).start() + patch("model.pb_utils.get_input_tensor_by_name", + new=mock_pb_utils_get_input_tensor_by_name_side_effect).start() + patch("model.pb_utils.TritonModelException", new=Exception).start() + + +@pytest.fixture +def triton_request() -> MockTritonRequest: + inputs = { + "input_ids": [[28524, 287, 5093, 12]], + "request_output_len": [[16]], + "streaming": [[True]], + "end_id": [50256], + "pad_id": [50256], + "stop_words_list": [[[14480, 326, 262, 1171], [1, 4, -1, -1]]], + "bad_words_list": [[[24044, 76, 1230], [2, 3, -1]]], + "embedding_bias": + np.array([[0., 0., 0.]], dtype=np.float32), + "beam_width": [2], + "runtime_top_k": [1], + "runtime_top_p": [0.], + "random_seed": [4], + "temperature": [1.], + "min_length": [3], + "repetition_penalty": [1.0], + "presence_penalty": [2.0], + "frequency_penalty": [4.0], + "len_penalty": [8.0], + "runtime_top_p_min": [1.0], + "runtime_top_p_reset_ids": [1], + "runtime_top_p_decay": [1.0], + "beam_search_diversity_rate": [1.0], + "early_stopping": [True], + "return_log_probs": + True, + "return_context_logits": + True, + "return_generation_logits": + True, + "draft_input_ids": [0, 1], + "draft_logits": + np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), + "draft_acceptance_threshold": + 1.0, + "prompt_embedding_table": + np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float16), + "lora_task_id": [1], + "lora_weights": + np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float16), + "lora_config": + np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32), + # Unused by executor backend but may still be in the request. + "input_lengths": [4], + "prompt_vocab_size": [2], + } + return MockTritonRequest( + {k: MockTritonTensor(k, np.array(v)) + for k, v in inputs.items()}) + + +@pytest.fixture +def triton_request_minimal() -> MockTritonRequest: + inputs = { + "input_ids": [[28524, 287, 5093, 12]], + "request_output_len": [[16]], + } + return MockTritonRequest( + {k: MockTritonTensor(k, np.array(v)) + for k, v in inputs.items()}) + + +@pytest.fixture +def trtllm_response() -> trtllm.Response: + result = trtllm.Result() + result.is_final = True + result.output_token_ids = [[1, 2, 3]] + result.cum_log_probs = [1] + result.log_probs = [[1, 3]] + result.context_logits = torch.ones(3, 10) + result.generation_logits = torch.ones(1, 5, 10) + return trtllm.Response(0, result) + + +@pytest.fixture +def trtllm_response_minimal() -> trtllm.Response: + result = trtllm.Result() + result.is_final = False + result.output_token_ids = [[1, 2, 3]] + return trtllm.Response(0, result) + + +@pytest.fixture +def trtllm_response_error() -> trtllm.Response: + return trtllm.Response(0, "internal error") + + +def test_get_input_tensor_by_name(triton_request: MockTritonRequest): + assert (get_input_tensor_by_name(triton_request, "input_ids") == np.array( + [[28524, 287, 5093, 12]])).all() + assert get_input_tensor_by_name(triton_request, "no_value") is None + + +def test_get_input_scalar_by_name(triton_request: MockTritonRequest): + assert get_input_scalar_by_name(triton_request, "request_output_len") == 16 + assert get_input_scalar_by_name(triton_request, "streaming") == True + assert get_input_scalar_by_name(triton_request, "end_id") == 50256 + assert get_input_scalar_by_name(triton_request, "pad_id") == 50256 + assert get_input_scalar_by_name(triton_request, "beam_width") == 2 + assert get_input_scalar_by_name(triton_request, "runtime_top_k") == 1 + assert get_input_scalar_by_name(triton_request, "runtime_top_p") == 0. + assert get_input_scalar_by_name(triton_request, "temperature") == 1. + + +def test_read_parameter_as_type(): + assert read_parameter_as_type("", "name") is None + assert read_parameter_as_type("", "name", int) is None + assert read_parameter_as_type("", "name", float) is None + assert read_parameter_as_type("", "name", bool) is None + assert read_parameter_as_type("${unfilled_parameter}", "name") is None + assert read_parameter_as_type("foo", "name", int) is None + assert read_parameter_as_type("string_value", "name") == "string_value" + assert read_parameter_as_type("4", "name", int) == 4 + assert read_parameter_as_type("0.5", "name", float) == 0.5 + assert read_parameter_as_type("1", "name", bool) == True + assert read_parameter_as_type("true", "name", bool) == True + assert read_parameter_as_type("True", "name", bool) == True + assert read_parameter_as_type("0", "name", bool) == False + assert read_parameter_as_type("false", "name", bool) == False + assert read_parameter_as_type("False", "name", bool) == False + + +def test_get_parameter(): + model_config = {"parameters": {"max_beam_width": {"string_value": "1"}}} + assert get_parameter(model_config, "max_beam_width", int) == 1 + assert get_parameter(model_config, "gpt_model_type", str) is None + + +def test_convert_word_list(): + assert convert_word_list(None) is None + assert convert_word_list(np.array([[[], []]])) == [] + assert convert_word_list( + np.array([[[14480, 326, 262, 1171], [1, 4, -1, + -1]]])) == [[14480], + [326, 262, 1171]] + assert convert_word_list(np.array([[[24044, 76, 1230], + [2, 3, -1]]])) == [[24044, 76], [1230]] + assert convert_word_list(np.array([[[326, 262, 1230], + [3, -1, -1]]])) == [[326, 262, 1230]] + for bad_format in [ + np.array([]), + np.array([[]]), + np.array([[[]]]), + np.array([[[1], [2], [3]]]), + np.array([[[262], [5]]]), + ]: + with pytest.raises(Exception, match="Invalid format for word list"): + convert_word_list(bad_format) + + +def test_parse_medusa_choices(): + assert parse_medusa_choices("{0, 0, 0}, {0, 1}") == [[0, 0, 0], [0, 1]] + for bad_format in [ + "{{}", + "{", + "{{}", + "}", + "{0, 1, 2", + "0, 1, 2", + "{0, 1, 2}, {\"foo\"}", + ]: + with pytest.raises(Exception, + match="Invalid format for medusa_choices"): + parse_medusa_choices(bad_format) + + +def test_convert_request(triton_request: MockTritonRequest): + converted = convert_request(triton_request, + exclude_input_from_output=True, + decoupled=True) + assert isinstance(converted, trtllm.Request) + assert converted.input_token_ids == [28524, 287, 5093, 12] + assert converted.max_new_tokens == 16 + assert converted.streaming == True + assert converted.end_id == 50256 + assert converted.pad_id == 50256 + assert converted.stop_words == [[14480], [326, 262, 1171]] + assert converted.bad_words == [[24044, 76], [1230]] + assert (converted.embedding_bias == torch.tensor([0., 0., 0.])).all() + assert converted.logits_post_processor_name is None + + assert isinstance(converted.speculative_decoding_config, + trtllm.SpeculativeDecodingConfig) + assert converted.speculative_decoding_config.tokens == [0, 1] + assert (converted.speculative_decoding_config.logits == torch.tensor( + [[1.0, 2.0], [3.0, 4.0]])).all() + assert converted.speculative_decoding_config.acceptance_threshold == 1.0 + + assert isinstance(converted.prompt_tuning_config, + trtllm.PromptTuningConfig) + assert (converted.prompt_tuning_config.embedding_table == torch.tensor( + [[1.0, 2.0], [3.0, 4.0]])).all() + + assert isinstance(converted.lora_config, trtllm.LoraConfig) + assert converted.lora_config.task_id == 1 + assert (converted.lora_config.weights == torch.tensor([[1.0, 2.0], + [3.0, 4.0]])).all() + assert (converted.lora_config.config == torch.tensor([[1, 2, 3], + [4, 5, 6]])).all() + + assert converted.sampling_config.beam_width == 2 + assert converted.sampling_config.top_k == 1 + assert converted.sampling_config.top_p is None + assert converted.sampling_config.top_p_min == 1.0 + assert converted.sampling_config.top_p_reset_ids == 1 + assert converted.sampling_config.top_p_decay == 1.0 + assert converted.sampling_config.random_seed == 4 + assert converted.sampling_config.temperature == 1.0 + assert converted.sampling_config.min_length == 3 + assert converted.sampling_config.beam_search_diversity_rate == 1.0 + assert converted.sampling_config.repetition_penalty == 1.0 + assert converted.sampling_config.presence_penalty == 2.0 + assert converted.sampling_config.frequency_penalty == 4.0 + assert converted.sampling_config.length_penalty == 8.0 + assert converted.sampling_config.early_stopping == True + + assert converted.output_config.return_log_probs == True + assert converted.output_config.return_context_logits == True + assert converted.output_config.return_generation_logits == True + assert converted.output_config.exclude_input_from_output == True + + +def test_convert_request_minimal(triton_request_minimal: MockTritonRequest): + converted = convert_request(triton_request_minimal, + exclude_input_from_output=False, + decoupled=False) + assert converted.input_token_ids == [28524, 287, 5093, 12] + assert converted.max_new_tokens == 16 + assert converted.streaming == False + assert converted.end_id is None + assert converted.pad_id is None + assert converted.stop_words is None + assert converted.bad_words is None + assert converted.embedding_bias is None + assert converted.logits_post_processor_name is None + assert converted.speculative_decoding_config is None + assert converted.prompt_tuning_config is None + assert converted.lora_config is None + + assert converted.sampling_config.beam_width == 1 + assert converted.sampling_config.top_k is None + assert converted.sampling_config.top_p is None + assert converted.sampling_config.top_p_min is None + assert converted.sampling_config.top_p_reset_ids is None + assert converted.sampling_config.top_p_decay is None + assert converted.sampling_config.random_seed is None + assert converted.sampling_config.temperature is None + assert converted.sampling_config.min_length is None + assert converted.sampling_config.beam_search_diversity_rate is None + assert converted.sampling_config.repetition_penalty is None + assert converted.sampling_config.presence_penalty is None + assert converted.sampling_config.frequency_penalty is None + assert converted.sampling_config.length_penalty is None + assert converted.sampling_config.early_stopping is None + + assert converted.output_config.return_log_probs == False + assert converted.output_config.return_context_logits == False + assert converted.output_config.return_generation_logits == False + assert converted.output_config.exclude_input_from_output == False + + +def test_convert_request_invalid(): + with pytest.raises(Exception, match="A value is required for input_ids"): + no_input_ids = MockTritonRequest({ + "request_output_len": + MockTritonTensor("request_output_len", np.array([[128]])) + }) + convert_request(no_input_ids, False, False) + with pytest.raises(Exception, match="Invalid format for input_ids"): + bad_input_ids = MockTritonRequest( + {"input_ids": MockTritonTensor("input_ids", np.array([]))}) + convert_request(bad_input_ids, False, False) + with pytest.raises(Exception, + match="A value is required for request_output_len"): + no_output_len = MockTritonRequest({ + "input_ids": + MockTritonTensor("input_ids", np.array([[1, 2, 3]])) + }) + convert_request(no_output_len, False, False) + with pytest.raises(Exception, + match="Streaming is only supported in decoupled mode."): + streaming_non_decoupled = MockTritonRequest({ + "input_ids": + MockTritonTensor("input_ids", np.array([[1, 2, 3]])), + "request_output_len": + MockTritonTensor("request_output_len", np.array([[128]])), + "streaming": + MockTritonTensor("streaming", np.array([[True]])), + }) + convert_request(streaming_non_decoupled, False, False) + + +def test_convert_response(trtllm_response: trtllm.Response): + response, is_final = convert_response(trtllm_response) + assert is_final == True + assert (response.tensors["output_ids"].as_numpy() == np.array([[1, 2, 3] + ])).all() + assert (response.tensors["sequence_length"].as_numpy() == np.array( + [[3]])).all() + assert (response.tensors["cum_log_probs"].as_numpy() == np.array( + [1])).all() + assert (response.tensors["output_log_probs"].as_numpy() == np.array( + [[1, 3]])).all() + assert (response.tensors["context_logits"].as_numpy() == np.ones( + (3, 10), dtype=np.float32)).all() + assert (response.tensors["generation_logits"].as_numpy() == np.ones( + (1, 5, 10), dtype=np.float32)).all() + + +def test_convert_response_minimal(trtllm_response_minimal: trtllm.Response): + response, is_final = convert_response(trtllm_response_minimal) + assert is_final == False + assert (response.tensors["output_ids"].as_numpy() == np.array([[1, 2, 3] + ])).all() + assert (response.tensors["sequence_length"].as_numpy() == np.array( + [[3]])).all() + assert (response.tensors["cum_log_probs"].as_numpy() == np.zeros( + (1, 1), np.float32)).all() + assert (response.tensors["output_log_probs"].as_numpy() == np.zeros( + (1, 1, 1), np.float32)).all() + assert (response.tensors["context_logits"].as_numpy() == np.zeros( + (1, 1, 1), np.float32)).all() + assert (response.tensors["generation_logits"].as_numpy() == np.zeros( + (1, 1, 1, 1), np.float32)).all() + + +def test_convert_response_error(trtllm_response_error: trtllm.Response): + response, is_final = convert_response(trtllm_response_error) + assert is_final == True + assert response.has_error() and response.error.message == "internal error" + + +def test_convert_scheduler_policy(): + assert convert_scheduler_policy( + "max_utilization") == trtllm.CapacitySchedulerPolicy.MAX_UTILIZATION + assert convert_scheduler_policy( + "guaranteed_no_evict" + ) == trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT + with pytest.raises( + Exception, + match="batch_scheduler_policy value of 'other' is not supported"): + convert_scheduler_policy("other") + + +def test_convert_batching_type(): + assert convert_batching_type( + "inflight_fused_batching") == trtllm.BatchingType.INFLIGHT + assert convert_batching_type( + "inflight_batching") == trtllm.BatchingType.INFLIGHT + assert convert_batching_type("v1") == trtllm.BatchingType.STATIC + with pytest.raises( + Exception, + match="gpt_model_type value of 'other' is not supported"): + convert_batching_type("other") + + +def test_convert_decoding_mode(): + assert convert_decoding_mode(None) is None + assert convert_decoding_mode("none") == trtllm.DecodingMode.NONE + assert convert_decoding_mode("top_k") == trtllm.DecodingMode.TOP_K + assert convert_decoding_mode("top_p") == trtllm.DecodingMode.TOP_P + assert convert_decoding_mode( + "top_k_top_p") == trtllm.DecodingMode.TOP_K_TOP_P + assert convert_decoding_mode( + "beam_search") == trtllm.DecodingMode.BEAM_SEARCH + assert convert_decoding_mode("medusa") == trtllm.DecodingMode.MEDUSA + with pytest.raises( + Exception, + match="decoding_mode value of 'other' is not supported"): + convert_decoding_mode("other") + + +@pytest.fixture +def model_config() -> Dict: + config = { + "max_beam_width": "2", + "enable_chunked_context": "true", + "normalize_log_probs": "false", + "gpt_model_type": "inflight_batching", + "medusa_choices": "{1, 2, 3, 4}, {5, 6, 7}", + "decoding_mode": "top_k_top_p", + "batch_scheduler_policy": "max_utilization", + "enable_kv_cache_reuse": "true", + "max_tokens_in_paged_kv_cache": "1", + "max_attention_window_size": "2", + "sink_token_length": "3", + "kv_cache_free_gpu_mem_fraction": "0.5", + "kv_cache_host_memory_bytes": "4", + "kv_cache_onboard_blocks": "false", + "gpu_device_ids": "0,1,2,3", + "executor_worker_path": str(os.path.abspath(__file__)), + "lora_cache_optimal_adapter_size": "1", + "lora_cache_max_adapter_size": "2", + "lora_cache_gpu_memory_fraction": "0.5", + "lora_cache_host_memory_bytes": "4", + } + return {"parameters": {k: {"string_value": v} for k, v in config.items()}} + + +def test_get_executor_config(model_config: Dict): + os.environ["TRTLLM_ORCHESTRATOR"] = "0" + config = TritonPythonModel().get_executor_config(model_config) + assert config.max_beam_width == 2 + assert config.enable_chunked_context == True + assert config.normalize_log_probs == False + assert config.batching_type == trtllm.BatchingType.INFLIGHT + assert config.medusa_choices == [[1, 2, 3, 4], [5, 6, 7]] + assert config.decoding_mode == trtllm.DecodingMode.TOP_K_TOP_P + assert config.scheduler_config.capacity_scheduler_policy == trtllm.CapacitySchedulerPolicy.MAX_UTILIZATION + assert config.kv_cache_config.enable_block_reuse == True + assert config.kv_cache_config.max_tokens == 1 + assert config.kv_cache_config.max_attention_window == 2 + assert config.kv_cache_config.sink_token_length == 3 + assert config.kv_cache_config.free_gpu_memory_fraction == 0.5 + assert config.kv_cache_config.host_cache_size == 4 + assert config.kv_cache_config.onboard_blocks == False + assert config.parallel_config.device_ids == [0, 1, 2, 3] + assert config.parallel_config.orchestrator_config is None + assert config.peft_cache_config.optimal_adapter_size == 1 + assert config.peft_cache_config.max_adapter_size == 2 + assert config.peft_cache_config.device_cache_percent == 0.5 + assert config.peft_cache_config.host_cache_size == 4 + assert config.iter_stats_max_iterations == 1000 + assert config.request_stats_max_iterations == 0 + assert config.logits_post_processor_map is None + del os.environ["TRTLLM_ORCHESTRATOR"] + + +def test_get_executor_config_orchestrator_mode(model_config: Dict): + os.environ["TRTLLM_ORCHESTRATOR"] = "1" + config = TritonPythonModel().get_executor_config(model_config) + assert config.parallel_config.device_ids == [0, 1, 2, 3] + assert config.parallel_config.orchestrator_config.is_orchestrator == True + assert config.parallel_config.orchestrator_config.worker_executable_path == str( + os.path.abspath(__file__)) + del os.environ["TRTLLM_ORCHESTRATOR"] + + +def test_get_executor_config_minimal(): + if "TRTLLM_ORCHESTRATOR" in os.environ: + del os.environ["TRTLLM_ORCHESTRATOR"] + config = TritonPythonModel().get_executor_config({"parameters": {}}) + assert config.max_beam_width == 1 + assert config.enable_chunked_context == False + assert config.normalize_log_probs == True + assert config.batching_type == trtllm.BatchingType.INFLIGHT + assert config.medusa_choices is None + assert config.decoding_mode is None + assert config.scheduler_config.capacity_scheduler_policy == trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT + assert config.kv_cache_config.enable_block_reuse == False + assert config.kv_cache_config.max_tokens is None + assert config.kv_cache_config.max_attention_window is None + assert config.kv_cache_config.sink_token_length is None + assert config.kv_cache_config.free_gpu_memory_fraction is None + assert config.kv_cache_config.host_cache_size is None + assert config.kv_cache_config.onboard_blocks == True + assert config.parallel_config is None + assert config.peft_cache_config.optimal_adapter_size == 8 + assert config.peft_cache_config.max_adapter_size == 64 + assert config.peft_cache_config.device_cache_percent is None + assert config.peft_cache_config.host_cache_size is None + assert config.iter_stats_max_iterations == 1000 + assert config.request_stats_max_iterations == 0 + assert config.logits_post_processor_map is None diff --git a/ci/L0_backend_trtllm/custom_metrics_verification_tests.py b/ci/L0_backend_trtllm/custom_metrics_verification_tests.py index c593d3d2..b163a030 100644 --- a/ci/L0_backend_trtllm/custom_metrics_verification_tests.py +++ b/ci/L0_backend_trtllm/custom_metrics_verification_tests.py @@ -25,10 +25,13 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import json +import os import re import unittest from datetime import datetime, timedelta +AVAILABLE_GPUS = int(os.environ.get("AVAILABLE_GPUS", "1")) + metric_to_stat_dict = { "request_type=context": "Context Requests", "request_type=scheduled": "Scheduled Requests", @@ -108,7 +111,7 @@ def _base_test(self, stats_file, metrics_file, is_v1): int(metrics[metric_key])) else: dt_log = datetime.strptime(stats[metric_key], - '%m-%d-%Y %H:%M:%S.%f') + '%m-%d-%Y %H:%M:%S') dt_curl = datetime.utcfromtimestamp( int(metrics[metric_key]) // 1000000) difference = dt_log - dt_curl @@ -128,29 +131,33 @@ def test_1_gpu_IFB_stream(self): self._base_test("1gpu_IFB_streaming_server.log", "1gpu_IFB_stream_metrics.out", False) - def test_2_gpu_v1(self): - self._base_test("2gpu_v1_no_streaming_server.log", - "2gpu_v1_no_stream_metrics.out", True) + if AVAILABLE_GPUS >= 2: + + def test_2_gpu_v1(self): + self._base_test("2gpu_v1_no_streaming_server.log", + "2gpu_v1_no_stream_metrics.out", True) + + def test_2_gpu_IFB_no_stream(self): + self._base_test("2gpu_IFB_no_streaming_server.log", + "2gpu_IFB_no_stream_metrics.out", False) - def test_2_gpu_IFB_no_stream(self): - self._base_test("2gpu_IFB_no_streaming_server.log", - "2gpu_IFB_no_stream_metrics.out", False) + def test_2_gpu_IFB_stream(self): + self._base_test("2gpu_IFB_streaming_server.log", + "2gpu_IFB_stream_metrics.out", False) - def test_2_gpu_IFB_stream(self): - self._base_test("2gpu_IFB_streaming_server.log", - "2gpu_IFB_stream_metrics.out", False) + if AVAILABLE_GPUS >= 4: - def test_4_gpu_v1(self): - self._base_test("4gpu_v1_no_streaming_server.log", - "4gpu_v1_no_stream_metrics.out", True) + def test_4_gpu_v1(self): + self._base_test("4gpu_v1_no_streaming_server.log", + "4gpu_v1_no_stream_metrics.out", True) - def test_4_gpu_IFB_no_stream(self): - self._base_test("4gpu_IFB_no_streaming_server.log", - "4gpu_IFB_no_stream_metrics.out", False) + def test_4_gpu_IFB_no_stream(self): + self._base_test("4gpu_IFB_no_streaming_server.log", + "4gpu_IFB_no_stream_metrics.out", False) - def test_4_gpu_IFB_stream(self): - self._base_test("4gpu_IFB_streaming_server.log", - "4gpu_IFB_stream_metrics.out", False) + def test_4_gpu_IFB_stream(self): + self._base_test("4gpu_IFB_streaming_server.log", + "4gpu_IFB_stream_metrics.out", False) if __name__ == "__main__": diff --git a/ci/L0_backend_trtllm/test.sh b/ci/L0_backend_trtllm/test.sh index 6570b629..47b44b92 100644 --- a/ci/L0_backend_trtllm/test.sh +++ b/ci/L0_backend_trtllm/test.sh @@ -104,6 +104,9 @@ function reset_model_repo { function kill_server { pgrep tritonserver | xargs kill -SIGINT + if pgrep -x "trtllmExecutorWorker" > /dev/null; then + pkill -SIGINT -f "trtllmExecutorWorker" + fi } function wait_for_server_terminated { @@ -142,13 +145,14 @@ python3 -m pip install --upgrade pip && \ pip3 install pandas && \ pip3 install tabulate +export AVAILABLE_GPUS=$(nvidia-smi -L | wc -l) + RET=0 NUM_GPUS_TO_TEST=("1" "2" "4") for NUM_GPU in "${NUM_GPUS_TO_TEST[@]}"; do - AVAILABLE_GPUS=$(nvidia-smi -L | wc -l) if [ "$AVAILABLE_GPUS" -lt "$NUM_GPU" ]; then - exit $RET + break fi SERVER_ARGS="--world_size=${NUM_GPU} --model_repo=${MODEL_DIR}" @@ -166,6 +170,7 @@ for NUM_GPU in "${NUM_GPUS_TO_TEST[@]}"; do replace_config_tags '${batching_strategy}' 'INVALID' "${MODEL_DIR}/tensorrt_llm/config.pbtxt" replace_config_tags '${engine_dir}' "${MODEL_DIR}/tensorrt_llm/1/inflight_${NUM_GPU}_gpu/" "${MODEL_DIR}/tensorrt_llm/config.pbtxt" replace_config_tags '${max_queue_delay_microseconds}' "50000" "${MODEL_DIR}/tensorrt_llm/config.pbtxt" + replace_config_tags '${triton_backend}' "tensorrtllm" "${MODEL_DIR}/tensorrt_llm/config.pbtxt" replace_config_tags '${triton_max_batch_size}' "128" "${MODEL_DIR}/postprocessing/config.pbtxt" replace_config_tags '${tokenizer_dir}' "${TOKENIZER_DIR}/" "${MODEL_DIR}/postprocessing/config.pbtxt" replace_config_tags '${postprocessing_instance_count}' '1' "${MODEL_DIR}/postprocessing/config.pbtxt" @@ -301,6 +306,7 @@ for NUM_GPU in "${NUM_GPUS_TO_TEST[@]}"; do python3 ${BASE_METRICS_VERIFICATION_TEST} >> ${BASE_METRICS_VERIFICATION_LOG} 2>&1 if [ $? -ne 0 ]; then cat ${BASE_METRICS_VERIFICATION_LOG} + echo -e "\n***\n*** Error executing base metrics verification test with ${NUM_GPU}GPU(s): line ${LINENO}\n***" RET=1 fi set +e diff --git a/dockerfile/Dockerfile.triton.trt_llm_backend b/dockerfile/Dockerfile.triton.trt_llm_backend index 455b7351..07b2d76d 100644 --- a/dockerfile/Dockerfile.triton.trt_llm_backend +++ b/dockerfile/Dockerfile.triton.trt_llm_backend @@ -3,9 +3,9 @@ ARG BASE_IMAGE FROM ${BASE_IMAGE} as base RUN apt-get update -q=2 && apt-get install -y --no-install-recommends python3-pip ccache git-lfs + # Remove previous TRT installation -# We didn't remove libnvinfer* here because tritonserver depends on the pre-installed libraries. -RUN apt-get remove -y tensorrt* +RUN apt-get remove --purge -y tensorrt* libnvinfer* RUN pip3 uninstall -y tensorrt ARG TRT_VER @@ -39,3 +39,12 @@ RUN pip3 install /usr/local/tensorrt/python/tensorrt-*-cp$( python3 -c "import s ENV LD_LIBRARY_PATH=/usr/local/tensorrt/lib:${LD_LIBRARY_PATH} ENV TRT_ROOT=/usr/local/tensorrt + +# Align with the pre-installed CUDA / NVCC / NVRTC versions from +# https://docs.nvidia.com/cuda/archive/12.4.0/cuda-toolkit-release-notes/index.html +# NVRTC static library doesn't exist in NGC PyTorch container. +ENV NVRTC_VER="12.4.99-1" +RUN apt-get remove --purge -y --allow-change-held-packages cuda-nvrtc-dev* +RUN CUDA_VER_SHORT=$(echo $CUDA_VER | awk -F. '{print $1"."$2}') \ + && NVRTC_CUDA_VERSION=$(echo $CUDA_VER_SHORT | sed 's/\./-/g') \ + && apt-get install -y --no-install-recommends cuda-nvrtc-dev-${NVRTC_CUDA_VERSION}=${NVRTC_VER} diff --git a/docs/baichuan.md b/docs/baichuan.md index 9d238809..fef7f5eb 100644 --- a/docs/baichuan.md +++ b/docs/baichuan.md @@ -44,7 +44,7 @@ python3 tools/fill_template.py -i baichuan_ifb/preprocessing/config.pbtxt tokeni python3 tools/fill_template.py -i baichuan_ifb/postprocessing/config.pbtxt tokenizer_dir:${HF_BAICHUAN_MODEL},triton_max_batch_size:64,postprocessing_instance_count:1 python3 tools/fill_template.py -i baichuan_ifb/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:64,decoupled_mode:False,bls_instance_count:1,accumulate_tokens:False python3 tools/fill_template.py -i baichuan_ifb/ensemble/config.pbtxt triton_max_batch_size:64 -python3 tools/fill_template.py -i baichuan_ifb/tensorrt_llm/config.pbtxt triton_max_batch_size:64,decoupled_mode:False,max_beam_width:1,engine_dir:/tmp/baichuan/13B/trt_engines/fp16/1-gpu/,max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:0 +python3 tools/fill_template.py -i baichuan_ifb/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:64,decoupled_mode:False,max_beam_width:1,engine_dir:/tmp/baichuan/13B/trt_engines/fp16/1-gpu/,max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:0 ``` * Launch server @@ -178,7 +178,7 @@ python3 tools/fill_template.py -i baichuan_ifb/preprocessing/config.pbtxt tokeni python3 tools/fill_template.py -i baichuan_ifb/postprocessing/config.pbtxt tokenizer_dir:${HF_BAICHUAN_MODEL},triton_max_batch_size:64,postprocessing_instance_count:1 python3 tools/fill_template.py -i baichuan_ifb/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:64,decoupled_mode:True,bls_instance_count:1,accumulate_tokens:True python3 tools/fill_template.py -i baichuan_ifb/ensemble/config.pbtxt triton_max_batch_size:64 -python3 tools/fill_template.py -i baichuan_ifb/tensorrt_llm/config.pbtxt triton_max_batch_size:64,decoupled_mode:True,max_beam_width:1,engine_dir:/tmp/baichuan/13B/trt_engines/fp16/1-gpu/,max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:0 +python3 tools/fill_template.py -i baichuan_ifb/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:64,decoupled_mode:True,max_beam_width:1,engine_dir:/tmp/baichuan/13B/trt_engines/fp16/1-gpu/,max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:0 pip install SentencePiece # please add `trust_remote_code=True` in tokenizer of preprocessing and postprocessing. Considering the security, we don't add it by default. diff --git a/docs/gemma.md b/docs/gemma.md index 31d8fe2a..fed782ae 100644 --- a/docs/gemma.md +++ b/docs/gemma.md @@ -20,7 +20,7 @@ python3 tools/fill_template.py -i gemma/preprocessing/config.pbtxt tokenizer_dir python3 tools/fill_template.py -i gemma/postprocessing/config.pbtxt tokenizer_dir:${TOKENIZER_DIR},tokenizer_type:sp,triton_max_batch_size:64,postprocessing_instance_count:1 python3 tools/fill_template.py -i gemma/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:64,decoupled_mode:False,bls_instance_count:1,accumulate_tokens:False python3 tools/fill_template.py -i gemma/ensemble/config.pbtxt triton_max_batch_size:64 -python3 tools/fill_template.py -i gemma/tensorrt_llm/config.pbtxt triton_max_batch_size:64,decoupled_mode:False,max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:0,batch_scheduler_policy:guaranteed_no_evict,enable_trt_overlap:False +python3 tools/fill_template.py -i gemma/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:64,decoupled_mode:False,max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:0,batch_scheduler_policy:guaranteed_no_evict ``` diff --git a/docs/llama.md b/docs/llama.md index 8cda4ff5..b1baf0b4 100644 --- a/docs/llama.md +++ b/docs/llama.md @@ -30,7 +30,7 @@ python3 tools/fill_template.py -i llama_ifb/preprocessing/config.pbtxt tokenizer python3 tools/fill_template.py -i llama_ifb/postprocessing/config.pbtxt tokenizer_dir:${HF_LLAMA_MODEL},triton_max_batch_size:64,postprocessing_instance_count:1 python3 tools/fill_template.py -i llama_ifb/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:64,decoupled_mode:False,bls_instance_count:1,accumulate_tokens:False python3 tools/fill_template.py -i llama_ifb/ensemble/config.pbtxt triton_max_batch_size:64 -python3 tools/fill_template.py -i llama_ifb/tensorrt_llm/config.pbtxt triton_max_batch_size:64,decoupled_mode:False,max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:0 +python3 tools/fill_template.py -i llama_ifb/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:64,decoupled_mode:False,max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:0 ``` * Launch server @@ -119,7 +119,7 @@ python3 tools/fill_template.py -i llama_ifb/preprocessing/config.pbtxt tokenizer python3 tools/fill_template.py -i llama_ifb/postprocessing/config.pbtxt tokenizer_dir:${HF_LLAMA_MODEL},triton_max_batch_size:64,postprocessing_instance_count:1 python3 tools/fill_template.py -i llama_ifb/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:64,decoupled_mode:True,bls_instance_count:1,accumulate_tokens:True python3 tools/fill_template.py -i llama_ifb/ensemble/config.pbtxt triton_max_batch_size:64 -python3 tools/fill_template.py -i llama_ifb/tensorrt_llm/config.pbtxt triton_max_batch_size:64,decoupled_mode:True,max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_batching,max_queue_delay_microseconds:0 +python3 tools/fill_template.py -i llama_ifb/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:64,decoupled_mode:True,max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_batching,max_queue_delay_microseconds:0 pip install SentencePiece python3 scripts/launch_triton_server.py --world_size 1 --model_repo=llama_ifb/ diff --git a/inflight_batcher_llm/CMakeLists.txt b/inflight_batcher_llm/CMakeLists.txt index fda8c167..1a5e96cf 100644 --- a/inflight_batcher_llm/CMakeLists.txt +++ b/inflight_batcher_llm/CMakeLists.txt @@ -34,7 +34,16 @@ include(${TRTLLM_DIR}/cpp/cmake/modules/find_library_create_target.cmake) project(tritontensorrtllmbackend LANGUAGES C CXX) -add_compile_options("-D_GLIBCXX_USE_CXX11_ABI=0" "-DENABLE_MULTI_DEVICE=1") +add_compile_options("-DENABLE_MULTI_DEVICE=1") +# https://gcc.gnu.org/onlinedocs/libstdc++/manual/using_dual_abi.html +option(USE_CXX11_ABI "Using CXX11 ABI of libstdc++" OFF) +message(STATUS "USE_CXX11_ABI: ${USE_CXX11_ABI}") +if(USE_CXX11_ABI) + add_compile_options("-D_GLIBCXX_USE_CXX11_ABI=1") +else() + add_compile_options("-D_GLIBCXX_USE_CXX11_ABI=0") +endif() + # # Options # @@ -134,7 +143,11 @@ find_library( CUDA_DRV_LIB cuda HINTS ${CUDA_TOOLKIT_ROOT_DIR} PATH_SUFFIXES lib lib64 lib/stubs lib64/stubs) -set(CUDA_LIBRARIES ${CUDART_LIB}) +find_library( + NVIDIA_ML_LIB nvidia-ml + HINTS ${CUDA_TOOLKIT_ROOT_DIR} + PATH_SUFFIXES lib lib64 lib/stubs lib64/stubs) +set(CUDA_LIBRARIES ${CUDART_LIB} ${NVIDIA_ML_LIB}) find_package(MPI REQUIRED) message(STATUS "Using MPI_INCLUDE_PATH: ${MPI_INCLUDE_PATH}") diff --git a/inflight_batcher_llm/README.md b/inflight_batcher_llm/README.md index cc4afd54..849c1947 100644 --- a/inflight_batcher_llm/README.md +++ b/inflight_batcher_llm/README.md @@ -87,26 +87,8 @@ parameters: { } ``` -In-flight batching is able to overlap the execution of batches of -requests. It may have a negative impact on performance when the number of -requests is too small. To enable that feature, set the `enable_trt_overlap` -parameter to `True` in the `config.pbtxt` file: - -``` -parameters: { - key: "enable_trt_overlap" - value: { - string_value: "True" - } -} -``` - -Or, equivalently, add `enable_trt_overlap:True` to the invocation of the -`fill_template.py` tool: - -```bash -python3 tools/fill_template.py -i all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt "enable_trt_overlap:True" -``` +Note that the parameter `enable_trt_overlap` has been removed from the `config.pbtxt`. This option allowed to overlap execution of two micro-batches to hide CPU overhead. +Optimization work has been done to reduce the CPU overhead and it was found that the overlapping of micro-batches did not provide additional benefits. To reuse previously computed KV cache values (e.g. for system prompt), set `enable_kv_cache_reuse` parameter to `True` in the `config.pbtxt` file: diff --git a/inflight_batcher_llm/src/model_instance_state.cc b/inflight_batcher_llm/src/model_instance_state.cc index 193c7b3f..c20b8319 100644 --- a/inflight_batcher_llm/src/model_instance_state.cc +++ b/inflight_batcher_llm/src/model_instance_state.cc @@ -156,10 +156,10 @@ executor::KvCacheConfig ModelInstanceState::getKvCacheConfigFromParams() TLLM_LOG_WARNING("enable_kv_cache_reuse is not specified, will be set to false"); } - std::optional maxAttentionWindowSizeType = std::nullopt; + std::optional maxAttentionWindowSizeType = std::nullopt; if (maxAttentionWindow.has_value()) { - maxAttentionWindowSizeType = static_cast(maxAttentionWindow.value()); + maxAttentionWindowSizeType = static_cast(maxAttentionWindow.value()); } return executor::KvCacheConfig(enableKVCacheReuse, maxTokensInPagedKvCache, maxAttentionWindowSizeType, @@ -194,15 +194,15 @@ executor::PeftCacheConfig ModelInstanceState::getPeftCacheConfigFromParams() // lora_cache_gpu_memory_fraction // lora_cache_host_memory_bytes - SizeType maxAdapterSize = 64; - SizeType optimalAdapterSize = 8; + SizeType32 maxAdapterSize = 64; + SizeType32 optimalAdapterSize = 8; std::optional hostCacheSize = std::nullopt; std::optional deviceCachePercent = std::nullopt; std::string fieldName = "lora_cache_max_adapter_size"; try { - maxAdapterSize = model_state_->GetParameter(fieldName); + maxAdapterSize = model_state_->GetParameter(fieldName); } catch (std::exception const& e) { @@ -212,7 +212,7 @@ executor::PeftCacheConfig ModelInstanceState::getPeftCacheConfigFromParams() fieldName = "lora_cache_optimal_adapter_size"; try { - optimalAdapterSize = model_state_->GetParameter(fieldName); + optimalAdapterSize = model_state_->GetParameter(fieldName); } catch (std::exception const& e) { @@ -244,17 +244,18 @@ executor::PeftCacheConfig ModelInstanceState::getPeftCacheConfigFromParams() executor::SchedulerConfig ModelInstanceState::getSchedulerConfigFromParams(bool enableChunkedContext) { - auto schedulerPolicy = executor::SchedulerPolicy::kGUARANTEED_NO_EVICT; + using executor::CapacitySchedulerPolicy; + auto schedulerPolicy = CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT; try { std::string schedulerPolicyStr = model_state_->GetParameter("batch_scheduler_policy"); if (schedulerPolicyStr == "max_utilization") { - schedulerPolicy = executor::SchedulerPolicy::kMAX_UTILIZATION; + schedulerPolicy = CapacitySchedulerPolicy::kMAX_UTILIZATION; } else if (schedulerPolicyStr == "guaranteed_no_evict") { - schedulerPolicy = executor::SchedulerPolicy::kGUARANTEED_NO_EVICT; + schedulerPolicy = CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT; } else { @@ -268,7 +269,7 @@ executor::SchedulerConfig ModelInstanceState::getSchedulerConfigFromParams(bool TLLM_LOG_WARNING(e.what()); } - if (isDecoupled() && schedulerPolicy != executor::SchedulerPolicy::kGUARANTEED_NO_EVICT) + if (isDecoupled() && schedulerPolicy != CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT) { if (!enableChunkedContext) { @@ -278,7 +279,7 @@ executor::SchedulerConfig ModelInstanceState::getSchedulerConfigFromParams(bool "enable_chunked_context to true. " "The batch scheduler policy will be set to guaranteed_no_evict " "since enable_chunked_context is false."); - schedulerPolicy = executor::SchedulerPolicy::kGUARANTEED_NO_EVICT; + schedulerPolicy = CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT; } } return executor::SchedulerConfig(schedulerPolicy); @@ -464,6 +465,18 @@ ModelInstanceState::ModelInstanceState(ModelState* model_state, TRITONBACKEND_Mo } mInstanceSpecificConfig.cancellationCheckPeriodMs = cancellationCheckPeriodMs; + int statsCheckPeriodMs = 100; + try + { + statsCheckPeriodMs = model_state_->GetParameter("stats_check_period_ms"); + } + catch (std::exception const& e) + { + // If parameter is not specified, just ignore + TLLM_LOG_WARNING("stats_check_period_ms is not specified, will be set to 100 (ms)"); + } + mInstanceSpecificConfig.statsCheckPeriodMs = statsCheckPeriodMs; + if (mExecutor->canEnqueueRequests()) { mStopWaitForResponse = false; @@ -509,6 +522,7 @@ bool ModelInstanceState::handleStopRequest(TRITONBACKEND_Request* request, std:: { throw std::runtime_error("Trying to stop a request but request ID is not provided"); } + std::lock_guard lock(mRequestIdToRequestDataMutex); if (mTritonRequestIdToRequestId.count(tritonRequestId)) { auto requestId = mTritonRequestIdToRequestId[tritonRequestId]; @@ -536,6 +550,10 @@ executor::Request ModelInstanceState::createExecutorRequest( void ModelInstanceState::enqueue(TRITONBACKEND_Request** requests, uint32_t const request_count) { + + uint64_t exec_start_ns{0}; + SET_TIMESTAMP(exec_start_ns); + for (uint32_t i = 0; i < request_count; ++i) { TRITONBACKEND_Request* request = requests[i]; @@ -559,8 +577,10 @@ void ModelInstanceState::enqueue(TRITONBACKEND_Request** requests, uint32_t cons = createExecutorRequest(request, mInstanceSpecificConfig.excludeInputFromOutput, isDecoupled()); int64_t inputTokensSize = executorRequest.getInputTokenIds().size(); - executor::SizeType beamWidthCopy = executorRequest.getSamplingConfig().getBeamWidth(); + executor::SizeType32 beamWidthCopy = executorRequest.getSamplingConfig().getBeamWidth(); std::lock_guard lock(mRequestIdToRequestDataMutex); + uint64_t compute_start_ns{0}; + SET_TIMESTAMP(compute_start_ns); auto requestId = mExecutor->enqueueRequest(executorRequest); if (mRequestIdToRequestData.count(requestId)) { @@ -572,8 +592,10 @@ void ModelInstanceState::enqueue(TRITONBACKEND_Request** requests, uint32_t cons TRITONBACKEND_ResponseFactory* factory; LOG_IF_ERROR( TRITONBACKEND_ResponseFactoryNew(&factory, request), "failed to create triton response factory"); - mRequestIdToRequestData.emplace( - requestId, RequestData{factory, request, tritonRequestId, inputTokensSize, beamWidthCopy}); + + mRequestIdToRequestData.emplace(requestId, + RequestData{factory, request, tritonRequestId, inputTokensSize, beamWidthCopy, + {exec_start_ns, compute_start_ns, 0, 0}}); if (tritonRequestId != "") { mTritonRequestIdToRequestId[tritonRequestId] = requestId; @@ -587,6 +609,24 @@ void ModelInstanceState::enqueue(TRITONBACKEND_Request** requests, uint32_t cons return; } +TRITONSERVER_Error* ModelInstanceState::reportBaseMetrics(RequestData& requestData, TRITONSERVER_Error* error) +{ + auto& timestamps = requestData.timestamps; + SET_TIMESTAMP(timestamps.exec_end_ns); + + RETURN_IF_ERROR( + TRITONBACKEND_ModelInstanceReportStatistics(modelInstance_, requestData.tritonRequest, (error == nullptr), + timestamps.exec_start_ns, timestamps.compute_start_ns, timestamps.compute_end_ns, timestamps.exec_end_ns)); + + // For now we will assume a batch size of 1 for each request. This may change in the future but for + // now it seems that even when requests are dynamically batched together each workItem is associated + // with its own request object and is handled independently due to the nature of IFB. + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceReportBatchStatistics(modelInstance_, 1 /* batch size */, + timestamps.exec_start_ns, timestamps.compute_start_ns, timestamps.compute_end_ns, timestamps.exec_end_ns)); + + return nullptr; // success +} + std::tuple ModelInstanceState::fillTritonResponse( TRITONBACKEND_ResponseFactory* factory, executor::Response const& response, RequestData const& requestData) { @@ -719,6 +759,8 @@ void ModelInstanceState::WaitForResponse() { std::chrono::milliseconds waitTime(1); auto responses = mExecutor->awaitResponses(waitTime); + uint64_t compute_end_ns{0}; + SET_TIMESTAMP(compute_end_ns); for (auto const& response : responses) { @@ -749,6 +791,10 @@ void ModelInstanceState::WaitForResponse() { mTritonRequestIdToRequestId.erase(requestData.tritonRequestId); } + + requestData.timestamps.compute_end_ns = compute_end_ns; + LOG_IF_ERROR(reportBaseMetrics(requestData, error), "Error reporting metrics"); + LOG_IF_ERROR(TRITONBACKEND_RequestRelease(requestData.tritonRequest, TRITONSERVER_REQUEST_RELEASE_ALL), "Cannot release request"); LOG_IF_ERROR(TRITONBACKEND_ResponseFactoryDelete(factory), "Cannot delete response factory"); @@ -762,6 +808,7 @@ void ModelInstanceState::WaitForStats() { while (!mStopWaitForStats) { + std::this_thread::sleep_for(std::chrono::milliseconds(mInstanceSpecificConfig.statsCheckPeriodMs)); auto stats = mExecutor->getLatestIterationStats(); for (auto const& stat : stats) { diff --git a/inflight_batcher_llm/src/model_instance_state.h b/inflight_batcher_llm/src/model_instance_state.h index f7e0fa43..f1645401 100644 --- a/inflight_batcher_llm/src/model_instance_state.h +++ b/inflight_batcher_llm/src/model_instance_state.h @@ -38,7 +38,6 @@ #include "tensorrt_llm/batch_manager/callbacks.h" #include "tensorrt_llm/batch_manager/kvCacheConfig.h" #include "tensorrt_llm/batch_manager/namedTensor.h" -#include "tensorrt_llm/batch_manager/schedulerPolicy.h" #include "tensorrt_llm/batch_manager/trtGptModelOptionalParams.h" #include "tensorrt_llm/runtime/decodingMode.h" @@ -50,7 +49,6 @@ using namespace tensorrt_llm; using namespace tensorrt_llm::batch_manager; -using namespace tensorrt_llm::batch_manager::batch_scheduler; namespace triton::backend::inflight_batcher_llm { @@ -60,6 +58,24 @@ struct InstanceSpecificConfig { bool excludeInputFromOutput; int cancellationCheckPeriodMs; + int statsCheckPeriodMs; +}; + +/// @brief Timestamps for each request, used to report Triton metrics +struct Timestamps +{ + uint64_t exec_start_ns = 0; + uint64_t compute_start_ns = 0; + uint64_t compute_end_ns = 0; + uint64_t exec_end_ns = 0; + + void Reset() + { + exec_start_ns = 0; + compute_start_ns = 0; + compute_end_ns = 0; + exec_end_ns = 0; + } }; /// @brief Per-request data stored for handling requests @@ -69,7 +85,8 @@ struct RequestData TRITONBACKEND_Request* tritonRequest; std::string tritonRequestId; int64_t inputTokensSize; - executor::SizeType beamWidth; + executor::SizeType32 beamWidth; + Timestamps timestamps; }; // @@ -87,11 +104,11 @@ class ModelInstanceState public: // number of cpu workers used to move weights host cache to gpu cache - static constexpr SizeType kPeftCacheNumEnsureWorkers = 4; + static constexpr SizeType32 kPeftCacheNumEnsureWorkers = 4; // number of cuda streams used for H2D copies of peft cache pages - static constexpr SizeType kPeftCacheNumCopyStreams = 4; + static constexpr SizeType32 kPeftCacheNumCopyStreams = 4; // number of cpu workers used to load weight into host cache - static constexpr SizeType kPeftCacheNumPutWorkers = 4; + static constexpr SizeType32 kPeftCacheNumPutWorkers = 4; /// @brief Create a ModelInstanceObject static TRITONSERVER_Error* Create( @@ -121,7 +138,7 @@ class ModelInstanceState } /// @brief Add the request to the executor - void enqueue(TRITONBACKEND_Request** requests, const uint32_t request_count); + void enqueue(TRITONBACKEND_Request** requests, uint32_t const request_count); private: /// @brief Get batching type @@ -169,6 +186,9 @@ class ModelInstanceState /// @brief Config to be used when sending requests to executor InstanceSpecificConfig mInstanceSpecificConfig; + /// @brief Report Triton base metrics for a given request + TRITONSERVER_Error* reportBaseMetrics(RequestData& requestData, TRITONSERVER_Error* error); + /// @brief Retrieve responses from the executor void WaitForResponse(); /// @brief The thread for WaitForResponse() to run diff --git a/inflight_batcher_llm/src/model_state.cc b/inflight_batcher_llm/src/model_state.cc index 3ff8eaf3..d0539311 100644 --- a/inflight_batcher_llm/src/model_state.cc +++ b/inflight_batcher_llm/src/model_state.cc @@ -119,7 +119,7 @@ std::string const ModelState::GetExecutorWorkerPath() auto workerPath = GetParameter("worker_path"); TLLM_THROW( "worker_path parameter is specified, but this is no longer supported. Please specify executor_worker_path " - "instead to specify the location of the trtllmExecuutorWorker executable."); + "instead to specify the location of the trtllmExecutorWorker executable."); } catch (std::exception const& e) { diff --git a/inflight_batcher_llm/src/utils.cc b/inflight_batcher_llm/src/utils.cc index 629bfcad..6611b3ab 100644 --- a/inflight_batcher_llm/src/utils.cc +++ b/inflight_batcher_llm/src/utils.cc @@ -376,7 +376,7 @@ executor::SamplingConfig getSamplingConfigFromTensors(InputTensors const& inputs // If beam_width is specified, set it from config.pbtxt extractSingleton(inputsTensors, InputFieldsNames::beamWidth, beamWidth); - std::optional topK{std::nullopt}; + std::optional topK{std::nullopt}; extractOptionalSingleton(inputsTensors, InputFieldsNames::topK, topK); std::optional topP{std::nullopt}; @@ -531,16 +531,16 @@ executor::Request createRequestFromInputTensors(std::unordered_map(inputsTensors, InputFieldsNames::maxNewTokens, maxNewTokens)) { throw std::runtime_error("request_output_len is not present in the request"); } - std::optional endId{std::nullopt}; + std::optional endId{std::nullopt}; utils::extractOptionalSingleton(inputsTensors, InputFieldsNames::endId, endId); - std::optional padId{std::nullopt}; + std::optional padId{std::nullopt}; utils::extractOptionalSingleton(inputsTensors, InputFieldsNames::padId, padId); if (streaming && !isDecoupled) diff --git a/scripts/benchmarking/build_model.sh b/scripts/benchmarking/build_model.sh deleted file mode 100644 index 661f72fa..00000000 --- a/scripts/benchmarking/build_model.sh +++ /dev/null @@ -1,342 +0,0 @@ -#!/usr/bin/bash - -MODEL=$1 -ENGINE_PATH=$2 -BS=$3 -MAX_INPUT_SEQLEN=$4 -MAX_OUTPUT_SEQLEN=$5 -MAX_TOKENS=$6 -TP=$7 -PP=$8 -WORLD_SIZE=$9 - -GPT2=/trt_llm_data/llm-models/gpt2 -OPT_125M=/trt_llm_data/llm-models/opt-125m -LLAMA=/trt_llm_data/llm-models/llama-models/llama-7b-hf -GPTJ=/trt_llm_data/llm-models/gpt-j-6b -MISTRAL=/trt_llm_data/llm-models/mistral-7b-v0.1 - -set -e -pushd ../../ - -if [ "$MODEL" = "llama-7b-fp16" ]; then - - pushd tensorrt_llm/examples/llama - - pip install -r requirements.txt - - python3 build.py --remove_input_padding \ - --enable_context_fmha \ - --parallel_build \ - --output_dir "$ENGINE_PATH" \ - --dtype float16 \ - --use_gpt_attention_plugin float16 \ - --max_batch_size "$BS" \ - --max_input_len "$MAX_INPUT_SEQLEN" \ - --max_output_len "$MAX_OUTPUT_SEQLEN" \ - --use_inflight_batching \ - --paged_kv_cache \ - --max_num_tokens "$MAX_TOKENS" \ - --world_size "$WORLD_SIZE" \ - --tp_size "$TP" \ - --pp_size "$PP" \ - --n_layer 32 --n_head 32 --n_embd 4096 --inter_size 11008 \ - --vocab_size 32000 --n_positions 4096 --hidden_act "silu" \ - --use_gemm_plugin float16 \ - - popd - -fi - -if [ "$MODEL" = "mistral-7b-fp16" ]; then - - pushd tensorrt_llm/examples/llama - - pip install -r requirements.txt - - python3 build.py --model_dir /tensorrtllm_backens/models/Mistral-7B-v0.1 --dtype float16 \ - --use_gpt_attention_plugin float16 \ - --use_gemm_plugin float16 \ - --output_dir "$ENGINE_PATH" \ - --max_batch_size "$BS" --max_input_len 32256 --max_output_len 512 \ - --enable_context_fmha --remove_input_padding \ - --use_inflight_batching --paged_kv_cache \ - --max_num_tokens "$MAX_TOKENS" - - popd - -fi - -if [ "$MODEL" = "llama-7b-fp8" ]; then - - pushd tensorrt_llm/examples/llama - - pip install -r requirements.txt - - python3 build.py --remove_input_padding \ - --enable_context_fmha \ - --parallel_build \ - --output_dir "$ENGINE_PATH" \ - --dtype float16 \ - --use_gpt_attention_plugin float16 \ - --max_batch_size "$BS" \ - --max_input_len "$MAX_INPUT_SEQLEN" \ - --max_output_len "$MAX_OUTPUT_SEQLEN" \ - --use_inflight_batching \ - --paged_kv_cache \ - --max_num_tokens "$MAX_TOKENS" \ - --world_size "$WORLD_SIZE" \ - --tp_size "$TP" \ - --pp_size "$PP" \ - --n_layer 32 --n_head 32 --n_embd 4096 --inter_size 11008 \ - --vocab_size 32000 --n_positions 4096 --hidden_act "silu" \ - --enable_fp8 \ - --fp8_kv_cache \ - --strongly_typed - - popd - -fi - -if [ "$MODEL" = "llama-13b-fp8" ]; then - - pushd tensorrt_llm/examples/llama - - pip install -r requirements.txt - - python3 build.py --remove_input_padding \ - --enable_context_fmha \ - --parallel_build \ - --output_dir "$ENGINE_PATH" \ - --dtype float16 \ - --use_gpt_attention_plugin float16 \ - --max_batch_size "$BS" \ - --max_input_len "$MAX_INPUT_SEQLEN" \ - --max_output_len "$MAX_OUTPUT_SEQLEN" \ - --use_inflight_batching \ - --paged_kv_cache \ - --max_num_tokens "$MAX_TOKENS" \ - --world_size "$WORLD_SIZE" \ - --tp_size "$TP" \ - --pp_size "$PP" \ - --n_layer 40 --n_head 40 --n_embd 5120 --inter_size 13824 \ - --vocab_size 32000 --n_positions 4096 --hidden_act "silu" \ - --enable_fp8 \ - --fp8_kv_cache \ - --strongly_typed - - popd - -fi - -if [ "$MODEL" = "llama-13b-fp16" ]; then - - pushd tensorrt_llm/examples/llama - - pip install -r requirements.txt - - python3 build.py --remove_input_padding \ - --enable_context_fmha \ - --parallel_build \ - --output_dir "$ENGINE_PATH" \ - --dtype float16 \ - --use_gpt_attention_plugin float16 \ - --max_batch_size "$BS" \ - --max_input_len "$MAX_INPUT_SEQLEN" \ - --max_output_len "$MAX_OUTPUT_SEQLEN" \ - --use_inflight_batching \ - --paged_kv_cache \ - --max_num_tokens "$MAX_TOKENS" \ - --world_size "$WORLD_SIZE" \ - --tp_size "$TP" \ - --pp_size "$PP" \ - --n_layer 40 --n_head 40 --n_embd 5120 --inter_size 13824 \ - --vocab_size 32000 --n_positions 4096 --hidden_act "silu" \ - --use_gemm_plugin float16 - - popd - -fi - -if [ "$MODEL" = "llama-70b-fp8" ]; then - - pushd tensorrt_llm/examples/llama - - pip install -r requirements.txt - - if [ "$PP" > 1 ]; then - # Use gen_micro_batch_size as max_batch_size for engine build - ENGINE_BS=$(expr $BS / $PP) - else - ENGINE_BS=$BS - fi - - python3 build.py --remove_input_padding \ - --enable_context_fmha \ - --parallel_build \ - --output_dir "$ENGINE_PATH" \ - --dtype float16 \ - --use_gpt_attention_plugin float16 \ - --max_batch_size "$ENGINE_BS" \ - --max_input_len "$MAX_INPUT_SEQLEN" \ - --max_output_len "$MAX_OUTPUT_SEQLEN" \ - --use_inflight_batching \ - --paged_kv_cache \ - --max_num_tokens "$MAX_TOKENS" \ - --world_size "$WORLD_SIZE" \ - --tp_size "$TP" \ - --pp_size "$PP" \ - --n_layer 80 --n_head 64 --n_kv_head 8 --n_embd 8192 --inter_size 28672 \ - --vocab_size 32000 --n_positions 4096 --hidden_act "silu" \ - --ffn_dim_multiplier 1.3 --multiple_of 4096 \ - --enable_fp8 \ - --fp8_kv_cache \ - --strongly_typed - - popd - -fi - -if [ "$MODEL" = "llama-70b-fp16" ]; then - - pushd tensorrt_llm/examples/llama - - pip install -r requirements.txt - - if [ "$PP" > 1 ]; then - # Use gen_micro_batch_size as max_batch_size for engine build - ENGINE_BS=$(expr $BS / $PP) - else - ENGINE_BS=$BS - fi - - python3 build.py --remove_input_padding \ - --enable_context_fmha \ - --parallel_build \ - --output_dir "$ENGINE_PATH" \ - --dtype float16 \ - --use_gpt_attention_plugin float16 \ - --max_batch_size "$ENGINE_BS" \ - --max_input_len "$MAX_INPUT_SEQLEN" \ - --max_output_len "$MAX_OUTPUT_SEQLEN" \ - --use_inflight_batching \ - --paged_kv_cache \ - --max_num_tokens "$MAX_TOKENS" \ - --world_size "$WORLD_SIZE" \ - --tp_size "$TP" \ - --pp_size "$PP" \ - --n_layer 80 --n_head 64 --n_kv_head 8 --n_embd 8192 --inter_size 28672 \ - --vocab_size 32000 --n_positions 4096 --hidden_act "silu" \ - --ffn_dim_multiplier 1.3 --multiple_of 4096 \ - --use_gemm_plugin float16 - - popd - -fi - -if [ "$MODEL" = "gptj-6b-fp8" ]; then - - pushd tensorrt_llm/examples/gptj - - pip install -r requirements.txt - - # No pipeline parallelism argument in build.py for now. - python3 build.py --dtype=float16 \ - --use_gpt_attention_plugin float16 \ - --max_batch_size "$BS" --max_input_len "$MAX_INPUT_SEQLEN" --max_output_len "$MAX_OUTPUT_SEQLEN" \ - --vocab_size 50401 --max_beam_width 1 \ - --output_dir "$ENGINE_PATH" \ - --model_dir /mlperf_inference_data/models/GPTJ-6B/checkpoint-final \ - --enable_context_fmha \ - --fp8_kv_cache \ - --enable_fp8 \ - --parallel_build \ - --world_size "$WORLD_SIZE" \ - --paged_kv_cache \ - --use_inflight_batching \ - --remove_input_padding \ - --strongly_typed \ - --max_num_tokens "$MAX_TOKENS" - - popd - -fi - -if [ "$MODEL" = "gptj-6b-fp16" ]; then - - pushd tensorrt_llm/examples/gptj - - pip install -r requirements.txt - - # No pipeline parallelism argument in build.py for now. - python3 build.py --dtype=float16 \ - --use_gpt_attention_plugin float16 \ - --use_gemm_plugin float16 \ - --max_batch_size "$BS" --max_input_len "$MAX_INPUT_SEQLEN" --max_output_len "$MAX_OUTPUT_SEQLEN" \ - --vocab_size 50401 --max_beam_width 1 \ - --output_dir "$ENGINE_PATH" \ - --model_dir /mlperf_inference_data/models/GPTJ-6B/checkpoint-final \ - --enable_context_fmha \ - --paged_kv_cache \ - --parallel_build \ - --world_size "$WORLD_SIZE" \ - --use_inflight_batching \ - --remove_input_padding \ - --max_num_tokens "$MAX_TOKENS" - - popd - -fi - -if [ "$MODEL" = "falcon-180b-fp8" ]; then - - pushd tensorrt_llm/examples/falcon - - pip install -r requirements.txt - - python3 build.py --use_inflight_batching \ - --paged_kv_cache \ - --remove_input_padding \ - --enable_context_fmha \ - --parallel_build \ - --output_dir "$ENGINE_PATH" \ - --dtype bfloat16 \ - --use_gpt_attention_plugin bfloat16 \ - --world_size "$WORLD_SIZE" \ - --tp_size "$TP" \ - --pp_size "$PP" \ - --max_batch_size "$BS" --max_input_len "$MAX_INPUT_SEQLEN" --max_output_len "$MAX_OUTPUT_SEQLEN" \ - --enable_fp8 --fp8_kv_cache \ - --strongly_typed \ - --n_layer 80 --n_head 232 --n_kv_head 8 --n_embd 14848 --vocab_size 65024 --new_decoder_architecture \ - --max_num_tokens "$MAX_TOKENS" - - popd - -fi - -if [ "$MODEL" = "falcon-180b-fp16" ]; then - - pushd tensorrt_llm/examples/falcon - - pip install -r requirements.txt - - python3 build.py --use_inflight_batching \ - --paged_kv_cache \ - --remove_input_padding \ - --enable_context_fmha \ - --parallel_build \ - --output_dir "$ENGINE_PATH" \ - --dtype bfloat16 \ - --use_gemm_plugin bfloat16 \ - --use_gpt_attention_plugin bfloat16 \ - --world_size "$WORLD_SIZE" \ - --tp_size "$TP" \ - --pp_size "$PP" \ - --max_batch_size "$BS" --max_input_len "$MAX_INPUT_SEQLEN" --max_output_len "$MAX_OUTPUT_SEQLEN" \ - --n_layer 80 --n_head 232 --n_kv_head 8 --n_embd 14848 --vocab_size 65024 --new_decoder_architecture \ - --max_num_tokens "$MAX_TOKENS" - - popd - -fi diff --git a/scripts/benchmarking/collate_reports.py b/scripts/benchmarking/collate_reports.py deleted file mode 100644 index e0f559e9..00000000 --- a/scripts/benchmarking/collate_reports.py +++ /dev/null @@ -1,59 +0,0 @@ -import argparse -import csv -import os - - -def combine_csv_files(args): - - # List to store the data rows - data_rows = [] - - # Loop through each CSV file in the directory - for filename in os.listdir(args.directory_path): - if filename.endswith('.csv'): - with open(os.path.join(args.directory_path, filename), - 'r', - newline='') as csvfile: - reader = csv.reader(csvfile) - # Read the header (row 1) from each CSV file - header = next(reader) - # Read the data (row 2) from each CSV file - data = next(reader) - file_info = (os.path.splitext(filename)[0]).split("__") - if file_info[2] == "v1": - file_info[3] = "NA" - new_data = file_info + data - data_rows.append(new_data) - - appended_header = [ - "Machine", "Model", "Batching Scheme", "Scheduler Policy", "Dataset", - "REQ_RATE" - ] + header - # Write the combined data to the output CSV file - with open(args.output_filename, 'w', newline='') as csvfile: - writer = csv.writer(csvfile) - # Write the header fields to the output file - writer.writerow(appended_header) - # Write the data rows to the output file - writer.writerows(data_rows) - - print(f"Combined data has been written to {args.output_filename}") - - -def main(): - parser = argparse.ArgumentParser( - description="Combine CSV files in a directory.") - parser.add_argument("--directory_path", - type=str, - help="Path to the directory containing CSV files") - parser.add_argument("--output_filename", - type=str, - help="Name of the output CSV file") - - args = parser.parse_args() - - combine_csv_files(args) - - -if __name__ == "__main__": - main() diff --git a/scripts/benchmarking/replace_bs.py b/scripts/benchmarking/replace_bs.py deleted file mode 100644 index b962bf87..00000000 --- a/scripts/benchmarking/replace_bs.py +++ /dev/null @@ -1,44 +0,0 @@ -import argparse -import os - - -# Function to perform in-place substitution -def replace_line(file_path, search_string, replace_string): - with open(file_path, 'r') as file: - lines = file.readlines() - - with open(file_path, 'w') as file: - for line in lines: - if search_string in line: - file.write("max_batch_size: {}\n".format(replace_string)) - else: - file.write(line) - - -# Function to search for and replace lines in files -def replace_in_files(directory, search_string, replace_string): - for root, _, files in os.walk(directory): - for file in files: - if file == 'config.pbtxt': - file_path = os.path.join(root, file) - replace_line(file_path, search_string, replace_string) - - -def main(): - parser = argparse.ArgumentParser( - description="Recursively replace lines in config.pbtxt files.") - parser.add_argument("directory", - type=str, - help="The directory to search for config.pbtxt files.") - #parser.add_argument("search_string", help="The string to search for in the lines.") - parser.add_argument("bs_replace_value", - type=str, - help="The string to replace matching lines with.") - - args = parser.parse_args() - search_string = "max_batch_size" - replace_in_files(args.directory, search_string, args.bs_replace_value) - - -if __name__ == "__main__": - main() diff --git a/scripts/benchmarking/test.sh b/scripts/benchmarking/test.sh deleted file mode 100644 index 4b4d4c4c..00000000 --- a/scripts/benchmarking/test.sh +++ /dev/null @@ -1,214 +0,0 @@ -#!/usr/bin/bash - -MODEL=$1 -ENGINE_PATH=$2 -TOKENIZER_PATH=$3 -BS=$5 -MAX_INPUT_SEQLEN=$6 -TP=$7 -PP=$8 -WORLD_SIZE=$9 -RECORD_LOG=${10} -MAX_ATTENTION_WINDOW_SIZE=${11} -GET_NSYS_REP="${12:-"false"}" - -set -e -nvidia-smi - -pushd ../../ -source tools/utils.sh - -#----------------------- WORKLOAD_DETAILS -----------------------# - -# token normal distribution. -# (ip_mean, ip_stdev, op_mean, op_stdev, num_prompts) -TOKEN_DIST_LIST=( "128,0,1,0,8192" "32,0,1024,0,1024" ) -DATASETS=( "" ) # names of datasets - -# key: dataset name, value: path to dataset json file -declare -A dataset_dict=( [""]="" ) - -# dictionary[workload] = list of request rates to shmoo over. Should contain keys from TOKEN_DIST_LIST and DATASETS -declare -A REQ_RATES=( ["128,0,1,0,8192"]="-1" - ["32,0,1024,0,1024"]="-1" - ["cnn"]="-1" - ["openweb"]="-1" - ) -REQ_RATES_HIST="" # -#-----------------------------------------------------------------# - -EXCLUDE_INPUT_IN_OUTPUT="false" -ENABLE_TRT_OVERLAP="false" -MAX_QUEUE_DELAY_MICROSECONDS="0" -MAX_BEAM_WIDTH="1" - -gpu_info=$(nvidia-smi --query-gpu=name --format=csv,noheader,nounits) -if [[ $gpu_info == *"H100"* ]]; then - MACHINE="H100" -elif [[ $gpu_info == *"A100"* ]]; then - MACHINE="A100" -elif [[ $gpu_info == *"L40S"* ]]; then - MACHINE="L40S" -fi - -fill_triton_repo () { - # Modify config.pbtxt - python3 tools/fill_template.py -i my_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt engine_dir:${ENGINE_PATH},decoupled_mode:"False",batching_strategy:${BATCHING_STRATEGY},max_attention_window_size:${MAX_ATTENTION_WINDOW_SIZE},batch_scheduler_policy:${BATCH_SCHEDULER_POLICY},exclude_input_in_output:${EXCLUDE_INPUT_IN_OUTPUT},triton_max_batch_size:${BS},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},max_beam_width:${MAX_BEAM_WIDTH},enable_trt_overlap:${ENABLE_TRT_OVERLAP} - python3 tools/fill_template.py -i my_models/inflight_batcher_llm/preprocessing/config.pbtxt triton_max_batch_size:${BS},tokenizer_dir:${TOKENIZER_PATH},preprocessing_instance_count:1 - python3 tools/fill_template.py -i my_models/inflight_batcher_llm/postprocessing/config.pbtxt triton_max_batch_size:${BS},tokenizer_dir:${TOKENIZER_PATH},postprocessing_instance_count:1 - python3 tools/fill_template.py -i my_models/inflight_batcher_llm/ensemble/config.pbtxt triton_max_batch_size:${BS} - python3 tools/fill_template.py -i my_models/inflight_batcher_llm/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:${BS},decoupled_mode:"False",accumulate_tokens:"False",bls_instance_count:1 -} - -print_test_params () { - - echo "----------------------------------" - echo " Test parameters:" - echo "----------------------------------" - echo "BATCHING_STRATEGY: ${BATCHING_STRATEGY}" - echo "BATCH_SCHEDULER_POLICY: ${BATCH_SCHEDULER_POLICY}" - echo "ENABLE_TRT_OVERLAP: ${ENABLE_TRT_OVERLAP}" - echo "EXCLUDE_INPUT_IN_OUTPUT: ${EXCLUDE_INPUT_IN_OUTPUT}" - echo "TRITON_MAX_BATCH_SIZE: ${BS}" - echo "MAX_QUEUE_DELAY_MICROSECONDS: ${MAX_QUEUE_DELAY_MICROSECONDS}" - echo "MAX_BEAM_WIDTH: ${MAX_BEAM_WIDTH}" - echo "MAX_ATTENTION_WINDOW_SIZE: ${MAX_ATTENTION_WINDOW_SIZE}" - echo "----------------------------------" -} - -if true; then - echo "TRUE" - - BATCHING_STRATEGIES=( "inflight_fused_batching" "v1" ) - - for BATCHING_STRATEGY in "${BATCHING_STRATEGIES[@]}"; do - - BATCH_SCHEDULER_POLICIES=( "guaranteed_no_evict" ) - - for BATCH_SCHEDULER_POLICY in "${BATCH_SCHEDULER_POLICIES[@]}"; do - - echo -e " \n ================= INITIALIZING TRITONSERVER FOR =============== \n" - print_test_params - - # Start each server with fresh configuration - rm -rf my_models - cp -R all_models my_models - - fill_triton_repo - - if [ "$RECORD_LOG" == "true" ]; then - echo -e " \n ========= Collecting log for the server ======== \n" - python3 scripts/launch_triton_server.py --world_size $WORLD_SIZE --model_repo my_models/inflight_batcher_llm/ --log --log-file triton_log.txt - elif [ "$GET_NSYS_REP" == "true" ]; then - # Change nsys profile delay and duration according to the server launch and dataset preprocessing time. - PROFILE_DELAY=30 - PROFILE_DURATION=120 - NSYS_OUT_NAME="trtllm" - echo -e " \n ========= Collecting Nsys report for the server (profile delay: ${PROFILE_DELAY} s, profile duration: ${PROFILE_DURATION} s) ======== \n" - nsys profile --trace cuda,nvtx --sample cpu -o $NSYS_OUT_NAME -f true --gpu-metrics-device=all --gpu-metrics-frequency=20000 --export sqlite -y ${PROFILE_DELAY} -d ${PROFILE_DURATION} \ - python3 scripts/launch_triton_server.py --world_size $WORLD_SIZE --model_repo my_models/inflight_batcher_llm/ & - else - python3 scripts/launch_triton_server.py --world_size $WORLD_SIZE --model_repo my_models/inflight_batcher_llm/ - fi - # Use pgrep to find the PID of the "mpirun"/"nsys" process - mpirun_pid=$(pgrep mpirun) - nsys_pid=$(pgrep nsys | head -1) - if [ -n "$mpirun_pid" ]; then - echo "PID of mpirun process: $mpirun_pid" - export SERVER_PID=($mpirun_pid) - elif [ -n "$nsys_pid" ]; then - echo "PID of nsys process: $nsys_pid" - export SERVER_PID=($nsys_pid) - else - echo "No mpirun or nsys process found." - fi - wait_for_server_ready ${SERVER_PID} 1200 - - pushd tools/inflight_batcher_llm/ - if [ $? -eq 0 ]; then - for DATASET in "${DATASETS[@]}"; do - IFS=',' read -ra REQUEST_RATES <<< "${REQ_RATES[${DATASET}]}" - for REQ_RATE in "${REQUEST_RATES[@]}"; do - op_stats_name="${MACHINE}__${MODEL}-tp${TP}-pp${PP}__${BATCHING_STRATEGY}__${BATCH_SCHEDULER_POLICY}__${DATASET}__${REQ_RATE}" - op_stats_csv_name="$op_stats_name.csv" - - echo -e "DATASET: $DATASET \n\n" - echo -e " ======== BENCHMARK_CORE_MODEL --> OP STATS FILE = ${op_stats_csv_name} ============== \n" - dataset_path="${dataset_dict[$DATASET]}" - python3 benchmark_core_model.py \ - -i grpc --max-input-len $MAX_INPUT_SEQLEN \ - --request-rate $REQ_RATE --op-stats-csv "$op_stats_csv_name" \ - --num-requests 15000 \ - dataset \ - --dataset $dataset_path \ - --tokenizer-dir "$TOKENIZER_PATH" - - sleep 5 - - if [ -n "$PROFILE_DURATION" ]; then - sleep $PROFILE_DURATION - fi - done - done - - for TOKEN_DIST in "${TOKEN_DIST_LIST[@]}"; do - IFS=',' read -ra REQUEST_RATES <<< "${REQ_RATES[${TOKEN_DIST}]}" - for REQ_RATE in "${REQUEST_RATES[@]}"; do - - # Use IFS and read to split the string into an array - IFS=',' read -ra token_params <<< "$TOKEN_DIST" - ip_mean=${token_params[0]} - ip_stdev=${token_params[1]} - op_mean=${token_params[2]} - op_stdev=${token_params[3]} - num_prompts=${token_params[4]} - - op_stats_name="${MACHINE}__${MODEL}-tp${TP}-pp${PP}__${BATCHING_STRATEGY}__${BATCH_SCHEDULER_POLICY}__normal-token-dist-${ip_mean}-${ip_stdev}-${op_mean}-${op_stdev}__${REQ_RATE}" - op_stats_csv_name="$op_stats_name.csv" - echo -e "DATASET: normal-token-dist \n\n" - echo -e " ======== BENCHMARK_CORE_MODEL --> OP STATS FILE = ${op_stats_csv_name} ============== \n" - python3 benchmark_core_model.py \ - -i grpc --max-input-len $MAX_INPUT_SEQLEN \ - --request-rate $REQ_RATE --op-stats-csv "$op_stats_csv_name" \ - --num-requests $num_prompts \ - token-norm-dist \ - --input-mean $ip_mean --input-stdev $ip_stdev --output-mean $op_mean --output-stdev $op_stdev \ - - - sleep 5 - done - done - - IFS=',' read -ra REQUEST_RATES <<< $REQ_RATES_HIST - for REQ_RATE in "${REQUEST_RATES[@]}"; do - - op_stats_name="${MACHINE}__${MODEL}-tp${TP}-pp${PP}__${BATCHING_STRATEGY}__${BATCH_SCHEDULER_POLICY}__token-hist-example__${REQ_RATE}" - op_stats_csv_name="$op_stats_name.csv" - echo -e "DATASET: token-hist-example \n\n" - echo -e " ======== BENCHMARK_CORE_MODEL --> OP STATS FILE = ${op_stats_csv_name} ============== \n" - python3 benchmark_core_model.py \ - -i grpc --max-input-len $MAX_INPUT_SEQLEN \ - --request-rate $REQ_RATE --op-stats-csv "$op_stats_csv_name" \ - token-from-histogram --histogram-key example - - sleep 5 - done - - echo -e " \n ========= KILLING TRITON SERVER WITH PID: #$SERVER_PID ============== \n" - triton_pid=$(pgrep triton | head -1) - if [ -n "$triton_pid" ]; then - kill -9 $triton_pid - fi - nsys_pid=$(pgrep nsys | head -1) - if [ -n "$nsys_pid" ]; then - kill -9 $nsys_pid - fi - kill -9 ${SERVER_PID} - else - echo -e "\n !!!!!!!!!!!! Triton Server initialization failed !!!!!!!!!!!!!!! \n" - fi - - popd # tools/inflight_batcher_llm - done - done -fi diff --git a/scripts/benchmarking/trtllm_perf.sh b/scripts/benchmarking/trtllm_perf.sh deleted file mode 100644 index 5bb159fc..00000000 --- a/scripts/benchmarking/trtllm_perf.sh +++ /dev/null @@ -1,88 +0,0 @@ -#!/usr/bin/bash - -# MODEL_SPEC is defined as "MODEL_NAME,TP,PP" -MODEL_SPEC=$1 -RECORD_SERVER_STATS="${2:-"false"}" - -TOKENIZER_DIR=/trt_llm_data/llm-models/llama-models/llama-7b-hf - -set -e - -######################## STATIC VALUES ####################### - -gpu_info=$(nvidia-smi --query-gpu=name --format=csv,noheader,nounits) -script_dir=$(dirname "$(realpath "$0")") - -declare -A bs_dict -if [[ $gpu_info == *"A100"* ]] || [[ $gpu_info == *"H100"* ]]; then - bs_dict["llama-7b-fp8,1,1"]=2048 - bs_dict["llama-13b-fp8,1,1"]=2048 - bs_dict["llama-7b-fp16,1,1"]=1024 - bs_dict["mistral-7b-fp16,1,1"]=1024 - bs_dict["llama-13b-fp16,1,1"]=1024 - bs_dict["gptj-6b-fp8,1,1"]=96 - bs_dict["llama-70b-fp8,2,1"]=512 - bs_dict["llama-70b-fp8,4,1"]=1024 - bs_dict["llama-70b-fp16,2,1"]=256 - bs_dict["llama-70b-fp16,4,1"]=512 - bs_dict["falcon-180b-fp8,8,1"]=512 -elif [[ $gpu_info == *"L40S"* ]]; then - bs_dict["llama-7b-fp8,1,1"]=1024 - bs_dict["llama-13b-fp8,1,1"]=512 - bs_dict["gptj-6b-fp8,1,1"]=1024 - bs_dict["llama-70b-fp8,2,1"]=256 - bs_dict["llama-70b-fp8,4,1"]=256 - bs_dict["llama-70b-fp8,1,4"]=256 - bs_dict["llama-70b-fp16,4,1"]=256 - bs_dict["llama-70b-fp16,1,4"]=256 -fi - -MAX_TOKENS=50000 - -if [ -z "$MODEL_SPEC" ]; then - echo "No model spec specified. Will run default list for the MACHINE" - - if [[ $gpu_info == *"A100"* ]]; then - MODEL_SPEC_LIST=( "llama-7b-fp16,1,1" "mistral-7b-fp16,1,1" "llama-13b-fp16,1,1" "gptj-6b-fp16,1,1" "llama-70b-fp16,4,1" "falcon-180b-fp16,8,1" ) - MACHINE="a100" - elif [[ $gpu_info == *"H100"* ]]; then - MODEL_SPEC_LIST=( "llama-7b-fp8,1,1" "llama-13b-fp8,1,1" "llama-70b-fp8,4,1" "gptj-6b-fp8,1,1" "llama-70b-fp8,2,1" "falcon-180b-fp8,8,1" ) - MACHINE="h100" - elif [[ $gpu_info == *"L40S"* ]]; then - MODEL_SPEC_LIST=( "llama-7b-fp8,1,1" "llama-13b-fp8,1,1" "gptj-6b-fp8,1,1" "llama-70b-fp8,2,1" "llama-70b-fp8,4,1" "llama-70b-fp8,1,4" ) - MACHINE="l40s" - else - echo -e "Nothing to run for this MACHINE" - fi -else - MODEL_SPEC_LIST=( "$MODEL_SPEC" ) - MACHINE="h100" -fi - -for MODEL_SPEC in "${MODEL_SPEC_LIST[@]}"; do - IFS=',' read -ra MODEL_SPECS <<< "${MODEL_SPEC}" - MODEL=${MODEL_SPECS[0]} - TP=${MODEL_SPECS[1]} - PP=${MODEL_SPECS[2]} - WORLD_SIZE=$((TP*PP)) - - BS=${bs_dict[${MODEL_SPEC}]} - MAX_INPUT_SEQLEN=16384 - MAX_OUTPUT_SEQLEN=4096 - if [[ $MODEL == *"gptj-6b"* ]]; then - MAX_INPUT_SEQLEN=1535 - MAX_OUTPUT_SEQLEN=512 - elif [[ $MODEL == *"mistral-7b"* ]]; then - MAX_INPUT_SEQLEN=32256 - MAX_OUTPUT_SEQLEN=512 - fi - DIR="bs${BS}_tokens${MAX_TOKENS}_tp${TP}_pp${PP}_isl${MAX_INPUT_SEQLEN}_osl${MAX_OUTPUT_SEQLEN}" - ENGINE_PATH=${script_dir}/../../tensorrt_llm/trt_engines/${MACHINE}/${MODEL}/${DIR} - - echo -e " \n ******** BUILDING $MODEL with TP=$TP PP=$PP ************* \n" - bash build_model.sh $MODEL $ENGINE_PATH $BS $MAX_INPUT_SEQLEN $MAX_OUTPUT_SEQLEN $MAX_TOKENS $TP $PP $WORLD_SIZE - - echo -e " \n ******** RUNNING $MODEL with TP=$TP PP=$PP *************** \n" - bash test.sh $MODEL $ENGINE_PATH $TOKENIZER_DIR $BS $MAX_INPUT_SEQLEN $TP $PP $WORLD_SIZE $RECORD_SERVER_STATS - -done diff --git a/tensorrt_llm b/tensorrt_llm index 89ba1b1a..bf0a5afc 160000 --- a/tensorrt_llm +++ b/tensorrt_llm @@ -1 +1 @@ -Subproject commit 89ba1b1a67d570e41b03da87e5518eaff0d31fbf +Subproject commit bf0a5afc92f4b2b3191e9e55073953c1f600cf2d diff --git a/tools/inflight_batcher_llm/benchmark_core_model.py b/tools/inflight_batcher_llm/benchmark_core_model.py index 7ebfcc09..619935be 100644 --- a/tools/inflight_batcher_llm/benchmark_core_model.py +++ b/tools/inflight_batcher_llm/benchmark_core_model.py @@ -17,14 +17,23 @@ from utils import utils -def callback(user_data, start_time, req_id, result, error): +def callback(user_data, result, error): user_data._completed_requests.put((result, error)) + if result is None: + # There was an error. + return + try: + # GRPC + req_id = result.get_response().id + except: + # HTTP + req_id = result.get_response()["id"] + start_time = user_data._start_time_dict[req_id] stop_time = datetime.now() latency = (stop_time - start_time).total_seconds() * 1000.0 latency = round(latency, 3) user_data._latencies.append(latency) user_data._latency_dict[req_id] = latency - user_data._start_time_dict[req_id] = start_time user_data._stop_time_dict[req_id] = stop_time @@ -33,6 +42,9 @@ def test_performance(client, input_start_ids, input_lens, output_lens, delays, model_name = "tensorrt_llm" print(f"[INFO] Warm up for benchmarking.") + if FLAGS.decoupled: + client.start_stream(callback=lambda result, error: None, + stream_timeout=FLAGS.stream_timeout) for i in range(10): output0_len = np.ones_like([[1]]).astype(np.int32) * 100 inputs = [ @@ -43,13 +55,22 @@ def test_performance(client, input_start_ids, input_lens, output_lens, delays, utils.prepare_tensor("request_output_len", output0_len, FLAGS.protocol), ] - client.infer(model_name, inputs, request_id=str(i)) + if FLAGS.decoupled: + client.async_stream_infer(model_name, inputs, request_id=str(i)) + else: + client.infer(model_name, inputs, request_id=str(i)) + if FLAGS.decoupled: + client.stop_stream() print(f"[INFO] Start benchmarking on {len(input_start_ids)} prompts.") latency = 0 async_requests = [] start_time = datetime.now() user_data = utils.UserData() + + if FLAGS.decoupled: + client.start_stream(callback=partial(callback, user_data), + stream_timeout=FLAGS.stream_timeout) for i, ids in enumerate(input_start_ids): output0_len = np.ones_like([[1]]).astype(np.int32) * output_lens[i] end_id = np.ones_like([[1]]).astype(np.int32) * -1 @@ -64,17 +85,23 @@ def test_performance(client, input_start_ids, input_lens, output_lens, delays, time.sleep(delays[i]) + user_data._start_time_dict[str(i)] = datetime.now() if FLAGS.protocol == "http": async_requests.append( client.async_infer(model_name, inputs, request_id=str(i))) elif FLAGS.protocol == "grpc": - async_requests.append( - client.async_infer(model_name, - inputs, - callback=partial(callback, user_data, - datetime.now(), i), - request_id=str(i))) - + if FLAGS.decoupled: + client.async_stream_infer(model_name, + inputs, + request_id=str(i)) + else: + async_requests.append( + client.async_infer(model_name, + inputs, + callback=partial(callback, user_data), + request_id=str(i))) + if FLAGS.decoupled: + client.stop_stream() try: if FLAGS.protocol == "http": utils.get_http_results(async_requests) @@ -213,6 +240,22 @@ def check_performance(data_dict, FLAGS): choices=['http', 'grpc'], help='Protocol ("http"/"grpc") used to ' + 'communicate with inference service. Default is "http".') + parser.add_argument( + '--decoupled', + action="store_true", + required=False, + default=False, + help= + 'Uses async_stream_infer which allows decoupled backends (must use grpc protocol)' + ), + parser.add_argument( + "-t", + "--stream-timeout", + type=float, + required=False, + default=None, + help="Stream timeout in seconds. Default is None.", + ) parser.add_argument('-c', '--concurrency', type=int, @@ -289,6 +332,9 @@ def check_performance(data_dict, FLAGS): FLAGS = parser.parse_args() if FLAGS.url is None: FLAGS.url = "localhost:8000" if FLAGS.protocol == "http" else "localhost:8001" + if FLAGS.decoupled and FLAGS.protocol != 'grpc': + print("Protocol must be set to 'grpc' when using '--decoupled'.") + sys.exit(1) try: client = utils.create_inference_server_client( diff --git a/tools/version.txt b/tools/version.txt index f4674031..1f7f6d8d 100644 --- a/tools/version.txt +++ b/tools/version.txt @@ -1 +1 @@ -c4379185a60f6924b9b3e8d3f52dead6e4c81d17 +4b62e65843b24315c9200c9b39a898a0a3717a7d