33'''
44Author : LiAo
55Date : 2022-07-05 19:45:12
6- LastEditTime : 2022-07-16 23:41:05
6+ LastEditTime : 2022-07-17 15:55:17
77LastAuthor : LiAo
88Description : Please add file description
99'''
@@ -189,21 +189,22 @@ class APM(nn.Module):
189189
190190 def __init__ (self , in_chs : int , out_chs : int ):
191191 super (APM , self ).__init__ ()
192- self .conv_blocks_1 = ConvBlock (in_chs , 8 , kernel_size = 13 , stride = 3 , padding = 1 , dilation = 1 ,
192+ self .conv_blocks_1 = ConvBlock (in_chs , 4 , kernel_size = 13 , stride = 2 , padding = 1 , dilation = 1 ,
193193 norm_layer = nn .BatchNorm2d , act_layer = nn .SiLU , pool_layer = nn .MaxPool2d (3 , stride = 1 ))
194- self .conv_blocks_2 = ConvBlock (8 , 32 , kernel_size = 3 , stride = 1 , padding = 1 , dilation = 1 ,
195- norm_layer = nn .BatchNorm2d , act_layer = nn .SiLU , pool_layer = nn .AvgPool2d (3 , stride = 1 ))
196- self .conv_blocks_3 = ConvBlock (32 , 64 , kernel_size = 3 , stride = 1 , padding = 1 , dilation = 1 ,
194+ self .conv_blocks_2 = ConvBlock (4 , 16 , kernel_size = 5 , stride = 2 , padding = 1 , dilation = 1 ,
197195 norm_layer = nn .BatchNorm2d , act_layer = nn .SiLU , pool_layer = nn .MaxPool2d (3 , stride = 1 ))
198- self .res_blocks_1 = Residual (64 , 64 , kernel_size = [3 , 3 ], stride = [1 , 1 ], padding = [1 , 1 ],
196+ self .conv_blocks_3 = ConvBlock (16 , 32 , kernel_size = 3 , stride = 1 , padding = 1 , dilation = 1 ,
197+ norm_layer = nn .BatchNorm2d , act_layer = nn .SiLU , pool_layer = nn .MaxPool2d (3 , stride = 1 ))
198+ self .res_blocks_1 = Residual (32 , 32 , kernel_size = [3 , 3 ], stride = [1 , 1 ], padding = [1 , 1 ],
199199 downsample = False )
200- self .res_blocks_2 = Residual (64 , 64 , kernel_size = [3 , 3 ], stride = [1 , 1 ], padding = [1 , 1 ],
200+ self .res_blocks_2 = Residual (32 , 32 , kernel_size = [3 , 3 ], stride = [1 , 1 ], padding = [1 , 1 ],
201201 downsample = False )
202- self .cbam_blocks = CBAM (64 )
203- self .conv_blocks_4 = ConvBlock (64 , 16 , kernel_size = 3 , stride = 1 , padding = 1 , dilation = 1 ,
204- norm_layer = nn .BatchNorm2d , act_layer = nn .SiLU , pool_layer = None )
205- self .conv_blocks_5 = ConvBlock (16 , out_chs , kernel_size = 3 , stride = 1 , padding = 1 , dilation = 1 ,
206- norm_layer = nn .BatchNorm2d , act_layer = nn .SiLU , pool_layer = None )
202+ # self.res_blocks_2 = nn.Identity()
203+ self .cbam_blocks = CBAM (32 )
204+ self .conv_blocks_4 = ConvBlock (32 , 8 , kernel_size = 3 , stride = 1 , padding = 1 , dilation = 1 ,
205+ norm_layer = nn .BatchNorm2d , act_layer = nn .SiLU , pool_layer = nn .MaxPool2d (3 , stride = 1 ))
206+ self .conv_blocks_5 = ConvBlock (8 , out_chs , kernel_size = 3 , stride = 1 , padding = 1 , dilation = 1 ,
207+ norm_layer = nn .BatchNorm2d , act_layer = nn .SiLU , pool_layer = nn .MaxPool2d (3 , stride = 1 ))
207208
208209 def forward (self , x ):
209210 x = self .conv_blocks_1 (x )
@@ -241,10 +242,10 @@ def __init__(self, backbone='tf_efficientnetv2_b0', pretrain=True, num_classes=7
241242 self .pool_upsample = nn .Upsample (size = pool_size , mode = pool_type )
242243
243244 def forward (self , x ):
244- x = self .apm (x )
245+ apm_x = self .apm (x )
245246 x_upsample = self .pool_upsample (x )
246- x_max = self .max_pool (x )
247- x_avg = self .avg_pool (x )
247+ x_max = self .max_pool (apm_x )
248+ x_avg = self .avg_pool (apm_x )
248249 x = torch .concat ([x_upsample , x_avg , x_max ], dim = 1 )
249250 x = self .backbone (x )
250251 x = self .dropout (self .classifier (x ))
@@ -264,6 +265,27 @@ def test_apm():
264265
265266
266267def test_model_modify ():
267- net = MultiClassification (backbone = 'tf_efficientnetv2_b3 ' ,
268+ net = MultiClassification (backbone = 'tf_efficientnet_b3 ' ,
268269 pretrain = True , num_classes = 3 )
269270 print (net )
271+
272+
273+ def test_model_repvgg ():
274+ input = torch .randn (1 , 3 , 224 , 224 )
275+ from thop import profile , clever_format
276+ net = timm .create_model (model_name = 'repvgg_b0' ,
277+ pretrained = True , num_classes = 3 )
278+ print (net )
279+ macs , params = profile (net , inputs = (input , ))
280+ macs , params = clever_format ([macs , params ], "%.4f" )
281+ print (macs )
282+ print (params )
283+ # repvgg_b0
284+ # 3.3960G
285+ # 14.5056M
286+ # tf_efficientnetv2_b3
287+ # 1.5538G
288+ # 12.7488M
289+ # tf_efficientnet_b3
290+ # 990.3157M
291+ # 10.6707M
0 commit comments