Skip to content

feat/add latency support for trtllm bench #3730

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 138 additions & 35 deletions tensorrt_llm/bench/benchmark/low_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,26 @@
import yaml
from click_option_group import (MutuallyExclusiveOptionGroup, OptionGroup,
optgroup)
from huggingface_hub import snapshot_download

from tensorrt_llm import LLM as PyTorchLLM
from tensorrt_llm._tensorrt_engine import LLM
from tensorrt_llm.bench.benchmark.utils.asynchronous import async_benchmark
from tensorrt_llm.bench.benchmark.utils.general import generate_warmup_dataset
from tensorrt_llm.bench.benchmark.utils.processes import IterationWriter
from tensorrt_llm.bench.build.build import get_model_config
from tensorrt_llm.bench.dataclasses.configuration import RuntimeConfig
from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment
from tensorrt_llm.bench.dataclasses.reporting import ReportUtility
from tensorrt_llm.llmapi import CapacitySchedulerPolicy
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode

# isort: off
from tensorrt_llm.bench.benchmark.utils.general import get_settings_from_engine
from tensorrt_llm.bench.benchmark.utils.general import get_settings_from_engine, get_settings, ALL_SUPPORTED_BACKENDS
# isort: on
from tensorrt_llm.bench.utils.data import (create_dataset_from_stream,
initialize_tokenizer)
initialize_tokenizer,
update_metadata_for_multimodal)
from tensorrt_llm.logger import logger
from tensorrt_llm.sampling_params import SamplingParams

Expand All @@ -38,15 +42,25 @@
readable=True,
path_type=Path,
resolve_path=True),
required=True,
default=None,
help="Path to a serialized TRT-LLM engine.",
)
@optgroup.option("--backend",
type=click.Choice(ALL_SUPPORTED_BACKENDS),
default="pytorch",
help="Set to 'pytorch' for pytorch path. Default is cpp path.")
@optgroup.option(
"--kv_cache_free_gpu_mem_fraction",
type=float,
default=.90,
help="The percentage of memory to use for KV Cache after model load.",
)
@optgroup.option(
"--max_seq_len",
type=int,
default=None,
help="Maximum sequence length.",
)
@optgroup.group(
"Engine Input Configuration",
help="Input configuration for driving the engine.",
Expand All @@ -60,6 +74,20 @@
default=None,
help="Pass in a dataset file for parsing instead of stdin.",
)
@optgroup.option(
"--modality",
type=click.Choice(["image", "video"]),
default=None,
help="Modality of the multimodal requests.",
)
@optgroup.option(
"--max_input_len",
type=int,
default=4096,
help=
"Maximum input sequence length to use for multimodal models. This is used only when --modality "
"is specified since the actual number of vision tokens is unknown before the model is run.",
)
@optgroup.option(
"--num_requests",
type=int,
Expand All @@ -73,6 +101,24 @@
default=2,
help="Number of requests warm up benchmark.",
)
@optgroup.option(
"--tp",
type=int,
default=1,
help="tensor parallelism size",
)
@optgroup.option(
"--pp",
type=int,
default=1,
help="pipeline parallelism size",
)
@optgroup.option(
"--ep",
type=int,
default=None,
help="expert parallelism size",
)
@optgroup.group("Request Load Control Options",
cls=MutuallyExclusiveOptionGroup,
help="Limits how requests are loaded.")
Expand Down Expand Up @@ -142,11 +188,11 @@ def latency_command(
concurrency: int = params.pop("concurrency")
beam_width: int = params.pop("beam_width")
warmup: int = params.get("warmup")
# Engine configuration parsing
exec_settings, build_cfg = get_settings_from_engine(engine_dir)
exec_settings["model"] = model
engine_tokens = exec_settings["settings_config"]["max_num_tokens"]
engine_max_seq_len = build_cfg["max_seq_len"]
modality: str = params.pop("modality")
max_input_len: int = params.pop("max_input_len")
max_seq_len: int = params.pop("max_seq_len")
backend: str = params.get("backend")
model_type = get_model_config(model, checkpoint_path).model_type

# Runtime Options
kv_cache_percent = params.pop("kv_cache_free_gpu_mem_fraction")
Expand All @@ -157,6 +203,56 @@ def latency_command(
iteration_log: Path = params.pop("iteration_log")
iteration_writer = IterationWriter(iteration_log)

# Initialize the HF tokenizer for the specified model.
tokenizer = initialize_tokenizer(checkpoint_path)

# Dataset Loading and Preparation
with open(dataset_path, "r") as dataset:
metadata, requests = create_dataset_from_stream(
tokenizer,
dataset,
num_requests=num_requests,
model_dir=checkpoint_path,
model_type=model_type,
modality=modality,
max_input_seq_len_for_multimodal=max_input_len)

metadata.dataset_path = dataset_path

if modality is None:
# Log dataset info
# NOTE: This table is only accurate for non-multimodal models.
# The accurate table for multimodal models will be logged after the benchmark is done.
logger.info(metadata.get_summary_for_print())

# Engine configuration parsing for PyTorch backend
kwargs = {}
if backend and backend.lower() in ALL_SUPPORTED_BACKENDS and backend.lower(
) != "cpp":
if bench_env.checkpoint_path is None:
snapshot_download(model)

exec_settings = get_settings(params, metadata, bench_env.model,
bench_env.checkpoint_path)
kwargs_max_sql = max_seq_len or metadata.max_sequence_length
logger.info(f"Setting PyTorch max sequence length to {kwargs_max_sql}")
kwargs["max_seq_len"] = kwargs_max_sql
else:
assert max_seq_len is None, (
"max_seq_len is not a runtime parameter for C++ backend")
exec_settings, build_cfg = get_settings_from_engine(engine_dir)
engine_max_seq_len = build_cfg["max_seq_len"]

if metadata.max_sequence_length > engine_max_seq_len:
raise RuntimeError(
f"Engine supports a max sequence of {engine_max_seq_len}. Provided "
"dataset contains a maximum sequence of "
f"{metadata.max_sequence_length}. Please rebuild a new engine to"
"support this dataset.")

exec_settings["model"] = model
engine_tokens = exec_settings["settings_config"]["max_num_tokens"]

# Update configuration with runtime options
exec_settings["settings_config"]["kv_cache_percent"] = kv_cache_percent
exec_settings["settings_config"]["max_batch_size"] = 1
Expand Down Expand Up @@ -187,49 +283,46 @@ def latency_command(
# Construct the runtime configuration dataclass.
runtime_config = RuntimeConfig(**exec_settings)

# Initialize the HF tokenizer for the specified model.
ignore_eos = True if runtime_config.decoding_config.decoding_mode == SpeculativeDecodingMode.NONE else False
tokenizer = initialize_tokenizer(checkpoint_path)
eos_id = tokenizer.eos_token_id if not ignore_eos else -1
pad_id = tokenizer.pad_token_id if not ignore_eos else -1
llm = None
kwargs = kwargs | runtime_config.get_llm_args()
kwargs['backend'] = backend

# Dataset Loading and Preparation
with open(dataset_path, "r") as dataset:
metadata, requests = create_dataset_from_stream(
tokenizer, dataset, num_requests=num_requests)
metadata.dataset_path = dataset_path
try:
logger.info("Setting up latency benchmark.")

if metadata.max_sequence_length > engine_max_seq_len:
raise RuntimeError(
f"Engine supports a max sequence of {engine_max_seq_len}. Provided "
"dataset contains a maximum sequence of "
f"{metadata.max_sequence_length}. Please rebuild a new engine to"
"support this dataset.")
if "pytorch_backend_config" in kwargs and iteration_log is not None:
kwargs["pytorch_backend_config"].enable_iter_perf_stats = True

logger.info(metadata.get_summary_for_print())
logger.info("Running experimental latency benchmark.")
if runtime_config.backend == 'pytorch':
llm = PyTorchLLM(**kwargs)
else:
llm = LLM(**kwargs)

llm = None
kwargs = runtime_config.get_llm_args()
ignore_eos = True if runtime_config.decoding_config.decoding_mode == SpeculativeDecodingMode.NONE else False
eos_id = tokenizer.eos_token_id if not ignore_eos else -1
pad_id = tokenizer.pad_token_id if not ignore_eos else -1

try:
sampling_params = SamplingParams(
end_id=eos_id,
pad_id=pad_id,
n=beam_width,
use_beam_search=beam_width > 1,
)
post_proc_params = None # No detokenization
llm = LLM(**kwargs)

# Perform warmup if requested.
if warmup > 0:
logger.info("Setting up for warmup...")
warmup_dataset = generate_warmup_dataset(requests, warmup)
logger.info("Running warmup.")
asyncio.run(
async_benchmark(llm, sampling_params, post_proc_params,
warmup_dataset, False, concurrency))
async_benchmark(llm,
sampling_params,
post_proc_params,
warmup_dataset,
False,
concurrency,
modality=modality))
# WAR: IterationResult is a singleton tied to the executor.
# Since the benchmark calls asyncio.run() multiple times (e.g., during warmup),
# we must reset it to ensure it attaches to the correct event loop.
Expand All @@ -238,11 +331,21 @@ def latency_command(

with iteration_writer.capture():
statistics = asyncio.run(
async_benchmark(llm, sampling_params, post_proc_params,
requests, True, concurrency,
iteration_writer.full_address))
async_benchmark(llm,
sampling_params,
post_proc_params,
requests,
True,
concurrency,
iteration_writer.full_address,
modality=modality))

logger.info(f"Benchmark done. Reporting results...")

if modality is not None:
# For multimodal models, we need to update the metadata with the correct input lengths
metadata = update_metadata_for_multimodal(metadata, statistics)

report_utility = ReportUtility(statistics, metadata, runtime_config,
logger, kwargs, True)
if report_json:
Expand Down
7 changes: 4 additions & 3 deletions tensorrt_llm/bench/benchmark/throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
help="Path to a serialized TRT-LLM engine.",
)
@optgroup.option("--backend",
type=click.Choice(["pytorch", "_autodeploy"]),
type=click.Choice(ALL_SUPPORTED_BACKENDS),
default=None,
help="Set to 'pytorch' for pytorch path. Default is cpp path.")
@optgroup.option(
Expand Down Expand Up @@ -293,10 +293,11 @@ def throughput_command(
logger.info(metadata.get_summary_for_print())

# Engine configuration parsing
if backend and backend.lower() in ["pytorch", "_autodeploy"]:
if backend and backend.lower() in ALL_SUPPORTED_BACKENDS and backend.lower(
) != "cpp":
# If we're dealing with a model name, perform a snapshot download to
# make sure we have a local copy of the model.
if checkpoint_path is None:
if bench_env.checkpoint_path is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-iterating the comment from the other PR.

@danielafrimi I ran into an issue with the previous code where the model is already downloaded and the code still tries to re-download since bench_env.checkpoint_path is None even though bench_env.model contains the local path.

Can you please elaborate on why do you think the current code is wrong?

snapshot_download(model)

exec_settings = get_settings(params, metadata, bench_env.model,
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/bench/benchmark/utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
QuantAlgo.NVFP4.value: "fp8",
}

ALL_SUPPORTED_BACKENDS = ["pytorch", "_autodeploy", "cpp"]


def get_settings_from_engine(
engine_path: Path
Expand Down