Skip to content

Commit

Permalink
Update common.py
Browse files Browse the repository at this point in the history
  • Loading branch information
guiguiniu authored Apr 23, 2024
1 parent 2eca89a commit 3d784be
Showing 1 changed file with 1 addition and 54 deletions.
55 changes: 1 addition & 54 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2062,59 +2062,6 @@ def fuseforward(self, x):
import torch
import torch.nn as nn

class WeightedSum(nn.Module):
def __init__(self, num_weights=4):
super(WeightedSum, self).__init__()
self.num_weights = num_weights
self.weights = nn.Parameter(torch.ones(num_weights), requires_grad=True)
self.attention = nn.Sequential(
nn.Linear(num_weights, num_weights), # 注意力模块,这里使用一个全连接层
nn.Sigmoid() # 注意力分数通过Sigmoid函数归一化
)

def forward(self, x):
attention_scores = self.attention(self.weights) # 计算注意力分数
weighted_sum = sum(x[i] * attention_scores[i] for i in range(self.num_weights)) # 注意力加权和
return weighted_sum

class MBConv(nn.Module):
def __init__(self, c1, c2, s, expand_ratio=1, use_se=True):
super(MBConv, self).__init__()
assert s in [1, 2]

hidden_dim = round(c1 * expand_ratio)
self.identity = s == 1 and c1 == c2
if use_se:
self.conv = nn.Sequential(
# pw
nn.Conv2d(c1, hidden_dim, 1, 1, 0, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, s, 1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
SELayer(c1, hidden_dim),
# pw-linear
nn.Conv2d(hidden_dim, c2, 1, 1, 0, bias=False),
nn.BatchNorm2d(c2),
)
else:
self.conv = nn.Sequential(
# fused
nn.Conv2d(c1, hidden_dim, 3, s, 1, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# pw-linear
nn.Conv2d(hidden_dim, c2, 1, 1, 0, bias=False),
nn.BatchNorm2d(c2),
)

def forward(self, x):
if self.identity:
return x + self.conv(x)
else:
return self.conv(x)

class SELayer(nn.Module):
def __init__(self, inp, oup, reduction=4):
Expand All @@ -2131,4 +2078,4 @@ def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y
return x * y

0 comments on commit 3d784be

Please sign in to comment.