forked from PRBonn/bonnetal
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmake_deploy_model.py
executable file
·81 lines (74 loc) · 2.1 KB
/
make_deploy_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#!/usr/bin/env python3
# This file is covered by the LICENSE file in the root of this project.
import argparse
import subprocess
import datetime
import yaml
from shutil import copyfile
import os
import shutil
import __init__ as booger
from tasks.segmentation.modules.traceSaver import *
if __name__ == '__main__':
parser = argparse.ArgumentParser("./make_deploy_model.py")
parser.add_argument(
'--path', '-p',
type=str,
required=True,
help='Directory to get the pretrained model. No default!'
)
parser.add_argument(
'--log', '-l',
type=str,
required=True,
help='Directory to put the new model. No default!'
)
parser.add_argument(
'--new_h',
type=int,
dest='new_h',
default=None,
help='Force Height to. Defaults to %(default)s',
)
parser.add_argument(
'--new_w',
type=int,
dest='new_w',
default=None,
help='Force Width to. Defaults to %(default)s',
)
FLAGS, unparsed = parser.parse_known_args()
# print summary of what we will do
print("----------")
print("INTERFACE:")
print("model path", FLAGS.path)
print("log dir", FLAGS.log)
print("Height force", FLAGS.new_h)
print("Width force", FLAGS.new_w)
print("----------\n")
print("Commit hash (training version): ", str(
subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).strip()))
print("----------\n")
# does model folder exist?
if FLAGS.path is not None:
if os.path.isdir(FLAGS.path):
print("model folder exists! Using model from %s" % (FLAGS.path))
else:
print("model folder doesnt exist!")
quit()
else:
print("No pretrained directory found.")
# create log folder
try:
if os.path.isdir(FLAGS.log):
shutil.rmtree(FLAGS.log)
os.makedirs(FLAGS.log)
except Exception as e:
print(e)
print("Error creating log directory. Check permissions!")
quit()
# create saver and start the exporting
onnx_maker = TraceSaver(FLAGS.path,
FLAGS.log,
(FLAGS.new_h, FLAGS.new_w)) # force image properties
onnx_maker.export()