-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathtesting.py
107 lines (68 loc) · 2.81 KB
/
testing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import argparse
import random
import json
import numpy as np
import torch
from args import Configs
import logging
import data_loader
from models import TReS, Net
print('torch version: {}'.format(torch.__version__))
def main(config,device):
os.environ['CUDA_VISIBLE_DEVICES'] = config.gpunum
folder_path = {
'live': config.datapath,
'csiq': config.datapath,
'tid2013': config.datapath,
'kadid10k': config.datapath,
'clive': config.datapath,
'koniq': config.datapath,
'fblive': config.datapath,
}
img_num = {
'live': list(range(0, 29)),
'csiq': list(range(0, 30)),
'kadid10k': list(range(0, 80)),
'tid2013': list(range(0, 25)),
'clive': list(range(0, 1162)),
'koniq': list(range(0, 10073)),
'fblive': list(range(0, 39810)),
}
print('Testing on {} dataset...'.format(config.dataset))
SavePath = config.svpath
svPath = SavePath+ config.dataset + '_' + str(config.vesion)+'_'+str(config.seed)+'/'+'sv'
os.makedirs(svPath, exist_ok=True)
# fix the seed if needed for reproducibility
if config.seed == 0:
pass
else:
torch.manual_seed(config.seed)
torch.cuda.manual_seed(config.seed)
np.random.seed(config.seed)
random.seed(config.seed)
pretrained_path = config.svpath + config.dataset + '_' + str(config.vesion)+'_'+str(config.seed)+'/'+'sv/'
print('path: {}'.format(pretrained_path))
path = pretrained_path + 'test_index_'+str(config.vesion)+'_'+str(config.seed)+'.json'
path2 = pretrained_path + 'train_index_'+str(config.vesion)+'_'+str(config.seed)+'.json'
with open(path) as json_file:
test_index = json.load(json_file)
with open(path2) as json_file:
train_index =json.load(json_file)
test_loader = data_loader.DataLoader(config.dataset, folder_path[config.dataset],
test_index, config.patch_size,
config.test_patch_num, istrain=False)
test_data = test_loader.get_data()
solver = TReS(config,device, svPath, folder_path[config.dataset], train_index, test_index,Net)
version_test_save = 1000
srcc_computed, plcc_computed = solver.test(test_data,version_test_save,svPath,config.seed,pretrained=1)
print('srcc_computed {}, plcc_computed {}'.format(srcc_computed, plcc_computed))
if __name__ == '__main__':
config = Configs()
print(config)
if torch.cuda.is_available():
if len(config.gpunum)==1:
device = torch.device("cuda", index=int(config.gpunum))
else:
device = torch.device("cpu")
main(config,device)