Skip to content

Commit 158a672

Browse files
committed
Small fixes
1 parent 5815e57 commit 158a672

File tree

4 files changed

+32
-10
lines changed

4 files changed

+32
-10
lines changed

models/turbine_models/custom_models/torchbench/__init__.py

Whitespace-only changes.

models/turbine_models/custom_models/torchbench/cmd_opts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def is_valid_file(arg):
4646
p.add_argument(
4747
"--model_lists",
4848
type=Path,
49-
nargs="*"
49+
nargs="*",
5050
help="path to a JSON list of models to benchmark. One or more paths.",
5151
default=["torchbench_models.json", "timm_models.json", "torchvision_models.json"],
5252
)

models/turbine_models/custom_models/torchbench/export.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,11 @@ def get_model_and_inputs(model_id, batch_size, tb_dir, tb_args, get_baseline=Fal
148148
return model_name, model, forward_args
149149

150150

151-
'''
151+
"""
152152
Imports models from torchbench model tooling, exports them with turbine AOT, and does simple benchmarking.
153-
'''
153+
"""
154+
155+
154156
@torch.no_grad()
155157
def benchmark_torchbench_model(
156158
model_id,
@@ -199,7 +201,7 @@ def benchmark_torchbench_model(
199201
)
200202
return vmfb_path
201203

202-
if compare_vs_eager:
204+
if compare_vs_eager:
203205
model_name, model, forward_args, golden, baseline = get_model_and_inputs(
204206
model_id, batch_size, tb_dir, tb_args, get_baseline=True
205207
)
@@ -316,13 +318,28 @@ def _run_iter(runner, inputs):
316318
res = runner.ctx.modules.compiled_torchbench_model["main"](*inputs)
317319
return res, time.time() - start
318320

321+
319322
def do_compare(shark_results, shark_latency, golden_results, golden_latency):
320-
numerics_pass_fail = np.allclose(shark_results.to_host(), golden_results.clone().cpu().numpy(), rtol=1e-4, atol=1e-4)
323+
numerics_pass_fail = np.allclose(
324+
shark_results.to_host(),
325+
golden_results.clone().cpu().numpy(),
326+
rtol=1e-4,
327+
atol=1e-4,
328+
)
321329
speedup = golden_latency / shark_latency
322330
return speedup, numerics_pass_fail
323331

332+
324333
def run_benchmark(
325-
device, vmfb_path, weights_path, example_args, model_id, csv_path, iters, golden=None, baseline=None,
334+
device,
335+
vmfb_path,
336+
weights_path,
337+
example_args,
338+
model_id,
339+
csv_path,
340+
iters,
341+
golden=None,
342+
baseline=None,
326343
):
327344
if "rocm" in device:
328345
device = "hip" + device.split("rocm")[-1]
@@ -344,7 +361,13 @@ def run_benchmark(
344361
if os.path.exists(csv_path):
345362
needs_header = False
346363
with open(csv_path, "a") as csvfile:
347-
fieldnames = ["model", "avg_latency", "avg_iter_per_sec", "speedup_over_eager", "numerics"]
364+
fieldnames = [
365+
"model",
366+
"avg_latency",
367+
"avg_iter_per_sec",
368+
"speedup_over_eager",
369+
"numerics",
370+
]
348371
data = [
349372
{
350373
"model": model_id,
@@ -422,7 +445,7 @@ def run_main(model_id, args, tb_dir, tb_args):
422445
from turbine_models.custom_models.torchbench.cmd_opts import args, unknown
423446
import json
424447

425-
torchbench_models_dict = json.load(args.model_list_json
448+
torchbench_models_dict = json.load(args.model_list_json)
426449
for list in args.model_lists:
427450
torchbench_models_dict = json.load(list)
428451
with open(args.models_json, "r") as f:

models/turbine_models/custom_models/torchbench/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424
"--iree-hip-waves-per-eu=2",
2525
"--iree-execution-model=async-external",
2626
],
27-
"preprocess_default": [
28-
]
27+
"preprocess_default": [],
2928
}
3029
GFX11_flags = {
3130
"all": [

0 commit comments

Comments
 (0)