11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3- from __future__ import annotations
43
5- import random
6- import time
4+ import itertools
5+ from typing import Optional
76
87import torch
9- from tabulate import tabulate
8+ import triton
9+ import random
10+ from torch import Tensor
1011
11- from tests import register_ops as ops
12- from tests .utils import STR_DTYPE_TO_TORCH_DTYPE , create_kv_caches_with_random
12+ from tests import register_ops as vllm_ops
13+ from tests .utils import (
14+ check_ipex_availability ,
15+ create_kv_caches_with_random ,
16+ parse_args ,
17+ )
1318
19+ HAS_IPEX = check_ipex_availability ()
1420
15- @torch .inference_mode ()
16- def run_benchmark (
17- num_tokens : int ,
18- num_heads : int ,
19- head_size : int ,
20- block_size : int ,
21- num_blocks : int ,
22- dtype : torch .dtype ,
21+ if HAS_IPEX :
22+ import intel_extension_for_pytorch as ipex
23+
24+
25+ def reshape_and_cache_vllm (
26+ key : Tensor ,
27+ value : Tensor ,
28+ key_cache : Tensor ,
29+ value_cache : Tensor ,
30+ slot_mapping : Tensor ,
2331 kv_cache_dtype : str ,
24- num_iters : int ,
25- device : str = "xpu" ,
26- ) -> float :
27- """Return latency (seconds) for given num_tokens."""
32+ k_scale : Optional [float ] = None ,
33+ v_scale : Optional [float ] = None ,
34+ ) -> None :
35+ """vLLM's fused kernel for reshaping and caching K/V tensors."""
36+ vllm_ops .reshape_and_cache (key , value , key_cache , value_cache , slot_mapping ,
37+ kv_cache_dtype , k_scale , v_scale )
38+
39+
40+ def reshape_and_cache_ipex (
41+ key : Tensor ,
42+ value : Tensor ,
43+ key_cache : Tensor ,
44+ value_cache : Tensor ,
45+ slot_mapping : Tensor ,
46+ kv_cache_dtype : str ,
47+ k_scale : Optional [float ] = None ,
48+ v_scale : Optional [float ] = None ,
49+ ) -> None :
50+ """IPEX native implementation using ipex.llm.modules.PagedAttention."""
51+ if not HAS_IPEX :
52+ raise RuntimeError ("IPEX is not available" )
53+ assert kv_cache_dtype == "auto" , "IPEX reshape_and_cache uses 'auto' mode"
54+
55+ ipex .llm .modules .PagedAttention .reshape_and_cache (
56+ key , value , key_cache , value_cache , slot_mapping
57+ )
58+
2859
29- if kv_cache_dtype == "fp8" and head_size % 16 :
30- raise ValueError (
31- "fp8 kv-cache requires head_size to be a multiple of 16." )
60+ def get_benchmark (
61+ dtype : torch .dtype ,
62+ device : str = "xpu" ,
63+ ):
64+
65+ @triton .testing .perf_report (
66+ triton .testing .Benchmark (
67+ x_names = ["num_tokens" , "num_heads" , "head_size" , "block_size" , "num_blocks" ],
68+ x_vals = configs ,
69+ line_arg = "provider" ,
70+ line_vals = ["vllm" , "ipex" ] if HAS_IPEX else ["vllm" ],
71+ line_names = ["vLLM" , "IPEX" ] if HAS_IPEX else ["vLLM" ],
72+ styles = [("blue" , "-" ), ("red" , "-" )] if HAS_IPEX else [("blue" , "-" )],
73+ ylabel = "latency (us)" ,
74+ plot_name = "reshape_and_cache-benchmark" ,
75+ args = {},
76+ )
77+ )
78+ @torch .inference_mode ()
79+ def benchmark (num_tokens , num_heads , head_size , block_size , num_blocks , provider , kv_cache_dtype = "auto" ):
80+
81+ if kv_cache_dtype == "fp8" and head_size % 16 :
82+ raise ValueError (
83+ "fp8 kv-cache requires head_size to be a multiple of 16." )
3284
33- seed = 42
34- random .seed (seed )
35- torch .manual_seed (seed )
36- torch .set_default_device (device )
85+ torch .manual_seed (42 )
86+ torch .set_default_device (device )
3787
38- # create random key / value tensors [T, H, D].
39- key = torch .randn (num_tokens ,
88+ key = torch .randn (num_tokens ,
4089 num_heads ,
4190 head_size ,
4291 dtype = dtype ,
4392 device = device )
44- value = torch .randn_like (key )
45-
46- # prepare the slot mapping.
47- # each token is assigned a unique slot in the KV-cache.
48- num_slots = block_size * num_blocks
49- if num_tokens > num_slots :
50- raise ValueError (
51- "num_tokens cannot exceed the total number of cache slots" )
52- slot_mapping_lst = random .sample (range (num_slots ), num_tokens )
53- slot_mapping = torch .tensor (slot_mapping_lst ,
93+ value = torch .randn_like (key )
94+ num_slots = block_size * num_blocks
95+ if num_tokens > num_slots :
96+ raise ValueError (
97+ "num_tokens cannot exceed the total number of cache slots" )
98+ slot_mapping_lst = random .sample (range (num_slots ), num_tokens )
99+ slot_mapping = torch .tensor (slot_mapping_lst ,
54100 dtype = torch .long ,
55101 device = device )
56102
57- num_layers = 1 # for simplicity, we use a single layer
58- key_caches , value_caches = create_kv_caches_with_random (
59- num_blocks ,
60- block_size ,
61- num_layers ,
62- num_heads ,
63- head_size ,
64- kv_cache_dtype ,
65- dtype ,
66- device = device ,
67- )
68- key_cache , value_cache = key_caches [0 ], value_caches [0 ]
103+ num_layers = 1 # for simplicity, we use a single layer
104+ key_caches , value_caches = create_kv_caches_with_random (
105+ num_blocks ,
106+ block_size ,
107+ num_layers ,
108+ num_heads ,
109+ head_size ,
110+ kv_cache_dtype ,
111+ dtype ,
112+ device = device ,
113+ )
114+ key_cache , value_cache = key_caches [0 ], value_caches [0 ]
69115
70- # compute per-kernel scaling factors for fp8 conversion (if used).
71- k_scale = (key .amax () / 64.0 ).to (torch .float32 )
72- v_scale = (value .amax () / 64.0 ).to (torch .float32 )
116+ # compute per-kernel scaling factors for fp8 conversion (if used).
117+ k_scale = (key .amax () / 64.0 ).to (torch .float32 )
118+ v_scale = (value .amax () / 64.0 ).to (torch .float32 )
73119
74- def run_xpu_benchmark (n_iters : int ) -> float :
75- nonlocal key , value , key_cache , value_cache , slot_mapping
76- torch .xpu .synchronize ()
77- start = time .perf_counter ()
78- for _ in range (n_iters ):
79- ops .reshape_and_cache (
80- key ,
81- value ,
82- key_cache ,
83- value_cache ,
84- slot_mapping ,
85- kv_cache_dtype ,
86- k_scale ,
87- v_scale ,
88- )
89120 torch .xpu .synchronize ()
90- end = time .perf_counter ()
91- return (end - start ) / n_iters
92-
93- # warm-up
94- run_xpu_benchmark (3 )
95-
96- lat = run_xpu_benchmark (num_iters )
97-
98- # free tensors to mitigate OOM when sweeping
99- del key , value , key_cache , value_cache , slot_mapping
100- torch .xpu .empty_cache ()
101-
102- return lat
103-
104-
105- def main (args ):
106- rows = []
107- for exp in range (1 , 12 ):
108- n_tok = 2 ** exp
109- lat = run_benchmark (
110- num_tokens = n_tok ,
111- num_heads = args .num_heads ,
112- head_size = args .head_size ,
113- block_size = args .block_size ,
114- num_blocks = args .num_blocks ,
115- dtype = STR_DTYPE_TO_TORCH_DTYPE [args .dtype ],
116- kv_cache_dtype = args .kv_cache_dtype ,
117- num_iters = args .iters ,
118- device = "xpu" ,
121+ # Warm up
122+ for _ in range (5 ):
123+ if provider == "vllm" :
124+ reshape_and_cache_vllm (
125+ key ,
126+ value ,
127+ key_cache ,
128+ value_cache ,
129+ slot_mapping ,
130+ kv_cache_dtype ,
131+ k_scale ,
132+ v_scale ,
133+ )
134+ elif provider == "ipex" and HAS_IPEX :
135+ reshape_and_cache_ipex (
136+ key ,
137+ value ,
138+ key_cache ,
139+ value_cache ,
140+ slot_mapping ,
141+ kv_cache_dtype ,
142+ k_scale ,
143+ v_scale ,
144+ )
145+
146+ # Benchmark
147+ quantiles = [0.5 , 0.2 , 0.8 ]
148+ ms , min_ms , max_ms = triton .testing .do_bench (
149+ lambda : {
150+ "vllm" : reshape_and_cache_vllm ,
151+ "ipex" : reshape_and_cache_ipex
152+ }[provider ](
153+ key , value , key_cache , value_cache , slot_mapping ,
154+ kv_cache_dtype , k_scale , v_scale ,
155+ ),
156+ quantiles = quantiles ,
119157 )
120- rows .append ([
121- n_tok ,
122- args .num_heads ,
123- args .head_size ,
124- args .block_size ,
125- args .num_blocks ,
126- args .dtype ,
127- args .kv_cache_dtype ,
128- f"{ lat * 1e6 :.3f} " ,
129- ])
130- print (
131- tabulate (
132- rows ,
133- headers = [
134- "num_tokens" ,
135- "num_heads" ,
136- "head_size" ,
137- "block_size" ,
138- "num_blocks" ,
139- "dtype" ,
140- "kv_cache_dtype" ,
141- "latency (us)" ,
142- ],
143- ))
158+ return 1000 * ms , 1000 * max_ms , 1000 * min_ms
144159
160+ return benchmark
145161
146- if __name__ == "__main__" :
147- import argparse
148-
149- parser = argparse .ArgumentParser ()
150- parser .add_argument ("--num-heads" , type = int , default = 8 )
151- parser .add_argument (
152- "--head-size" ,
153- type = int ,
154- choices = [64 , 80 , 96 , 112 , 120 , 128 , 192 , 256 ],
155- default = 128 ,
156- )
157- parser .add_argument ("--block-size" ,
158- type = int ,
159- choices = [16 , 32 , 64 ],
160- default = 64 )
161- parser .add_argument ("--num-blocks" , type = int , default = 1024 )
162-
163- parser .add_argument (
164- "--dtype" ,
165- type = str ,
166- choices = ["half" , "bfloat16" ],
167- default = "half" ,
168- )
169162
170- parser .add_argument (
171- "--kv-cache-dtype" ,
172- type = str ,
173- choices = ["auto" , "fp8" , "fp8_e4m3" , "fp8_e5m2" ],
174- default = "auto" ,
163+ if __name__ == "__main__" :
164+ args = parse_args ()
165+
166+ device = "xpu"
167+
168+ print ("Benchmark Configuration:" )
169+ print (f" Num Heads: { args .head_num_range } " )
170+ print (f" Head Size: { args .head_size } " )
171+ print (f" Block Size: { args .block_size } " )
172+ print (f" Num Blocks: { args .num_blocks } " )
173+ print (f" Data Type: { args .dtype } " )
174+ print (f" KV Cache Dtype: auto (IPEX & vLLM)" )
175+ print (f" Device: { device } " )
176+ if HAS_IPEX :
177+ print (f"✅ IPEX { ipex .__version__ } is available." )
178+ else :
179+ print ("⚠️ IPEX not available. Only benchmarking vLLM." )
180+
181+ num_token_range = [2 ** i for i in range (1 , 12 )]
182+ head_num_range = args .head_num_range
183+ head_size_range = [args .head_size ]
184+ block_size_range = [args .block_size ]
185+ num_blocks_range = [args .num_blocks ]
186+ configs = list (
187+ itertools .product (num_token_range , head_num_range , head_size_range , block_size_range , num_blocks_range ))
188+
189+ benchmark = get_benchmark (
190+ dtype = args .dtype ,
191+ device = device ,
175192 )
176-
177- parser .add_argument ("--iters" , type = int , default = 100 )
178- args = parser .parse_args ()
179-
180- main (args )
193+ benchmark .run (print_data = True , save_path = None )
0 commit comments