diff --git a/tensorrt_llm/bench/benchmark/low_latency.py b/tensorrt_llm/bench/benchmark/low_latency.py index 490ac62f4f..7c5b924c92 100644 --- a/tensorrt_llm/bench/benchmark/low_latency.py +++ b/tensorrt_llm/bench/benchmark/low_latency.py @@ -9,11 +9,14 @@ import yaml from click_option_group import (MutuallyExclusiveOptionGroup, OptionGroup, optgroup) +from huggingface_hub import snapshot_download +from tensorrt_llm import LLM as PyTorchLLM from tensorrt_llm._tensorrt_engine import LLM from tensorrt_llm.bench.benchmark.utils.asynchronous import async_benchmark from tensorrt_llm.bench.benchmark.utils.general import generate_warmup_dataset from tensorrt_llm.bench.benchmark.utils.processes import IterationWriter +from tensorrt_llm.bench.build.build import get_model_config from tensorrt_llm.bench.dataclasses.configuration import RuntimeConfig from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment from tensorrt_llm.bench.dataclasses.reporting import ReportUtility @@ -21,10 +24,11 @@ from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode # isort: off -from tensorrt_llm.bench.benchmark.utils.general import get_settings_from_engine +from tensorrt_llm.bench.benchmark.utils.general import get_settings_from_engine, get_settings, ALL_SUPPORTED_BACKENDS # isort: on from tensorrt_llm.bench.utils.data import (create_dataset_from_stream, - initialize_tokenizer) + initialize_tokenizer, + update_metadata_for_multimodal) from tensorrt_llm.logger import logger from tensorrt_llm.sampling_params import SamplingParams @@ -38,15 +42,25 @@ readable=True, path_type=Path, resolve_path=True), - required=True, + default=None, help="Path to a serialized TRT-LLM engine.", ) +@optgroup.option("--backend", + type=click.Choice(ALL_SUPPORTED_BACKENDS), + default="pytorch", + help="Set to 'pytorch' for pytorch path. Default is cpp path.") @optgroup.option( "--kv_cache_free_gpu_mem_fraction", type=float, default=.90, help="The percentage of memory to use for KV Cache after model load.", ) +@optgroup.option( + "--max_seq_len", + type=int, + default=None, + help="Maximum sequence length.", +) @optgroup.group( "Engine Input Configuration", help="Input configuration for driving the engine.", @@ -60,6 +74,20 @@ default=None, help="Pass in a dataset file for parsing instead of stdin.", ) +@optgroup.option( + "--modality", + type=click.Choice(["image", "video"]), + default=None, + help="Modality of the multimodal requests.", +) +@optgroup.option( + "--max_input_len", + type=int, + default=4096, + help= + "Maximum input sequence length to use for multimodal models. This is used only when --modality " + "is specified since the actual number of vision tokens is unknown before the model is run.", +) @optgroup.option( "--num_requests", type=int, @@ -73,6 +101,24 @@ default=2, help="Number of requests warm up benchmark.", ) +@optgroup.option( + "--tp", + type=int, + default=1, + help="tensor parallelism size", +) +@optgroup.option( + "--pp", + type=int, + default=1, + help="pipeline parallelism size", +) +@optgroup.option( + "--ep", + type=int, + default=None, + help="expert parallelism size", +) @optgroup.group("Request Load Control Options", cls=MutuallyExclusiveOptionGroup, help="Limits how requests are loaded.") @@ -142,11 +188,11 @@ def latency_command( concurrency: int = params.pop("concurrency") beam_width: int = params.pop("beam_width") warmup: int = params.get("warmup") - # Engine configuration parsing - exec_settings, build_cfg = get_settings_from_engine(engine_dir) - exec_settings["model"] = model - engine_tokens = exec_settings["settings_config"]["max_num_tokens"] - engine_max_seq_len = build_cfg["max_seq_len"] + modality: str = params.pop("modality") + max_input_len: int = params.pop("max_input_len") + max_seq_len: int = params.pop("max_seq_len") + backend: str = params.get("backend") + model_type = get_model_config(model, checkpoint_path).model_type # Runtime Options kv_cache_percent = params.pop("kv_cache_free_gpu_mem_fraction") @@ -157,6 +203,56 @@ def latency_command( iteration_log: Path = params.pop("iteration_log") iteration_writer = IterationWriter(iteration_log) + # Initialize the HF tokenizer for the specified model. + tokenizer = initialize_tokenizer(checkpoint_path) + + # Dataset Loading and Preparation + with open(dataset_path, "r") as dataset: + metadata, requests = create_dataset_from_stream( + tokenizer, + dataset, + num_requests=num_requests, + model_dir=checkpoint_path, + model_type=model_type, + modality=modality, + max_input_seq_len_for_multimodal=max_input_len) + + metadata.dataset_path = dataset_path + + if modality is None: + # Log dataset info + # NOTE: This table is only accurate for non-multimodal models. + # The accurate table for multimodal models will be logged after the benchmark is done. + logger.info(metadata.get_summary_for_print()) + + # Engine configuration parsing for PyTorch backend + kwargs = {} + if backend and backend.lower() in ALL_SUPPORTED_BACKENDS and backend.lower( + ) != "cpp": + if bench_env.checkpoint_path is None: + snapshot_download(model) + + exec_settings = get_settings(params, metadata, bench_env.model, + bench_env.checkpoint_path) + kwargs_max_sql = max_seq_len or metadata.max_sequence_length + logger.info(f"Setting PyTorch max sequence length to {kwargs_max_sql}") + kwargs["max_seq_len"] = kwargs_max_sql + else: + assert max_seq_len is None, ( + "max_seq_len is not a runtime parameter for C++ backend") + exec_settings, build_cfg = get_settings_from_engine(engine_dir) + engine_max_seq_len = build_cfg["max_seq_len"] + + if metadata.max_sequence_length > engine_max_seq_len: + raise RuntimeError( + f"Engine supports a max sequence of {engine_max_seq_len}. Provided " + "dataset contains a maximum sequence of " + f"{metadata.max_sequence_length}. Please rebuild a new engine to" + "support this dataset.") + + exec_settings["model"] = model + engine_tokens = exec_settings["settings_config"]["max_num_tokens"] + # Update configuration with runtime options exec_settings["settings_config"]["kv_cache_percent"] = kv_cache_percent exec_settings["settings_config"]["max_batch_size"] = 1 @@ -187,32 +283,25 @@ def latency_command( # Construct the runtime configuration dataclass. runtime_config = RuntimeConfig(**exec_settings) - # Initialize the HF tokenizer for the specified model. - ignore_eos = True if runtime_config.decoding_config.decoding_mode == SpeculativeDecodingMode.NONE else False - tokenizer = initialize_tokenizer(checkpoint_path) - eos_id = tokenizer.eos_token_id if not ignore_eos else -1 - pad_id = tokenizer.pad_token_id if not ignore_eos else -1 + llm = None + kwargs = kwargs | runtime_config.get_llm_args() + kwargs['backend'] = backend - # Dataset Loading and Preparation - with open(dataset_path, "r") as dataset: - metadata, requests = create_dataset_from_stream( - tokenizer, dataset, num_requests=num_requests) - metadata.dataset_path = dataset_path + try: + logger.info("Setting up latency benchmark.") - if metadata.max_sequence_length > engine_max_seq_len: - raise RuntimeError( - f"Engine supports a max sequence of {engine_max_seq_len}. Provided " - "dataset contains a maximum sequence of " - f"{metadata.max_sequence_length}. Please rebuild a new engine to" - "support this dataset.") + if "pytorch_backend_config" in kwargs and iteration_log is not None: + kwargs["pytorch_backend_config"].enable_iter_perf_stats = True - logger.info(metadata.get_summary_for_print()) - logger.info("Running experimental latency benchmark.") + if runtime_config.backend == 'pytorch': + llm = PyTorchLLM(**kwargs) + else: + llm = LLM(**kwargs) - llm = None - kwargs = runtime_config.get_llm_args() + ignore_eos = True if runtime_config.decoding_config.decoding_mode == SpeculativeDecodingMode.NONE else False + eos_id = tokenizer.eos_token_id if not ignore_eos else -1 + pad_id = tokenizer.pad_token_id if not ignore_eos else -1 - try: sampling_params = SamplingParams( end_id=eos_id, pad_id=pad_id, @@ -220,7 +309,6 @@ def latency_command( use_beam_search=beam_width > 1, ) post_proc_params = None # No detokenization - llm = LLM(**kwargs) # Perform warmup if requested. if warmup > 0: @@ -228,8 +316,13 @@ def latency_command( warmup_dataset = generate_warmup_dataset(requests, warmup) logger.info("Running warmup.") asyncio.run( - async_benchmark(llm, sampling_params, post_proc_params, - warmup_dataset, False, concurrency)) + async_benchmark(llm, + sampling_params, + post_proc_params, + warmup_dataset, + False, + concurrency, + modality=modality)) # WAR: IterationResult is a singleton tied to the executor. # Since the benchmark calls asyncio.run() multiple times (e.g., during warmup), # we must reset it to ensure it attaches to the correct event loop. @@ -238,11 +331,21 @@ def latency_command( with iteration_writer.capture(): statistics = asyncio.run( - async_benchmark(llm, sampling_params, post_proc_params, - requests, True, concurrency, - iteration_writer.full_address)) + async_benchmark(llm, + sampling_params, + post_proc_params, + requests, + True, + concurrency, + iteration_writer.full_address, + modality=modality)) logger.info(f"Benchmark done. Reporting results...") + + if modality is not None: + # For multimodal models, we need to update the metadata with the correct input lengths + metadata = update_metadata_for_multimodal(metadata, statistics) + report_utility = ReportUtility(statistics, metadata, runtime_config, logger, kwargs, True) if report_json: diff --git a/tensorrt_llm/bench/benchmark/throughput.py b/tensorrt_llm/bench/benchmark/throughput.py index fd9ad5016e..295e17ebfd 100755 --- a/tensorrt_llm/bench/benchmark/throughput.py +++ b/tensorrt_llm/bench/benchmark/throughput.py @@ -44,7 +44,7 @@ help="Path to a serialized TRT-LLM engine.", ) @optgroup.option("--backend", - type=click.Choice(["pytorch", "_autodeploy"]), + type=click.Choice(ALL_SUPPORTED_BACKENDS), default=None, help="Set to 'pytorch' for pytorch path. Default is cpp path.") @optgroup.option( @@ -293,10 +293,11 @@ def throughput_command( logger.info(metadata.get_summary_for_print()) # Engine configuration parsing - if backend and backend.lower() in ["pytorch", "_autodeploy"]: + if backend and backend.lower() in ALL_SUPPORTED_BACKENDS and backend.lower( + ) != "cpp": # If we're dealing with a model name, perform a snapshot download to # make sure we have a local copy of the model. - if checkpoint_path is None: + if bench_env.checkpoint_path is None: snapshot_download(model) exec_settings = get_settings(params, metadata, bench_env.model, diff --git a/tensorrt_llm/bench/benchmark/utils/general.py b/tensorrt_llm/bench/benchmark/utils/general.py index a02a7a52af..0baca0094c 100755 --- a/tensorrt_llm/bench/benchmark/utils/general.py +++ b/tensorrt_llm/bench/benchmark/utils/general.py @@ -22,6 +22,8 @@ QuantAlgo.NVFP4.value: "fp8", } +ALL_SUPPORTED_BACKENDS = ["pytorch", "_autodeploy", "cpp"] + def get_settings_from_engine( engine_path: Path