-
Notifications
You must be signed in to change notification settings - Fork 38
/
Copy pathevaluate.py
96 lines (71 loc) · 2.59 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import argparse
import os
import time
import numpy as np
from scipy.misc import imread, imresize, imsave
import torch
from torch.autograd import Variable
import torch.utils.data as data
from util import eval_forward, evaluate, get_models, set_eval, save_numpy_array_as_image
from torchvision import transforms
from dataset import get_loader
def save_codes(name, codes):
print(codes)
codes = (codes.astype(np.int8) + 1) // 2
export = np.packbits(codes.reshape(-1))
np.savez_compressed(
name + '.codes',
shape=codes.shape,
codes=export)
def save_output_images(name, ex_imgs):
for i, img in enumerate(ex_imgs):
save_numpy_array_as_image(
'%s_iter%02d.png' % (name, i + 1),
img
)
def finish_batch(args, filenames, original, out_imgs,
losses, code_batch, output_suffix):
all_losses, all_msssim, all_psnr = [], [], []
for ex_idx, filename in enumerate(filenames):
filename = filename.split('/')[-1]
if args.save_codes:
save_codes(
os.path.join(args.out_dir, output_suffix, 'codes', filename),
code_batch[:, ex_idx, :, :, :]
)
if args.save_out_img:
save_output_images(
os.path.join(args.out_dir, output_suffix, 'images', filename),
out_imgs[:, ex_idx, :, :, :]
)
msssim, psnr = evaluate(
original[None, ex_idx],
[out_img[None, ex_idx] for out_img in out_imgs])
all_losses.append(losses)
all_msssim.append(msssim)
all_psnr.append(psnr)
return all_losses, all_msssim, all_psnr
def run_eval(model, eval_loader, args, output_suffix=''):
for sub_dir in ['codes', 'images']:
cur_eval_dir = os.path.join(args.out_dir, output_suffix, sub_dir)
if not os.path.exists(cur_eval_dir):
print("Creating directory %s." % cur_eval_dir)
os.makedirs(cur_eval_dir)
all_losses, all_msssim, all_psnr = [], [], []
start_time = time.time()
for i, (batch, ctx_frames, filenames) in enumerate(eval_loader):
batch = Variable(batch.cuda(), volatile=True)
original, out_imgs, losses, code_batch = eval_forward(
model, (batch, ctx_frames), args)
losses, msssim, psnr = finish_batch(
args, filenames, original, out_imgs,
losses, code_batch, output_suffix)
all_losses += losses
all_msssim += msssim
all_psnr += psnr
if i % 10 == 0:
print('\tevaluating iter %d (%f seconds)...' % (
i, time.time() - start_time))
return (np.array(all_losses).mean(axis=0),
np.array(all_msssim).mean(axis=0),
np.array(all_psnr).mean(axis=0))