-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtest.py
61 lines (51 loc) · 2.48 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
from options.test_options import TestOptions
from data.data_loader import DataLoader
from models.DIFL_model import DIFLModel
import numpy as np
class Tester():
def __init__(self):
# Parse test options. Note test code only supports nThreads=1 and batchSize=1
self.opt = TestOptions().parse()
self.opt.nThreads = 1
self.opt.batchSize = 1
self.dataset = DataLoader(self.opt)
self.model = DIFLModel(self.opt)
# Read groundtruth poses of database from txt files for each slice, note that the poses has already been
# transformed to R,t, not original R,c for CMU-Seasons dataset
self.split_file = os.path.join(self.opt.dataroot, 's' + str(self.opt.which_slice),
'pose_new_s' + str(self.opt.which_slice) + '.txt')
self.names = np.loadtxt(self.split_file, dtype=str, delimiter=' ', skiprows=0, usecols=(0))
with open(self.split_file, 'r') as pose_file:
self.poses = pose_file.read().splitlines()
if self.opt.test_using_cos:
metric_mode = "cos"
else:
metric_mode = "l2"
# Open the result txt file
self.result_file = open(self.opt.results_dir + self.opt.name + "_" + str(self.opt.which_epoch) + '_s' + str(
self.opt.which_slice) + "_" + metric_mode + ".txt", 'w')
def test(self):
for i, data in enumerate(self.dataset):
if not self.opt.serial_test and i >= self.opt.how_many:
break
self.model.set_input(data)
if self.opt.test_after_pca:
retrieved_path = self.model.test_pca()
else:
retrieved_path = self.model.test()
img_path = self.model.get_image_paths()
if retrieved_path != "database":
# find and write the corresponding pose for every retrieved path
for k in range(len(self.names)):
if self.names[k].split('/')[-1] == retrieved_path.split('/')[-1]:
self.result_file.write(
img_path[0].split('/')[-1] + self.poses[k][len(self.poses[k].split(' ')[0]):] + '\n')
print('Now %s' % img_path[0].split('/')[-1])
else:
print('Building up database... %s' % img_path[0].split('/')[-1])
self.result_file.close()
print("Done slice {}".format(self.opt.which_slice))
if __name__ == "__main__":
tester = Tester()
tester.test()