-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_.py
108 lines (92 loc) · 3.44 KB
/
model_.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""
all credits to @nizhib
"""
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import models
import torch.nn.functional as F
import math
import torch.utils.model_zoo as model_zoo
from fastai.conv_learner import *
nonlinearity = nn.ReLU
class DecoderBlock(nn.Module):
def __init__(self, in_channels, n_filters):
super().__init__()
# B, C, H, W -> B, C/4, H, W
self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1)
self.norm1 = nn.BatchNorm2d(in_channels // 4)
self.relu1 = nonlinearity(inplace=True)
# B, C/4, H, W -> B, C/4, H, W
self.deconv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 3,
stride=2, padding=1, output_padding=1)
self.norm2 = nn.BatchNorm2d(in_channels // 4)
self.relu2 = nonlinearity(inplace=True)
# B, C/4, H, W -> B, C, H, W
self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1)
self.norm3 = nn.BatchNorm2d(n_filters)
self.relu3 = nonlinearity(inplace=True)
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.deconv2(x)
x = self.norm2(x)
x = self.relu2(x)
x = self.conv3(x)
x = self.norm3(x)
x = self.relu3(x)
return x
class LinkNet34(nn.Module):
def __init__(self, num_classes, num_channels=3):
super().__init__()
assert num_channels == 3, "num channels not used now. to use changle first conv layer to support num channels other then 3"
filters = [64, 128, 256, 512]
# self.resnet = models.resnet34(pretrained=True)
f = resnet34 # Pytorch model
cut,lr_cut = model_meta[f] # Layer nums of where to cut off the head, etc.
layers = cut_model(f(True), cut)
self.rn = nn.Sequential(*layers)
self.firstconv = self.rn[0]
self.firstbn = self.rn[1]
self.firstrelu = self.rn[2]
self.firstmaxpool = self.rn[3]
self.encoder1 = self.rn[4]
self.encoder2 = self.rn[5]
self.encoder3 = self.rn[6]
self.encoder4 = self.rn[7]
# Decoder
self.decoder4 = DecoderBlock(filters[3], filters[2])
self.decoder3 = DecoderBlock(filters[2], filters[1])
self.decoder2 = DecoderBlock(filters[1], filters[0])
self.decoder1 = DecoderBlock(filters[0], filters[0])
# Final Classifier
self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 3, stride=2)
self.finalrelu1 = nonlinearity(inplace=True)
self.finalconv2 = nn.Conv2d(32, 32, 3)
self.finalrelu2 = nonlinearity(inplace=True)
self.finalconv3 = nn.Conv2d(32, num_classes, 2, padding=1)
# noinspection PyCallingNonCallable
def forward(self, x):
# Encoder
x = self.firstconv(x)
x = self.firstbn(x)
x = self.firstrelu(x)
x = self.firstmaxpool(x)
e1 = self.encoder1(x)
e2 = self.encoder2(e1)
e3 = self.encoder3(e2)
e4 = self.encoder4(e3)
# Decoder with Skip Connections
d4 = self.decoder4(e4) + e3
# d4 = e3
d3 = self.decoder3(d4) + e2
d2 = self.decoder2(d3) + e1
d1 = self.decoder1(d2)
# Final Classification
f1 = self.finaldeconv1(d1)
f2 = self.finalrelu1(f1)
f3 = self.finalconv2(f2)
f4 = self.finalrelu2(f3)
f5 = self.finalconv3(f4)
return f5