36
36
torchbench_models_dict = {
37
37
# "BERT_pytorch": {
38
38
# "dim": 128,
39
- # },
39
+ # }, # Dynamo Export Issue
40
40
# "Background_Matting": {
41
41
# "dim": 16,
42
- # },
43
- # "LearningToPaint": {
44
- # "dim": 1024,
45
- # },
42
+ # }, # Transpose Bubbling Pattern Failed
43
+ "LearningToPaint" : {
44
+ "dim" : 1024 ,
45
+ },
46
46
"alexnet" : {
47
47
"dim" : 1024 ,
48
48
},
49
- # "densenet121": {
50
- # "dim": 64,
51
- # },
49
+ "densenet121" : {
50
+ "dim" : 64 ,
51
+ },
52
52
# "hf_Albert": {"dim": 32, "buffer_prefix": "albert"},
53
53
# "hf_Bart": {
54
54
# "dim": 16,
@@ -131,17 +131,28 @@ def get_runner(tb_dir, tb_args):
131
131
return runner
132
132
133
133
134
- def get_model_and_inputs (model_id , batch_size , tb_dir , tb_args ):
134
+ def get_model_and_inputs (model_id , batch_size , tb_dir , tb_args , get_baseline = False ):
135
135
runner = get_runner (tb_dir , tb_args )
136
- return runner .load_model (
136
+ _ , model_name , model , forward_args , _ = runner .load_model (
137
137
"cuda:0" ,
138
138
model_id ,
139
139
batch_size = batch_size ,
140
140
)
141
-
142
-
141
+ match get_baseline :
142
+ case True :
143
+ start_t = time .time ()
144
+ res = runner .forward_pass (model , forward_args , collect_outputs = True )
145
+ baseline = time .time () - start_t
146
+ return model_name , model , forward_args , res , baseline
147
+ case False :
148
+ return model_name , model , forward_args
149
+
150
+
151
+ '''
152
+ Imports models from torchbench model tooling, exports them with turbine AOT, and does simple benchmarking.
153
+ '''
143
154
@torch .no_grad ()
144
- def export_torchbench_model (
155
+ def benchmark_torchbench_model (
145
156
model_id ,
146
157
tb_dir ,
147
158
tb_args ,
@@ -159,6 +170,7 @@ def export_torchbench_model(
159
170
input_mlir = None ,
160
171
weights_only = False ,
161
172
upload_ir = False ,
173
+ compare_vs_eager = False ,
162
174
):
163
175
static_dim = torchbench_models_dict [model_id ]["dim" ]
164
176
dtype = torch .float16 if precision == "fp16" else torch .float32
@@ -187,9 +199,16 @@ def export_torchbench_model(
187
199
)
188
200
return vmfb_path
189
201
190
- _ , model_name , model , forward_args , _ = get_model_and_inputs (
191
- model_id , batch_size , tb_dir , tb_args
192
- )
202
+ if compare_vs_eager :
203
+ model_name , model , forward_args , golden , baseline = get_model_and_inputs (
204
+ model_id , batch_size , tb_dir , tb_args , get_baseline = True
205
+ )
206
+ else :
207
+ model_name , model , forward_args = get_model_and_inputs (
208
+ model_id , batch_size , tb_dir , tb_args
209
+ )
210
+ golden = None
211
+ baseline = None
193
212
194
213
if dtype == torch .float16 :
195
214
model = model .half ()
@@ -275,7 +294,8 @@ class CompiledTorchbenchModel(CompiledModule):
275
294
inst = CompiledTorchbenchModel (context = Context (), import_to = "IMPORT" )
276
295
277
296
module = CompiledModule .get_mlir_module (inst )
278
-
297
+ model .to ("cpu" )
298
+ del model
279
299
if compile_to != "vmfb" :
280
300
return str (module )
281
301
else :
@@ -288,17 +308,21 @@ class CompiledTorchbenchModel(CompiledModule):
288
308
return_path = not exit_on_vmfb ,
289
309
attn_spec = attn_spec ,
290
310
)
291
- return vmfb_path , external_weight_path , forward_args
311
+ return vmfb_path , external_weight_path , forward_args , golden , baseline
292
312
293
313
294
314
def _run_iter (runner , inputs ):
295
315
start = time .time ()
296
316
res = runner .ctx .modules .compiled_torchbench_model ["main" ](* inputs )
297
317
return res , time .time () - start
298
318
319
+ def do_compare (shark_results , shark_latency , golden_results , golden_latency ):
320
+ numerics_pass_fail = np .allclose (shark_results .to_host (), golden_results .clone ().cpu ().numpy (), rtol = 1e-4 , atol = 1e-4 )
321
+ speedup = golden_latency / shark_latency
322
+ return speedup , numerics_pass_fail
299
323
300
324
def run_benchmark (
301
- device , vmfb_path , weights_path , example_args , model_id , csv_path , iters
325
+ device , vmfb_path , weights_path , example_args , model_id , csv_path , iters , golden = None , baseline = None ,
302
326
):
303
327
if "rocm" in device :
304
328
device = "hip" + device .split ("rocm" )[- 1 ]
@@ -311,16 +335,23 @@ def run_benchmark(
311
335
avg_latency = sum (iter_latencies ) / len (iter_latencies )
312
336
it_per_sec = 1 / avg_latency
313
337
338
+ if golden is not None and baseline is not None :
339
+ speedup , numerics_pass_fail = do_compare (results , avg_latency , golden , baseline )
340
+ else :
341
+ speedup , numerics_pass_fail = ("N/A" , "N/A" )
342
+
314
343
needs_header = True
315
344
if os .path .exists (csv_path ):
316
345
needs_header = False
317
346
with open (csv_path , "a" ) as csvfile :
318
- fieldnames = ["model" , "avg_latency" , "avg_iter_per_sec" ]
347
+ fieldnames = ["model" , "avg_latency" , "avg_iter_per_sec" , "speedup_over_eager" , "numerics" ]
319
348
data = [
320
349
{
321
350
"model" : model_id ,
322
351
"avg_latency" : avg_latency ,
323
352
"avg_iter_per_sec" : it_per_sec ,
353
+ "speedup_over_eager" : speedup ,
354
+ "numerics" : numerics_pass_fail ,
324
355
}
325
356
]
326
357
writer = csv .DictWriter (csvfile , fieldnames = fieldnames )
@@ -346,7 +377,7 @@ def torch_to_iree(iree_runner, example_args):
346
377
347
378
def run_main (model_id , args , tb_dir , tb_args ):
348
379
print (f"exporting { model_id } " )
349
- mod_str , weights_path , example_args = export_torchbench_model (
380
+ mod_str , weights_path , example_args , golden , baseline = benchmark_torchbench_model (
350
381
model_id ,
351
382
tb_dir ,
352
383
tb_args ,
@@ -361,6 +392,7 @@ def run_main(model_id, args, tb_dir, tb_args):
361
392
decomp_attn = args .decomp_attn ,
362
393
attn_spec = args .attn_spec ,
363
394
input_mlir = args .input_mlir ,
395
+ compare_vs_eager = args .compare_vs_torch ,
364
396
)
365
397
if args .compile_to in ["torch" , "mlir" ]:
366
398
safe_name = utils .create_safe_name (
@@ -379,6 +411,8 @@ def run_main(model_id, args, tb_dir, tb_args):
379
411
model_id ,
380
412
args .output_csv ,
381
413
args .num_iters ,
414
+ golden ,
415
+ baseline ,
382
416
)
383
417
384
418
gc .collect ()
0 commit comments