Skip to content

Commit 5815e57

Browse files
committed
Split models list up into groups, pipe through CLI option
1 parent 896c221 commit 5815e57

File tree

5 files changed

+76
-8
lines changed

5 files changed

+76
-8
lines changed

models/turbine_models/custom_models/torchbench/cmd_opts.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,13 @@ def is_valid_file(arg):
4343
help="model ID as it appears in the torchbench models text file lists, or 'all' for batch export",
4444
default="all",
4545
)
46+
p.add_argument(
47+
"--model_lists",
48+
type=Path,
49+
nargs="*"
50+
help="path to a JSON list of models to benchmark. One or more paths.",
51+
default=["torchbench_models.json", "timm_models.json", "torchvision_models.json"],
52+
)
4653
p.add_argument(
4754
"--external_weights_dir",
4855
type=str,

models/turbine_models/custom_models/torchbench/export.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
import csv
3535

36-
torchbench_models_dict = {
36+
torchbench_models_all = {
3737
# "BERT_pytorch": {
3838
# "dim": 128,
3939
# }, # Dynamo Export Issue
@@ -420,10 +420,17 @@ def run_main(model_id, args, tb_dir, tb_args):
420420

421421
if __name__ == "__main__":
422422
from turbine_models.custom_models.torchbench.cmd_opts import args, unknown
423-
424-
tb_dir = setup_torchbench_cwd()
425-
if args.model_id.lower() == "all":
426-
for name in torchbench_models_dict.keys():
427-
run_main(name, args, tb_dir, unknown)
428-
else:
429-
run_main(args.model_id, args, tb_dir, unknown)
423+
import json
424+
425+
torchbench_models_dict = json.load(args.model_list_json
426+
for list in args.model_lists:
427+
torchbench_models_dict = json.load(list)
428+
with open(args.models_json, "r") as f:
429+
torchbench_models_dict = json.load(file)
430+
431+
tb_dir = setup_torchbench_cwd()
432+
if args.model_id.lower() == "all":
433+
for name in torchbench_models_dict.keys():
434+
run_main(name, args, tb_dir, unknown)
435+
else:
436+
run_main(args.model_id, args, tb_dir, unknown)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"timm_efficientnet": {
3+
"dim": 128
4+
},
5+
"timm_regnet": {
6+
"dim": 128
7+
},
8+
"timm_resnest": {
9+
"dim": 256
10+
},
11+
"timm_vovnet": {
12+
"dim": 128
13+
}
14+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"pytorch_unet": {
3+
"dim": 8
4+
}
5+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
{
2+
"LearningToPaint": {
3+
"dim": 1024
4+
},
5+
"alexnet": {
6+
"dim": 1024
7+
},
8+
"densenet121": {
9+
"dim": 64
10+
},
11+
"mnasnet1_0": {
12+
"dim": 256
13+
},
14+
"mobilenet_v2": {
15+
"dim": 128
16+
},
17+
"mobilenet_v3_large": {
18+
"dim": 256
19+
},
20+
"resnet18": {
21+
"dim": 512
22+
},
23+
"resnet50": {
24+
"dim": 128
25+
},
26+
"resnext50_32x4d": {
27+
"dim": 128
28+
},
29+
"shufflenet_v2_x1_0": {
30+
"dim": 512
31+
},
32+
"squeezenet1_1": {
33+
"dim": 512
34+
}
35+
}

0 commit comments

Comments
 (0)