forked from PRBonn/bonnetal
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinfer_img.py
executable file
·161 lines (144 loc) · 4.52 KB
/
infer_img.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
#!/usr/bin/env python3
# This file is covered by the LICENSE file in the root of this project.
import argparse
import subprocess
import datetime
import os
import shutil
import cv2
import __init__ as booger
# choice of backends implemented
backend_choices = ["native", "caffe2", "tensorrt", "pytorch"]
if __name__ == '__main__':
parser = argparse.ArgumentParser("./infer_img.py")
parser.add_argument(
'--image', '-i',
nargs='+',
type=str,
required=True,
help='Image to infer. No Default',
)
parser.add_argument(
'--path', '-p',
type=str,
required=True,
default=None,
help='Directory to get the trained model.'
)
parser.add_argument(
'--backend', '-b',
type=str,
required=False,
default=backend_choices[0],
help='Backend to use to infer. Defaults to %(default)s, and choices:{choices}'.format(
choices=backend_choices),
)
parser.add_argument(
'--workspace', '-w',
type=int,
required=False,
default=1000000000,
help='Workspace size for tensorRT. Defaults to %(default)s'
)
parser.add_argument(
'--topk', '-k',
type=int,
required=False,
default=1,
help='Top predictions. Defaults to %(default)s'
)
parser.add_argument(
'--calib_images', '-ci',
nargs='+',
type=str,
required=False,
default=None,
help='Images to calibrate int8 inference. No Default',
)
FLAGS, unparsed = parser.parse_known_args()
# print summary of what we will do
print("----------")
print("INTERFACE:")
print("Image", FLAGS.image)
print("model path", FLAGS.path)
print("backend", FLAGS.backend)
print("workspace", FLAGS.workspace)
print("topk", FLAGS.topk)
print("INT8 Calibration Images", FLAGS.calib_images)
print("----------\n")
print("Commit hash: ", str(
subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).strip()))
print("----------\n")
if(FLAGS.calib_images is not None and not isinstance(FLAGS.calib_images, list)):
FLAGS.calib_images = [FLAGS.calib_images]
# does model folder exist?
if FLAGS.path is not None:
if os.path.isdir(FLAGS.path):
print("model folder exists! Using model from %s" % (FLAGS.path))
else:
print("model folder doesnt exist! Exiting...")
quit()
else:
print("No pretrained directory found")
quit()
# check that backend makes sense
assert(FLAGS.backend in backend_choices)
# create inference context for the desired backend
if FLAGS.backend == "tensorrt":
# import and use tensorRT
from tasks.classification.modules.userTensorRT import UserTensorRT
user = UserTensorRT(FLAGS.path, FLAGS.workspace, FLAGS.calib_images)
elif FLAGS.backend == "caffe2":
# import and use caffe2
from tasks.classification.modules.userCaffe2 import UserCaffe2
user = UserCaffe2(FLAGS.path)
elif FLAGS.backend == "pytorch":
# import and use caffe2
from tasks.classification.modules.userPytorch import UserPytorch
user = UserPytorch(FLAGS.path)
else:
# default to native pytorch
from tasks.classification.modules.user import User
user = User(FLAGS.path)
# cv2 window that can be resized
cv2.namedWindow('predictions', cv2.WINDOW_NORMAL)
cv2.resizeWindow('predictions', 960, 540)
# open images
if type(FLAGS.image) is not list:
images = [FLAGS.image]
else:
images = FLAGS.image
for img in images:
# order
print("*" * 80)
# open
cv_img = cv2.imread(img, cv2.IMREAD_COLOR)
if cv_img is None:
print("Can't open ", img)
continue
# infer
print("Inferring ", img)
max_class, max_class_str = user.infer(cv_img, FLAGS.topk)
# make string from classes
h, w, d = cv_img.shape
for i, c in enumerate(max_class_str, 1):
# put text in frame to show
watermark = "[" + str(i) + "]: " + c
font_size, _ = cv2.getTextSize(watermark, fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=0.75, thickness=1)
cv2.putText(cv_img, watermark,
org=(10, h - 10 -
(2 * (len(max_class_str) - i) * font_size[1])),
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=0.75,
color=(255, 255, 255),
thickness=1,
lineType=cv2.LINE_AA,
bottomLeftOrigin=False)
# Display the resulting frame
cv2.imshow('predictions', cv_img)
ret = cv2.waitKey(0)
if ret == ord('q') or ret == 27:
break
else:
continue