Skip to content

Support paged kv cache for benchmarks #130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: release/0.5.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion benchmarks/python/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,22 @@ def parse_arguments():
help=
'By default, we use dtype for KV cache. fp8_kv_cache chooses fp8 quantization for KV'
)
parser.add_argument(
'--use_inflight_batching',
action="store_true",
default=False,
help="Activates inflight batching mode of gptAttentionPlugin.")
parser.add_argument(
'--paged_kv_cache',
action="store_true",
default=False,
help=
'By default we use contiguous KV cache. By setting this flag you enable paged KV cache'
)
parser.add_argument('--tokens_per_block',
type=int,
default=16,
help='Number of tokens per block in paged KV cache')
parser.add_argument('--csv',
default=False,
action="store_true",
Expand Down Expand Up @@ -236,7 +252,10 @@ def main(args):
enable_fp8=args.enable_fp8,
fp8_kv_cache=args.fp8_kv_cache,
enable_cuda_graph=args.enable_cuda_graph,
enable_custom_all_reduce=args.enable_custom_all_reduce)
enable_custom_all_reduce=args.enable_custom_all_reduce,
paged_kv_cached=args.paged_kv_cache,
tokens_per_block=args.tokens_per_block,
use_inflight_batching=args.use_inflight_batching)
elif args.model in get_allowed_models(benchmark_type="bert"):
benchmarker = BERTBenchmark(args.engine_dir,
args.model,
Expand Down
30 changes: 29 additions & 1 deletion benchmarks/python/gpt_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tensorrt_llm._utils import str_dtype_to_trt
from tensorrt_llm.builder import Builder
from tensorrt_llm.layers import PositionEmbeddingType
from tensorrt_llm.logger import logger
from tensorrt_llm.models import (fp8_quantize, smooth_quantize,
weight_only_quantize)
from tensorrt_llm.network import net_guard
Expand All @@ -50,6 +51,9 @@ def __init__(self,
max_output_len=None,
max_batch_size=None,
enable_custom_all_reduce=None,
paged_kv_cached=None,
tokens_per_block=None,
use_inflight_batching=None,
**kwargs):
super().__init__(engine_dir, model_name, dtype, output_dir)
self.batch_sizes = batch_sizes
Expand Down Expand Up @@ -94,6 +98,9 @@ def __init__(self,
self.enable_context_fmha = True
self.quant_mode = QuantMode(0)
self.remove_input_padding = is_plugin_mode
self.paged_kv_cache = paged_kv_cached
self.tokens_per_block = tokens_per_block
self.use_inflight_batching = use_inflight_batching

for key, value in get_build_config(model_name).items():
setattr(self, key, value)
Expand Down Expand Up @@ -124,6 +131,20 @@ def __init__(self,
if kwargs.get('force_num_layer_1', False):
self.num_layers = 1

if self.use_inflight_batching:
if not self.use_gpt_attention_plugin:
self.use_gpt_attention_plugin = 'float16'
logger.info(
f"Using GPT attention plugin for inflight batching mode. Setting to default '{self.use_gpt_attention_plugin}'"
)
if not self.remove_input_padding:
self.remove_input_padding = True
logger.info(
"Using remove input padding for inflight batching mode.")
if not self.paged_kv_cache:
self.paged_kv_cache = True
logger.info("Using paged KV cache for inflight batching mode.")

if self.use_smooth_quant:
self.quant_mode = QuantMode.use_smooth_quant(
self.per_token, self.per_channel)
Expand Down Expand Up @@ -153,6 +174,9 @@ def __init__(self,
remove_input_padding=self.remove_input_padding,
quant_mode=self.quant_mode,
use_custom_all_reduce=self.enable_custom_all_reduce,
paged_kv_cache=self.paged_kv_cache,
tokens_per_block=self.tokens_per_block,
dtype=self.dtype,
)
if model_name == 'chatglm_6b':
self.sampling_config = tensorrt_llm.runtime.SamplingConfig(
Expand Down Expand Up @@ -204,6 +228,8 @@ def prepare_inputs(self, config):

def build(self):
builder = Builder()
int8_trt_flag = self.quant_mode.has_act_and_weight_quant() or (
not self.paged_kv_cache and self.quant_mode.has_int8_kv_cache())
builder_config = builder.create_builder_config(
name=self.model_name,
precision=self.dtype,
Expand All @@ -221,7 +247,7 @@ def build(self):
max_batch_size=self.max_batch_size,
max_input_len=self.max_input_len,
max_output_len=self.max_output_len,
int8=self.quant_mode.has_act_and_weight_quant(),
int8=int8_trt_flag,
fp8=self.quant_mode.has_fp8_qdq(),
quant_mode=self.quant_mode,
use_refit=self.refit,
Expand Down Expand Up @@ -388,6 +414,8 @@ def build(self):
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
if self.remove_input_padding:
network.plugin_config.enable_remove_input_padding()
if self.paged_kv_cache:
network.plugin_config.enable_paged_kv_cache(self.tokens_per_block)

# Quantization plugins.
if self.use_smooth_quant:
Expand Down