forked from isalirezag/TReS
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun.py
119 lines (77 loc) · 3.22 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import argparse
import random
import json
import numpy as np
import torch
from args import Configs
import logging
from models import TReS, Net
print('torch version: {}'.format(torch.__version__))
def main(config,device):
os.environ['CUDA_VISIBLE_DEVICES'] = config.gpunum
folder_path = {
'live': config.datapath,
'csiq': config.datapath,
'tid2013': config.datapath,
'kadid10k': config.datapath,
'clive': config.datapath,
'koniq': config.datapath,
'fblive': config.datapath,
}
img_num = {
'live': list(range(0, 29)),
'csiq': list(range(0, 30)),
'kadid10k': list(range(0, 80)),
'tid2013': list(range(0, 25)),
'clive': list(range(0, 1162)),
'koniq': list(range(0, 10073)),
'fblive': list(range(0, 39810)),
}
print('Training and Testing on {} dataset...'.format(config.dataset))
SavePath = config.svpath
svPath = SavePath+ config.dataset + '_' + str(config.vesion)+'_'+str(config.seed)+'/'+'sv'
os.makedirs(svPath, exist_ok=True)
# fix the seed if needed for reproducibility
if config.seed == 0:
pass
else:
print('we are using the seed = {}'.format(config.seed))
torch.manual_seed(config.seed)
torch.cuda.manual_seed(config.seed)
np.random.seed(config.seed)
random.seed(config.seed)
total_num_images = img_num[config.dataset]
# Randomly select 80% images for training and the rest for testing
random.shuffle(total_num_images)
train_index = total_num_images[0:int(round(0.8 * len(total_num_images)))]
test_index = total_num_images[int(round(0.8 * len(total_num_images))):len(total_num_images)]
imgsTrainPath = svPath + '/' + 'train_index_'+str(config.vesion)+'_'+str(config.seed)+'.json'
imgsTestPath = svPath + '/' + 'test_index_'+str(config.vesion)+'_'+str(config.seed)+'.json'
with open(imgsTrainPath, 'w') as json_file2:
json.dump( train_index, json_file2)
with open(imgsTestPath, 'w') as json_file2:
json.dump( test_index, json_file2)
solver = TReS(config,device, svPath, folder_path[config.dataset], train_index, test_index,Net)
srcc_computed, plcc_computed = solver.train(config.seed,svPath)
# logging the performance
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
handler = logging.FileHandler(svPath + '/LogPerformance.log')
formatter = logging.Formatter('%(asctime)s : %(levelname)s : %(name)s : %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
Dataset = config.dataset
logger.info(Dataset)
PrintToLogg = 'Best PLCC: {}, SROCC: {}'.format(plcc_computed,srcc_computed)
logger.info(PrintToLogg)
logger.info('---------------------------')
if __name__ == '__main__':
config = Configs()
print(config)
if torch.cuda.is_available():
if len(config.gpunum)==1:
device = torch.device("cuda", index=int(config.gpunum))
else:
device = torch.device("cpu")
main(config,device)