@@ -148,9 +148,11 @@ def get_model_and_inputs(model_id, batch_size, tb_dir, tb_args, get_baseline=Fal
148
148
return model_name , model , forward_args
149
149
150
150
151
- '''
151
+ """
152
152
Imports models from torchbench model tooling, exports them with turbine AOT, and does simple benchmarking.
153
- '''
153
+ """
154
+
155
+
154
156
@torch .no_grad ()
155
157
def benchmark_torchbench_model (
156
158
model_id ,
@@ -199,7 +201,7 @@ def benchmark_torchbench_model(
199
201
)
200
202
return vmfb_path
201
203
202
- if compare_vs_eager :
204
+ if compare_vs_eager :
203
205
model_name , model , forward_args , golden , baseline = get_model_and_inputs (
204
206
model_id , batch_size , tb_dir , tb_args , get_baseline = True
205
207
)
@@ -316,13 +318,28 @@ def _run_iter(runner, inputs):
316
318
res = runner .ctx .modules .compiled_torchbench_model ["main" ](* inputs )
317
319
return res , time .time () - start
318
320
321
+
319
322
def do_compare (shark_results , shark_latency , golden_results , golden_latency ):
320
- numerics_pass_fail = np .allclose (shark_results .to_host (), golden_results .clone ().cpu ().numpy (), rtol = 1e-4 , atol = 1e-4 )
323
+ numerics_pass_fail = np .allclose (
324
+ shark_results .to_host (),
325
+ golden_results .clone ().cpu ().numpy (),
326
+ rtol = 1e-4 ,
327
+ atol = 1e-4 ,
328
+ )
321
329
speedup = golden_latency / shark_latency
322
330
return speedup , numerics_pass_fail
323
331
332
+
324
333
def run_benchmark (
325
- device , vmfb_path , weights_path , example_args , model_id , csv_path , iters , golden = None , baseline = None ,
334
+ device ,
335
+ vmfb_path ,
336
+ weights_path ,
337
+ example_args ,
338
+ model_id ,
339
+ csv_path ,
340
+ iters ,
341
+ golden = None ,
342
+ baseline = None ,
326
343
):
327
344
if "rocm" in device :
328
345
device = "hip" + device .split ("rocm" )[- 1 ]
@@ -344,7 +361,13 @@ def run_benchmark(
344
361
if os .path .exists (csv_path ):
345
362
needs_header = False
346
363
with open (csv_path , "a" ) as csvfile :
347
- fieldnames = ["model" , "avg_latency" , "avg_iter_per_sec" , "speedup_over_eager" , "numerics" ]
364
+ fieldnames = [
365
+ "model" ,
366
+ "avg_latency" ,
367
+ "avg_iter_per_sec" ,
368
+ "speedup_over_eager" ,
369
+ "numerics" ,
370
+ ]
348
371
data = [
349
372
{
350
373
"model" : model_id ,
@@ -422,7 +445,7 @@ def run_main(model_id, args, tb_dir, tb_args):
422
445
from turbine_models .custom_models .torchbench .cmd_opts import args , unknown
423
446
import json
424
447
425
- torchbench_models_dict = json .load (args .model_list_json
448
+ torchbench_models_dict = json .load (args .model_list_json )
426
449
for list in args .model_lists :
427
450
torchbench_models_dict = json .load (list )
428
451
with open (args .models_json , "r" ) as f :
0 commit comments