Skip to content

Commit c7d9d6b

Browse files
authored
Add files via upload
Fixed some bugs and redundant code……
1 parent a9b6b9d commit c7d9d6b

File tree

3 files changed

+571
-310
lines changed

3 files changed

+571
-310
lines changed

MMRAN.py

Lines changed: 388 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,388 @@
1+
import torch
2+
from torch import nn
3+
import torch.nn.functional as F
4+
import math
5+
6+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7+
8+
# 以下为无1×1卷积代码
9+
class DoubleConv(nn.Module):
10+
def __init__(self, in_ch, out_ch):
11+
super(DoubleConv, self).__init__()
12+
self.conv = nn.Sequential(
13+
nn.Conv2d(in_ch, out_ch, 3, padding=1),
14+
nn.BatchNorm2d(out_ch), # 已添加BN层
15+
# nn.GroupNorm(64, out_ch), # 在Batchsize比较小的时候,使用GN层替代BN层可以提升一定的模型精度
16+
nn.ReLU(inplace=True),
17+
nn.Conv2d(out_ch, out_ch, 3, padding=1),
18+
nn.BatchNorm2d(out_ch),
19+
# nn.GroupNorm(64, out_ch),
20+
nn.ReLU(inplace=True)
21+
)
22+
self.shortcut = nn.Sequential(
23+
nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, bias=False),
24+
nn.BatchNorm2d(out_ch)
25+
# nn.GroupNorm(64, out_ch),
26+
)
27+
28+
def forward(self, input):
29+
out = self.conv(input)
30+
out = out + self.shortcut(input)
31+
out = F.relu(out)
32+
return out
33+
34+
# double
35+
class sSE(nn.Module):
36+
def __init__(self, in_channels):
37+
super().__init__()
38+
self.Conv1x1 = nn.Conv2d(in_channels, 1, kernel_size=1, bias=False)
39+
self.norm = nn.Sigmoid()
40+
41+
def forward(self, U):
42+
q = self.Conv1x1(U) # U:[bs,c,h,w] to q:[bs,1,h,w]
43+
q = self.norm(q)
44+
return U * q # 广播机制
45+
46+
class cSE(nn.Module):
47+
def __init__(self, in_channels):
48+
super().__init__()
49+
self.avgpool = nn.AdaptiveAvgPool2d(1)
50+
#self.# new
51+
self.Conv_Squeeze = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1, bias=False)
52+
self.Conv_Excitation = nn.Conv2d(in_channels//2, in_channels, kernel_size=1, bias=False)
53+
self.norm = nn.Sigmoid()
54+
self.maxpool = nn.AdaptiveMaxPool2d(1, return_indices=False)
55+
56+
def forward(self, U):
57+
z = self.avgpool(U)# shape: [bs, c, h, w] to [bs, c, 1, 1]
58+
z = self.Conv_Squeeze(z) # shape: [bs, c/2]
59+
z = self.Conv_Excitation(z) # shape: [bs, c]
60+
z = self.norm(z)
61+
x = self.maxpool(U)
62+
x = self.Conv_Squeeze(x)
63+
x = self.Conv_Excitation(x)
64+
x = self.norm(x)
65+
x = z + x
66+
return U * x.expand_as(U)
67+
68+
# 多尺度卷积模块
69+
class MultiScaleModule(nn.Module):
70+
def __init__(self, in_channels):
71+
super(MultiScaleModule, self).__init__()
72+
# 动态调整每个分支的通道数
73+
branch_channels = in_channels // 4
74+
if in_channels % 4 != 0:
75+
raise ValueError(f"in_channels ({in_channels}) must be divisible by 4 for MultiScaleModule.")
76+
77+
self.conv0 = nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0, bias=False) # 1x1卷积
78+
self.conv1 = nn.Conv2d(in_channels, branch_channels, kernel_size=3, padding=1, bias=False) # 3x3卷积
79+
self.conv2 = nn.Conv2d(in_channels, branch_channels, kernel_size=5, padding=2, bias=False) # 5x5卷积
80+
self.conv3 = nn.Conv2d(in_channels, branch_channels, kernel_size=7, padding=3, bias=False) # 7x7卷积
81+
self.norm = nn.BatchNorm2d(in_channels) # 对最终结果进行归一化
82+
83+
def forward(self, x):
84+
# 四个并行卷积
85+
F0 = self.conv0(x) # 1x1卷积
86+
F1 = self.conv1(x) # 3x3卷积
87+
F2 = self.conv2(x) # 5x5卷积
88+
F3 = self.conv3(x) # 7x7卷积
89+
# 通道维度拼接 F0, F1, F2, F3
90+
F_out = torch.cat([F0, F1, F2, F3], dim=1) # [B, C, H, W]
91+
F_out = self.norm(F_out) # 归一化
92+
return F_out
93+
94+
# scSE模块,结合MultiScaleModule
95+
class MRAM(nn.Module):
96+
def __init__(self, in_channels):
97+
super(MRAM, self).__init__()
98+
self.multi_scale = MultiScaleModule(in_channels) # 添加多尺度卷积模块
99+
self.cSE = cSE(in_channels)
100+
self.sSE = sSE(in_channels)
101+
102+
def forward(self, U):
103+
U = self.multi_scale(U) # 先经过多尺度卷积模块
104+
U_cse = self.cSE(U) # 通道注意力
105+
U_sse = self.sSE(U_cse) # 空间注意力
106+
return U_sse + U # 残差连接
107+
108+
109+
class MMRAN(nn.Module):
110+
def __init__(self, in_ch, out_ch, reduction_factor=4):
111+
super(MMRAN, self).__init__()
112+
# 确保 reduction_factor 只能取值 1, 2, 4
113+
if reduction_factor not in [1, 2, 4]:
114+
raise ValueError(f"Invalid reduction_factor: {reduction_factor}. It must be 1, 2, or 4.")
115+
factor = reduction_factor
116+
print(f"Factor={factor} (default=4), all channels of Convolutional Layers will be reduced to 1 / {factor}.")
117+
118+
# 通道数根据 factor 调整
119+
self.conv1 = DoubleConv(in_ch, 64 // factor) # 原 64
120+
self.conv2 = DoubleConv(64 // factor, 128 // factor) # 原 128
121+
self.conv3 = DoubleConv(128 // factor, 256 // factor) # 原 256
122+
self.conv4 = DoubleConv(256 // factor, 512 // factor) # 原 512
123+
self.conv5 = DoubleConv(512 // factor, 1024 // factor) # 原 512
124+
125+
self.pool = nn.MaxPool2d(2) # 共享池化层
126+
127+
# 上采样分支
128+
self.up6 = nn.ConvTranspose2d(1024 // factor, 512 // factor, 2, stride=2) # 原 1024->512
129+
self.conv6 = DoubleConv(1024 // factor, 512 // factor) # 原 1024->512
130+
131+
self.up7 = nn.ConvTranspose2d(512 // factor, 256 // factor, 2, stride=2) # 原 512->256
132+
self.conv7 = DoubleConv(512 // factor, 256 // factor) # 原 512->256
133+
134+
self.up8 = nn.ConvTranspose2d(256 // factor, 128 // factor, 2, stride=2) # 原 256->128
135+
self.conv8 = DoubleConv(256 // factor, 128 // factor) # 原 256->128
136+
137+
self.up9 = nn.ConvTranspose2d(128 // factor, 64 // factor, 2, stride=2) # 原 128->64
138+
self.conv9 = DoubleConv(128 // factor, 64 // factor) # 原 128->64
139+
140+
self.conv10 = nn.Conv2d(64 // factor, out_ch, 1) # 输出通道数不变
141+
142+
self.num_levels = 4
143+
self.pool_type = 'max_pool'
144+
145+
# 下采样分支
146+
self.conv11 = DoubleConv(1024 // factor, 512 // factor) # 原 1024->512
147+
self.conv12 = DoubleConv(512 // factor, 256 // factor) # 原 512->256
148+
self.conv13 = DoubleConv(256 // factor, 128 // factor) # 原 256->128
149+
self.conv14 = DoubleConv(128 // factor, 64 // factor) # 原 128->64
150+
151+
self.fc1 = nn.Linear(1920 // factor, 100) # 原 1920,减半
152+
self.fc2 = nn.Linear(100, 3) # 3分类
153+
154+
self.c_se1 = MRAM(64 // factor)
155+
self.c_se2 = MRAM(128 // factor)
156+
self.c_se3 = MRAM(256 // factor)
157+
self.c_se4 = MRAM(512 // factor)
158+
159+
def forward(self, x):
160+
x = self.conv1(x) # 512 * 512 * (32/64)
161+
att1 = self.c_se1(x)
162+
x = self.pool(x) # 256 * 256 * (32/64)
163+
164+
x = self.conv2(x) # 256 * 256 * (64/128)
165+
att2 = self.c_se2(x)
166+
x = self.pool(x) # 128 * 128 * (64/128)
167+
168+
x = self.conv3(x) # 128 * 128 * (128/256)
169+
att3 = self.c_se3(x)
170+
x = self.pool(x) # 64 * 64 * (128/256)
171+
172+
x = self.conv4(x) # 64 * 64 * (256/512)
173+
att4 = self.c_se4(x)
174+
x = self.pool(x) # 32 * 32 * (256/512)
175+
176+
x = self.conv5(x) # 32 * 32 * (256/512)\
177+
178+
# 在本文中并没有使用这个模块,但是您也可以加上以提升性能
179+
# x = self.psp(x) #在网络最底层增加了多尺度融合
180+
181+
# 上采样部分
182+
x_up = self.up6(x) # 64 * 64 * (256/512)
183+
x_up = torch.cat([x_up, att4], dim=1) # 64 * 64 * (512/1024)
184+
x_up = self.conv6(x_up) # 64 * 64 * (512/1024)
185+
186+
x_up = self.up7(x_up) # 128 * 128 * (128/256)
187+
x_up = torch.cat([x_up, att3], dim=1) # 128 * 128 * (256/512)
188+
x_up = self.conv7(x_up) # 128 * 128 * (128/256)
189+
190+
x_up = self.up8(x_up) # 256 * 256 * (64/128)
191+
x_up = torch.cat([x_up, att2], dim=1) # 256 * 256 * (128/256)
192+
x_up = self.conv8(x_up) # 256 * 256 * (64/128)
193+
194+
x_up = self.up9(x_up) # 512 * 512 * (32/64)
195+
x_up = torch.cat([x_up, att1], dim=1) # 512 * 512 * (64/128)
196+
x_up = self.conv9(x_up) # 512 * 512 * (32/64)
197+
198+
seg_output = self.conv10(x_up) # 512 * 512 * out_ch
199+
200+
# CNN部分
201+
x = self.conv11(x) # 32 * 32 * (256/512)
202+
x = self.pool(x) # 16 * 16 * (256/512)
203+
x = self.conv12(x) # 16 * 16 * (128/256)
204+
205+
# 在本文中并没有使用这个模块,但是您也可以加上以提升性能
206+
# x = self.psp2(x)
207+
x = self.pool(x) # 8 * 8 * (128/256)
208+
x = self.conv13(x) # 8 * 8 * (64/128)
209+
x = self.pool(x) # 4 * 4 * (64/128)
210+
x = self.conv14(x) # 4 * 4 * (32/64)
211+
212+
# SPP 层
213+
spp_layer = SPPLayer(self.num_levels, self.pool_type)
214+
x = spp_layer(x)
215+
216+
x = F.relu(self.fc1(x))
217+
cls_output = self.fc2(x)
218+
219+
return seg_output, cls_output
220+
221+
222+
class focal_loss(nn.Module):
223+
def __init__(self, alpha=0.25, gamma=2, num_classes=3, size_average=True):
224+
"""
225+
focal_loss损失函数, -α(1-yi)**γ *ce_loss(xi,yi)
226+
步骤详细的实现了 focal_loss损失函数.
227+
:param alpha: 阿尔法α,类别权重. 当α是列表时,为各类别权重,当α为常数时,类别权重为[α, 1-α, 1-α, ....],常用于 目标检测算法中抑制背景类 , retainnet中设置为0.25
228+
:param gamma: 伽马γ,难易样本调节参数. retainnet中设置为2
229+
:param num_classes: 类别数量
230+
:param size_average: 损失计算方式,默认取均值
231+
"""
232+
233+
super(focal_loss, self).__init__()
234+
self.size_average = size_average
235+
if isinstance(alpha, list):
236+
assert len(alpha) == num_classes # α可以以list方式输入,size:[num_classes] 用于对不同类别精细地赋予权重
237+
print("Focal_loss alpha = {}, Fine tune the assignment of weights for each category".format(alpha))
238+
self.alpha = torch.Tensor(alpha)
239+
else:
240+
assert alpha < 1 # 如果α为一个常数,则降低第一类的影响,在目标检测中为第一类
241+
print(" --- Focal_loss alpha = {} --- ".format(alpha))
242+
self.alpha = torch.zeros(num_classes)
243+
self.alpha[0] += alpha
244+
self.alpha[1:] += (1 - alpha) # α 最终为 [ α, 1-α, 1-α, 1-α, 1-α, ...] size:[num_classes]
245+
self.gamma = gamma
246+
247+
def forward(self, preds, labels):
248+
"""
249+
focal_loss损失计算
250+
:param preds: 预测类别. size:[B,N,C] or [B,C] 分别对应与检测与分类任务, B 批次, N检测框数, C类别数
251+
:param labels: 实际类别. size:[B,N] or [B]
252+
:return:
253+
"""
254+
# assert preds.dim()==2 and labels.dim()==1
255+
preds = preds.view(-1, preds.size(-1))
256+
self.alpha = self.alpha.to(preds.device)
257+
preds_softmax = F.softmax(preds,
258+
dim=1) # 这里并没有直接使用log_softmax, 因为后面会用到softmax的结果(当然你也可以使用log_softmax,然后进行exp操作)
259+
preds_logsoft = torch.log(preds_softmax)
260+
preds_softmax = preds_softmax.gather(1, labels.view(-1, 1)) # 这部分实现nll_loss ( crossempty = log_softmax + nll )
261+
preds_logsoft = preds_logsoft.gather(1, labels.view(-1, 1))
262+
self.alpha = self.alpha.gather(0, labels.view(-1))
263+
loss = -torch.mul(torch.pow((1 - preds_softmax), self.gamma),
264+
preds_logsoft) # torch.pow((1-preds_softmax), self.gamma) 为focal loss中 (1-pt)**γ
265+
loss = torch.mul(self.alpha, loss.t())
266+
if self.size_average:
267+
loss = loss.mean()
268+
else:
269+
loss = loss.sum()
270+
return loss
271+
272+
273+
class DiceLoss(nn.Module):
274+
def __init__(self):
275+
super(DiceLoss, self).__init__()
276+
self.epsilon = 1e-5
277+
278+
def forward(self, predict, target):
279+
assert predict.size() == target.size(), "the size of predict and target must be equal."
280+
num = predict.size(0)
281+
282+
pre = torch.sigmoid(predict).view(num, -1)
283+
tar = target.view(num, -1)
284+
285+
intersection = (pre * tar).sum(-1).sum() # 利用预测值与标签相乘当作交集
286+
union = (pre + tar).sum(-1).sum()
287+
288+
score = 1 - 2 * (intersection + self.epsilon) / (union + self.epsilon)
289+
290+
return score
291+
292+
293+
class SPPLayer(torch.nn.Module):
294+
295+
def __init__(self, num_levels, pool_type='max_pool'):
296+
super(SPPLayer, self).__init__()
297+
298+
self.num_levels = num_levels
299+
self.pool_type = pool_type
300+
301+
def forward(self, x):
302+
# num:样本数量 c:通道数 h:高 w:宽
303+
# num: the number of samples
304+
# c: the number of channels
305+
# h: height
306+
# w: width
307+
num, c, h, w = x.size()
308+
# print(x.size())
309+
for i in range(self.num_levels):
310+
level = i+1
311+
312+
'''
313+
The equation is explained on the following site:
314+
http://www.cnblogs.com/marsggbo/p/8572846.html#autoid-0-0-0
315+
'''
316+
kernel_size = (math.ceil(h / level), math.ceil(w / level))
317+
stride = (math.floor(h / level), math.floor(w / level))
318+
pooling = (math.floor((kernel_size[0]*level-h+1)/2), math.floor((kernel_size[1]*level-w+1)/2))
319+
320+
# update input data with padding
321+
zero_pad = torch.nn.ZeroPad2d((pooling[1],pooling[1],pooling[0],pooling[0]))
322+
x_new = zero_pad(x)
323+
324+
# update kernel and stride
325+
h_new = 2*pooling[0] + h
326+
w_new = 2*pooling[1] + w
327+
328+
kernel_size = (math.ceil(h_new / level), math.ceil(w_new / level))
329+
stride = (math.floor(h_new / level), math.floor(w_new / level))
330+
331+
332+
# 选择池化方式
333+
if self.pool_type == 'max_pool':
334+
try:
335+
tensor = F.max_pool2d(x_new, kernel_size=kernel_size, stride=stride).view(num, -1)
336+
except Exception as e:
337+
print(str(e))
338+
print(x.size())
339+
print(level)
340+
else:
341+
tensor = F.avg_pool2d(x_new, kernel_size=kernel_size, stride=stride).view(num, -1)
342+
343+
# 展开、拼接
344+
if (i == 0):
345+
x_flatten = tensor.view(num, -1)
346+
else:
347+
x_flatten = torch.cat((x_flatten, tensor.view(num, -1)), 1)
348+
return x_flatten
349+
350+
351+
# PSP模块,以下两个模块在文中并没有用到,但是您也可以在网络中使用它们,对分类效果的提升有一定的帮助。
352+
class PSPModule(nn.Module):
353+
def __init__(self, features, out_features, sizes=(1, 2, 3, 6)):
354+
super().__init__()
355+
self.stages = []
356+
self.stages = nn.ModuleList([self._make_stage(features, size) for size in sizes])
357+
self.bottleneck = nn.Conv2d(features * (len(sizes) + 1), out_features, kernel_size=1)
358+
self.relu = nn.ReLU()
359+
360+
def _make_stage(self, features, size):
361+
# prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
362+
prior = nn.AdaptiveMaxPool2d(output_size=(size, size))
363+
conv = nn.Conv2d(features, features, kernel_size=1, bias=False) #第一次加入多尺度模块时没加1*1卷积层,但是精度也有不错的提升
364+
return nn.Sequential(prior, conv)
365+
#return nn.Sequential(prior)
366+
def forward(self, feats):
367+
h, w = feats.size(2), feats.size(3)
368+
priors = [F.upsample(input=stage(feats), size=(h, w), mode='bilinear') for stage in self.stages] + [feats]
369+
bottle = self.bottleneck(torch.cat(priors, 1)) # 1代表cat按列拼
370+
return self.relu(bottle)
371+
372+
class PSPModule2(nn.Module):
373+
def __init__(self, features, out_features, size=(1,2,3,6)):
374+
super().__init__()
375+
self.pool1 = nn.MaxPool2d(1)
376+
self.pool2 = nn.MaxPool2d(2)
377+
self.pool3 = nn.MaxPool2d(3)
378+
self.pool4 = nn.MaxPool2d(6)
379+
self.bottleneck = nn.Conv2d(features * 4, out_features, kernel_size=1)
380+
self.relu = nn.ReLU()
381+
382+
def forward(self, x): # x:512 * 64
383+
p1 = F.interpolate(self.pool1(x), size = [16, 16]) # 512 * 64
384+
p2 = F.interpolate(self.pool2(x), size = [16, 16]) # 256 * 64
385+
p3 = F.interpolate(self.pool3(x), size = [16, 16]) # 170 * 64
386+
p4 = F.interpolate(self.pool4(x), size = [16, 16]) # 85 * 64
387+
x = self.bottleneck(torch.cat([p1, p2, p3, p4], 1))
388+
return self.relu(x)

0 commit comments

Comments
 (0)