-
Notifications
You must be signed in to change notification settings - Fork 68
/
Copy pathvgg19_keras.py
executable file
·58 lines (47 loc) · 2.02 KB
/
vgg19_keras.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import tensorflow as tf
from keras.applications.vgg19 import preprocess_input
from ops import L1_loss
class VGGLoss(tf.keras.Model):
def __init__(self):
super(VGGLoss, self).__init__(name='VGGLoss')
self.vgg = Vgg19()
self.layer_weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
def call(self, x, y):
x = ((x + 1) / 2) * 255.0
y = ((y + 1) / 2) * 255.0
x_vgg, y_vgg = self.vgg(preprocess_input(x)), self.vgg(preprocess_input(y))
loss = 0
for i in range(len(x_vgg)):
y_vgg_detach = tf.stop_gradient(y_vgg[i])
loss += self.layer_weights[i] * L1_loss(x_vgg[i], y_vgg_detach)
return loss
class Vgg19(tf.keras.Model):
def __init__(self, trainable=False):
super(Vgg19, self).__init__(name='Vgg19')
vgg_pretrained_features = tf.keras.applications.vgg19.VGG19(weights='imagenet', include_top=False)
if trainable is False:
vgg_pretrained_features.trainable = False
vgg_pretrained_features = vgg_pretrained_features.layers
self.slice1 = tf.keras.Sequential()
self.slice2 = tf.keras.Sequential()
self.slice3 = tf.keras.Sequential()
self.slice4 = tf.keras.Sequential()
self.slice5 = tf.keras.Sequential()
for x in range(1, 2):
self.slice1.add(vgg_pretrained_features[x])
for x in range(2, 5):
self.slice2.add(vgg_pretrained_features[x])
for x in range(5, 8):
self.slice3.add(vgg_pretrained_features[x])
for x in range(8, 13):
self.slice4.add(vgg_pretrained_features[x])
for x in range(13, 18):
self.slice5.add(vgg_pretrained_features[x])
def call(self, x):
h_relu1 = self.slice1(x)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out