99from torch import nn
1010
1111from tests import register_ops as vllm_ops
12+ from tests .utils import check_ipex_availability , get_model_config
13+
14+ HAS_IPEX = check_ipex_availability ()
15+
16+ if HAS_IPEX :
17+ import intel_extension_for_pytorch as ipex
1218
1319
1420class HuggingFaceRMSNorm (nn .Module ):
@@ -112,6 +118,36 @@ def rmsnorm_vllm(
112118 return output
113119
114120
121+ def rmsnorm_ipex (
122+ x : torch .Tensor ,
123+ weight : torch .Tensor ,
124+ residual : Optional [torch .Tensor ] = None ,
125+ eps : float = 1e-6 ,
126+ ):
127+ """IPEX implementation using ipex.llm.functional.rms_norm"""
128+ if not HAS_IPEX :
129+ raise RuntimeError ("IPEX is not available" )
130+
131+ orig_shape = x .shape
132+ x = x .view (- 1 , x .shape [- 1 ])
133+
134+ if residual is not None :
135+ residual = residual .view (- 1 , residual .shape [- 1 ])
136+ if hasattr (ipex .llm .functional , 'fused_add_rms_norm' ):
137+ output , residual_out = ipex .llm .functional .fused_add_rms_norm (
138+ x , residual , weight , eps )
139+ output = (output .view (orig_shape ), residual_out .view (orig_shape ))
140+ else :
141+ x = x + residual
142+ output = ipex .llm .functional .rms_norm (x , weight , eps )
143+ output = (output .view (orig_shape ), x .view (orig_shape ))
144+ else :
145+ output = ipex .llm .functional .rms_norm (x , weight , eps )
146+ output = output .view (orig_shape )
147+
148+ return output
149+
150+
115151def calculate_diff (batch_size , seq_len , hidden_size , use_residual = True ):
116152 dtype = torch .bfloat16
117153 x = torch .randn (batch_size ,
@@ -136,42 +172,49 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
136172 print (f"Naive output={ output_naive } " )
137173 print (f"vLLM output={ output_vllm } " )
138174
175+ if HAS_IPEX :
176+ try :
177+ output_ipex = rmsnorm_ipex (
178+ x .clone (), weight ,
179+ residual .clone () if residual is not None else None )
180+ if use_residual :
181+ output_ipex = output_ipex [0 ]
182+ print (f"IPEX output={ output_ipex } " )
183+
184+ if torch .allclose (output_naive , output_ipex , atol = 1e-2 , rtol = 1e-2 ):
185+ print ("✅ IPEX implementation matches naive" )
186+ else :
187+ print ("❌ IPEX implementation differs from naive" )
188+ except Exception as e :
189+ print (f"❌ IPEX implementation failed: { e } " )
190+
139191 if torch .allclose (output_naive , output_vllm , atol = 1e-2 , rtol = 1e-2 ):
140192 print ("✅ All implementations match" )
141193 else :
142194 print ("❌ Implementations differ" )
143195
144196
145- batch_size_range = [2 ** i for i in range (0 , 7 , 2 )]
146- seq_length_range = [2 ** i for i in range (6 , 10 , 1 )]
147- head_num_range = [
148- 32 , #Llama
149- 40 , #Qwen 14B/32B
150- 48 ,
151- 64 , #Llama 2 70B
152- 128 , # Deepseek R1
153- ]
154- configs = list (
155- itertools .product (head_num_range , batch_size_range , seq_length_range ))
156-
157-
158- def get_benchmark (use_residual ):
197+ def get_benchmark (use_residual , dtype ):
159198
160199 @triton .testing .perf_report (
161200 triton .testing .Benchmark (
162201 x_names = ["head_num" , "batch_size" , "seq_len" ],
163202 x_vals = [tuple (_ ) for _ in configs ],
164203 line_arg = "provider" ,
165- line_vals = ["huggingface" , "vllm" , "t.compile" ],
166- line_names = ["HuggingFace" , "vLLM" , "t.compile" ],
167- styles = [("blue" , "-" ), ("green" , "-" ), ("orange" , "-" )],
204+ line_vals = ["huggingface" , "vllm" , "t.compile" , "ipex" ]
205+ if HAS_IPEX else ["huggingface" , "vllm" , "t.compile" ],
206+ line_names = ["HuggingFace" , "vLLM" , "t.compile" , "IPEX" ]
207+ if HAS_IPEX else ["HuggingFace" , "vLLM" , "t.compile" ],
208+ styles = [("blue" , "-" ), ("green" , "-" ), ("orange" , "-" ),
209+ ("red" , "-" )] if HAS_IPEX else [("blue" , "-" ),
210+ ("green" , "-" ),
211+ ("orange" , "-" )],
168212 ylabel = "us" ,
169213 plot_name =
170214 f"rmsnorm-perf-{ 'with' if use_residual else 'without' } -residual" ,
171215 args = {},
172216 ))
173217 def benchmark (head_num , batch_size , seq_len , provider ):
174- dtype = torch .bfloat16
175218 hidden_size = head_num * 128 # assuming head_dim = 128
176219
177220 x = torch .randn (batch_size ,
@@ -202,6 +245,15 @@ def benchmark(head_num, batch_size, seq_len, provider):
202245 ),
203246 quantiles = quantiles ,
204247 )
248+ elif provider == "ipex" and HAS_IPEX :
249+ ms , min_ms , max_ms = triton .testing .do_bench (
250+ lambda : rmsnorm_ipex (
251+ x .clone (),
252+ weight ,
253+ residual .clone () if residual is not None else None ,
254+ ),
255+ quantiles = quantiles ,
256+ )
205257 else :
206258 ms , min_ms , max_ms = triton .testing .do_bench (
207259 lambda : rmsnorm_vllm (
@@ -216,9 +268,7 @@ def benchmark(head_num, batch_size, seq_len, provider):
216268 return benchmark
217269
218270
219- if __name__ == "__main__" :
220- import argparse
221-
271+ def parse_args ():
222272 parser = argparse .ArgumentParser ()
223273 parser .add_argument (
224274 "--batch-size" ,
@@ -238,6 +288,42 @@ def benchmark(head_num, batch_size, seq_len, provider):
238288 default = 4096 ,
239289 help = "Hidden size (2nd dimension) of the sequence" ,
240290 )
291+ parser .add_argument (
292+ "--intermediate-size" ,
293+ type = int ,
294+ default = None ,
295+ help = "Intermediate size for FFN layers" ,
296+ )
297+ parser .add_argument (
298+ "--num-groups" ,
299+ type = int ,
300+ default = None ,
301+ help = "Number of expert groups for MoE models" ,
302+ )
303+ parser .add_argument (
304+ "--dtype" ,
305+ type = str ,
306+ default = torch .bfloat16 ,
307+ help = "Data type from model config" ,
308+ )
309+ parser .add_argument (
310+ "--model-name" ,
311+ type = str ,
312+ default = None ,
313+ help = "Model name to load configuration from" ,
314+ )
315+ parser .add_argument ("--head-num-range" ,
316+ type = int ,
317+ nargs = '+' ,
318+ default = [12 , 32 , 40 , 48 , 64 , 96 , 128 ],
319+ help = ("Range of attention head numbers to test/use. "
320+ "Default: 12 32 40 48 64 96 128" ))
321+ parser .add_argument (
322+ "--tp-size" ,
323+ type = int ,
324+ default = 1 ,
325+ help = "Tensor parallelism size" ,
326+ )
241327 parser .add_argument ("--use-residual" ,
242328 action = "store_true" ,
243329 help = "Whether to use residual connection" )
@@ -250,6 +336,67 @@ def benchmark(head_num, batch_size, seq_len, provider):
250336
251337 args = parser .parse_args ()
252338
339+ if args .model_name :
340+ model_config = get_model_config (args .model_name , args .tp_size )
341+
342+ if args .hidden_size == 4096 :
343+ args .hidden_size = model_config ["hidden_size" ]
344+
345+ if args .intermediate_size is None :
346+ args .intermediate_size = model_config ["intermediate_size" ]
347+
348+ if args .num_groups is None :
349+ args .num_groups = model_config ["num_groups" ]
350+
351+ if args .dtype is None :
352+ args .dtype = model_config ["dtype" ]
353+
354+ if args .head_num_range == [12 , 32 , 40 , 48 , 64 , 96 , 128 ]:
355+ model_heads = model_config .get ("num_attention_heads" , 32 )
356+ if model_heads not in args .head_num_range :
357+ args .head_num_range .append (model_heads )
358+ args .head_num_range .sort ()
359+ print (
360+ f"Added model's head number { model_heads } to head_num_range"
361+ )
362+
363+ print (f"Using model configuration from: { args .model_name } " )
364+ print (f"Updated hidden_size: { args .hidden_size } " )
365+ print (f"Updated intermediate_size: { args .intermediate_size } " )
366+ print (f"Updated num_groups: { args .num_groups } " )
367+ print (f"Updated head_num_range: { args .head_num_range } " )
368+ print (f"Updated dtype: { args .dtype } " )
369+
370+ return args
371+
372+
373+ if __name__ == "__main__" :
374+
375+ import argparse
376+
377+ args = parse_args ()
378+
379+ print ("Final configuration:" )
380+ print (f" Batch size: { args .batch_size } " )
381+ print (f" Sequence length: { args .seq_len } " )
382+ print (f" Hidden size: { args .hidden_size } " )
383+ print (f" Intermediate size: { args .intermediate_size } " )
384+ print (f" Number of groups: { args .num_groups } " )
385+ print (f" Data type: { args .dtype } " )
386+ print (f" Use residual: { args .use_residual } " )
387+
388+ batch_size_range = [2 ** i for i in range (0 , 7 , 2 )]
389+ seq_length_range = [2 ** i for i in range (6 , 10 , 1 )]
390+ head_num_range = args .head_num_range
391+ configs = list (
392+ itertools .product (head_num_range , batch_size_range , seq_length_range ))
393+
394+ if HAS_IPEX :
395+ print ("✅ IPEX is available" )
396+ print (f"IPEX version: { ipex .__version__ } " )
397+ else :
398+ print ("⚠️ IPEX is not available, skipping IPEX benchmarks" )
399+
253400 # Run correctness test
254401 calculate_diff (
255402 batch_size = args .batch_size ,
@@ -259,6 +406,6 @@ def benchmark(head_num, batch_size, seq_len, provider):
259406 )
260407
261408 # Get the benchmark function with proper use_residual setting
262- benchmark = get_benchmark (args .use_residual )
409+ benchmark = get_benchmark (args .use_residual , args . dtype )
263410 # Run performance benchmark
264411 benchmark .run (print_data = True , save_path = args .save_path )
0 commit comments