Skip to content

Commit

Permalink
add solov2_r101vd model (PaddlePaddle#3286)
Browse files Browse the repository at this point in the history
  • Loading branch information
yghstill authored Jun 5, 2021
1 parent 32abf1a commit 1adb26e
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 23 deletions.
1 change: 1 addition & 0 deletions configs/solov2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ SOLOv2 (Segmenting Objects by Locations) is a fast instance segmentation framewo
| SOLOv2 (Paper) | X101-DCN-FPN | True | 3x | 42.4 | 5.9 | V100 | - | - |
| SOLOv2 | R50-FPN | False | 1x | 35.5 | 21.9 | V100 | [model](https://paddledet.bj.bcebos.com/models/solov2_r50_fpn_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/solov2/solov2_r50_fpn_1x_coco.yml) |
| SOLOv2 | R50-FPN | True | 3x | 38.0 | 21.9 | V100 | [model](https://paddledet.bj.bcebos.com/models/solov2_r50_fpn_3x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/solov2/solov2_r50_fpn_3x_coco.yml) |
| SOLOv2 | R101vd-FPN | True | 3x | 42.7 | 12.1 | V100 | [model](https://paddledet.bj.bcebos.com/models/solov2_r101_vd_fpn_3x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/solov2/solov2_r101_vd_fpn_3x_coco.yml) |

**Notes:**

Expand Down
1 change: 0 additions & 1 deletion configs/solov2/_base_/solov2_r50_fpn.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ SOLOv2:

ResNet:
depth: 50
norm_type: bn
freeze_at: 0
return_idx: [0,1,2,3]
num_stages: 4
Expand Down
66 changes: 66 additions & 0 deletions configs/solov2/solov2_r101_vd_fpn_3x_coco.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
_BASE_: [
'../datasets/coco_instance.yml',
'../runtime.yml',
'_base_/solov2_r50_fpn.yml',
'_base_/optimizer_1x.yml',
'_base_/solov2_reader.yml',
]
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_vd_pretrained.pdparams
weights: output/solov2_r101_vd_fpn_3x_coco/model_final
epoch: 36
use_ema: true
ema_decay: 0.9998

ResNet:
depth: 101
variant: d
freeze_at: 0
return_idx: [0,1,2,3]
dcn_v2_stages: [1,2,3]
num_stages: 4

SOLOv2Head:
seg_feat_channels: 512
stacked_convs: 4
num_grids: [40, 36, 24, 16, 12]
kernel_out_channels: 256
solov2_loss: SOLOv2Loss
mask_nms: MaskMatrixNMS
dcn_v2_stages: [0, 1, 2, 3]

SOLOv2MaskHead:
mid_channels: 128
out_channels: 256
start_level: 0
end_level: 3
use_dcn_in_tower: True


LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [24, 33]
- !LinearWarmup
start_factor: 0.
steps: 2000

TrainReader:
sample_transforms:
- Decode: {}
- Poly2Mask: {}
- RandomResize: {interp: 1,
target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]],
keep_ratio: True}
- RandomFlip: {}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
- Gt2Solov2Target: {num_grids: [40, 36, 24, 16, 12],
scale_ranges: [[1, 96], [48, 192], [96, 384], [192, 768], [384, 2048]],
coord_sigma: 0.2}
batch_size: 2
shuffle: true
drop_last: true
47 changes: 25 additions & 22 deletions ppdet/modeling/heads/solov2_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,39 +43,39 @@ class SOLOv2MaskHead(nn.Layer):
end_level (int): The position where the input ends.
use_dcn_in_tower (bool): Whether to use dcn in tower or not.
"""
__shared__ = ['norm_type']

def __init__(self,
in_channels=256,
mid_channels=128,
out_channels=256,
start_level=0,
end_level=3,
use_dcn_in_tower=False):
use_dcn_in_tower=False,
norm_type='gn'):
super(SOLOv2MaskHead, self).__init__()
assert start_level >= 0 and end_level >= start_level
self.in_channels = in_channels
self.out_channels = out_channels
self.mid_channels = mid_channels
self.use_dcn_in_tower = use_dcn_in_tower
self.range_level = end_level - start_level + 1
# TODO: add DeformConvNorm
conv_type = [ConvNormLayer]
self.conv_func = conv_type[0]
if self.use_dcn_in_tower:
self.conv_func = conv_type[1]
self.use_dcn = True if self.use_dcn_in_tower else False
self.convs_all_levels = []
self.norm_type = norm_type
for i in range(start_level, end_level + 1):
conv_feat_name = 'mask_feat_head.convs_all_levels.{}'.format(i)
conv_pre_feat = nn.Sequential()
if i == start_level:
conv_pre_feat.add_sublayer(
conv_feat_name + '.conv' + str(i),
self.conv_func(
ConvNormLayer(
ch_in=self.in_channels,
ch_out=self.mid_channels,
filter_size=3,
stride=1,
norm_type='gn'))
use_dcn=self.use_dcn,
norm_type=self.norm_type))
self.add_sublayer('conv_pre_feat' + str(i), conv_pre_feat)
self.convs_all_levels.append(conv_pre_feat)
else:
Expand All @@ -87,12 +87,13 @@ def __init__(self,
ch_in = self.mid_channels
conv_pre_feat.add_sublayer(
conv_feat_name + '.conv' + str(j),
self.conv_func(
ConvNormLayer(
ch_in=ch_in,
ch_out=self.mid_channels,
filter_size=3,
stride=1,
norm_type='gn'))
use_dcn=self.use_dcn,
norm_type=self.norm_type))
conv_pre_feat.add_sublayer(
conv_feat_name + '.conv' + str(j) + 'act', nn.ReLU())
conv_pre_feat.add_sublayer(
Expand All @@ -105,12 +106,13 @@ def __init__(self,
conv_pred_name = 'mask_feat_head.conv_pred.0'
self.conv_pred = self.add_sublayer(
conv_pred_name,
self.conv_func(
ConvNormLayer(
ch_in=self.mid_channels,
ch_out=self.out_channels,
filter_size=1,
stride=1,
norm_type='gn'))
use_dcn=self.use_dcn,
norm_type=self.norm_type))

def forward(self, inputs):
"""
Expand Down Expand Up @@ -165,7 +167,7 @@ class SOLOv2Head(nn.Layer):
mask_nms (object): MaskMatrixNMS instance.
"""
__inject__ = ['solov2_loss', 'mask_nms']
__shared__ = ['num_classes']
__shared__ = ['norm_type', 'num_classes']

def __init__(self,
num_classes=80,
Expand All @@ -179,7 +181,8 @@ def __init__(self,
solov2_loss=None,
score_threshold=0.1,
mask_threshold=0.5,
mask_nms=None):
mask_nms=None,
norm_type='gn'):
super(SOLOv2Head, self).__init__()
self.num_classes = num_classes
self.in_channels = in_channels
Expand All @@ -194,33 +197,33 @@ def __init__(self,
self.mask_nms = mask_nms
self.score_threshold = score_threshold
self.mask_threshold = mask_threshold
self.norm_type = norm_type

conv_type = [ConvNormLayer]
self.conv_func = conv_type[0]
self.kernel_pred_convs = []
self.cate_pred_convs = []
for i in range(self.stacked_convs):
if i in self.dcn_v2_stages:
self.conv_func = conv_type[1]
use_dcn = True if i in self.dcn_v2_stages else False
ch_in = self.in_channels + 2 if i == 0 else self.seg_feat_channels
kernel_conv = self.add_sublayer(
'bbox_head.kernel_convs.' + str(i),
self.conv_func(
ConvNormLayer(
ch_in=ch_in,
ch_out=self.seg_feat_channels,
filter_size=3,
stride=1,
norm_type='gn'))
use_dcn=use_dcn,
norm_type=self.norm_type))
self.kernel_pred_convs.append(kernel_conv)
ch_in = self.in_channels if i == 0 else self.seg_feat_channels
cate_conv = self.add_sublayer(
'bbox_head.cate_convs.' + str(i),
self.conv_func(
ConvNormLayer(
ch_in=ch_in,
ch_out=self.seg_feat_channels,
filter_size=3,
stride=1,
norm_type='gn'))
use_dcn=use_dcn,
norm_type=self.norm_type))
self.cate_pred_convs.append(cate_conv)

self.solo_kernel = self.add_sublayer(
Expand Down

0 comments on commit 1adb26e

Please sign in to comment.