Skip to content

Commit 2dcde10

Browse files
authored
add ipex and model config for rmsmorm op (#28)
* add ipex and model config for rmsmorm op Signed-off-by: Liu, Wenjun <[email protected]> * use model dtype if model set Signed-off-by: Liu, Wenjun <[email protected]> * fix yapf style Signed-off-by: Liu, Wenjun <[email protected]> * fix pre commit issue Signed-off-by: Liu, Wenjun <[email protected]> * fix: IPEX availability check and undefined variable issues Signed-off-by: Liu, Wenjun <[email protected]> * delete useless comments Signed-off-by: Liu, Wenjun <[email protected]> --------- Signed-off-by: Liu, Wenjun <[email protected]>
1 parent d30e6f2 commit 2dcde10

File tree

3 files changed

+301
-23
lines changed

3 files changed

+301
-23
lines changed

benchmark/benchmark_rmsnorm.py

Lines changed: 169 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
from torch import nn
1010

1111
from tests import register_ops as vllm_ops
12+
from tests.utils import check_ipex_availability, get_model_config
13+
14+
HAS_IPEX = check_ipex_availability()
15+
16+
if HAS_IPEX:
17+
import intel_extension_for_pytorch as ipex
1218

1319

1420
class HuggingFaceRMSNorm(nn.Module):
@@ -112,6 +118,36 @@ def rmsnorm_vllm(
112118
return output
113119

114120

121+
def rmsnorm_ipex(
122+
x: torch.Tensor,
123+
weight: torch.Tensor,
124+
residual: Optional[torch.Tensor] = None,
125+
eps: float = 1e-6,
126+
):
127+
"""IPEX implementation using ipex.llm.functional.rms_norm"""
128+
if not HAS_IPEX:
129+
raise RuntimeError("IPEX is not available")
130+
131+
orig_shape = x.shape
132+
x = x.view(-1, x.shape[-1])
133+
134+
if residual is not None:
135+
residual = residual.view(-1, residual.shape[-1])
136+
if hasattr(ipex.llm.functional, 'fused_add_rms_norm'):
137+
output, residual_out = ipex.llm.functional.fused_add_rms_norm(
138+
x, residual, weight, eps)
139+
output = (output.view(orig_shape), residual_out.view(orig_shape))
140+
else:
141+
x = x + residual
142+
output = ipex.llm.functional.rms_norm(x, weight, eps)
143+
output = (output.view(orig_shape), x.view(orig_shape))
144+
else:
145+
output = ipex.llm.functional.rms_norm(x, weight, eps)
146+
output = output.view(orig_shape)
147+
148+
return output
149+
150+
115151
def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
116152
dtype = torch.bfloat16
117153
x = torch.randn(batch_size,
@@ -136,42 +172,49 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
136172
print(f"Naive output={output_naive}")
137173
print(f"vLLM output={output_vllm}")
138174

175+
if HAS_IPEX:
176+
try:
177+
output_ipex = rmsnorm_ipex(
178+
x.clone(), weight,
179+
residual.clone() if residual is not None else None)
180+
if use_residual:
181+
output_ipex = output_ipex[0]
182+
print(f"IPEX output={output_ipex}")
183+
184+
if torch.allclose(output_naive, output_ipex, atol=1e-2, rtol=1e-2):
185+
print("✅ IPEX implementation matches naive")
186+
else:
187+
print("❌ IPEX implementation differs from naive")
188+
except Exception as e:
189+
print(f"❌ IPEX implementation failed: {e}")
190+
139191
if torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
140192
print("✅ All implementations match")
141193
else:
142194
print("❌ Implementations differ")
143195

144196

145-
batch_size_range = [2**i for i in range(0, 7, 2)]
146-
seq_length_range = [2**i for i in range(6, 10, 1)]
147-
head_num_range = [
148-
32, #Llama
149-
40, #Qwen 14B/32B
150-
48,
151-
64, #Llama 2 70B
152-
128, # Deepseek R1
153-
]
154-
configs = list(
155-
itertools.product(head_num_range, batch_size_range, seq_length_range))
156-
157-
158-
def get_benchmark(use_residual):
197+
def get_benchmark(use_residual, dtype):
159198

160199
@triton.testing.perf_report(
161200
triton.testing.Benchmark(
162201
x_names=["head_num", "batch_size", "seq_len"],
163202
x_vals=[tuple(_) for _ in configs],
164203
line_arg="provider",
165-
line_vals=["huggingface", "vllm", "t.compile"],
166-
line_names=["HuggingFace", "vLLM", "t.compile"],
167-
styles=[("blue", "-"), ("green", "-"), ("orange", "-")],
204+
line_vals=["huggingface", "vllm", "t.compile", "ipex"]
205+
if HAS_IPEX else ["huggingface", "vllm", "t.compile"],
206+
line_names=["HuggingFace", "vLLM", "t.compile", "IPEX"]
207+
if HAS_IPEX else ["HuggingFace", "vLLM", "t.compile"],
208+
styles=[("blue", "-"), ("green", "-"), ("orange", "-"),
209+
("red", "-")] if HAS_IPEX else [("blue", "-"),
210+
("green", "-"),
211+
("orange", "-")],
168212
ylabel="us",
169213
plot_name=
170214
f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual",
171215
args={},
172216
))
173217
def benchmark(head_num, batch_size, seq_len, provider):
174-
dtype = torch.bfloat16
175218
hidden_size = head_num * 128 # assuming head_dim = 128
176219

177220
x = torch.randn(batch_size,
@@ -202,6 +245,15 @@ def benchmark(head_num, batch_size, seq_len, provider):
202245
),
203246
quantiles=quantiles,
204247
)
248+
elif provider == "ipex" and HAS_IPEX:
249+
ms, min_ms, max_ms = triton.testing.do_bench(
250+
lambda: rmsnorm_ipex(
251+
x.clone(),
252+
weight,
253+
residual.clone() if residual is not None else None,
254+
),
255+
quantiles=quantiles,
256+
)
205257
else:
206258
ms, min_ms, max_ms = triton.testing.do_bench(
207259
lambda: rmsnorm_vllm(
@@ -216,9 +268,7 @@ def benchmark(head_num, batch_size, seq_len, provider):
216268
return benchmark
217269

218270

219-
if __name__ == "__main__":
220-
import argparse
221-
271+
def parse_args():
222272
parser = argparse.ArgumentParser()
223273
parser.add_argument(
224274
"--batch-size",
@@ -238,6 +288,42 @@ def benchmark(head_num, batch_size, seq_len, provider):
238288
default=4096,
239289
help="Hidden size (2nd dimension) of the sequence",
240290
)
291+
parser.add_argument(
292+
"--intermediate-size",
293+
type=int,
294+
default=None,
295+
help="Intermediate size for FFN layers",
296+
)
297+
parser.add_argument(
298+
"--num-groups",
299+
type=int,
300+
default=None,
301+
help="Number of expert groups for MoE models",
302+
)
303+
parser.add_argument(
304+
"--dtype",
305+
type=str,
306+
default=torch.bfloat16,
307+
help="Data type from model config",
308+
)
309+
parser.add_argument(
310+
"--model-name",
311+
type=str,
312+
default=None,
313+
help="Model name to load configuration from",
314+
)
315+
parser.add_argument("--head-num-range",
316+
type=int,
317+
nargs='+',
318+
default=[12, 32, 40, 48, 64, 96, 128],
319+
help=("Range of attention head numbers to test/use. "
320+
"Default: 12 32 40 48 64 96 128"))
321+
parser.add_argument(
322+
"--tp-size",
323+
type=int,
324+
default=1,
325+
help="Tensor parallelism size",
326+
)
241327
parser.add_argument("--use-residual",
242328
action="store_true",
243329
help="Whether to use residual connection")
@@ -250,6 +336,67 @@ def benchmark(head_num, batch_size, seq_len, provider):
250336

251337
args = parser.parse_args()
252338

339+
if args.model_name:
340+
model_config = get_model_config(args.model_name, args.tp_size)
341+
342+
if args.hidden_size == 4096:
343+
args.hidden_size = model_config["hidden_size"]
344+
345+
if args.intermediate_size is None:
346+
args.intermediate_size = model_config["intermediate_size"]
347+
348+
if args.num_groups is None:
349+
args.num_groups = model_config["num_groups"]
350+
351+
if args.dtype is None:
352+
args.dtype = model_config["dtype"]
353+
354+
if args.head_num_range == [12, 32, 40, 48, 64, 96, 128]:
355+
model_heads = model_config.get("num_attention_heads", 32)
356+
if model_heads not in args.head_num_range:
357+
args.head_num_range.append(model_heads)
358+
args.head_num_range.sort()
359+
print(
360+
f"Added model's head number {model_heads} to head_num_range"
361+
)
362+
363+
print(f"Using model configuration from: {args.model_name}")
364+
print(f"Updated hidden_size: {args.hidden_size}")
365+
print(f"Updated intermediate_size: {args.intermediate_size}")
366+
print(f"Updated num_groups: {args.num_groups}")
367+
print(f"Updated head_num_range: {args.head_num_range}")
368+
print(f"Updated dtype: {args.dtype}")
369+
370+
return args
371+
372+
373+
if __name__ == "__main__":
374+
375+
import argparse
376+
377+
args = parse_args()
378+
379+
print("Final configuration:")
380+
print(f" Batch size: {args.batch_size}")
381+
print(f" Sequence length: {args.seq_len}")
382+
print(f" Hidden size: {args.hidden_size}")
383+
print(f" Intermediate size: {args.intermediate_size}")
384+
print(f" Number of groups: {args.num_groups}")
385+
print(f" Data type: {args.dtype}")
386+
print(f" Use residual: {args.use_residual}")
387+
388+
batch_size_range = [2**i for i in range(0, 7, 2)]
389+
seq_length_range = [2**i for i in range(6, 10, 1)]
390+
head_num_range = args.head_num_range
391+
configs = list(
392+
itertools.product(head_num_range, batch_size_range, seq_length_range))
393+
394+
if HAS_IPEX:
395+
print("✅ IPEX is available")
396+
print(f"IPEX version: {ipex.__version__}")
397+
else:
398+
print("⚠️ IPEX is not available, skipping IPEX benchmarks")
399+
253400
# Run correctness test
254401
calculate_diff(
255402
batch_size=args.batch_size,
@@ -259,6 +406,6 @@ def benchmark(head_num, batch_size, seq_len, provider):
259406
)
260407

261408
# Get the benchmark function with proper use_residual setting
262-
benchmark = get_benchmark(args.use_residual)
409+
benchmark = get_benchmark(args.use_residual, args.dtype)
263410
# Run performance benchmark
264411
benchmark.run(print_data=True, save_path=args.save_path)

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,5 @@ pytest
2020

2121
# dependencies
2222
matplotlib # in benchmark, introduced by triton
23-
pandas # in benchmark, introduced by triton
23+
pandas # in benchmark, introduced by triton
24+
transformers # for model config

0 commit comments

Comments
 (0)