Skip to content

Commit ccd7f37

Browse files
committed
refine model config extract
Signed-off-by: diwei sun <[email protected]>
1 parent e33f1b8 commit ccd7f37

File tree

2 files changed

+289
-154
lines changed

2 files changed

+289
-154
lines changed
Lines changed: 166 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -1,180 +1,193 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
from __future__ import annotations
43

5-
import random
6-
import time
4+
import itertools
5+
from typing import Optional
76

87
import torch
9-
from tabulate import tabulate
8+
import triton
9+
import random
10+
from torch import Tensor
1011

11-
from tests import register_ops as ops
12-
from tests.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
12+
from tests import register_ops as vllm_ops
13+
from tests.utils import (
14+
check_ipex_availability,
15+
create_kv_caches_with_random,
16+
parse_args,
17+
)
1318

19+
HAS_IPEX = check_ipex_availability()
1420

15-
@torch.inference_mode()
16-
def run_benchmark(
17-
num_tokens: int,
18-
num_heads: int,
19-
head_size: int,
20-
block_size: int,
21-
num_blocks: int,
22-
dtype: torch.dtype,
21+
if HAS_IPEX:
22+
import intel_extension_for_pytorch as ipex
23+
24+
25+
def reshape_and_cache_vllm(
26+
key: Tensor,
27+
value: Tensor,
28+
key_cache: Tensor,
29+
value_cache: Tensor,
30+
slot_mapping: Tensor,
2331
kv_cache_dtype: str,
24-
num_iters: int,
25-
device: str = "xpu",
26-
) -> float:
27-
"""Return latency (seconds) for given num_tokens."""
32+
k_scale: Optional[float] = None,
33+
v_scale: Optional[float] = None,
34+
) -> None:
35+
"""vLLM's fused kernel for reshaping and caching K/V tensors."""
36+
vllm_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
37+
kv_cache_dtype, k_scale, v_scale)
38+
39+
40+
def reshape_and_cache_ipex(
41+
key: Tensor,
42+
value: Tensor,
43+
key_cache: Tensor,
44+
value_cache: Tensor,
45+
slot_mapping: Tensor,
46+
kv_cache_dtype: str,
47+
k_scale: Optional[float] = None,
48+
v_scale: Optional[float] = None,
49+
) -> None:
50+
"""IPEX native implementation using ipex.llm.modules.PagedAttention."""
51+
if not HAS_IPEX:
52+
raise RuntimeError("IPEX is not available")
53+
assert kv_cache_dtype == "auto", "IPEX reshape_and_cache uses 'auto' mode"
54+
55+
ipex.llm.modules.PagedAttention.reshape_and_cache(
56+
key, value, key_cache, value_cache, slot_mapping
57+
)
58+
2859

29-
if kv_cache_dtype == "fp8" and head_size % 16:
30-
raise ValueError(
31-
"fp8 kv-cache requires head_size to be a multiple of 16.")
60+
def get_benchmark(
61+
dtype: torch.dtype,
62+
device: str = "xpu",
63+
):
64+
65+
@triton.testing.perf_report(
66+
triton.testing.Benchmark(
67+
x_names=["num_tokens", "num_heads", "head_size", "block_size", "num_blocks"],
68+
x_vals=configs,
69+
line_arg="provider",
70+
line_vals=["vllm", "ipex"] if HAS_IPEX else ["vllm"],
71+
line_names=["vLLM", "IPEX"] if HAS_IPEX else ["vLLM"],
72+
styles=[("blue", "-"), ("red", "-")] if HAS_IPEX else [("blue", "-")],
73+
ylabel="latency (us)",
74+
plot_name="reshape_and_cache-benchmark",
75+
args={},
76+
)
77+
)
78+
@torch.inference_mode()
79+
def benchmark(num_tokens, num_heads, head_size, block_size, num_blocks, provider, kv_cache_dtype="auto"):
80+
81+
if kv_cache_dtype == "fp8" and head_size % 16:
82+
raise ValueError(
83+
"fp8 kv-cache requires head_size to be a multiple of 16.")
3284

33-
seed = 42
34-
random.seed(seed)
35-
torch.manual_seed(seed)
36-
torch.set_default_device(device)
85+
torch.manual_seed(42)
86+
torch.set_default_device(device)
3787

38-
# create random key / value tensors [T, H, D].
39-
key = torch.randn(num_tokens,
88+
key = torch.randn(num_tokens,
4089
num_heads,
4190
head_size,
4291
dtype=dtype,
4392
device=device)
44-
value = torch.randn_like(key)
45-
46-
# prepare the slot mapping.
47-
# each token is assigned a unique slot in the KV-cache.
48-
num_slots = block_size * num_blocks
49-
if num_tokens > num_slots:
50-
raise ValueError(
51-
"num_tokens cannot exceed the total number of cache slots")
52-
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
53-
slot_mapping = torch.tensor(slot_mapping_lst,
93+
value = torch.randn_like(key)
94+
num_slots = block_size * num_blocks
95+
if num_tokens > num_slots:
96+
raise ValueError(
97+
"num_tokens cannot exceed the total number of cache slots")
98+
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
99+
slot_mapping = torch.tensor(slot_mapping_lst,
54100
dtype=torch.long,
55101
device=device)
56102

57-
num_layers = 1 # for simplicity, we use a single layer
58-
key_caches, value_caches = create_kv_caches_with_random(
59-
num_blocks,
60-
block_size,
61-
num_layers,
62-
num_heads,
63-
head_size,
64-
kv_cache_dtype,
65-
dtype,
66-
device=device,
67-
)
68-
key_cache, value_cache = key_caches[0], value_caches[0]
103+
num_layers = 1 # for simplicity, we use a single layer
104+
key_caches, value_caches = create_kv_caches_with_random(
105+
num_blocks,
106+
block_size,
107+
num_layers,
108+
num_heads,
109+
head_size,
110+
kv_cache_dtype,
111+
dtype,
112+
device=device,
113+
)
114+
key_cache, value_cache = key_caches[0], value_caches[0]
69115

70-
# compute per-kernel scaling factors for fp8 conversion (if used).
71-
k_scale = (key.amax() / 64.0).to(torch.float32)
72-
v_scale = (value.amax() / 64.0).to(torch.float32)
116+
# compute per-kernel scaling factors for fp8 conversion (if used).
117+
k_scale = (key.amax() / 64.0).to(torch.float32)
118+
v_scale = (value.amax() / 64.0).to(torch.float32)
73119

74-
def run_xpu_benchmark(n_iters: int) -> float:
75-
nonlocal key, value, key_cache, value_cache, slot_mapping
76-
torch.xpu.synchronize()
77-
start = time.perf_counter()
78-
for _ in range(n_iters):
79-
ops.reshape_and_cache(
80-
key,
81-
value,
82-
key_cache,
83-
value_cache,
84-
slot_mapping,
85-
kv_cache_dtype,
86-
k_scale,
87-
v_scale,
88-
)
89120
torch.xpu.synchronize()
90-
end = time.perf_counter()
91-
return (end - start) / n_iters
92-
93-
# warm-up
94-
run_xpu_benchmark(3)
95-
96-
lat = run_xpu_benchmark(num_iters)
97-
98-
# free tensors to mitigate OOM when sweeping
99-
del key, value, key_cache, value_cache, slot_mapping
100-
torch.xpu.empty_cache()
101-
102-
return lat
103-
104-
105-
def main(args):
106-
rows = []
107-
for exp in range(1, 12):
108-
n_tok = 2**exp
109-
lat = run_benchmark(
110-
num_tokens=n_tok,
111-
num_heads=args.num_heads,
112-
head_size=args.head_size,
113-
block_size=args.block_size,
114-
num_blocks=args.num_blocks,
115-
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
116-
kv_cache_dtype=args.kv_cache_dtype,
117-
num_iters=args.iters,
118-
device="xpu",
121+
# Warm up
122+
for _ in range(5):
123+
if provider == "vllm":
124+
reshape_and_cache_vllm(
125+
key,
126+
value,
127+
key_cache,
128+
value_cache,
129+
slot_mapping,
130+
kv_cache_dtype,
131+
k_scale,
132+
v_scale,
133+
)
134+
elif provider == "ipex" and HAS_IPEX:
135+
reshape_and_cache_ipex(
136+
key,
137+
value,
138+
key_cache,
139+
value_cache,
140+
slot_mapping,
141+
kv_cache_dtype,
142+
k_scale,
143+
v_scale,
144+
)
145+
146+
# Benchmark
147+
quantiles = [0.5, 0.2, 0.8]
148+
ms, min_ms, max_ms = triton.testing.do_bench(
149+
lambda: {
150+
"vllm": reshape_and_cache_vllm,
151+
"ipex": reshape_and_cache_ipex
152+
}[provider](
153+
key, value, key_cache, value_cache, slot_mapping,
154+
kv_cache_dtype, k_scale, v_scale,
155+
),
156+
quantiles=quantiles,
119157
)
120-
rows.append([
121-
n_tok,
122-
args.num_heads,
123-
args.head_size,
124-
args.block_size,
125-
args.num_blocks,
126-
args.dtype,
127-
args.kv_cache_dtype,
128-
f"{lat * 1e6:.3f}",
129-
])
130-
print(
131-
tabulate(
132-
rows,
133-
headers=[
134-
"num_tokens",
135-
"num_heads",
136-
"head_size",
137-
"block_size",
138-
"num_blocks",
139-
"dtype",
140-
"kv_cache_dtype",
141-
"latency (us)",
142-
],
143-
))
158+
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
144159

160+
return benchmark
145161

146-
if __name__ == "__main__":
147-
import argparse
148-
149-
parser = argparse.ArgumentParser()
150-
parser.add_argument("--num-heads", type=int, default=8)
151-
parser.add_argument(
152-
"--head-size",
153-
type=int,
154-
choices=[64, 80, 96, 112, 120, 128, 192, 256],
155-
default=128,
156-
)
157-
parser.add_argument("--block-size",
158-
type=int,
159-
choices=[16, 32, 64],
160-
default=64)
161-
parser.add_argument("--num-blocks", type=int, default=1024)
162-
163-
parser.add_argument(
164-
"--dtype",
165-
type=str,
166-
choices=["half", "bfloat16"],
167-
default="half",
168-
)
169162

170-
parser.add_argument(
171-
"--kv-cache-dtype",
172-
type=str,
173-
choices=["auto", "fp8", "fp8_e4m3", "fp8_e5m2"],
174-
default="auto",
163+
if __name__ == "__main__":
164+
args = parse_args()
165+
166+
device = "xpu"
167+
168+
print("Benchmark Configuration:")
169+
print(f" Num Heads: {args.head_num_range}")
170+
print(f" Head Size: {args.head_size}")
171+
print(f" Block Size: {args.block_size}")
172+
print(f" Num Blocks: {args.num_blocks}")
173+
print(f" Data Type: {args.dtype}")
174+
print(f" KV Cache Dtype: auto (IPEX & vLLM)")
175+
print(f" Device: {device}")
176+
if HAS_IPEX:
177+
print(f"✅ IPEX {ipex.__version__} is available.")
178+
else:
179+
print("⚠️ IPEX not available. Only benchmarking vLLM.")
180+
181+
num_token_range = [2**i for i in range(1, 12)]
182+
head_num_range = args.head_num_range
183+
head_size_range = [args.head_size]
184+
block_size_range = [args.block_size]
185+
num_blocks_range = [args.num_blocks]
186+
configs = list(
187+
itertools.product(num_token_range, head_num_range, head_size_range, block_size_range, num_blocks_range))
188+
189+
benchmark = get_benchmark(
190+
dtype=args.dtype,
191+
device=device,
175192
)
176-
177-
parser.add_argument("--iters", type=int, default=100)
178-
args = parser.parse_args()
179-
180-
main(args)
193+
benchmark.run(print_data=True, save_path=None)

0 commit comments

Comments
 (0)