-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patherfnet_model.py
More file actions
111 lines (89 loc) · 3.76 KB
/
erfnet_model.py
File metadata and controls
111 lines (89 loc) · 3.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
import torch.nn as nn
import torch.nn.functional as F
class DownsamplerBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels-in_channels, kernel_size=3, stride=2, padding=1, bias=True)
self.pool = nn.MaxPool2d(2, stride=2)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
output = torch.cat([self.conv(x), self.pool(x)], 1)
return self.bn(output)
class UpsamplerBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1,
output_padding=1, bias=True)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
return self.bn(self.conv(x))
class NonBottleneck1D(nn.Module):
def __init__(self, channels, dropout_prob=0.1, dilation=1):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=(3, 1), stride=1, padding=(1, 0), bias=True)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=(1, 3), stride=1, padding=(0, 1), bias=True)
self.bn1 = nn.BatchNorm2d(channels)
self.conv3 = nn.Conv2d(channels, channels, kernel_size=(3, 1), stride=1,
padding=(dilation, 0), bias=True, dilation=(dilation, 1))
self.conv4 = nn.Conv2d(channels, channels, kernel_size=(1, 3), stride=1,
padding=(0, dilation), bias=True, dilation=(1, dilation))
self.bn2 = nn.BatchNorm2d(channels)
self.dropout = nn.Dropout2d(dropout_prob)
def forward(self, x):
output = self.conv1(x)
output = F.relu(output)
output = self.conv2(output)
output = self.bn1(output)
output = F.relu(output)
output = self.conv3(output)
output = F.relu(output)
output = self.conv4(output)
output = self.bn2(output)
if self.dropout.p != 0:
output = self.dropout(output)
return F.relu(output + x)
class ERFNet(nn.Module):
def __init__(self, in_channels=4, num_classes=1):
super().__init__()
# Initial block
self.initial_block = DownsamplerBlock(in_channels, 16)
# Encoder
self.encoder = nn.Sequential(
DownsamplerBlock(16, 64),
NonBottleneck1D(64, 0.03, 1),
NonBottleneck1D(64, 0.03, 1),
NonBottleneck1D(64, 0.03, 1),
NonBottleneck1D(64, 0.03, 1),
NonBottleneck1D(64, 0.03, 1),
DownsamplerBlock(64, 128),
NonBottleneck1D(128, 0.3, 2),
NonBottleneck1D(128, 0.3, 4),
NonBottleneck1D(128, 0.3, 8),
NonBottleneck1D(128, 0.3, 16)
)
# Decoder
self.decoder = nn.Sequential(
UpsamplerBlock(128, 64),
NonBottleneck1D(64, 0, 1),
NonBottleneck1D(64, 0, 1),
UpsamplerBlock(64, 16),
NonBottleneck1D(16, 0, 1),
NonBottleneck1D(16, 0, 1)
)
# Final conv layer
self.final = nn.Conv2d(16, num_classes, kernel_size=1, stride=1, padding=0, bias=True)
def forward(self, x):
# Initial block
x = self.initial_block(x)
# Encoder
x = self.encoder(x)
# Decoder
x = self.decoder(x)
# Final classification
x = self.final(x)
return x
# Modified train script to use ERFNet
def get_model(in_channels=4, num_classes=1):
model = ERFNet(in_channels=in_channels, num_classes=num_classes)
return model