From e33f1b85d4db8277dca9577f7863c44b134cd7ed Mon Sep 17 00:00:00 2001 From: diwei sun Date: Thu, 18 Sep 2025 00:35:05 -0700 Subject: [PATCH 1/5] signoff: refine test_utils to allow all benchmark script utilizes it Signed-off-by: diwei sun --- benchmark/benchmark_rmsnorm.py | 445 +++++---------------------------- 1 file changed, 58 insertions(+), 387 deletions(-) diff --git a/benchmark/benchmark_rmsnorm.py b/benchmark/benchmark_rmsnorm.py index 4513c07..e46f298 100644 --- a/benchmark/benchmark_rmsnorm.py +++ b/benchmark/benchmark_rmsnorm.py @@ -1,411 +1,82 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import itertools -from typing import Optional, Union +import time +from argparse import ArgumentParser import torch -import triton -from torch import nn -from tests import register_ops as vllm_ops -from tests.utils import check_ipex_availability, get_model_config +from tests.ops.layernorm_op import RMSNorm +from tests.utils import STR_DTYPE_TO_TORCH_DTYPE -HAS_IPEX = check_ipex_availability() -if HAS_IPEX: - import intel_extension_for_pytorch as ipex +@torch.inference_mode() +def main( + num_tokens: int, + hidden_size: int, + add_residual: bool, + dtype: torch.dtype, + seed: int = 0, + num_warmup_iters: int = 5, + num_iters: int = 100, +) -> None: + torch.set_default_device("xpu") + layer = RMSNorm(hidden_size).to(dtype=dtype) + layer.weight.data.normal_(mean=1.0, std=0.1) + scale = 1 / (2 * hidden_size) + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + x *= scale + residual = torch.randn_like(x) * scale if add_residual else None -class HuggingFaceRMSNorm(nn.Module): + def run_xpu_benchmark(num_iters: int) -> float: + torch.xpu.synchronize() - def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps + start_time = time.perf_counter() - def forward( - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - orig_dtype = x.dtype - x = x.to(torch.float32) - if residual is not None: - x = x + residual.to(torch.float32) - residual = x.to(orig_dtype) + for _ in range(num_iters): + layer(x, residual) + torch.xpu.synchronize() - variance = x.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + self.variance_epsilon) - x = x.to(orig_dtype) * self.weight - if residual is None: - return x - else: - return x, residual + end_time = time.perf_counter() + return (end_time - start_time) / num_iters -def rmsnorm_naive( - x: torch.Tensor, - weight: torch.Tensor, - residual: Optional[torch.Tensor] = None, - eps: float = 1e-6, -): - naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps) - naive_norm.weight = nn.Parameter(weight) - naive_norm = naive_norm.to(x.device) + # Warmup. + print("Warming up...") + run_benchmark = run_xpu_benchmark + run_benchmark(num_iters=num_warmup_iters) - orig_shape = x.shape - x = x.view(-1, x.shape[-1]) - if residual is not None: - residual = residual.view(-1, residual.shape[-1]) + # Benchmark. + latency = run_benchmark(num_iters=num_iters) + print(f"Kernel running time: {latency * 1000000:.3f} us") - output = naive_norm(x, residual) - if isinstance(output, tuple): - output = (output[0].view(orig_shape), output[1].view(orig_shape)) - else: - output = output.view(orig_shape) - return output - - -@torch.compile -def rmsnorm_compile(x: torch.Tensor, - weight: torch.Tensor, - residual: Optional[torch.Tensor] = None, - eps: float = 1e-6): - """PyTorch-native implementation equivalent to forward().""" - orig_dtype = x.dtype - x = x.to(torch.float32) - if residual is not None: - x = x + residual.to(torch.float32) - residual = x.to(orig_dtype) - - x_var = x - variance = x_var.pow(2).mean(dim=-1, keepdim=True) - - x = x * torch.rsqrt(variance + eps) - x = x.to(orig_dtype) - x = x * weight - if residual is None: - return x - else: - return x, residual - - -def rmsnorm_vllm( - x: torch.Tensor, - weight: torch.Tensor, - residual: Optional[torch.Tensor] = None, - eps: float = 1e-6, -): - orig_shape = x.shape - x = x.view(-1, x.shape[-1]) - if residual is not None: - residual = residual.view(-1, residual.shape[-1]) - - if residual is not None: - vllm_ops.fused_add_rms_norm(x, residual, weight, eps) - output = (x, residual) - else: - out = torch.empty_like(x) - vllm_ops.rms_norm(out, x, weight, eps) - output = out - - if isinstance(output, tuple): - output = (output[0].view(orig_shape), output[1].view(orig_shape)) - else: - output = output.view(orig_shape) - return output - - -def rmsnorm_ipex( - x: torch.Tensor, - weight: torch.Tensor, - residual: Optional[torch.Tensor] = None, - eps: float = 1e-6, -): - """IPEX implementation using ipex.llm.functional.rms_norm""" - if not HAS_IPEX: - raise RuntimeError("IPEX is not available") - - orig_shape = x.shape - x = x.view(-1, x.shape[-1]) - - if residual is not None: - residual = residual.view(-1, residual.shape[-1]) - if hasattr(ipex.llm.functional, 'fused_add_rms_norm'): - output, residual_out = ipex.llm.functional.fused_add_rms_norm( - x, residual, weight, eps) - output = (output.view(orig_shape), residual_out.view(orig_shape)) - else: - x = x + residual - output = ipex.llm.functional.rms_norm(x, weight, eps) - output = (output.view(orig_shape), x.view(orig_shape)) - else: - output = ipex.llm.functional.rms_norm(x, weight, eps) - output = output.view(orig_shape) - - return output - - -def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): - dtype = torch.bfloat16 - x = torch.randn(batch_size, - seq_len, - hidden_size, - dtype=dtype, - device="xpu") - weight = torch.ones(hidden_size, dtype=dtype, device="xpu") - residual = torch.randn_like(x) if use_residual else None - - output_naive = rmsnorm_naive( - x.clone(), weight, - residual.clone() if residual is not None else None) - output_vllm = rmsnorm_vllm( - x.clone(), weight, - residual.clone() if residual is not None else None) - - if use_residual: - output_naive = output_naive[0] - output_vllm = output_vllm[0] - - print(f"Naive output={output_naive}") - print(f"vLLM output={output_vllm}") - - if HAS_IPEX: - try: - output_ipex = rmsnorm_ipex( - x.clone(), weight, - residual.clone() if residual is not None else None) - if use_residual: - output_ipex = output_ipex[0] - print(f"IPEX output={output_ipex}") - - if torch.allclose(output_naive, output_ipex, atol=1e-2, rtol=1e-2): - print("✅ IPEX implementation matches naive") - else: - print("❌ IPEX implementation differs from naive") - except Exception as e: - print(f"❌ IPEX implementation failed: {e}") - - if torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2): - print("✅ All implementations match") - else: - print("❌ Implementations differ") - - -def get_benchmark(use_residual, dtype): - - @triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["head_num", "batch_size", "seq_len"], - x_vals=[tuple(_) for _ in configs], - line_arg="provider", - line_vals=["huggingface", "vllm", "t.compile", "ipex"] - if HAS_IPEX else ["huggingface", "vllm", "t.compile"], - line_names=["HuggingFace", "vLLM", "t.compile", "IPEX"] - if HAS_IPEX else ["HuggingFace", "vLLM", "t.compile"], - styles=[("blue", "-"), ("green", "-"), ("orange", "-"), - ("red", "-")] if HAS_IPEX else [("blue", "-"), - ("green", "-"), - ("orange", "-")], - ylabel="us", - plot_name= - f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual", - args={}, - )) - def benchmark(head_num, batch_size, seq_len, provider): - hidden_size = head_num * 128 # assuming head_dim = 128 - - x = torch.randn(batch_size, - seq_len, - hidden_size, - dtype=dtype, - device="xpu") - weight = torch.ones(hidden_size, dtype=dtype, device="xpu") - residual = torch.randn_like(x) if use_residual else None - - quantiles = [0.5, 0.2, 0.8] - - if provider == "huggingface": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: rmsnorm_naive( - x.clone(), - weight, - residual.clone() if residual is not None else None, - ), - quantiles=quantiles, - ) - elif provider == "t.compile": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: rmsnorm_compile( - x.clone(), - weight, - residual.clone() if residual is not None else None, - ), - quantiles=quantiles, - ) - elif provider == "ipex" and HAS_IPEX: - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: rmsnorm_ipex( - x.clone(), - weight, - residual.clone() if residual is not None else None, - ), - quantiles=quantiles, - ) - else: - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: rmsnorm_vllm( - x.clone(), - weight, - residual.clone() if residual is not None else None, - ), - quantiles=quantiles, - ) - return 1000 * ms, 1000 * max_ms, 1000 * min_ms - - return benchmark - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--batch-size", - type=int, - default=4, - help="Batch size", - ) - parser.add_argument( - "--seq-len", - type=int, - default=128, - help="Sequence length", - ) - parser.add_argument( - "--hidden-size", - type=int, - default=4096, - help="Hidden size (2nd dimension) of the sequence", - ) - parser.add_argument( - "--intermediate-size", - type=int, - default=None, - help="Intermediate size for FFN layers", - ) - parser.add_argument( - "--num-groups", - type=int, - default=None, - help="Number of expert groups for MoE models", - ) - parser.add_argument( - "--dtype", - type=str, - default=torch.bfloat16, - help="Data type from model config", - ) - parser.add_argument( - "--model-name", - type=str, - default=None, - help="Model name to load configuration from", - ) - parser.add_argument("--head-num-range", +if __name__ == "__main__": + parser = ArgumentParser(description="Benchmark the layernorm kernel.") + parser.add_argument("--num-tokens", type=int, default=4096) + parser.add_argument("--hidden-size", type=int, default=8192) + parser.add_argument("--add-residual", action="store_true") + parser.add_argument("--dtype", + type=str, + choices=["half", "bfloat16", "float"], + default="half") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--num-warmup-iters", type=int, default=5) + parser.add_argument("--num-iters", type=int, - nargs='+', - default=[12, 32, 40, 48, 64, 96, 128], - help=("Range of attention head numbers to test/use. " - "Default: 12 32 40 48 64 96 128")) - parser.add_argument( - "--tp-size", - type=int, - default=1, - help="Tensor parallelism size", - ) - parser.add_argument("--use-residual", - action="store_true", - help="Whether to use residual connection") - parser.add_argument( - "--save-path", - type=str, - default="./configs/rmsnorm/", - help="Path to save rmsnorm benchmark results", - ) + default=100, + help="Number of benchmark iterations. ") args = parser.parse_args() + print(args) - if args.model_name: - model_config = get_model_config(args.model_name, args.tp_size) - - if args.hidden_size == 4096: - args.hidden_size = model_config["hidden_size"] - - if args.intermediate_size is None: - args.intermediate_size = model_config["intermediate_size"] - - if args.num_groups is None: - args.num_groups = model_config["num_groups"] - - if args.dtype is None: - args.dtype = model_config["dtype"] - - if args.head_num_range == [12, 32, 40, 48, 64, 96, 128]: - model_heads = model_config.get("num_attention_heads", 32) - if model_heads not in args.head_num_range: - args.head_num_range.append(model_heads) - args.head_num_range.sort() - print( - f"Added model's head number {model_heads} to head_num_range" - ) - - print(f"Using model configuration from: {args.model_name}") - print(f"Updated hidden_size: {args.hidden_size}") - print(f"Updated intermediate_size: {args.intermediate_size}") - print(f"Updated num_groups: {args.num_groups}") - print(f"Updated head_num_range: {args.head_num_range}") - print(f"Updated dtype: {args.dtype}") - - return args - - -if __name__ == "__main__": - - import argparse - - args = parse_args() - - print("Final configuration:") - print(f" Batch size: {args.batch_size}") - print(f" Sequence length: {args.seq_len}") - print(f" Hidden size: {args.hidden_size}") - print(f" Intermediate size: {args.intermediate_size}") - print(f" Number of groups: {args.num_groups}") - print(f" Data type: {args.dtype}") - print(f" Use residual: {args.use_residual}") - - batch_size_range = [2**i for i in range(0, 7, 2)] - seq_length_range = [2**i for i in range(6, 10, 1)] - head_num_range = args.head_num_range - configs = list( - itertools.product(head_num_range, batch_size_range, seq_length_range)) - - if HAS_IPEX: - print("✅ IPEX is available") - print(f"IPEX version: {ipex.__version__}") - else: - print("⚠️ IPEX is not available, skipping IPEX benchmarks") - - # Run correctness test - calculate_diff( - batch_size=args.batch_size, - seq_len=args.seq_len, + main( + num_tokens=args.num_tokens, hidden_size=args.hidden_size, - use_residual=args.use_residual, + add_residual=args.add_residual, + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], + seed=args.seed, + num_warmup_iters=args.num_warmup_iters, + num_iters=args.num_iters, ) - - # Get the benchmark function with proper use_residual setting - benchmark = get_benchmark(args.use_residual, args.dtype) - # Run performance benchmark - benchmark.run(print_data=True, save_path=args.save_path) From ccd7f371fd1cb88df01f7b6c84dbf55e63489cb9 Mon Sep 17 00:00:00 2001 From: diwei sun Date: Thu, 18 Sep 2025 00:35:41 -0700 Subject: [PATCH 2/5] refine model config extract Signed-off-by: diwei sun --- benchmark/benchmark_reshape_and_cache.py | 319 ++++++++++++----------- tests/utils.py | 124 ++++++++- 2 files changed, 289 insertions(+), 154 deletions(-) diff --git a/benchmark/benchmark_reshape_and_cache.py b/benchmark/benchmark_reshape_and_cache.py index c6d96e7..95fffd6 100644 --- a/benchmark/benchmark_reshape_and_cache.py +++ b/benchmark/benchmark_reshape_and_cache.py @@ -1,180 +1,193 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations -import random -import time +import itertools +from typing import Optional import torch -from tabulate import tabulate +import triton +import random +from torch import Tensor -from tests import register_ops as ops -from tests.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random +from tests import register_ops as vllm_ops +from tests.utils import ( + check_ipex_availability, + create_kv_caches_with_random, + parse_args, +) +HAS_IPEX = check_ipex_availability() -@torch.inference_mode() -def run_benchmark( - num_tokens: int, - num_heads: int, - head_size: int, - block_size: int, - num_blocks: int, - dtype: torch.dtype, +if HAS_IPEX: + import intel_extension_for_pytorch as ipex + + +def reshape_and_cache_vllm( + key: Tensor, + value: Tensor, + key_cache: Tensor, + value_cache: Tensor, + slot_mapping: Tensor, kv_cache_dtype: str, - num_iters: int, - device: str = "xpu", -) -> float: - """Return latency (seconds) for given num_tokens.""" + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, +) -> None: + """vLLM's fused kernel for reshaping and caching K/V tensors.""" + vllm_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, + kv_cache_dtype, k_scale, v_scale) + + +def reshape_and_cache_ipex( + key: Tensor, + value: Tensor, + key_cache: Tensor, + value_cache: Tensor, + slot_mapping: Tensor, + kv_cache_dtype: str, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, +) -> None: + """IPEX native implementation using ipex.llm.modules.PagedAttention.""" + if not HAS_IPEX: + raise RuntimeError("IPEX is not available") + assert kv_cache_dtype == "auto", "IPEX reshape_and_cache uses 'auto' mode" + + ipex.llm.modules.PagedAttention.reshape_and_cache( + key, value, key_cache, value_cache, slot_mapping + ) + - if kv_cache_dtype == "fp8" and head_size % 16: - raise ValueError( - "fp8 kv-cache requires head_size to be a multiple of 16.") +def get_benchmark( + dtype: torch.dtype, + device: str = "xpu", +): + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["num_tokens", "num_heads", "head_size", "block_size", "num_blocks"], + x_vals=configs, + line_arg="provider", + line_vals=["vllm", "ipex"] if HAS_IPEX else ["vllm"], + line_names=["vLLM", "IPEX"] if HAS_IPEX else ["vLLM"], + styles=[("blue", "-"), ("red", "-")] if HAS_IPEX else [("blue", "-")], + ylabel="latency (us)", + plot_name="reshape_and_cache-benchmark", + args={}, + ) + ) + @torch.inference_mode() + def benchmark(num_tokens, num_heads, head_size, block_size, num_blocks, provider, kv_cache_dtype="auto"): + + if kv_cache_dtype == "fp8" and head_size % 16: + raise ValueError( + "fp8 kv-cache requires head_size to be a multiple of 16.") - seed = 42 - random.seed(seed) - torch.manual_seed(seed) - torch.set_default_device(device) + torch.manual_seed(42) + torch.set_default_device(device) - # create random key / value tensors [T, H, D]. - key = torch.randn(num_tokens, + key = torch.randn(num_tokens, num_heads, head_size, dtype=dtype, device=device) - value = torch.randn_like(key) - - # prepare the slot mapping. - # each token is assigned a unique slot in the KV-cache. - num_slots = block_size * num_blocks - if num_tokens > num_slots: - raise ValueError( - "num_tokens cannot exceed the total number of cache slots") - slot_mapping_lst = random.sample(range(num_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping_lst, + value = torch.randn_like(key) + num_slots = block_size * num_blocks + if num_tokens > num_slots: + raise ValueError( + "num_tokens cannot exceed the total number of cache slots") + slot_mapping_lst = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) - num_layers = 1 # for simplicity, we use a single layer - key_caches, value_caches = create_kv_caches_with_random( - num_blocks, - block_size, - num_layers, - num_heads, - head_size, - kv_cache_dtype, - dtype, - device=device, - ) - key_cache, value_cache = key_caches[0], value_caches[0] + num_layers = 1 # for simplicity, we use a single layer + key_caches, value_caches = create_kv_caches_with_random( + num_blocks, + block_size, + num_layers, + num_heads, + head_size, + kv_cache_dtype, + dtype, + device=device, + ) + key_cache, value_cache = key_caches[0], value_caches[0] - # compute per-kernel scaling factors for fp8 conversion (if used). - k_scale = (key.amax() / 64.0).to(torch.float32) - v_scale = (value.amax() / 64.0).to(torch.float32) + # compute per-kernel scaling factors for fp8 conversion (if used). + k_scale = (key.amax() / 64.0).to(torch.float32) + v_scale = (value.amax() / 64.0).to(torch.float32) - def run_xpu_benchmark(n_iters: int) -> float: - nonlocal key, value, key_cache, value_cache, slot_mapping - torch.xpu.synchronize() - start = time.perf_counter() - for _ in range(n_iters): - ops.reshape_and_cache( - key, - value, - key_cache, - value_cache, - slot_mapping, - kv_cache_dtype, - k_scale, - v_scale, - ) torch.xpu.synchronize() - end = time.perf_counter() - return (end - start) / n_iters - - # warm-up - run_xpu_benchmark(3) - - lat = run_xpu_benchmark(num_iters) - - # free tensors to mitigate OOM when sweeping - del key, value, key_cache, value_cache, slot_mapping - torch.xpu.empty_cache() - - return lat - - -def main(args): - rows = [] - for exp in range(1, 12): - n_tok = 2**exp - lat = run_benchmark( - num_tokens=n_tok, - num_heads=args.num_heads, - head_size=args.head_size, - block_size=args.block_size, - num_blocks=args.num_blocks, - dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], - kv_cache_dtype=args.kv_cache_dtype, - num_iters=args.iters, - device="xpu", + # Warm up + for _ in range(5): + if provider == "vllm": + reshape_and_cache_vllm( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + elif provider == "ipex" and HAS_IPEX: + reshape_and_cache_ipex( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + # Benchmark + quantiles = [0.5, 0.2, 0.8] + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: { + "vllm": reshape_and_cache_vllm, + "ipex": reshape_and_cache_ipex + }[provider]( + key, value, key_cache, value_cache, slot_mapping, + kv_cache_dtype, k_scale, v_scale, + ), + quantiles=quantiles, ) - rows.append([ - n_tok, - args.num_heads, - args.head_size, - args.block_size, - args.num_blocks, - args.dtype, - args.kv_cache_dtype, - f"{lat * 1e6:.3f}", - ]) - print( - tabulate( - rows, - headers=[ - "num_tokens", - "num_heads", - "head_size", - "block_size", - "num_blocks", - "dtype", - "kv_cache_dtype", - "latency (us)", - ], - )) + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + return benchmark -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("--num-heads", type=int, default=8) - parser.add_argument( - "--head-size", - type=int, - choices=[64, 80, 96, 112, 120, 128, 192, 256], - default=128, - ) - parser.add_argument("--block-size", - type=int, - choices=[16, 32, 64], - default=64) - parser.add_argument("--num-blocks", type=int, default=1024) - - parser.add_argument( - "--dtype", - type=str, - choices=["half", "bfloat16"], - default="half", - ) - parser.add_argument( - "--kv-cache-dtype", - type=str, - choices=["auto", "fp8", "fp8_e4m3", "fp8_e5m2"], - default="auto", +if __name__ == "__main__": + args = parse_args() + + device = "xpu" + + print("Benchmark Configuration:") + print(f" Num Heads: {args.head_num_range}") + print(f" Head Size: {args.head_size}") + print(f" Block Size: {args.block_size}") + print(f" Num Blocks: {args.num_blocks}") + print(f" Data Type: {args.dtype}") + print(f" KV Cache Dtype: auto (IPEX & vLLM)") + print(f" Device: {device}") + if HAS_IPEX: + print(f"✅ IPEX {ipex.__version__} is available.") + else: + print("⚠️ IPEX not available. Only benchmarking vLLM.") + + num_token_range = [2**i for i in range(1, 12)] + head_num_range = args.head_num_range + head_size_range = [args.head_size] + block_size_range = [args.block_size] + num_blocks_range = [args.num_blocks] + configs = list( + itertools.product(num_token_range, head_num_range, head_size_range, block_size_range, num_blocks_range)) + + benchmark = get_benchmark( + dtype=args.dtype, + device=device, ) - - parser.add_argument("--iters", type=int, default=100) - args = parser.parse_args() - - main(args) + benchmark.run(print_data=True, save_path=None) \ No newline at end of file diff --git a/tests/utils.py b/tests/utils.py index 0559bac..9efe821 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,6 +4,7 @@ from collections.abc import Sequence from typing import Any, Optional, Union +import argparse import numpy as np import torch from torch._prims_common import TensorLikeType @@ -350,4 +351,125 @@ def check_ipex_availability(): return True else: print("Warning: IPEX not available, skipping IPEX benchmarks") - return False \ No newline at end of file + return False + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--batch-size", + type=int, + default=4, + help="Batch size", + ) + parser.add_argument( + "--seq-len", + type=int, + default=128, + help="Sequence length", + ) + parser.add_argument( + "--hidden-size", + type=int, + default=4096, + help="Hidden size (2nd dimension) of the sequence", + ) + parser.add_argument( + "--intermediate-size", + type=int, + default=None, + help="Intermediate size for FFN layers", + ) + parser.add_argument( + "--num-groups", + type=int, + default=None, + help="Number of expert groups for MoE models", + ) + parser.add_argument( + "--dtype", + type=str, + default=torch.bfloat16, + help="Data type from model config", + ) + parser.add_argument( + "--model-name", + type=str, + default=None, + help="Model name to load configuration from", + ) + parser.add_argument( + "--head-size", + type=int, + choices=[64, 80, 96, 112, 120, 128, 192, 256], + default=128, + ) + parser.add_argument("--num-blocks", type=int, default=1024) + parser.add_argument( + "--kv-cache-dtype", + type=str, + choices=["auto", "fp8", "fp8_e4m3", "fp8_e5m2"], + default="auto", + ) + parser.add_argument("--block-size", + type=int, + choices=[16, 32, 64], + default=64) + parser.add_argument("--head-num-range", + type=int, + nargs='+', + default=[12, 32, 40, 48, 64, 96, 128], + help=("Range of attention head numbers to test/use. " + "Default: 12 32 40 48 64 96 128")) + parser.add_argument( + "--tp-size", + type=int, + default=1, + help="Tensor parallelism size", + ) + parser.add_argument("--use-residual", + action="store_true", + help="Whether to use residual connection") + parser.add_argument( + "--save-path", + type=str, + default="./configs/rmsnorm/", + help="Path to save rmsnorm benchmark results", + ) + + args = parser.parse_args() + + if args.model_name: + model_config = get_model_config(args.model_name, args.tp_size) + + if args.hidden_size == 4096: + args.hidden_size = model_config["hidden_size"] + + if args.intermediate_size is None: + args.intermediate_size = model_config["intermediate_size"] + + if args.num_groups is None: + args.num_groups = model_config["num_groups"] + + if args.dtype is None: + args.dtype = model_config["dtype"] + + if args.head_size is None: + args.head_size = model_config["head_dim"] + + if args.head_num_range == [12, 32, 40, 48, 64, 96, 128]: + model_heads = model_config.get("num_attention_heads", 32) + if model_heads not in args.head_num_range: + args.head_num_range.append(model_heads) + args.head_num_range.sort() + print( + f"Added model's head number {model_heads} to head_num_range" + ) + + print(f"Using model configuration from: {args.model_name}") + print(f"Updated hidden_size: {args.hidden_size}") + print(f"Updated intermediate_size: {args.intermediate_size}") + print(f"Updated num_groups: {args.num_groups}") + print(f"Updated head_num_range: {args.head_num_range}") + print(f"Updated dtype: {args.dtype}") + + return args From 92cd917dacf288e4f8f0ba52f2c6b8b34b53ac29 Mon Sep 17 00:00:00 2001 From: diwei sun Date: Thu, 18 Sep 2025 22:59:50 -0700 Subject: [PATCH 3/5] fix for rmsnorm Signed-off-by: diwei sun --- benchmark/benchmark_rmsnorm.py | 347 +++++++++++++++++++++++++++------ 1 file changed, 286 insertions(+), 61 deletions(-) diff --git a/benchmark/benchmark_rmsnorm.py b/benchmark/benchmark_rmsnorm.py index e46f298..d58742f 100644 --- a/benchmark/benchmark_rmsnorm.py +++ b/benchmark/benchmark_rmsnorm.py @@ -1,82 +1,307 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import time -from argparse import ArgumentParser +import itertools +from typing import Optional, Union import torch +import triton +from torch import nn -from tests.ops.layernorm_op import RMSNorm -from tests.utils import STR_DTYPE_TO_TORCH_DTYPE +from tests import register_ops as vllm_ops +from tests.utils import check_ipex_availability, parse_args +HAS_IPEX = check_ipex_availability() -@torch.inference_mode() -def main( - num_tokens: int, - hidden_size: int, - add_residual: bool, - dtype: torch.dtype, - seed: int = 0, - num_warmup_iters: int = 5, - num_iters: int = 100, -) -> None: - torch.set_default_device("xpu") +if HAS_IPEX: + import intel_extension_for_pytorch as ipex - layer = RMSNorm(hidden_size).to(dtype=dtype) - layer.weight.data.normal_(mean=1.0, std=0.1) - scale = 1 / (2 * hidden_size) - x = torch.randn(num_tokens, hidden_size, dtype=dtype) - x *= scale - residual = torch.randn_like(x) * scale if add_residual else None - def run_xpu_benchmark(num_iters: int) -> float: - torch.xpu.synchronize() +class HuggingFaceRMSNorm(nn.Module): - start_time = time.perf_counter() + def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps - for _ in range(num_iters): - layer(x, residual) - torch.xpu.synchronize() + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + orig_dtype = x.dtype + x = x.to(torch.float32) + if residual is not None: + x = x + residual.to(torch.float32) + residual = x.to(orig_dtype) - end_time = time.perf_counter() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x.to(orig_dtype) * self.weight + if residual is None: + return x + else: + return x, residual - return (end_time - start_time) / num_iters - # Warmup. - print("Warming up...") - run_benchmark = run_xpu_benchmark - run_benchmark(num_iters=num_warmup_iters) +def rmsnorm_naive( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, +): + naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps) + naive_norm.weight = nn.Parameter(weight) + naive_norm = naive_norm.to(x.device) - # Benchmark. - latency = run_benchmark(num_iters=num_iters) - print(f"Kernel running time: {latency * 1000000:.3f} us") + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + output = naive_norm(x, residual) + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +@torch.compile +def rmsnorm_compile(x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6): + """PyTorch-native implementation equivalent to forward().""" + orig_dtype = x.dtype + x = x.to(torch.float32) + if residual is not None: + x = x + residual.to(torch.float32) + residual = x.to(orig_dtype) + + x_var = x + variance = x_var.pow(2).mean(dim=-1, keepdim=True) + + x = x * torch.rsqrt(variance + eps) + x = x.to(orig_dtype) + x = x * weight + if residual is None: + return x + else: + return x, residual + + +def rmsnorm_vllm( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, +): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + if residual is not None: + vllm_ops.fused_add_rms_norm(x, residual, weight, eps) + output = (x, residual) + else: + out = torch.empty_like(x) + vllm_ops.rms_norm(out, x, weight, eps) + output = out + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +def rmsnorm_ipex( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, +): + """IPEX implementation using ipex.llm.functional.rms_norm""" + if not HAS_IPEX: + raise RuntimeError("IPEX is not available") + + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + if hasattr(ipex.llm.functional, 'fused_add_rms_norm'): + output, residual_out = ipex.llm.functional.fused_add_rms_norm( + x, residual, weight, eps) + output = (output.view(orig_shape), residual_out.view(orig_shape)) + else: + x = x + residual + output = ipex.llm.functional.rms_norm(x, weight, eps) + output = (output.view(orig_shape), x.view(orig_shape)) + else: + output = ipex.llm.functional.rms_norm(x, weight, eps) + output = output.view(orig_shape) + + return output + + +def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): + dtype = torch.bfloat16 + x = torch.randn(batch_size, + seq_len, + hidden_size, + dtype=dtype, + device="xpu") + weight = torch.ones(hidden_size, dtype=dtype, device="xpu") + residual = torch.randn_like(x) if use_residual else None + + output_naive = rmsnorm_naive( + x.clone(), weight, + residual.clone() if residual is not None else None) + output_vllm = rmsnorm_vllm( + x.clone(), weight, + residual.clone() if residual is not None else None) + + if use_residual: + output_naive = output_naive[0] + output_vllm = output_vllm[0] + + print(f"Naive output={output_naive}") + print(f"vLLM output={output_vllm}") + + if HAS_IPEX: + try: + output_ipex = rmsnorm_ipex( + x.clone(), weight, + residual.clone() if residual is not None else None) + if use_residual: + output_ipex = output_ipex[0] + print(f"IPEX output={output_ipex}") + + if torch.allclose(output_naive, output_ipex, atol=1e-2, rtol=1e-2): + print("✅ IPEX implementation matches naive") + else: + print("❌ IPEX implementation differs from naive") + except Exception as e: + print(f"❌ IPEX implementation failed: {e}") + + if torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2): + print("✅ All implementations match") + else: + print("❌ Implementations differ") + + +def get_benchmark(use_residual, dtype): + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["head_num", "batch_size", "seq_len"], + x_vals=[tuple(_) for _ in configs], + line_arg="provider", + line_vals=["huggingface", "vllm", "t.compile", "ipex"] + if HAS_IPEX else ["huggingface", "vllm", "t.compile"], + line_names=["HuggingFace", "vLLM", "t.compile", "IPEX"] + if HAS_IPEX else ["HuggingFace", "vLLM", "t.compile"], + styles=[("blue", "-"), ("green", "-"), ("orange", "-"), + ("red", "-")] if HAS_IPEX else [("blue", "-"), + ("green", "-"), + ("orange", "-")], + ylabel="us", + plot_name= + f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual", + args={}, + )) + def benchmark(head_num, batch_size, seq_len, provider): + hidden_size = head_num * 128 # assuming head_dim = 128 + + x = torch.randn(batch_size, + seq_len, + hidden_size, + dtype=dtype, + device="xpu") + weight = torch.ones(hidden_size, dtype=dtype, device="xpu") + residual = torch.randn_like(x) if use_residual else None + + quantiles = [0.5, 0.2, 0.8] + + if provider == "huggingface": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_naive( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), + quantiles=quantiles, + ) + elif provider == "t.compile": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_compile( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), + quantiles=quantiles, + ) + elif provider == "ipex" and HAS_IPEX: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_ipex( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_vllm( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), + quantiles=quantiles, + ) + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + return benchmark if __name__ == "__main__": - parser = ArgumentParser(description="Benchmark the layernorm kernel.") - parser.add_argument("--num-tokens", type=int, default=4096) - parser.add_argument("--hidden-size", type=int, default=8192) - parser.add_argument("--add-residual", action="store_true") - parser.add_argument("--dtype", - type=str, - choices=["half", "bfloat16", "float"], - default="half") - parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--num-warmup-iters", type=int, default=5) - parser.add_argument("--num-iters", - type=int, - default=100, - help="Number of benchmark iterations. ") - - args = parser.parse_args() - print(args) - - main( - num_tokens=args.num_tokens, + + args = parse_args() + + print("Final configuration:") + print(f" Batch size: {args.batch_size}") + print(f" Sequence length: {args.seq_len}") + print(f" Hidden size: {args.hidden_size}") + print(f" Intermediate size: {args.intermediate_size}") + print(f" Number of groups: {args.num_groups}") + print(f" Data type: {args.dtype}") + print(f" Use residual: {args.use_residual}") + + batch_size_range = [2**i for i in range(0, 7, 2)] + seq_length_range = [2**i for i in range(6, 10, 1)] + head_num_range = args.head_num_range + configs = list( + itertools.product(head_num_range, batch_size_range, seq_length_range)) + + if HAS_IPEX: + print("✅ IPEX is available") + print(f"IPEX version: {ipex.__version__}") + else: + print("⚠️ IPEX is not available, skipping IPEX benchmarks") + + # Run correctness test + calculate_diff( + batch_size=args.batch_size, + seq_len=args.seq_len, hidden_size=args.hidden_size, - add_residual=args.add_residual, - dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], - seed=args.seed, - num_warmup_iters=args.num_warmup_iters, - num_iters=args.num_iters, + use_residual=args.use_residual, ) + + # Get the benchmark function with proper use_residual setting + benchmark = get_benchmark(args.use_residual, args.dtype) + # Run performance benchmark + benchmark.run(print_data=True, save_path=args.save_path) From 1baae13549bbc035e96cb756905834a831260301 Mon Sep 17 00:00:00 2001 From: diwei sun Date: Tue, 23 Sep 2025 01:06:20 -0700 Subject: [PATCH 4/5] format fix --- benchmark/benchmark_reshape_and_cache.py | 95 ++++++++++++++---------- tests/utils.py | 3 +- 2 files changed, 56 insertions(+), 42 deletions(-) diff --git a/benchmark/benchmark_reshape_and_cache.py b/benchmark/benchmark_reshape_and_cache.py index 95fffd6..6fdf56e 100644 --- a/benchmark/benchmark_reshape_and_cache.py +++ b/benchmark/benchmark_reshape_and_cache.py @@ -2,19 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools +import random from typing import Optional import torch import triton -import random from torch import Tensor from tests import register_ops as vllm_ops -from tests.utils import ( - check_ipex_availability, - create_kv_caches_with_random, - parse_args, -) +from tests.utils import (check_ipex_availability, create_kv_caches_with_random, + parse_args) HAS_IPEX = check_ipex_availability() @@ -33,8 +30,8 @@ def reshape_and_cache_vllm( v_scale: Optional[float] = None, ) -> None: """vLLM's fused kernel for reshaping and caching K/V tensors.""" - vllm_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, - kv_cache_dtype, k_scale, v_scale) + vllm_ops.reshape_and_cache(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype, k_scale, v_scale) def reshape_and_cache_ipex( @@ -52,9 +49,9 @@ def reshape_and_cache_ipex( raise RuntimeError("IPEX is not available") assert kv_cache_dtype == "auto", "IPEX reshape_and_cache uses 'auto' mode" - ipex.llm.modules.PagedAttention.reshape_and_cache( - key, value, key_cache, value_cache, slot_mapping - ) + ipex.llm.modules.PagedAttention.reshape_and_cache(key, value, key_cache, + value_cache, + slot_mapping) def get_benchmark( @@ -64,20 +61,29 @@ def get_benchmark( @triton.testing.perf_report( triton.testing.Benchmark( - x_names=["num_tokens", "num_heads", "head_size", "block_size", "num_blocks"], + x_names=[ + "num_tokens", "num_heads", "head_size", "block_size", + "num_blocks" + ], x_vals=configs, line_arg="provider", line_vals=["vllm", "ipex"] if HAS_IPEX else ["vllm"], line_names=["vLLM", "IPEX"] if HAS_IPEX else ["vLLM"], - styles=[("blue", "-"), ("red", "-")] if HAS_IPEX else [("blue", "-")], + styles=[("blue", "-"), + ("red", "-")] if HAS_IPEX else [("blue", "-")], ylabel="latency (us)", plot_name="reshape_and_cache-benchmark", args={}, - ) - ) + )) @torch.inference_mode() - def benchmark(num_tokens, num_heads, head_size, block_size, num_blocks, provider, kv_cache_dtype="auto"): - + def benchmark(num_tokens, + num_heads, + head_size, + block_size, + num_blocks, + provider, + kv_cache_dtype="auto"): + if kv_cache_dtype == "fp8" and head_size % 16: raise ValueError( "fp8 kv-cache requires head_size to be a multiple of 16.") @@ -86,10 +92,10 @@ def benchmark(num_tokens, num_heads, head_size, block_size, num_blocks, provider torch.set_default_device(device) key = torch.randn(num_tokens, - num_heads, - head_size, - dtype=dtype, - device=device) + num_heads, + head_size, + dtype=dtype, + device=device) value = torch.randn_like(key) num_slots = block_size * num_blocks if num_tokens > num_slots: @@ -97,8 +103,8 @@ def benchmark(num_tokens, num_heads, head_size, block_size, num_blocks, provider "num_tokens cannot exceed the total number of cache slots") slot_mapping_lst = random.sample(range(num_slots), num_tokens) slot_mapping = torch.tensor(slot_mapping_lst, - dtype=torch.long, - device=device) + dtype=torch.long, + device=device) num_layers = 1 # for simplicity, we use a single layer key_caches, value_caches = create_kv_caches_with_random( @@ -122,24 +128,24 @@ def benchmark(num_tokens, num_heads, head_size, block_size, num_blocks, provider for _ in range(5): if provider == "vllm": reshape_and_cache_vllm( - key, - value, - key_cache, - value_cache, + key, + value, + key_cache, + value_cache, slot_mapping, - kv_cache_dtype, - k_scale, + kv_cache_dtype, + k_scale, v_scale, ) elif provider == "ipex" and HAS_IPEX: reshape_and_cache_ipex( - key, - value, - key_cache, - value_cache, + key, + value, + key_cache, + value_cache, slot_mapping, - kv_cache_dtype, - k_scale, + kv_cache_dtype, + k_scale, v_scale, ) @@ -150,8 +156,14 @@ def benchmark(num_tokens, num_heads, head_size, block_size, num_blocks, provider "vllm": reshape_and_cache_vllm, "ipex": reshape_and_cache_ipex }[provider]( - key, value, key_cache, value_cache, slot_mapping, - kv_cache_dtype, k_scale, v_scale, + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, ), quantiles=quantiles, ) @@ -171,7 +183,7 @@ def benchmark(num_tokens, num_heads, head_size, block_size, num_blocks, provider print(f" Block Size: {args.block_size}") print(f" Num Blocks: {args.num_blocks}") print(f" Data Type: {args.dtype}") - print(f" KV Cache Dtype: auto (IPEX & vLLM)") + print(" KV Cache Dtype: auto (IPEX & vLLM)") print(f" Device: {device}") if HAS_IPEX: print(f"✅ IPEX {ipex.__version__} is available.") @@ -184,10 +196,11 @@ def benchmark(num_tokens, num_heads, head_size, block_size, num_blocks, provider block_size_range = [args.block_size] num_blocks_range = [args.num_blocks] configs = list( - itertools.product(num_token_range, head_num_range, head_size_range, block_size_range, num_blocks_range)) - + itertools.product(num_token_range, head_num_range, head_size_range, + block_size_range, num_blocks_range)) + benchmark = get_benchmark( dtype=args.dtype, device=device, ) - benchmark.run(print_data=True, save_path=None) \ No newline at end of file + benchmark.run(print_data=True, save_path=None) diff --git a/tests/utils.py b/tests/utils.py index 9efe821..12c8669 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 +import argparse import random import unittest from collections.abc import Sequence from typing import Any, Optional, Union -import argparse import numpy as np import torch from torch._prims_common import TensorLikeType @@ -353,6 +353,7 @@ def check_ipex_availability(): print("Warning: IPEX not available, skipping IPEX benchmarks") return False + def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( From 91e64727449491d4263c725c21370b48e0b39fb9 Mon Sep 17 00:00:00 2001 From: diwei sun Date: Tue, 23 Sep 2025 01:14:59 -0700 Subject: [PATCH 5/5] refine utils --- tests/utils.py | 29 ++++++----------------------- 1 file changed, 6 insertions(+), 23 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 12c8669..4c98a49 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -442,29 +442,12 @@ def parse_args(): if args.model_name: model_config = get_model_config(args.model_name, args.tp_size) - if args.hidden_size == 4096: - args.hidden_size = model_config["hidden_size"] - - if args.intermediate_size is None: - args.intermediate_size = model_config["intermediate_size"] - - if args.num_groups is None: - args.num_groups = model_config["num_groups"] - - if args.dtype is None: - args.dtype = model_config["dtype"] - - if args.head_size is None: - args.head_size = model_config["head_dim"] - - if args.head_num_range == [12, 32, 40, 48, 64, 96, 128]: - model_heads = model_config.get("num_attention_heads", 32) - if model_heads not in args.head_num_range: - args.head_num_range.append(model_heads) - args.head_num_range.sort() - print( - f"Added model's head number {model_heads} to head_num_range" - ) + args.hidden_size = model_config["hidden_size"] + args.intermediate_size = model_config["intermediate_size"] + args.num_groups = model_config["num_groups"] + args.dtype = model_config["dtype"] + args.head_size = model_config["head_dim"] + args.head_num_range = [model_config.get("num_attention_heads", 32)] print(f"Using model configuration from: {args.model_name}") print(f"Updated hidden_size: {args.hidden_size}")