-
Notifications
You must be signed in to change notification settings - Fork 2
/
predict.py
80 lines (56 loc) · 2.08 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import numpy as np
import math
import os
from tensorflow.keras.models import model_from_json
from skimage.measure import compare_ssim as ssim
import cv2
from matplotlib import pyplot as plt
from prepare_images import psnr,mse,compare_images,prepare_images
# define necessary image processing functions
def modcrop(img, scale):
tmpsz = img.shape
sz = tmpsz[0:2]
sz = sz - np.mod(sz, scale)
img = img[0:sz[0], 1:sz[1]]
return img
def shave(image, border):
img = image[border: -border, border: -border]
return img
# define main prediction function
def predict(image,img_name):
# load the srcnn model with weights
with open('model.json', "r") as json_file:
loaded_model_json = json_file.read()
srcnn = model_from_json(loaded_model_json)
srcnn.load_weights('3051crop_weight_200.h5')
degraded = image
file=img_name
ref = cv2.imread('static/input/{}'.format(file))
# preprocess the image with modcrop
ref = modcrop(ref, 3)
degraded = modcrop(degraded, 3)
# convert the image to YCrCb - (srcnn trained on Y channel)
temp = cv2.cvtColor(ref, cv2.COLOR_BGR2YCrCb)
# create image slice and normalize
Y = np.zeros((1, temp.shape[0], temp.shape[1], 1), dtype=float)
Y[0, :, :, 0] = temp[:, :, 0].astype(float) / 255
# perform super-resolution with srcnn
pre = srcnn.predict(Y, batch_size=1)
# post-process output
pre *= 255
pre[pre[:] > 255] = 255
pre[pre[:] < 0] = 0
pre = pre.astype(np.uint8)
# copy Y channel back to image and convert to BGR
temp = shave(temp, 6)
temp[:, :, 0] = pre[0, :, :, 0]
output = cv2.cvtColor(temp, cv2.COLOR_YCrCb2BGR)
# remove border from reference and degraged image
ref = shave(ref.astype(np.uint8), 6)
degraded = shave(degraded.astype(np.uint8), 6)
# image quality calculations
scores = []
scores.append(compare_images(degraded, ref))
scores.append(compare_images(output, ref))
# return images and scores
return ref, degraded, output, scores