-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloadVGG.py
64 lines (54 loc) · 2.52 KB
/
loadVGG.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# -*- coding: utf-8 -*-
"""
Created on Fri Apr 20 20:59:47 2018
@author: nicemo
"""
import numpy as np
import scipy.io
import tensorflow as tf
VGG_FILENAME = 'imagenet-vgg-verydeep-19.mat'
class VGG(object):
def __init__(self, input_img):
self.Vgg_layers = scipy.io.loadmat(VGG_FILENAME)['layers']
self.input_img = input_img
self.mean_pixel = np.array([123.68, 116.779, 103.939]).reshape((1,1,1,3))
def _weight(self, layer_idx, expected_layer_name):
W = self.Vgg_layers[0][layer_idx][0][0][2][0][0]
b = self.Vgg_layers[0][layer_idx][0][0][2][0][1]
layer_name = self.Vgg_layers[0][layer_idx][0][0][0][0]
assert expected_layer_name == layer_name
return W, b.reshape(b.size)
def conv2d_relu(self, prev_layer, layer_idx, layer_name):
with tf.variable_scope(layer_name) as scope:
W, b = self._weight(layer_idx, layer_name)
W = tf.constant(W, name='weights')
b = tf.constant(b, name='bias')
conv2d = tf.nn.conv2d(prev_layer, filter=W, strides=[1, 1, 1, 1], padding='SAME')
out = tf.nn.relu(conv2d + b)
setattr(self, layer_name, out)
def avgpool(self, prev_layer, layer_name):
with tf.variable_scope(layer_name) as scope:
out = tf.nn.avg_pool(prev_layer, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
setattr(self, layer_name, out)
def load(self):
self.conv2d_relu(self.input_img, 0, 'conv1_1')
self.conv2d_relu(self.conv1_1, 2, 'conv1_2')
self.avgpool(self.conv1_2, 'avgpool1')
self.conv2d_relu(self.avgpool1, 5, 'conv2_1')
self.conv2d_relu(self.conv2_1, 7, 'conv2_2')
self.avgpool(self.conv2_2, 'avgpool2')
self.conv2d_relu(self.avgpool2, 10, 'conv3_1')
self.conv2d_relu(self.conv3_1, 12, 'conv3_2')
self.conv2d_relu(self.conv3_2, 14, 'conv3_3')
self.conv2d_relu(self.conv3_3, 16, 'conv3_4')
self.avgpool(self.conv3_4, 'avgpool3')
self.conv2d_relu(self.avgpool3, 19, 'conv4_1')
self.conv2d_relu(self.conv4_1, 21, 'conv4_2')
self.conv2d_relu(self.conv4_2, 23, 'conv4_3')
self.conv2d_relu(self.conv4_3, 25, 'conv4_4')
self.avgpool(self.conv4_4, 'avgpool4')
self.conv2d_relu(self.avgpool4, 28, 'conv5_1')
self.conv2d_relu(self.conv5_1, 30, 'conv5_2')
self.conv2d_relu(self.conv5_2, 32, 'conv5_3')
self.conv2d_relu(self.conv5_3, 34, 'conv5_4')
self.avgpool(self.conv5_4, 'avgpool5')