diff --git a/benchmarks/python/benchmark.py b/benchmarks/python/benchmark.py index 2ead10b840..20fd4be4e8 100644 --- a/benchmarks/python/benchmark.py +++ b/benchmarks/python/benchmark.py @@ -182,6 +182,22 @@ def parse_arguments(): help= 'By default, we use dtype for KV cache. fp8_kv_cache chooses fp8 quantization for KV' ) + parser.add_argument( + '--use_inflight_batching', + action="store_true", + default=False, + help="Activates inflight batching mode of gptAttentionPlugin.") + parser.add_argument( + '--paged_kv_cache', + action="store_true", + default=False, + help= + 'By default we use contiguous KV cache. By setting this flag you enable paged KV cache' + ) + parser.add_argument('--tokens_per_block', + type=int, + default=16, + help='Number of tokens per block in paged KV cache') parser.add_argument('--csv', default=False, action="store_true", @@ -236,7 +252,10 @@ def main(args): enable_fp8=args.enable_fp8, fp8_kv_cache=args.fp8_kv_cache, enable_cuda_graph=args.enable_cuda_graph, - enable_custom_all_reduce=args.enable_custom_all_reduce) + enable_custom_all_reduce=args.enable_custom_all_reduce, + paged_kv_cached=args.paged_kv_cache, + tokens_per_block=args.tokens_per_block, + use_inflight_batching=args.use_inflight_batching) elif args.model in get_allowed_models(benchmark_type="bert"): benchmarker = BERTBenchmark(args.engine_dir, args.model, diff --git a/benchmarks/python/gpt_benchmark.py b/benchmarks/python/gpt_benchmark.py index eb5eb74fa8..d3e983025d 100644 --- a/benchmarks/python/gpt_benchmark.py +++ b/benchmarks/python/gpt_benchmark.py @@ -24,6 +24,7 @@ from tensorrt_llm._utils import str_dtype_to_trt from tensorrt_llm.builder import Builder from tensorrt_llm.layers import PositionEmbeddingType +from tensorrt_llm.logger import logger from tensorrt_llm.models import (fp8_quantize, smooth_quantize, weight_only_quantize) from tensorrt_llm.network import net_guard @@ -50,6 +51,9 @@ def __init__(self, max_output_len=None, max_batch_size=None, enable_custom_all_reduce=None, + paged_kv_cached=None, + tokens_per_block=None, + use_inflight_batching=None, **kwargs): super().__init__(engine_dir, model_name, dtype, output_dir) self.batch_sizes = batch_sizes @@ -94,6 +98,9 @@ def __init__(self, self.enable_context_fmha = True self.quant_mode = QuantMode(0) self.remove_input_padding = is_plugin_mode + self.paged_kv_cache = paged_kv_cached + self.tokens_per_block = tokens_per_block + self.use_inflight_batching = use_inflight_batching for key, value in get_build_config(model_name).items(): setattr(self, key, value) @@ -124,6 +131,20 @@ def __init__(self, if kwargs.get('force_num_layer_1', False): self.num_layers = 1 + if self.use_inflight_batching: + if not self.use_gpt_attention_plugin: + self.use_gpt_attention_plugin = 'float16' + logger.info( + f"Using GPT attention plugin for inflight batching mode. Setting to default '{self.use_gpt_attention_plugin}'" + ) + if not self.remove_input_padding: + self.remove_input_padding = True + logger.info( + "Using remove input padding for inflight batching mode.") + if not self.paged_kv_cache: + self.paged_kv_cache = True + logger.info("Using paged KV cache for inflight batching mode.") + if self.use_smooth_quant: self.quant_mode = QuantMode.use_smooth_quant( self.per_token, self.per_channel) @@ -153,6 +174,9 @@ def __init__(self, remove_input_padding=self.remove_input_padding, quant_mode=self.quant_mode, use_custom_all_reduce=self.enable_custom_all_reduce, + paged_kv_cache=self.paged_kv_cache, + tokens_per_block=self.tokens_per_block, + dtype=self.dtype, ) if model_name == 'chatglm_6b': self.sampling_config = tensorrt_llm.runtime.SamplingConfig( @@ -204,6 +228,8 @@ def prepare_inputs(self, config): def build(self): builder = Builder() + int8_trt_flag = self.quant_mode.has_act_and_weight_quant() or ( + not self.paged_kv_cache and self.quant_mode.has_int8_kv_cache()) builder_config = builder.create_builder_config( name=self.model_name, precision=self.dtype, @@ -221,7 +247,7 @@ def build(self): max_batch_size=self.max_batch_size, max_input_len=self.max_input_len, max_output_len=self.max_output_len, - int8=self.quant_mode.has_act_and_weight_quant(), + int8=int8_trt_flag, fp8=self.quant_mode.has_fp8_qdq(), quant_mode=self.quant_mode, use_refit=self.refit, @@ -388,6 +414,8 @@ def build(self): network.plugin_config.set_context_fmha(ContextFMHAType.enabled) if self.remove_input_padding: network.plugin_config.enable_remove_input_padding() + if self.paged_kv_cache: + network.plugin_config.enable_paged_kv_cache(self.tokens_per_block) # Quantization plugins. if self.use_smooth_quant: