Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference error encountered while using the draft target model. #678

Open
2 of 4 tasks
pimang62 opened this issue Jan 13, 2025 · 0 comments
Open
2 of 4 tasks

Inference error encountered while using the draft target model. #678

pimang62 opened this issue Jan 13, 2025 · 0 comments
Labels
bug Something isn't working

Comments

@pimang62
Copy link

pimang62 commented Jan 13, 2025

System Info

  • GPU A100 40G*4
  • Container:nvcr.io/nvidia/tritonserver:24.11-trtllm-python-py3
  • Model:EXAONE(llamafied)

Who can help?

@juney-nvidia @juney-nvidia @kaiyux

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  • I set the arguments in triton model configs(: tensorrt_llm, tensorrt_llm_draft...) below.
# Only at once
# mkdir triton_model
BATCH_SIZE=1
DRAFT_VERSION=7b
ENGINE_NAME=engine_target_32b_checkpoint_4gpu_f16_model_4_${BATCH_SIZE}
ENGINE_DIR=engines/${DRAFT_VERSION}_${ENGINE_NAME}
DRAFT_ENGINE_NAME=engine_draft_${DRAFT_VERSION}_checkpoint_4gpu_f16_model_4_${BATCH_SIZE}
DRAFT_ENGINE_DIR=engines/${DRAFT_ENGINE_NAME}
TOKENIZER_DIR=/ex_disk_LUCY/pimang62/ckpt/7b-ours/checkpoint-8619 # 2b-ours/checkpoint-184 # 7b-ours/checkpoint-8619
MODEL_FOLDER=triton_model/${DRAFT_VERSION}_${ENGINE_NAME}
TRITON_MAX_BATCH_SIZE=${BATCH_SIZE}
INSTANCE_COUNT=1
MAX_QUEUE_SIZE=0
FILL_TEMPLATE_SCRIPT=tensorrtllm_backend/tools/fill_template.py
DECOUPLED_MODE=false  # set "true" when streaming
ACCUMULATE_TOKENS=false
BATCH_SCHEDULER_POLICY=guaranteed_no_evict
KV_CACHE_FREE_GPU_MEM_FRACTION=0.4
EXCLUDE_INPUT_IN_OUTPUT=true  ##
MAX_TOKENS_IN_KV_CACHE=32768  ##
MAX_ATTENTION_WINDOW_SIZE=""
MAX_QUEUE_DELAY_MICROSECONDS=1000000
MAX_BEAM_WIDTH=1
ENABLE_KV_CACHE_REUSE=true
NORMALIZE_LOG_PROBS=true
# TARGET_GPU_DEVICE_IDS=0,1,2,3
# DRAFT_GPU_DEVICE_IDS=0,1,2,3
DECODING_MODE="top_k_top_p"
tensorrt_llm_model_name=tensorrt_llm
tensorrt_llm_draft_model_name=tensorrt_llm_draft
USE_DRAFT_LOGITS=false

mkdir $MODEL_FOLDER
cp -r tensorrtllm_backend/all_models/inflight_batcher_llm/* $MODEL_FOLDER/

python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/ensemble/config.pbtxt triton_max_batch_size:${TRITON_MAX_BATCH_SIZE}
python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/preprocessing/config.pbtxt tokenizer_dir:${TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},preprocessing_instance_count:${INSTANCE_COUNT}
python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/postprocessing/config.pbtxt tokenizer_dir:${TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},postprocessing_instance_count:${INSTANCE_COUNT}
python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},bls_instance_count:${INSTANCE_COUNT},accumulate_tokens:${ACCUMULATE_TOKENS},tensorrt_llm_model_name:${tensorrt_llm_model_name},tensorrt_llm_draft_model_name:${tensorrt_llm_draft_model_name}

cp -r $MODEL_FOLDER/tensorrt_llm $MODEL_FOLDER/tensorrt_llm_draft
sed -i 's/name: "tensorrt_llm"/name: "tensorrt_llm_draft"/g' ${MODEL_FOLDER}/tensorrt_llm_draft/config.pbtxt

python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,engine_dir:${ENGINE_DIR},decoupled_mode:${DECOUPLED_MODE},max_tokens_in_paged_kv_cache:${MAX_TOKENS_IN_KV_CACHE},max_attention_window_size:${MAX_ATTENTION_WINDOW_SIZE},batch_scheduler_policy:${BATCH_SCHEDULER_POLICY},batching_strategy:inflight_fused_batching,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},max_beam_width:${MAX_BEAM_WIDTH},enable_kv_cache_reuse:${ENABLE_KV_CACHE_REUSE},normalize_log_probs:${NORMALIZE_LOG_PROBS},decoding_mode:${DECODING_MODE},encoder_input_features_data_type:TYPE_FP16,use_draft_logits:${USE_DRAFT_LOGITS},exclude_input_in_output:${EXCLUDE_INPUT_IN_OUTPUT}
python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/tensorrt_llm_draft/config.pbtxt triton_backend:tensorrtllm,engine_dir:${DRAFT_ENGINE_DIR},decoupled_mode:${DECOUPLED_MODE},max_tokens_in_paged_kv_cache:${MAX_TOKENS_IN_KV_CACHE},max_attention_window_size:${MAX_ATTENTION_WINDOW_SIZE},batch_scheduler_policy:${BATCH_SCHEDULER_POLICY},batching_strategy:inflight_fused_batching,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},max_beam_width:${MAX_BEAM_WIDTH},enable_kv_cache_reuse:${ENABLE_KV_CACHE_REUSE},normalize_log_probs:${NORMALIZE_LOG_PROBS},decoding_mode:${DECODING_MODE},use_draft_logits:${USE_DRAFT_LOGITS},exclude_input_in_output:${EXCLUDE_INPUT_IN_OUTPUT}
  • And sen�t requests iteratively using the run_speculative_inference() function in the e2e_grpc_speculative_decoding_client.py file. (actual executed func:run_speculative_inference_with_defaults())
    • There are no arguments in config "*_log_probs" or "*_logits", so I erased it.
  1. This is no streaming mode. (false decoupled mode)
#!/usr/bin/python

import os
import sys

sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))

import argparse
import queue
import sys

import numpy as np
import grpc
import tritonclient.grpc as grpcclient
from tritonclient.grpc._utils import _get_inference_request 
from tritonclient.utils import InferenceServerException, np_to_triton_dtype


def prepare_tensor(name, input):
    t = grpcclient.InferInput(name, input.shape,
                              np_to_triton_dtype(input.dtype))
    t.set_data_from_numpy(input)
    return t


class UserData:

    def __init__(self):
        self._completed_requests = queue.Queue()


def callback(user_data, result, error):
    if error:
        user_data._completed_requests.put(error)
    else:
        user_data._completed_requests.put(result)
        output = result.as_numpy('text_output')
        print(output, flush=True)

def get_preprocessor_inputs(prompt, output_len, bad_words, stop_words, end_id,
                            pad_id):
    input0 = [[prompt]]
    input0_data = np.array(input0).astype(object)
    output0_len = np.ones_like(input0).astype(np.int32) * output_len

    preprocessor_inputs = [
        prepare_tensor("QUERY", input0_data),
        prepare_tensor("REQUEST_OUTPUT_LEN", output0_len),
    ]

    if bad_words:
        bad_words_list = np.array([bad_words], dtype=object)
        preprocessor_inputs += [
            prepare_tensor("BAD_WORDS_DICT", bad_words_list)
        ]

    if stop_words:
        stop_words_list = np.array([stop_words], dtype=object)
        preprocessor_inputs += [
            prepare_tensor("STOP_WORDS_DICT", stop_words_list)
        ]

    if end_id:
        end_id_data = np.array([[end_id]], dtype=np.int32)
        preprocessor_inputs += [prepare_tensor("END_ID", end_id_data)]

    if pad_id:
        pad_id_data = np.array([[pad_id]], dtype=np.int32)
        preprocessor_inputs += [prepare_tensor("PAD_ID", pad_id_data)]

    return preprocessor_inputs


def extract_preprocessor_outputs(result):

    input_ids = np.squeeze(result.as_numpy("INPUT_ID").astype(np.int32),
                           axis=0)
    bad_words_ids = result.as_numpy("BAD_WORDS_IDS").astype(np.int32)
    stop_words_ids = result.as_numpy("STOP_WORDS_IDS").astype(np.int32)
    end_id = result.as_numpy("OUT_END_ID").astype(np.int32)[0][0]
    pad_id = result.as_numpy("OUT_PAD_ID").astype(np.int32)[0][0]

    return input_ids, bad_words_ids, stop_words_ids, end_id, pad_id


def get_trtllm_inputs(input_ids,
                      input_length,
                      request_output_len,
                      draft_tokens,
                      beam_width,
                      temperature,
                      repetition_penalty,
                      presence_penalty,
                      frequency_penalty,
                      bad_words_ids,
                      stop_words_ids,
                      end_id,
                      pad_id,
                      return_draft_model_draft_logits=False,
                      return_target_model_accepted_token_logits=False):

    # These two flags correspond to the settings of draft model and target model respectively,
    # and only one of them can be true at a time.
    assert not (return_draft_model_draft_logits
                and return_target_model_accepted_token_logits)

    # input_ids is expected to have shape [input_length]
    # Add batch dimension of 1
    input_ids_data = np.expand_dims(input_ids, axis=0)
    inputs = [
        prepare_tensor("input_ids", input_ids_data),
        prepare_tensor("input_lengths",
                       np.array([[input_length]], dtype=np.int32)),
        prepare_tensor("request_output_len",
                       np.array([[request_output_len]], dtype=np.int32)),
        prepare_tensor("bad_words_list", bad_words_ids),
        prepare_tensor("stop_words_list", stop_words_ids),
        prepare_tensor("beam_width", np.array([[beam_width]], dtype=np.int32)),
        prepare_tensor("temperature",
                       np.array([[temperature]], dtype=np.float32)),
    ]

    if draft_tokens is not None:
        draft_tokens_data = np.array([draft_tokens], dtype=np.int32)
        inputs.append(prepare_tensor("draft_input_ids", draft_tokens_data))

    if repetition_penalty is not None:
        repetition_penalty_data = np.array([[repetition_penalty]],
                                           dtype=np.float32)
        inputs.append(
            prepare_tensor("repetition_penalty", repetition_penalty_data))

    if presence_penalty is not None:
        presence_penalty_data = np.array([[presence_penalty]],
                                         dtype=np.float32)
        inputs.append(prepare_tensor("presence_penalty",
                                     presence_penalty_data))

    if frequency_penalty is not None:
        frequency_penalty_data = np.array([[frequency_penalty]],
                                          dtype=np.float32)
        inputs.append(
            prepare_tensor("frequency_penalty", frequency_penalty_data))

    if end_id is not None:
        end_id_data = np.array([[end_id]], dtype=np.int32)
        inputs.append(prepare_tensor("end_id", end_id_data))

    if pad_id is not None:
        pad_id_data = np.array([[pad_id]], dtype=np.int32)
        inputs.append(prepare_tensor("pad_id", pad_id_data))

    if return_draft_model_draft_logits:
        return_draft_model_draft_logits_data = np.array(
            [[return_draft_model_draft_logits]], dtype=bool)
        inputs.append(
            prepare_tensor("return_generation_logits",
                           return_draft_model_draft_logits_data))

    if return_target_model_accepted_token_logits:
        return_target_model_accepted_token_logits_data = np.array(
            [[return_target_model_accepted_token_logits]], dtype=bool)
        inputs.append(
            prepare_tensor("return_generation_logits",
                           return_target_model_accepted_token_logits_data))

    return inputs


def check_result(result, model_name):
    if type(result) == InferenceServerException:
        print(
            f"Received an error from server while calling {model_name}: {result}"
        )


def extract_trtllm_outputs(result):
    # Get batch 0, beam 0 output_ids
    output_ids = np.squeeze(result.as_numpy("output_ids").astype(np.int32),
                            axis=(0, 1))
    sequence_length_data = result.as_numpy("sequence_length").astype(np.int32)
    assert sequence_length_data.shape[0] == 1
    assert sequence_length_data.shape[1] == 1
    sequence_length = sequence_length_data[0, 0]
    # cum_log_probs = result.as_numpy("cum_log_probs").astype(np.float32)
    # output_log_probs = result.as_numpy("output_log_probs").astype(np.float32)
    # context_logits = result.as_numpy("context_logits").astype(np.float32)
    # generation_logits = result.as_numpy("generation_logits").astype(np.float32)
    return output_ids, sequence_length #, cum_log_probs, output_log_probs, context_logits, generation_logits


def get_postprocessor_inputs(output_ids): # , cum_log_probs, output_log_probs, context_logits, generation_logits
    output_ids_data = np.expand_dims(output_ids, axis=(0, 1))
    inputs = [
        prepare_tensor("TOKENS_BATCH", output_ids_data),
        prepare_tensor("SEQUENCE_LENGTH",
                       np.array([[len(output_ids)]], dtype=np.int32)),
        # prepare_tensor("CUM_LOG_PROBS", cum_log_probs),
        # prepare_tensor("OUTPUT_LOG_PROBS", output_log_probs),
        # prepare_tensor("CONTEXT_LOGITS", context_logits),
        # prepare_tensor("GENERATION_LOGITS", generation_logits)
    ]

    return inputs


def encountered_stop_words(input_ids, stop_words_ids):
    for stop_word_ids in stop_words_ids:
        if np.array_equal(input_ids[-len(stop_word_ids):], stop_word_ids):
            return True
    return False


def run_speculative_inference(
        client_draft, client_target, prompt, output_len, in_num_draft_tokens,
        request_id, repetition_penalty, presence_penalty, frequency_penalty,
        temperature, stop_words, bad_words, end_id, pad_id, beam_width,
        preprocessor_model_name, draft_tensorrt_llm_model_name,
        target_tensorrt_llm_model_name, postprocessor_model_name,
        return_draft_model_draft_logits,
        return_target_model_accepted_token_logits, verbose):

    from datetime import datetime ##
    start_time = datetime.now() ##
    # Call the preprocessor
    preprocessor_inputs = get_preprocessor_inputs(prompt, output_len,
                                                  bad_words, stop_words,
                                                  end_id, pad_id)
    preprocessor_result = client_draft.infer(preprocessor_model_name,
                                             preprocessor_inputs,
                                             request_id=request_id)
    check_result(preprocessor_result, preprocessor_model_name)
    prompt_input_ids, bad_words_ids, stop_words_ids, end_id, pad_id = extract_preprocessor_outputs(
        preprocessor_result)

    input_ids = prompt_input_ids
    last_input_ids = None
    draft_output_ids = None

    while True:

        num_draft_tokens = min(
            in_num_draft_tokens,
            len(prompt_input_ids) + output_len - len(input_ids) - 1)

        if num_draft_tokens > 0:

            if verbose:
                print("Draft model input ids:")
                print(input_ids.tolist())

            #Generate up to num_draft_tokens with draft model
            draft_inputs = get_trtllm_inputs(
                input_ids,
                len(input_ids),
                num_draft_tokens,
                None,
                beam_width,
                temperature,
                repetition_penalty,
                presence_penalty,
                frequency_penalty,
                bad_words_ids,
                stop_words_ids,
                end_id,
                pad_id,
                return_draft_model_draft_logits=return_draft_model_draft_logits
            )

            draft_result = client_draft.infer(draft_tensorrt_llm_model_name,
                                              draft_inputs,
                                              request_id=request_id)
            check_result(draft_result, draft_tensorrt_llm_model_name)
            draft_output_ids, draft_seq_len = extract_trtllm_outputs(  # , cum_log_probs, output_log_probs, context_logits, generation_logits
                draft_result)

            if verbose:
                print("Draft model output ids:")
                print(draft_output_ids.tolist())
                print("draft_sequence_length")
                print(draft_seq_len)

            # Set the draft token and call the target model to generate up to num_draft_tokens + 1
            draft_tokens = draft_output_ids[len(input_ids):draft_seq_len]

            if verbose:
                print("draft_tokens")
                print(draft_tokens.tolist())
                if return_draft_model_draft_logits:
                    draft_model_draft_token_logits = generation_logits.squeeze(
                        0)  # [beam_width, num_draft_tokens, vocab_size]
                    print(
                        f"draft model draft tokens' logits: shape: {draft_model_draft_token_logits.shape}, value: {draft_model_draft_token_logits}"
                    )

        if verbose:
            print("Target model input ids")
            print(input_ids.tolist())

        # Generate up to len(draft_tokens) + 1 with target model
        target_inputs = get_trtllm_inputs(
            input_ids,
            len(input_ids),
            len(draft_tokens) + 1 if num_draft_tokens > 0 else 1,
            draft_tokens if num_draft_tokens > 0 else None,
            beam_width,
            temperature,
            repetition_penalty,
            presence_penalty,
            frequency_penalty,
            bad_words_ids,
            stop_words_ids,
            end_id,
            pad_id,
            return_target_model_accepted_token_logits=
            return_target_model_accepted_token_logits)

        target_result = client_target.infer(target_tensorrt_llm_model_name,
                                            target_inputs,
                                            request_id=request_id)
        check_result(target_result, target_tensorrt_llm_model_name)
        target_output_ids, seq_length = extract_trtllm_outputs(  # , cum_log_probs, output_log_probs, context_logits, generation_logits
            target_result)

        if verbose:
            print("Target model output_ids")
            print(target_output_ids.tolist())
            print("target seq_length")
            print(seq_length)
            if return_target_model_accepted_token_logits:
                target_model_accept_token_logits = generation_logits.squeeze(
                    0).squeeze(0)  # [num_accepted_tokens, vocab_size]
                print(
                    f"target model accepted tokens' logits: shape: {target_model_accept_token_logits.shape}, value: {target_model_accept_token_logits}"
                )
        
        # Store the last iteration input_ids to check if EOS was encountered
        last_input_ids = input_ids
        # Update the input ids with new output_ids
        input_ids = target_output_ids

        # Evaluate criteria to stop generation loop.
        # If we've hit or exceeded the max output length, should stop
        length_stop = (len(input_ids) >= len(prompt_input_ids) + output_len)
        # If draft and target have same outputs, should stop. Normally target should return 1 more token.
        # If they are the same length, they should differ at the last token
        target_draft_equal = draft_output_ids is not None and np.array_equal(
            draft_output_ids, target_output_ids)
        # If tokens no longer change, should stop, means we have hit early stopping
        last_current_equal = np.array_equal(last_input_ids, input_ids)
        # Need to check if stop words was encountered
        hit_stop_words = encountered_stop_words(input_ids, stop_words_ids[0])

        if verbose:
            print("length_stop:", length_stop)
            print("target_draft_equal:", target_draft_equal)
            print("last_current_equal:", last_current_equal)
            print("hit_stop_words:", hit_stop_words)

        if (length_stop or target_draft_equal or last_current_equal
                or hit_stop_words):
            break

    # Call the postprocessor
    postprocessor_inputs = get_postprocessor_inputs(input_ids)  # , cum_log_probs, output_log_probs, context_logits, generation_logits
    postprocessor_result = client_target.infer(postprocessor_model_name,
                                               postprocessor_inputs,
                                               request_id=request_id)
    check_result(postprocessor_result, postprocessor_model_name)
    output = postprocessor_result.as_numpy("OUTPUT")
    
    # print(f"Output: {output[0].decode('utf-8')}")
    response_time = (datetime.now() - start_time).total_seconds()
    print(f"Response Time: {response_time}") ##
    return output

def run_speculative_inference_with_defaults(
        prompt: str,
        url_target: str = "localhost:8001",
        url_draft: str = None,
        output_len: int = 1000,
        num_draft_tokens: int = 10,
        beam_width: int = 1,
        temperature: float = 1.0,
        repetition_penalty: float = None,
        presence_penalty: float = None,
        frequency_penalty: float = None,
        stop_words: list = None,
        bad_words: list = None,
        end_id: int = 2,
        pad_id: int = 0,
        preprocessor_model_name: str = "preprocessing",
        draft_tensorrt_llm_model_name: str = "tensorrt_llm_draft",
        target_tensorrt_llm_model_name: str = "tensorrt_llm",
        postprocessor_model_name: str = "postprocessing",
        return_draft_model_draft_logits: bool = False,
        return_target_model_accepted_token_logits: bool = False,
        verbose: bool = False
):
    # Ensure draft URL defaults to target URL if not provided
    if url_draft is None:
        url_draft = url_target

    # Create Triton clients for target and draft
    try:
        client_target = grpcclient.InferenceServerClient(url=url_target)
        client_draft = grpcclient.InferenceServerClient(
            url=url_draft) if url_target != url_draft else client_target
    except Exception as e:
        print(f"Failed to create Triton client: {e}")
        return None

    if beam_width > 1:
        raise Exception(
            'Beam width > 1 is not yet supported with speculative decoding'
        )

    # Call the speculative inference function
    return run_speculative_inference(
        url_draft=url_draft,
        url_target=url_target,
        prompt=prompt,
        output_len=output_len,
        in_num_draft_tokens=num_draft_tokens,
        request_id="1",  # Default request ID
        repetition_penalty=repetition_penalty,
        presence_penalty=presence_penalty,
        frequency_penalty=frequency_penalty,
        temperature=temperature,
        stop_words=stop_words,
        bad_words=bad_words,
        end_id=end_id,
        pad_id=pad_id,
        beam_width=beam_width,
        preprocessor_model_name=preprocessor_model_name,
        draft_tensorrt_llm_model_name=draft_tensorrt_llm_model_name,
        target_tensorrt_llm_model_name=target_tensorrt_llm_model_name,
        postprocessor_model_name=postprocessor_model_name,
        return_draft_model_draft_logits=return_draft_model_draft_logits,
        return_target_model_accepted_token_logits=
        return_target_model_accepted_token_logits,
        verbose=verbose
    )

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-v',
                        '--verbose',
                        action="store_true",
                        required=False,
                        default=False,
                        help='Enable verbose output')

    parser.add_argument('--url-target',
                        type=str,
                        required=True,
                        help='Inference server URL for the target model')

    parser.add_argument('--url-draft',
                        type=str,
                        required=False,
                        help='Inference server URL for the draft model')

    parser.add_argument(
        '--preprocessor-model-name',
        type=str,
        required=False,
        default="preprocessing",
        help='Name of the preprocessor model (should be hosted at url-draft)')

    parser.add_argument(
        '--postprocessor-model-name',
        type=str,
        required=False,
        default="postprocessing",
        help='Name of the postprocessor model (should be hosted at url-target)'
    )

    parser.add_argument(
        '--draft-tensorrt-llm-model-name',
        type=str,
        required=False,
        default="tensorrt_llm",
        help='Name of the tensorrt_llm draft model (hosted at url-draft)')

    parser.add_argument(
        '--target-tensorrt-llm-model-name',
        type=str,
        required=False,
        default="tensorrt_llm",
        help='Name of the tensorrt_llm draft model (hosted at url-target)')

    parser.add_argument('-p',
                        '--prompt',
                        type=str,
                        required=True,
                        help='Input prompt.')

    parser.add_argument(
        "-b",
        "--beam-width",
        required=False,
        type=int,
        default=1,
        help="Beam width value",
    )

    parser.add_argument(
        "--temperature",
        type=float,
        required=False,
        default=1.0,
        help="temperature value",
    )

    parser.add_argument(
        "--repetition-penalty",
        type=float,
        required=False,
        default=None,
        help="The repetition penalty value",
    )

    parser.add_argument(
        "--presence-penalty",
        type=float,
        required=False,
        default=None,
        help="The presence penalty value",
    )

    parser.add_argument(
        "--frequency-penalty",
        type=float,
        required=False,
        default=None,
        help="The frequency penalty value",
    )

    parser.add_argument('-o',
                        '--output-len',
                        type=int,
                        default=1000,
                        required=False,
                        help='Specify output length')

    parser.add_argument(
        '--num-draft-tokens',
        type=int,
        default=5,
        required=False,
        help=
        'Specify the number of speculative tokens for the draft model to generate per lookahead.'
    )

    parser.add_argument('--end-id',
                        type=int,
                        default=None,
                        required=False,
                        help='The end if token')

    parser.add_argument('--pad-id',
                        type=int,
                        default=None,
                        required=False,
                        help='The pad if token')

    parser.add_argument('--request-id',
                        type=str,
                        default='1',
                        required=False,
                        help='The request_id for the stop request')

    parser.add_argument('--stop-words',
                        nargs='+',
                        default=[],
                        help='The stop words')

    parser.add_argument('--bad-words',
                        nargs='+',
                        default=[],
                        help='The bad words')

    parser.add_argument(
        "--return-draft-model-draft-logits",
        action="store_true",
        required=False,
        default=False,
        help=
        "Return draft model's draft tokens' logits, require to enable `gather_generation_logits` when build engine"
    )

    parser.add_argument(
        "--return-target-model-accepted-token-logits",
        action="store_true",
        required=False,
        default=False,
        help=
        "Return target model's accepted token logits, require to enable `gather_generation_logits` when build engine",
    )

    FLAGS = parser.parse_args()
    if not FLAGS.url_target:
        FLAGS.url_target = "localhost:8001"

    if not FLAGS.url_draft:
        FLAGS.url_draft = FLAGS.url_target

    try:
        client_target = grpcclient.InferenceServerClient(url=FLAGS.url_target)
        client_draft = grpcclient.InferenceServerClient(
            url=FLAGS.url_draft) if (
                FLAGS.url_target != FLAGS.url_draft) else client_target
    except Exception as e:
        print("client creation failed: " + str(e))
        sys.exit(1)

    if (FLAGS.beam_width > 1):
        raise Exception(
            'Beam width > 1 is not yet supported with speculative decoding')

    output_text = run_speculative_inference(
        client_draft, client_target, FLAGS.prompt, FLAGS.output_len,
        FLAGS.num_draft_tokens, FLAGS.request_id, FLAGS.repetition_penalty,
        FLAGS.presence_penalty, FLAGS.frequency_penalty, FLAGS.temperature,
        FLAGS.stop_words, FLAGS.bad_words, FLAGS.end_id, FLAGS.pad_id,
        FLAGS.beam_width, FLAGS.preprocessor_model_name,
        FLAGS.draft_tensorrt_llm_model_name,
        FLAGS.target_tensorrt_llm_model_name, FLAGS.postprocessor_model_name,
        FLAGS.return_draft_model_draft_logits,
        FLAGS.return_target_model_accepted_token_logits, FLAGS.verbose)

    # Print the final text
    print("Final text:\n", output_text)
  1. This is streaming mode. (true decoupled mode)
#!/usr/bin/python

import os
import sys

sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))

import argparse
import queue
import sys

import numpy as np
import grpc
import tritonclient.grpc as grpcclient
from tritonclient.grpc._utils import _get_inference_request 
from tritonclient.utils import InferenceServerException, np_to_triton_dtype

def prepare_tensor(name, input):
    t = grpcclient.InferInput(name, input.shape,
                              np_to_triton_dtype(input.dtype))
    t.set_data_from_numpy(input)
    return t

def callback(user_data, result, error):
    if error:
        user_data._completed_requests.put(error)
    else:
        user_data._completed_requests.put(result)
        output = result.as_numpy('text_output')
        print(output, flush=True)

class _RequestIterator:
    def __init__(self,
        model_name,
        inputs,
        model_version="",
        outputs=None,
        request_id="",
        sequence_id=0,
        sequence_start=False,
        sequence_end=False,
        priority=0,
        timeout=None,
        parameters=None):
        self.request_queue = queue.Queue()
        request = _get_inference_request(model_name = model_name,
                                                    inputs=inputs,
                                                    model_version=model_version,
                                                    request_id=request_id,
                                                    outputs=outputs,
                                                    sequence_id=sequence_id,
                                                    sequence_start=sequence_start,
                                                    sequence_end=sequence_end,
                                                    priority=priority,
                                                    timeout=timeout,
                                                    parameters = parameters)
        self.request_queue.put(request)
    def __next__(self):
        if self.request_queue.empty():
            raise StopIteration
        request = self.request_queue.get()
        return request


def get_preprocessor_inputs(prompt, output_len, bad_words, stop_words, end_id,
                            pad_id):
    input0 = [[prompt]]
    input0_data = np.array(input0).astype(object)
    output0_len = np.ones_like(input0).astype(np.int32) * output_len

    preprocessor_inputs = [
        prepare_tensor("QUERY", input0_data),
        prepare_tensor("REQUEST_OUTPUT_LEN", output0_len),
    ]

    if bad_words:
        bad_words_list = np.array([bad_words], dtype=object)
        preprocessor_inputs += [
            prepare_tensor("BAD_WORDS_DICT", bad_words_list)
        ]

    if stop_words:
        stop_words_list = np.array([stop_words], dtype=object)
        preprocessor_inputs += [
            prepare_tensor("STOP_WORDS_DICT", stop_words_list)
        ]

    if end_id:
        end_id_data = np.array([[end_id]], dtype=np.int32)
        preprocessor_inputs += [prepare_tensor("END_ID", end_id_data)]

    if pad_id:
        pad_id_data = np.array([[pad_id]], dtype=np.int32)
        preprocessor_inputs += [prepare_tensor("PAD_ID", pad_id_data)]

    return preprocessor_inputs


def extract_preprocessor_outputs(result):

    input_ids = np.squeeze(result.as_numpy("INPUT_ID").astype(np.int32),
                           axis=0)
    bad_words_ids = result.as_numpy("BAD_WORDS_IDS").astype(np.int32)
    stop_words_ids = result.as_numpy("STOP_WORDS_IDS").astype(np.int32)
    end_id = result.as_numpy("OUT_END_ID").astype(np.int32)[0][0]
    pad_id = result.as_numpy("OUT_PAD_ID").astype(np.int32)[0][0]

    return input_ids, bad_words_ids, stop_words_ids, end_id, pad_id


def get_trtllm_inputs(input_ids,
                      input_length,
                      request_output_len,
                      draft_tokens,
                      beam_width,
                      temperature,
                      repetition_penalty,
                      presence_penalty,
                      frequency_penalty,
                      bad_words_ids,
                      stop_words_ids,
                      end_id,
                      pad_id,
                      return_draft_model_draft_logits=False,
                      return_target_model_accepted_token_logits=False):

    # These two flags correspond to the settings of draft model and target model respectively,
    # and only one of them can be true at a time.
    assert not (return_draft_model_draft_logits
                and return_target_model_accepted_token_logits)

    # input_ids is expected to have shape [input_length]
    # Add batch dimension of 1
    input_ids_data = np.expand_dims(input_ids, axis=0)
    inputs = [
        prepare_tensor("input_ids", input_ids_data),
        prepare_tensor("input_lengths",
                       np.array([[input_length]], dtype=np.int32)),
        prepare_tensor("request_output_len",
                       np.array([[request_output_len]], dtype=np.int32)),
        prepare_tensor("bad_words_list", bad_words_ids),
        prepare_tensor("stop_words_list", stop_words_ids),
        prepare_tensor("beam_width", np.array([[beam_width]], dtype=np.int32)),
        prepare_tensor("temperature",
                       np.array([[temperature]], dtype=np.float32)),
    ]

    if draft_tokens is not None:
        draft_tokens_data = np.array([draft_tokens], dtype=np.int32)
        inputs.append(prepare_tensor("draft_input_ids", draft_tokens_data))

    if repetition_penalty is not None:
        repetition_penalty_data = np.array([[repetition_penalty]],
                                           dtype=np.float32)
        inputs.append(
            prepare_tensor("repetition_penalty", repetition_penalty_data))

    if presence_penalty is not None:
        presence_penalty_data = np.array([[presence_penalty]],
                                         dtype=np.float32)
        inputs.append(prepare_tensor("presence_penalty",
                                     presence_penalty_data))

    if frequency_penalty is not None:
        frequency_penalty_data = np.array([[frequency_penalty]],
                                          dtype=np.float32)
        inputs.append(
            prepare_tensor("frequency_penalty", frequency_penalty_data))

    if end_id is not None:
        end_id_data = np.array([[end_id]], dtype=np.int32)
        inputs.append(prepare_tensor("end_id", end_id_data))

    if pad_id is not None:
        pad_id_data = np.array([[pad_id]], dtype=np.int32)
        inputs.append(prepare_tensor("pad_id", pad_id_data))

    if return_draft_model_draft_logits:
        return_draft_model_draft_logits_data = np.array(
            [[return_draft_model_draft_logits]], dtype=bool)
        inputs.append(
            prepare_tensor("return_generation_logits",
                           return_draft_model_draft_logits_data))

    if return_target_model_accepted_token_logits:
        return_target_model_accepted_token_logits_data = np.array(
            [[return_target_model_accepted_token_logits]], dtype=bool)
        inputs.append(
            prepare_tensor("return_generation_logits",
                           return_target_model_accepted_token_logits_data))

    return inputs


def check_result(result, model_name):
    if type(result) == InferenceServerException:
        print(
            f"Received an error from server while calling {model_name}: {result}"
        )


def extract_trtllm_outputs(result):
    # Get batch 0, beam 0 output_ids
    output_ids = np.squeeze(result.as_numpy("output_ids").astype(np.int32),
                            axis=(0, 1))
    sequence_length_data = result.as_numpy("sequence_length").astype(np.int32)
    assert sequence_length_data.shape[0] == 1
    assert sequence_length_data.shape[1] == 1
    sequence_length = sequence_length_data[0, 0]
    # cum_log_probs = result.as_numpy("cum_log_probs").astype(np.float32)
    # output_log_probs = result.as_numpy("output_log_probs").astype(np.float32)
    # context_logits = result.as_numpy("context_logits").astype(np.float32)
    # generation_logits = result.as_numpy("generation_logits").astype(np.float32)
    return output_ids, sequence_length #, cum_log_probs, output_log_probs, context_logits, generation_logits


def get_postprocessor_inputs(output_ids): # , cum_log_probs, output_log_probs, context_logits, generation_logits
    output_ids_data = np.expand_dims(output_ids, axis=(0, 1))
    inputs = [
        prepare_tensor("TOKENS_BATCH", output_ids_data),
        prepare_tensor("SEQUENCE_LENGTH",
                       np.array([[len(output_ids)]], dtype=np.int32)),
        # prepare_tensor("CUM_LOG_PROBS", cum_log_probs),
        # prepare_tensor("OUTPUT_LOG_PROBS", output_log_probs),
        # prepare_tensor("CONTEXT_LOGITS", context_logits),
        # prepare_tensor("GENERATION_LOGITS", generation_logits)
    ]

    return inputs


def encountered_stop_words(input_ids, stop_words_ids):
    for stop_word_ids in stop_words_ids:
        if np.array_equal(input_ids[-len(stop_word_ids):], stop_word_ids):
            return True
    return False


def run_speculative_inference(
        client_draft, client_target, prompt, output_len, in_num_draft_tokens,
        request_id, repetition_penalty, presence_penalty, frequency_penalty,
        temperature, stop_words, bad_words, end_id, pad_id, beam_width,
        preprocessor_model_name, draft_tensorrt_llm_model_name,
        target_tensorrt_llm_model_name, postprocessor_model_name,
        return_draft_model_draft_logits,
        return_target_model_accepted_token_logits, verbose):

    from datetime import datetime ##
    start_time = datetime.now() ##
    with grpc.insecure_channel(url_target) as channel:
        stub = grpcclient.service_pb2_grpc.GRPCInferenceServiceStub(channel)

        # Preprocessor : streaming
        preprocessor_inputs = get_preprocessor_inputs(prompt, output_len, bad_words, stop_words, end_id, pad_id)
        preprocessor_iterator = stub.ModelStreamInfer(
            _RequestIterator(
                model_name=preprocessor_model_name,
                inputs=preprocessor_inputs,
            ),
            metadata={},
            timeout=None,
            compression=grpc.Compression.NoCompression
        )
        for response in preprocessor_iterator:
            if response.error_message != "":
                raise RuntimeError(grpcclient.InferenceServerException(msg=response.error_message))
            preprocessor_result = grpcclient.InferResult(response.infer_response)
            break

        check_result(preprocessor_result, preprocessor_model_name)
        prompt_input_ids, bad_words_ids, stop_words_ids, end_id, pad_id = extract_preprocessor_outputs(preprocessor_result)

        input_ids = prompt_input_ids
        last_input_ids = prompt_input_ids
        draft_output_ids = None

        while True:
            num_draft_tokens = min(
                in_num_draft_tokens,
                len(prompt_input_ids) + output_len - len(input_ids) - 1
            )

            draft_tokens = None
            if num_draft_tokens > 0:
                if verbose:
                    print("Draft model input ids:")
                    print(input_ids.tolist())

                draft_inputs = get_trtllm_inputs(
                    input_ids,
                    len(input_ids),
                    num_draft_tokens,
                    None,
                    beam_width,
                    temperature,
                    repetition_penalty,
                    presence_penalty,
                    frequency_penalty,
                    bad_words_ids,
                    stop_words_ids,
                    end_id,
                    pad_id,
                    return_draft_model_draft_logits=return_draft_model_draft_logits
                )

                draft_iterator = stub.ModelStreamInfer(
                    _RequestIterator(
                        model_name=draft_tensorrt_llm_model_name,
                        inputs=draft_inputs
                    ),
                    metadata={},
                    timeout=None,
                    compression=grpc.Compression.NoCompression
                )
                for response in draft_iterator:
                    if response.error_message != "":
                        raise RuntimeError(grpcclient.InferenceServerException(msg=response.error_message))
                    draft_result = grpcclient.InferResult(response.infer_response)
                    break

                check_result(draft_result, draft_tensorrt_llm_model_name)
                draft_output_ids, draft_seq_len = extract_trtllm_outputs(draft_result)

                if verbose:
                    print("Draft model output ids:")
                    print(draft_output_ids.tolist())
                    print("draft_sequence_length:", draft_seq_len)

                draft_tokens = draft_output_ids[len(input_ids):draft_seq_len]
                if verbose:
                    print("draft_tokens:")
                    print(draft_tokens.tolist())

            if verbose:
                print("Target model input ids:")
                print(input_ids.tolist())

            target_inputs = get_trtllm_inputs(
                input_ids,
                len(input_ids),
                len(draft_tokens) + 1 if num_draft_tokens > 0 else 1,
                draft_tokens if num_draft_tokens > 0 else None,
                beam_width,
                temperature,
                repetition_penalty,
                presence_penalty,
                frequency_penalty,
                bad_words_ids,
                stop_words_ids,
                end_id,
                pad_id,
                return_target_model_accepted_token_logits=return_target_model_accepted_token_logits
            )

            target_iterator = stub.ModelStreamInfer(
                _RequestIterator(
                    model_name=target_tensorrt_llm_model_name,
                    inputs=target_inputs
                ),
                metadata={},
                timeout=None,
                compression=grpc.Compression.NoCompression
            )
            for response in target_iterator:
                if response.error_message != "":
                    raise RuntimeError(grpcclient.InferenceServerException(msg=response.error_message))
                target_result = grpcclient.InferResult(response.infer_response)
                break

            check_result(target_result, target_tensorrt_llm_model_name)
            target_output_ids, seq_length = extract_trtllm_outputs(target_result)

            if verbose:
                print("Target model output_ids:")
                print(target_output_ids.tolist())
                print("target seq_length:", seq_length)


            last_input_ids = input_ids
            input_ids = target_output_ids

            length_stop = (len(input_ids) >= len(prompt_input_ids) + output_len)
            target_draft_equal = draft_output_ids is not None and np.array_equal(draft_output_ids, target_output_ids)
            last_current_equal = np.array_equal(last_input_ids, input_ids)
            hit_stop_words = encountered_stop_words(input_ids, stop_words_ids)

            if verbose:
                print("length_stop:", length_stop)
                print("target_draft_equal:", target_draft_equal)
                print("last_current_equal:", last_current_equal)
                print("hit_stop_words:", hit_stop_words)

            if (length_stop or target_draft_equal or last_current_equal or hit_stop_words):
                break

        # Postprocessor: streaming
        postprocessor_inputs = get_postprocessor_inputs(input_ids)
        postprocessor_iterator = stub.ModelStreamInfer(
            _RequestIterator(
                model_name=postprocessor_model_name,
                inputs=postprocessor_inputs
            ),
            metadata={},
            timeout=None,
            compression=grpc.Compression.NoCompression
        )
        for response in postprocessor_iterator:
            if response.error_message != "":
                raise RuntimeError(grpcclient.InferenceServerException(msg=response.error_message))
            postprocessor_result = grpcclient.InferResult(response.infer_response)
            break

        check_result(postprocessor_result, postprocessor_model_name)
        output = postprocessor_result.as_numpy("OUTPUT")
        response_time = (datetime.now()-start_time).total_seconds() ##
        print(f"Response Time: {response_time}") ##

    return output

def run_speculative_inference_with_defaults(
        prompt: str,
        url_target: str = "localhost:8001",
        url_draft: str = None,
        output_len: int = 1000,
        num_draft_tokens: int = 10,
        beam_width: int = 1,
        temperature: float = 1.0,
        repetition_penalty: float = None,
        presence_penalty: float = None,
        frequency_penalty: float = None,
        stop_words: list = None,
        bad_words: list = None,
        end_id: int = 2,
        pad_id: int = 0,
        preprocessor_model_name: str = "preprocessing",
        draft_tensorrt_llm_model_name: str = "tensorrt_llm_draft",
        target_tensorrt_llm_model_name: str = "tensorrt_llm",
        postprocessor_model_name: str = "postprocessing",
        return_draft_model_draft_logits: bool = False,
        return_target_model_accepted_token_logits: bool = False,
        verbose: bool = False
):
    # Ensure draft URL defaults to target URL if not provided
    if url_draft is None:
        url_draft = url_target

    # Create Triton clients for target and draft
    try:
        client_target = grpcclient.InferenceServerClient(url=url_target)
        client_draft = grpcclient.InferenceServerClient(
            url=url_draft) if url_target != url_draft else client_target
    except Exception as e:
        print(f"Failed to create Triton client: {e}")
        return None

    if beam_width > 1:
        raise Exception(
            'Beam width > 1 is not yet supported with speculative decoding'
        )

    # Call the speculative inference function
    return run_speculative_inference(
        url_draft=url_draft,
        url_target=url_target,
        prompt=prompt,
        output_len=output_len,
        in_num_draft_tokens=num_draft_tokens,
        request_id="1",  # Default request ID
        repetition_penalty=repetition_penalty,
        presence_penalty=presence_penalty,
        frequency_penalty=frequency_penalty,
        temperature=temperature,
        stop_words=stop_words,
        bad_words=bad_words,
        end_id=end_id,
        pad_id=pad_id,
        beam_width=beam_width,
        preprocessor_model_name=preprocessor_model_name,
        draft_tensorrt_llm_model_name=draft_tensorrt_llm_model_name,
        target_tensorrt_llm_model_name=target_tensorrt_llm_model_name,
        postprocessor_model_name=postprocessor_model_name,
        return_draft_model_draft_logits=return_draft_model_draft_logits,
        return_target_model_accepted_token_logits=
        return_target_model_accepted_token_logits,
        verbose=verbose
    )

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-v',
                        '--verbose',
                        action="store_true",
                        required=False,
                        default=False,
                        help='Enable verbose output')

    parser.add_argument('--url-target',
                        type=str,
                        required=True,
                        help='Inference server URL for the target model')

    parser.add_argument('--url-draft',
                        type=str,
                        required=False,
                        help='Inference server URL for the draft model')

    parser.add_argument(
        '--preprocessor-model-name',
        type=str,
        required=False,
        default="preprocessing",
        help='Name of the preprocessor model (should be hosted at url-draft)')

    parser.add_argument(
        '--postprocessor-model-name',
        type=str,
        required=False,
        default="postprocessing",
        help='Name of the postprocessor model (should be hosted at url-target)'
    )

    parser.add_argument(
        '--draft-tensorrt-llm-model-name',
        type=str,
        required=False,
        default="tensorrt_llm",
        help='Name of the tensorrt_llm draft model (hosted at url-draft)')

    parser.add_argument(
        '--target-tensorrt-llm-model-name',
        type=str,
        required=False,
        default="tensorrt_llm",
        help='Name of the tensorrt_llm draft model (hosted at url-target)')

    parser.add_argument('-p',
                        '--prompt',
                        type=str,
                        required=True,
                        help='Input prompt.')

    parser.add_argument(
        "-b",
        "--beam-width",
        required=False,
        type=int,
        default=1,
        help="Beam width value",
    )

    parser.add_argument(
        "--temperature",
        type=float,
        required=False,
        default=1.0,
        help="temperature value",
    )

    parser.add_argument(
        "--repetition-penalty",
        type=float,
        required=False,
        default=None,
        help="The repetition penalty value",
    )

    parser.add_argument(
        "--presence-penalty",
        type=float,
        required=False,
        default=None,
        help="The presence penalty value",
    )

    parser.add_argument(
        "--frequency-penalty",
        type=float,
        required=False,
        default=None,
        help="The frequency penalty value",
    )

    parser.add_argument('-o',
                        '--output-len',
                        type=int,
                        default=1000,
                        required=False,
                        help='Specify output length')

    parser.add_argument(
        '--num-draft-tokens',
        type=int,
        default=5,
        required=False,
        help=
        'Specify the number of speculative tokens for the draft model to generate per lookahead.'
    )

    parser.add_argument('--end-id',
                        type=int,
                        default=None,
                        required=False,
                        help='The end if token')

    parser.add_argument('--pad-id',
                        type=int,
                        default=None,
                        required=False,
                        help='The pad if token')

    parser.add_argument('--request-id',
                        type=str,
                        default='1',
                        required=False,
                        help='The request_id for the stop request')

    parser.add_argument('--stop-words',
                        nargs='+',
                        default=[],
                        help='The stop words')

    parser.add_argument('--bad-words',
                        nargs='+',
                        default=[],
                        help='The bad words')

    parser.add_argument(
        "--return-draft-model-draft-logits",
        action="store_true",
        required=False,
        default=False,
        help=
        "Return draft model's draft tokens' logits, require to enable `gather_generation_logits` when build engine"
    )

    parser.add_argument(
        "--return-target-model-accepted-token-logits",
        action="store_true",
        required=False,
        default=False,
        help=
        "Return target model's accepted token logits, require to enable `gather_generation_logits` when build engine",
    )

    FLAGS = parser.parse_args()
    if not FLAGS.url_target:
        FLAGS.url_target = "localhost:8001"

    if not FLAGS.url_draft:
        FLAGS.url_draft = FLAGS.url_target

    try:
        client_target = grpcclient.InferenceServerClient(url=FLAGS.url_target)
        client_draft = grpcclient.InferenceServerClient(
            url=FLAGS.url_draft) if (
                FLAGS.url_target != FLAGS.url_draft) else client_target
    except Exception as e:
        print("client creation failed: " + str(e))
        sys.exit(1)

    if (FLAGS.beam_width > 1):
        raise Exception(
            'Beam width > 1 is not yet supported with speculative decoding')

    output_text = run_speculative_inference(
        client_draft, client_target, FLAGS.prompt, FLAGS.output_len,
        FLAGS.num_draft_tokens, FLAGS.request_id, FLAGS.repetition_penalty,
        FLAGS.presence_penalty, FLAGS.frequency_penalty, FLAGS.temperature,
        FLAGS.stop_words, FLAGS.bad_words, FLAGS.end_id, FLAGS.pad_id,
        FLAGS.beam_width, FLAGS.preprocessor_model_name,
        FLAGS.draft_tensorrt_llm_model_name,
        FLAGS.target_tensorrt_llm_model_name, FLAGS.postprocessor_model_name,
        FLAGS.return_draft_model_draft_logits,
        FLAGS.return_target_model_accepted_token_logits, FLAGS.verbose)

    # Print the final text
    print("Final text:\n", output_text)

Expected behavior

"prompt" in "output" out

actual behavior

  • I've got an assertion error with the message of !mTokens.empty()

    File "/AutoEval/e2e_grpc_speculative_decoding_client.py", line 367, in run_speculative_inference
    raise RuntimeError(grpcclient.InferenceServerException(msg=response.error_message))
    RuntimeError: [TensorRT-LLM][ERROR] Assertion failed: !mTokens.empty() (/workspace/tensorrt_llm/cpp/tensorrt_llm/executor/decodingConfig.cpp:31)
    1 0x7e52d0b7bc64 tensorrt_llm::common::throwRuntimeError(char const*, int, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&) + 100
    2 0x7e52d3006738 tensorrt_llm::executor::ExternalDraftTokensConfig::ExternalDraftTokensConfig(std::vector<int, std::allocator >, std::optional<tensorrt_llm::executor::Tensor>, std::optional const&, std::optional const&) + 712
    3 0x7e544d685d72 triton::backend::inflight_batcher_llm::utils::getExternalDraftTokensConfigFromTensors(std::unordered_map<std::__cxx11::basic_string<char, std::char_traits, std::allocator >, tensorrt_llm::batch_manager::NamedTensor, std::hash<std::__cxx11::basic_string<char, std::char_traits, std::allocator > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits, std::allocator > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits, std::allocator > const, tensorrt_llm::batch_manager::NamedTensor> > > const&, bool) + 818
    4 0x7e544d68764d triton::backend::inflight_batcher_llm::utils::createRequestsFromInputTensors(std::vector<std::unordered_map<std::__cxx11::basic_string<char, std::char_traits, std::allocator >, tensorrt_llm::batch_manager::NamedTensor, std::hash<std::__cxx11::basic_string<char, std::char_traits, std::allocator > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits, std::allocator > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits, std::allocator > const, tensorrt_llm::batch_manager::NamedTensor> > >, std::allocator<std::unordered_map<std::__cxx11::basic_string<char, std::char_traits, std::allocator >, tensorrt_llm::batch_manager::NamedTensor, std::hash<std::__cxx11::basic_string<char, std::char_traits, std::allocator > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits, std::allocator > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits, std::allocator > const, tensorrt_llm::batch_manager::NamedTensor> > > > > const&, bool, bool, bool, tensorrt_llm::executor::ModelType, tensorrt_llm::executor::RequestType, bool, bool) + 2813
    5 0x7e544d660490 triton::backend::inflight_batcher_llm::ModelInstanceState::createExecutorRequests(TRITONBACKEND_Request*, bool, bool, tensorrt_llm::executor::ModelType, bool, bool) + 144
    6 0x7e544d66c5a2 triton::backend::inflight_batcher_llm::ModelInstanceState::enqueue(TRITONBACKEND_Request**, unsigned int) + 434
    7 0x7e544d659bb5 TRITONBACKEND_ModelInstanceExecute + 101
    8 0x7e545984b384 /opt/tritonserver/bin/../lib/libtritonserver.so(+0x1af384) [0x7e545984b384]
    9 0x7e545984b6fb /opt/tritonserver/bin/../lib/libtritonserver.so(+0x1af6fb) [0x7e545984b6fb]
    10 0x7e545996d76d /opt/tritonserver/bin/../lib/libtritonserver.so(+0x2d176d) [0x7e545996d76d]
    11 0x7e545984f384 /opt/tritonserver/bin/../lib/libtritonserver.so(+0x1b3384) [0x7e545984f384]
    12 0x7e545afb7253 /usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xdc253) [0x7e545afb7253]
    13 0x7e5458c6bac3 /usr/lib/x86_64-linux-gnu/libc.so.6(+0x94ac3) [0x7e5458c6bac3]
    14 0x7e5458cfca04 clone + 68

additional notes

  • I tried to use all of the scripts that you mentioned in documents.

    • tensorrtllm_backend/inflight_batcher_llm/client/e2e_grpc_speculative_decoding_client.py
    • TensorRT-LLM/examples/run.py
  • It works fine when attempted once, but an error usually occurs when multiple requests are sent.

@pimang62 pimang62 added the bug Something isn't working label Jan 13, 2025
@pimang62 pimang62 changed the title Inference error with using draft target model Inference error encountered while using the draft target model. Jan 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant