Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] New config type #787

Open
wants to merge 2 commits into
base: new_config
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions mmyolo/configs/_base_/default_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.engine.hooks import DetVisualizationHook
from mmdet.visualization import DetLocalVisualizer
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
LoggerHook, ParamSchedulerHook)
from mmengine.runner import LogProcessor
from mmengine.visualization import LocalVisBackend

default_scope = None
default_hooks = dict(
timer=dict(type=IterTimerHook),
logger=dict(type=LoggerHook, interval=50),
param_scheduler=dict(type=ParamSchedulerHook),
checkpoint=dict(type=CheckpointHook, interval=1),
sampler_seed=dict(type=DistSamplerSeedHook),
visualization=dict(type=DetVisualizationHook))

env_cfg = dict(
cudnn_benchmark=False,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'),
)

vis_backends = [dict(type=LocalVisBackend)]
visualizer = dict(
type=DetLocalVisualizer, vis_backends=vis_backends, name='visualizer')
log_processor = dict(type=LogProcessor, window_size=50, by_epoch=True)

log_level = 'INFO'
load_from = None
resume = False

# file_client_args = dict(
# backend='petrel',
# path_mapping=dict({
# './data/': 's3://openmmlab/datasets/detection/',
# 'data/': 's3://openmmlab/datasets/detection/'
# }))
file_client_args = dict(backend='disk')
65 changes: 65 additions & 0 deletions mmyolo/configs/_base_/det_p5_tta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.ops import nms
from mmcv.transforms import Compose, LoadImageFromFile, TestTimeAug
from mmdet.datasets.transforms import (LoadAnnotations, PackDetInputs,
RandomFlip)
from mmdet.models import DetTTAModel

from mmyolo.datasets.transforms import LetterResize, YOLOv5KeepRatioResize

# TODO: Need to solve the problem of multiple file_client_args parameters
# _file_client_args = dict(
# backend='petrel',
# path_mapping=dict({
# './data/': 's3://openmmlab/datasets/detection/',
# 'data/': 's3://openmmlab/datasets/detection/'
# }))
_file_client_args = dict(backend='disk')

tta_model = dict(
type=DetTTAModel,
tta_cfg=dict(nms=dict(type=nms, iou_threshold=0.65), max_per_img=300))

img_scales = [(640, 640), (320, 320), (960, 960)]

# LoadImageFromFile
# / | \
# (RatioResize,LetterResize) (RatioResize,LetterResize) (RatioResize,LetterResize) # noqa
# / \ / \ / \
# RandomFlip RandomFlip RandomFlip RandomFlip RandomFlip RandomFlip # noqa
# | | | | | |
# LoadAnn LoadAnn LoadAnn LoadAnn LoadAnn LoadAnn
# | | | | | |
# PackDetIn PackDetIn PackDetIn PackDetIn PackDetIn PackDetIn # noqa

_multiscale_resize_transforms = [
dict(
type=Compose,
transforms=[
dict(type=YOLOv5KeepRatioResize, scale=s),
dict(
type=LetterResize,
scale=s,
allow_scale_up=False,
pad_val=dict(img=114))
]) for s in img_scales
]

tta_pipeline = [
dict(type=LoadImageFromFile, file_client_args=_file_client_args),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该是 backend_args

dict(
type=TestTimeAug,
transforms=[
_multiscale_resize_transforms,
[dict(type=RandomFlip, prob=1.),
dict(type=RandomFlip, prob=0.)],
[dict(type=LoadAnnotations, with_bbox=True)],
[
dict(
type=PackDetInputs,
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'pad_param', 'flip',
'flip_direction'))
]
])
]
328 changes: 328 additions & 0 deletions mmyolo/configs/rtmdet/rtmdet_l_syncbn_fast_8xb32_300e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,328 @@
# Copyright (c) OpenMMLab. All rights reserved.
if '_base_':
from .._base_.default_runtime import *
from .._base_.det_p5_tta import *

from mmcv.transforms import RandomResize
from mmdet.datasets.transforms import (PackDetInputs, Pad, RandomCrop,
RandomFlip, Resize, YOLOXHSVRandomAug)
from mmdet.engine.hooks import PipelineSwitchHook
from mmdet.evaluation import CocoMetric
from mmdet.models import GIoULoss, QualityFocalLoss
from mmdet.models.task_modules import BboxOverlaps2D, MlvlPointGenerator
from mmengine.dataset import DefaultSampler
from mmengine.hooks import EMAHook
from mmengine.optim import CosineAnnealingLR, LinearLR, OptimWrapper
from mmengine.runner import EpochBasedTrainLoop, TestLoop, ValLoop
from torch.nn import BatchNorm2d, SiLU
from torch.optim import AdamW

from mmyolo.datasets import (BatchShapePolicy, Mosaic, YOLOv5CocoDataset,
yolov5_collate)
from mmyolo.datasets.transforms import LoadAnnotations, YOLOv5MixUp
from mmyolo.models import (CSPNeXt, CSPNeXtPAFPN, ExpMomentumEMA, RTMDetHead,
RTMDetSepBNHeadModule, YOLODetector,
YOLOv5DetDataPreprocessor)
from mmyolo.models.task_modules.assigners import BatchDynamicSoftLabelAssigner
from mmyolo.models.task_modules.coders import DistancePointBBoxCoder

# -----data related-----
data_root = 'data/coco/'
# Path of train annotation file
train_ann_file = 'annotations/instances_train2017.json'
train_data_prefix = 'train2017/' # Prefix of train image path
# Path of val annotation file
val_ann_file = 'annotations/instances_val2017.json'
val_data_prefix = 'val2017/' # Prefix of val image path

num_classes = 80 # Number of classes for classification
# Batch size of a single GPU during training
train_batch_size_per_gpu = 32
# Worker to pre-fetch data for each single GPU during training
train_num_workers = 10
# persistent_workers must be False if num_workers is 0.
persistent_workers = True

# -----train val related-----
# Base learning rate for optim_wrapper. Corresponding to 8xb16=64 bs
base_lr = 0.004
max_epochs = 300 # Maximum training epochs
# Change train_pipeline for final 20 epochs (stage 2)
num_epochs_stage2 = 20
model_test_cfg = dict(
# The config of multi-label for multi-class prediction.
multi_label=True,
# The number of boxes before NMS
nms_pre=30000,
score_thr=0.001, # Threshold to filter out boxes.
nms=dict(type=nms, iou_threshold=0.65), # NMS type and threshold
max_per_img=300) # Max number of detections of each image

# ========================Possible modified parameters========================
# -----data related-----
img_scale = (640, 640) # width, height
# ratio range for random resize
random_resize_ratio_range = (0.1, 2.0)
# Cached images number in mosaic
mosaic_max_cached_images = 40
# Number of cached images in mixup
mixup_max_cached_images = 20
# Dataset type, this will be used to define the dataset
dataset_type = YOLOv5CocoDataset
# Batch size of a single GPU during validation
val_batch_size_per_gpu = 32
# Worker to pre-fetch data for each single GPU during validation
val_num_workers = 10

# Config of batch shapes. Only on val.
batch_shapes_cfg = dict(
type=BatchShapePolicy,
batch_size=val_batch_size_per_gpu,
img_size=img_scale[0],
size_divisor=32,
extra_pad_ratio=0.5)

# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 1.0
# The scaling factor that controls the width of the network structure
widen_factor = 1.0
# Strides of multi-scale prior box
strides = [8, 16, 32]

norm_cfg = dict(type=BatchNorm2d) # Normalization config

# -----train val related-----
lr_start_factor = 1.0e-5
dsl_topk = 13 # Number of bbox selected in each level
loss_cls_weight = 1.0
loss_bbox_weight = 2.0
qfl_beta = 2.0 # beta of QualityFocalLoss
weight_decay = 0.05

# Save model checkpoint and validation intervals
save_checkpoint_intervals = 10
# validation intervals in stage 2
val_interval_stage2 = 1
# The maximum checkpoints to keep.
max_keep_ckpts = 3
# single-scale training is recommended to
# be turned on, which can speed up training.
env_cfg = dict(cudnn_benchmark=True)

# ===============================Unmodified in most cases====================
model = dict(
type=YOLODetector,
data_preprocessor=dict(
type=YOLOv5DetDataPreprocessor,
mean=[103.53, 116.28, 123.675],
std=[57.375, 57.12, 58.395],
bgr_to_rgb=False),
backbone=dict(
type=CSPNeXt,
arch='P5',
expand_ratio=0.5,
deepen_factor=deepen_factor,
widen_factor=widen_factor,
channel_attention=True,
norm_cfg=norm_cfg,
act_cfg=dict(type=SiLU, inplace=True)),
neck=dict(
type=CSPNeXtPAFPN,
deepen_factor=deepen_factor,
widen_factor=widen_factor,
in_channels=[256, 512, 1024],
out_channels=256,
num_csp_blocks=3,
expand_ratio=0.5,
norm_cfg=norm_cfg,
act_cfg=dict(type=SiLU, inplace=True)),
bbox_head=dict(
type=RTMDetHead,
head_module=dict(
type=RTMDetSepBNHeadModule,
num_classes=num_classes,
in_channels=256,
stacked_convs=2,
feat_channels=256,
norm_cfg=norm_cfg,
act_cfg=dict(type=SiLU, inplace=True),
share_conv=True,
pred_kernel_size=1,
featmap_strides=strides),
prior_generator=dict(
type=MlvlPointGenerator, offset=0, strides=strides),
bbox_coder=dict(type=DistancePointBBoxCoder),
loss_cls=dict(
type=QualityFocalLoss,
use_sigmoid=True,
beta=qfl_beta,
loss_weight=loss_cls_weight),
loss_bbox=dict(type=GIoULoss, loss_weight=loss_bbox_weight)),
train_cfg=dict(
assigner=dict(
type=BatchDynamicSoftLabelAssigner,
num_classes=num_classes,
topk=dsl_topk,
iou_calculator=dict(type=BboxOverlaps2D)),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=model_test_cfg,
)

train_pipeline = [
dict(type=LoadImageFromFile, file_client_args=file_client_args),
dict(type=LoadAnnotations, with_bbox=True),
dict(
type=Mosaic,
img_scale=img_scale,
use_cached=True,
max_cached_images=mosaic_max_cached_images,
pad_val=114.0),
dict(
type=RandomResize,
# img_scale is (width, height)
scale=(img_scale[0] * 2, img_scale[1] * 2),
ratio_range=random_resize_ratio_range,
resize_type=Resize,
keep_ratio=True),
dict(type=RandomCrop, crop_size=img_scale),
dict(type=YOLOXHSVRandomAug),
dict(type=RandomFlip, prob=0.5),
dict(type=Pad, size=img_scale, pad_val=dict(img=(114, 114, 114))),
dict(
type=YOLOv5MixUp,
use_cached=True,
max_cached_images=mixup_max_cached_images),
dict(type=PackDetInputs)
]

train_pipeline_stage2 = [
dict(type=LoadImageFromFile, file_client_args=file_client_args),
dict(type=LoadAnnotations, with_bbox=True),
dict(
type=RandomResize,
scale=img_scale,
ratio_range=random_resize_ratio_range,
resize_type=Resize,
keep_ratio=True),
dict(type=RandomCrop, crop_size=img_scale),
dict(type=YOLOXHSVRandomAug),
dict(type=RandomFlip, prob=0.5),
dict(type=Pad, size=img_scale, pad_val=dict(img=(114, 114, 114))),
dict(type=PackDetInputs)
]

test_pipeline = [
dict(type=LoadImageFromFile, file_client_args=file_client_args),
dict(type=YOLOv5KeepRatioResize, scale=img_scale),
dict(
type=LetterResize,
scale=img_scale,
allow_scale_up=False,
pad_val=dict(img=114)),
dict(type=LoadAnnotations, with_bbox=True, _scope_='mmdet'),
dict(
type=PackDetInputs,
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'pad_param'))
]

train_dataloader = dict(
batch_size=train_batch_size_per_gpu,
num_workers=train_num_workers,
persistent_workers=persistent_workers,
pin_memory=True,
collate_fn=dict(type=yolov5_collate),
sampler=dict(type=DefaultSampler, shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file=train_ann_file,
data_prefix=dict(img=train_data_prefix),
filter_cfg=dict(filter_empty_gt=True, min_size=32),
pipeline=train_pipeline))

val_dataloader = dict(
batch_size=val_batch_size_per_gpu,
num_workers=val_num_workers,
persistent_workers=persistent_workers,
pin_memory=True,
drop_last=False,
sampler=dict(type=DefaultSampler, shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file=val_ann_file,
data_prefix=dict(img=val_data_prefix),
test_mode=True,
batch_shapes_cfg=batch_shapes_cfg,
pipeline=test_pipeline))

test_dataloader = val_dataloader

# Reduce evaluation time
val_evaluator = dict(
type=CocoMetric,
proposal_nums=(100, 1, 10),
ann_file=data_root + val_ann_file,
metric='bbox')
test_evaluator = val_evaluator

# optimizer
optim_wrapper = dict(
type=OptimWrapper,
optimizer=dict(type=AdamW, lr=base_lr, weight_decay=weight_decay),
paramwise_cfg=dict(
norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))

# learning rate
param_scheduler = [
dict(
type=LinearLR,
start_factor=lr_start_factor,
by_epoch=False,
begin=0,
end=1000),
dict(
# use cosine lr from 150 to 300 epoch
type=CosineAnnealingLR,
eta_min=base_lr * 0.05,
begin=max_epochs // 2,
end=max_epochs,
T_max=max_epochs // 2,
by_epoch=True,
convert_to_iter_based=True),
]

# hooks
default_hooks = dict(
checkpoint=dict(
type=CheckpointHook,
interval=save_checkpoint_intervals,
max_keep_ckpts=max_keep_ckpts # only keep latest 3 checkpoints
))

custom_hooks = [
dict(
type=EMAHook,
ema_type=ExpMomentumEMA,
momentum=0.0002,
update_buffers=True,
strict_load=False,
priority=49),
dict(
type=PipelineSwitchHook,
switch_epoch=max_epochs - num_epochs_stage2,
switch_pipeline=train_pipeline_stage2)
]

train_cfg = dict(
type=EpochBasedTrainLoop,
max_epochs=max_epochs,
val_interval=save_checkpoint_intervals,
dynamic_intervals=[(max_epochs - num_epochs_stage2, val_interval_stage2)])

val_cfg = dict(type=ValLoop)
test_cfg = dict(type=TestLoop)
96 changes: 96 additions & 0 deletions mmyolo/configs/rtmdet/rtmdet_s_syncbn_fast_8xb32_300e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) OpenMMLab. All rights reserved.
if '_base_':
from .rtmdet_l_syncbn_fast_8xb32_300e_coco import *

from mmengine.model import PretrainedInit

checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-s_imagenet_600e.pth' # noqa

# ========================modified parameters======================
deepen_factor = 0.33
widen_factor = 0.5

# ratio range for random resize
random_resize_ratio_range = (0.5, 2.0)
# Number of cached images in mosaic
mosaic_max_cached_images = 40
# Number of cached images in mixup
mixup_max_cached_images = 20

# =======================Unmodified in most cases==================
model.update(
backbone=dict(
deepen_factor=deepen_factor,
widen_factor=widen_factor,
# Since the checkpoint includes CUDA:0 data,
# it must be forced to set map_location.
# Once checkpoint is fixed, it can be removed.
init_cfg=dict(
type=PretrainedInit,
prefix='backbone.',
checkpoint=checkpoint,
map_location='cpu')),
neck=dict(
deepen_factor=deepen_factor,
widen_factor=widen_factor,
),
bbox_head=dict(head_module=dict(widen_factor=widen_factor)))

train_pipeline = [
dict(type=LoadImageFromFile, file_client_args=file_client_args),
dict(type=LoadAnnotations, with_bbox=True),
dict(
type=Mosaic,
img_scale=img_scale,
use_cached=True,
max_cached_images=mosaic_max_cached_images,
pad_val=114.0),
dict(
type=RandomResize,
# img_scale is (width, height)
scale=(img_scale[0] * 2, img_scale[1] * 2),
ratio_range=random_resize_ratio_range, # note
resize_type=Resize,
keep_ratio=True),
dict(type=RandomCrop, crop_size=img_scale),
dict(type=YOLOXHSVRandomAug),
dict(type=RandomFlip, prob=0.5),
dict(type=Pad, size=img_scale, pad_val=dict(img=(114, 114, 114))),
dict(
type=YOLOv5MixUp,
use_cached=True,
max_cached_images=mixup_max_cached_images),
dict(type=PackDetInputs)
]

train_pipeline_stage2 = [
dict(type=LoadImageFromFile, file_client_args=file_client_args),
dict(type=LoadAnnotations, with_bbox=True),
dict(
type=RandomResize,
scale=img_scale,
ratio_range=random_resize_ratio_range, # note
resize_type=Resize,
keep_ratio=True),
dict(type=RandomCrop, crop_size=img_scale),
dict(type=YOLOXHSVRandomAug),
dict(type=RandomFlip, prob=0.5),
dict(type=Pad, size=img_scale, pad_val=dict(img=(114, 114, 114))),
dict(type=PackDetInputs)
]

train_dataloader.update(dataset=dict(pipeline=train_pipeline))

custom_hooks = [
dict(
type=EMAHook,
ema_type=ExpMomentumEMA,
momentum=0.0002,
update_buffers=True,
strict_load=False,
priority=49),
dict(
type=PipelineSwitchHook,
switch_epoch=max_epochs - num_epochs_stage2,
switch_pipeline=train_pipeline_stage2)
]
60 changes: 60 additions & 0 deletions mmyolo/configs/rtmdet/rtmdet_tiny_syncbn_fast_8xb32_300e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) OpenMMLab. All rights reserved.
if '_base_':
from .rtmdet_s_syncbn_fast_8xb32_300e_coco import *

checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-tiny_imagenet_600e.pth' # noqa

# ========================modified parameters======================
deepen_factor = 0.167
widen_factor = 0.375

# ratio range for random resize
random_resize_ratio_range = (0.5, 2.0)
# Number of cached images in mosaic
mosaic_max_cached_images = 20
# Number of cached images in mixup
mixup_max_cached_images = 10

# =======================Unmodified in most cases==================
model.update(
backbone=dict(
deepen_factor=deepen_factor,
widen_factor=widen_factor,
init_cfg=dict(checkpoint=checkpoint)),
neck=dict(
deepen_factor=deepen_factor,
widen_factor=widen_factor,
),
bbox_head=dict(head_module=dict(widen_factor=widen_factor)))

train_pipeline = [
dict(type=LoadImageFromFile, file_client_args=file_client_args),
dict(type=LoadAnnotations, with_bbox=True),
dict(
type=Mosaic,
img_scale=img_scale,
use_cached=True,
max_cached_images=mosaic_max_cached_images, # note
random_pop=False, # note
pad_val=114.0),
dict(
type=RandomResize,
# img_scale is (width, height)
scale=(img_scale[0] * 2, img_scale[1] * 2),
ratio_range=random_resize_ratio_range,
resize_type=Resize,
keep_ratio=True),
dict(type=RandomCrop, crop_size=img_scale),
dict(type=YOLOXHSVRandomAug),
dict(type=RandomFlip, prob=0.5),
dict(type=Pad, size=img_scale, pad_val=dict(img=(114, 114, 114))),
dict(
type=YOLOv5MixUp,
use_cached=True,
random_pop=False,
max_cached_images=mixup_max_cached_images,
prob=0.5),
dict(type=PackDetInputs)
]

train_dataloader.update(dataset=dict(pipeline=train_pipeline))
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -19,3 +19,6 @@ SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true
skip = *.ipynb
quiet-level = 3
ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids,tood,ba,warmup,elease,dota

[flake8]
per-file-ignores = mmyolo/configs/*: F401,F403,F405