-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
100 lines (71 loc) · 2.59 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import cv2
import numpy as np
from matplotlib import pyplot as plt
import matplotlib
from utils import load_model
def test(parser_args):
model_path = parser_args.model_path
SIZE_X = parser_args.img_size[0]
SIZE_Y = parser_args.img_size[1]
IMG_CHANNELS = parser_args.img_size[2]
n_classes = parser_args.class_num
save_dir = parser_args.save_dir
model_type = parser_args.model_type
im_dir = parser_args.img_dir
# plot_img = True
plot_img = False
viridis = matplotlib.cm.get_cmap('viridis', 256)
COLORS = viridis(np.linspace(0, 1, n_classes))[...,:3]
model = load_model(model_type, n_classes, SIZE_X, SIZE_Y, IMG_CHANNELS)
model.load_weights(model_path)
im_names = [f for f in os.listdir(im_dir) if f[-4:] == ".png"]
full_save_dir = os.path.join(save_dir, os.path.split(model_path)[-1].split(".")[0])
if not os.path.isdir(full_save_dir):
os.mkdir(full_save_dir)
for n in im_names:
save_path = os.path.join(full_save_dir, f'{n[:-4]}_mask.{n[-3:]}')
X_test = cv2.imread(os.path.join(im_dir, n)) / 255.0
test_img_input=np.expand_dims(X_test, 0)
prediction = (model.predict(test_img_input))
y_pred_argmax=np.argmax(prediction, axis=3)
res_mask = np.zeros((SIZE_X, SIZE_Y, 3))
# label_color = np.zeros((SIZE_X, SIZE_Y, 3))
for i in range(SIZE_X):
for j in range(SIZE_Y):
res_mask[i,j] = COLORS[y_pred_argmax[0,i,j]]
# label_color[i,j] = COLORS[label[i,j]]
plt.imsave(save_path, res_mask)
if plot_img:
plt.subplot(221)
plt.title('test_image')
plt.imshow(X_test)
plt.subplot(222)
plt.title('prediction')
plt.imshow(res_mask)
plt.show()
# plt.subplot(223)
# plt.title('GT label')
# plt.imshow(label_color)
# plt.show()
# if __name__ == '__main__':
# test_img_number = 0
# test_img = X_test[test_img_number]
# test_img_input=np.expand_dims(test_img, 0)
# prediction = (model.predict(test_img_input))
# print(prediction.shape)
# predicted_img=np.argmax(prediction, axis=3)[0,:,:]
# plt.imshow(predicted_img)
# plt.show()
# plt.subplot(221)
# plt.imshow(X_test[0])
# plt.subplot(222)
# plt.imshow(X_test[1])
# #IOU
# y_pred=model.predict(X_test)
# y_pred_argmax=np.argmax(y_pred, axis=3)
# plt.subplot(223)
# plt.imshow(y_pred_argmax[0, ...])
# plt.subplot(224)
# plt.imshow(y_pred_argmax[1, ...])
# plt.show()