-
Notifications
You must be signed in to change notification settings - Fork 152
/
Copy pathmodel.lua
81 lines (67 loc) · 2.3 KB
/
model.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
-- from net-orig, add more to fc and make another deconv layer
require 'nn'
require 'cunn'
require 'cudnn'
local net = nn.Sequential()
local function MSRinit(net)
local function init(name)
for k,v in pairs(net:findModules(name)) do
local n = v.kT*v.kW*v.kH*v.nOutputPlane
v.weight:normal(0,math.sqrt(2/n))
v.bias:zero()
end
end
init'VolumetricConvolution'
return net
end
-- create net
if (opt.retrain == '' or opt.retrain == nil) then
local nf0 = 16
net:add(cudnn.VolumetricConvolution(1, nf0, 4, 3, 3, 2, 2, 2)) -- output nf0 x 30x15x15
net:add(cudnn.VolumetricBatchNormalization(nf0))
net:add(cudnn.ReLU())
net:add(cudnn.VolumetricConvolution(nf0, nf0, 1, 1, 1))
net:add(cudnn.VolumetricBatchNormalization(nf0))
net:add(cudnn.ReLU())
net:add(cudnn.VolumetricConvolution(nf0, nf0, 1, 1, 1))
net:add(cudnn.VolumetricBatchNormalization(nf0))
net:add(cudnn.ReLU())
net:add(nn.VolumetricDropout(0.2))
local nf1 = 32
net:add(cudnn.VolumetricConvolution(nf0, nf1, 4, 3, 3, 2, 2, 2)) -- output nf1 x 14x7x7
net:add(cudnn.VolumetricBatchNormalization(nf1))
net:add(cudnn.ReLU())
net:add(cudnn.VolumetricConvolution(nf1, nf1, 1, 1, 1))
net:add(cudnn.VolumetricBatchNormalization(nf1))
net:add(cudnn.ReLU())
net:add(cudnn.VolumetricConvolution(nf1, nf1, 1, 1, 1))
net:add(cudnn.VolumetricBatchNormalization(nf1))
net:add(cudnn.ReLU())
net:add(nn.VolumetricDropout(0.2))
local nf2 = 64
net:add(cudnn.VolumetricConvolution(nf1, nf2, 4, 3, 3, 2, 2, 2)) -- output nf x 6x3x3
net:add(cudnn.VolumetricBatchNormalization(nf2))
net:add(cudnn.ReLU())
net:add(cudnn.VolumetricConvolution(nf2, nf2, 1, 1, 1))
net:add(cudnn.VolumetricBatchNormalization(nf2))
net:add(cudnn.ReLU())
net:add(cudnn.VolumetricConvolution(nf2, nf2, 1, 1, 1))
net:add(cudnn.VolumetricBatchNormalization(nf2))
net:add(cudnn.ReLU())
net:add(nn.VolumetricDropout(0.2))
local bf = 1200
net:add(nn.View(nf2 * 54))
net:add(nn.Linear(nf2 * 54, bf))
net:add(cudnn.ReLU())
net:add(nn.Dropout(0.5))
net:add(nn.Linear(bf, num_classes*62))
net:add(nn.View(num_classes, 1, 62))
MSRinit(net)
else --preload network
assert(paths.filep(opt.retrain), 'File not found: ' .. opt.retrain)
print('loading previously trained network: ' .. opt.retrain)
net = torch.load(opt.retrain)
end
cudnn.convert(net, cudnn)
print('net:')
return net