forked from Goutam-Kelam/iSalGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathiSalGan_inference.py
More file actions
101 lines (74 loc) · 3.66 KB
/
iSalGan_inference.py
File metadata and controls
101 lines (74 loc) · 3.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import numpy as np
import os
import torch
from PIL import Image
#from torch.autograd import Variable
from torchvision import transforms
#from config import msra10k_path
#from config import ecssd_path
from config import dutomron_path,ecssd_path,msra10k_path
#from config import ecssd_path, hkuis_path, pascals_path, sod_path, dutomron_path
from misc import check_mkdir, crf_refine, AvgMeter, cal_precision_recall_mae, cal_fmeasure
#from model import R3Net
from iSalGan_generator import iSalGan
torch.manual_seed(2019)
# set which gpu to use
torch.cuda.set_device(0)
# the following two args specify the location of the file of trained model (pth extension)
# you should have the pth file in the folder './$ckpt_path$/$exp_name$'
ckpt_path = './ckpt'
exp_name = 'iSalGan'
args = {
'snapshot': '6000', # your snapshot filename (exclude extension name)
'crf_refine': True, # whether to use crf to refine results
'save_results': False # whether to save the resulting masks
}
img_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
to_pil = transforms.ToPILImage()
#to_test = {'ecssd': ecssd_path, 'hkuis': hkuis_path, 'pascal': pascals_path, 'sod': sod_path, 'dutomron': dutomron_path}
#to_test = {'ecssd': ecssd_path}
#to_test = {'dutomron': dutomron_path}
to_test = {'msra10k': msra10k_path}
def main():
net = iSalGan().cuda()
print('load snapshot \'%s\' for testing' % args['snapshot'])
net.load_state_dict(torch.load("/home/gautam/Project/iSalGan/ckpt/iSalGan/iSalGan_generator.pth"))
net.eval()
results = {}
with torch.no_grad():
for name, root in to_test.items():
precision_record, recall_record, = [AvgMeter() for _ in range(256)], [AvgMeter() for _ in range(256)]
mae_record = AvgMeter()
if args['save_results']:
check_mkdir(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot'])))
img_list = [os.path.splitext(f)[0] for f in os.listdir(root) if f.endswith('.jpg')]
for idx, img_name in enumerate(img_list):
print('predicting for %s: %d / %d' % (name, idx + 1, len(img_list)))
img = Image.open(os.path.join(root, img_name + '.jpg')).convert('RGB')
img_var = img_transform(img).unsqueeze(0).cuda()
#prediction = net(img_var)
# His net gave only 1 output, ours give 3 outs
predict = net(img_var)
prediction = predict[2] # Send the combined saliency as main output
prediction = np.array(to_pil(prediction.data.squeeze(0).cpu()))
if args['crf_refine']:
prediction = crf_refine(np.array(img), prediction)
gt = np.array(Image.open(os.path.join(root, img_name + '.png')).convert('L'))
precision, recall, mae = cal_precision_recall_mae(prediction, gt)
for pidx, pdata in enumerate(zip(precision, recall)):
p, r = pdata
precision_record[pidx].update(p)
recall_record[pidx].update(r)
mae_record.update(mae)
if args['save_results']:
Image.fromarray(prediction).save("/home/gautam/Project/iSalGan/ckpt/iSalGan/iSalGan_Result/{}.png".format(img_name))
fmeasure = cal_fmeasure([precord.avg for precord in precision_record],
[rrecord.avg for rrecord in recall_record])
results[name] = {'fmeasure': fmeasure, 'mae': mae_record.avg}
print('test results:')
print (results)
if __name__ == '__main__':
main()