-
Notifications
You must be signed in to change notification settings - Fork 216
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add two stage dab deformable detr * update two stage criterion * dino * Add two-stage dab-deformable-detr (#46) * add two stage * update two stage with warmup * update warmup * update model init * refine dab-deformable-two-stage model config * refine dino project * delete redundant files * add readme for dino * refine dino config Co-authored-by: SlongLiu <[email protected]> Co-authored-by: hao zhang <[email protected]> Co-authored-by: Shilong Liu <[email protected]> Co-authored-by: ntianhe ren <[email protected]>
- Loading branch information
1 parent
32856ea
commit 48cba8d
Showing
12 changed files
with
1,302 additions
and
63 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
## DINO: DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection | ||
|
||
Hao Zhang, Feng Li, Shilong Liu, Lei Zhang, Hang Su, Jun Zhu, Lionel M. Ni, Heung-Yeung Shum | ||
|
||
[[`arXiv`](https://arxiv.org/abs/2203.03605)] [[`BibTeX`](#citing-dino)] | ||
|
||
<div align="center"> | ||
<img src="./assets/dino_arch.png"/> | ||
</div><br/> | ||
|
||
|
||
## Training | ||
All configs can be trained with: | ||
```bash | ||
cd detrex | ||
python tools/train_net.py --config-file projects/dino/configs/path/to/config.py --num-gpus 8 | ||
``` | ||
By default, we use 8 GPUs with total batch size as 16 for training. | ||
|
||
## Evaluation | ||
Model evaluation can be done as follows: | ||
```bash | ||
cd detrex | ||
python tools/train_net.py --config-file projects/dino/configs/path/to/config.py --eval-only train.init_checkpoint=/path/to/model_checkpoint | ||
``` | ||
|
||
|
||
## Citing DINO | ||
If you find our work helpful for your research, please consider citing the following BibTeX entry. | ||
|
||
```BibTex | ||
@misc{zhang2022dino, | ||
title={DINO: DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection}, | ||
author={Hao Zhang and Feng Li and Shilong Liu and Lei Zhang and Hang Su and Jun Zhu and Lionel M. Ni and Heung-Yeung Shum}, | ||
year={2022}, | ||
eprint={2203.03605}, | ||
archivePrefix={arXiv}, | ||
primaryClass={cs.CV} | ||
} | ||
``` |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from detrex.config import get_config | ||
from .models.dino_r50 import model | ||
|
||
dataloader = get_config("common/data/coco_detr.py").dataloader | ||
optimizer = get_config("common/optim.py").AdamW | ||
lr_multiplier = get_config("common/coco_schedule.py").lr_multiplier_50ep | ||
train = get_config("common/train.py").train | ||
|
||
# modify training config | ||
train.init_checkpoint = "detectron2://ImageNetPretrained/torchvision/R-50.pkl" | ||
train.output_dir = "./output/dino_r50_50ep" | ||
train.max_iter = 375000 | ||
train.clip_grad.enabled = True | ||
train.clip_grad.params.max_norm = 0.1 | ||
train.clip_grad.params.norm_type = 2 | ||
|
||
# modify optimizer config | ||
optimizer.weight_decay = 1e-4 | ||
optimizer.params.lr_factor_func = lambda module_name: 0.1 if "backbone" in module_name else 1 | ||
|
||
# modify dataloader config | ||
dataloader.train.num_workers = 16 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import copy | ||
import torch.nn as nn | ||
|
||
from detectron2.modeling.backbone import ResNet, BasicStem | ||
from detectron2.layers import ShapeSpec | ||
from detectron2.config import LazyCall as L | ||
|
||
from detrex.modeling.matcher import HungarianMatcher | ||
from detrex.modeling.neck import ChannelMapper | ||
from detrex.layers import PositionEmbeddingSine | ||
|
||
from projects.dino.modeling import ( | ||
DINO, | ||
DINOTransformer, | ||
DINOTransformerEncoder, | ||
DINOTransformerDecoder, | ||
DINOCriterion, | ||
) | ||
|
||
num_feature_levels = 4 | ||
|
||
model = L(DINO)( | ||
backbone=L(ResNet)( | ||
stem=L(BasicStem)(in_channels=3, out_channels=64, norm="FrozenBN"), | ||
stages=L(ResNet.make_default_stages)( | ||
depth=50, | ||
stride_in_1x1=False, | ||
norm="FrozenBN", | ||
), | ||
out_features=["res3", "res4", "res5"], | ||
freeze_at=1, | ||
), | ||
position_embedding=L(PositionEmbeddingSine)( | ||
num_pos_feats=128, | ||
temperature=10000, | ||
normalize=True, | ||
offset=-0.5, | ||
), | ||
neck=L(ChannelMapper)( | ||
input_shapes={ | ||
"res3": ShapeSpec(channels=512), | ||
"res4": ShapeSpec(channels=1024), | ||
"res5": ShapeSpec(channels=2048), | ||
}, | ||
in_features=["res3", "res4", "res5"], | ||
out_channels=256, | ||
num_outs=4, | ||
norm_layer=L(nn.GroupNorm)(num_groups=32, num_channels=256), | ||
), | ||
transformer=L(DINOTransformer)( | ||
encoder=L(DINOTransformerEncoder)( | ||
embed_dim=256, | ||
num_heads=8, | ||
feedforward_dim=2048, | ||
attn_dropout=0.0, | ||
ffn_dropout=0.0, | ||
num_layers=6, | ||
post_norm=False, | ||
num_feature_levels=num_feature_levels, | ||
), | ||
decoder=L(DINOTransformerDecoder)( | ||
embed_dim=256, | ||
num_heads=8, | ||
feedforward_dim=2048, | ||
attn_dropout=0.0, | ||
ffn_dropout=0.0, | ||
num_layers=6, | ||
return_intermediate=True, | ||
use_dab=True, | ||
num_feature_levels=num_feature_levels, | ||
), | ||
as_two_stage="${..as_two_stage}", | ||
num_feature_levels=num_feature_levels, | ||
two_stage_num_proposals=900, | ||
), | ||
num_classes=80, | ||
num_queries=900, | ||
aux_loss=True, | ||
as_two_stage=True, | ||
criterion = L(DINOCriterion)( | ||
num_classes=80, | ||
matcher=L(HungarianMatcher)( | ||
cost_class=2.0, | ||
cost_bbox=5.0, | ||
cost_giou=2.0, | ||
cost_class_type="focal_loss_cost", | ||
alpha=0.25, | ||
gamma=2.0, | ||
), | ||
weight_dict={ | ||
"loss_class": 1, | ||
"loss_bbox": 5.0, | ||
"loss_giou": 2.0, | ||
"loss_class_dn":1, | ||
'loss_bbox_dn':5.0, | ||
'loss_giou_dn':2.0 | ||
}, | ||
|
||
losses=[ | ||
"class", | ||
"boxes", | ||
], | ||
loss_class_type="focal_loss", | ||
alpha=0.25, | ||
gamma=2.0, | ||
two_stage_binary_cls=False, | ||
), | ||
pixel_mean=[123.675, 116.280, 103.530], | ||
pixel_std=[58.395, 57.120, 57.375], | ||
dn_number=100, | ||
label_noise_ratio=0.2, | ||
box_noise_scale=1.0, | ||
device="cuda", | ||
) | ||
|
||
# set aux loss weight dict | ||
base_weight_dict = copy.deepcopy(model.criterion.weight_dict) | ||
if model.aux_loss: | ||
weight_dict = model.criterion.weight_dict | ||
aux_weight_dict = {} | ||
for i in range(model.transformer.decoder.num_layers - 1): | ||
aux_weight_dict.update({k + f"_{i}": v for k, v in base_weight_dict.items()}) | ||
weight_dict.update(aux_weight_dict) | ||
model.criterion.weight_dict = weight_dict | ||
if model.as_two_stage: | ||
weight_dict = model.criterion.weight_dict | ||
aux_weight_dict = {} | ||
aux_weight_dict.update({k + f"_enc": v for k, v in base_weight_dict.items()}) | ||
weight_dict.update(aux_weight_dict) | ||
model.criterion.weight_dict = weight_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# coding=utf-8 | ||
# Copyright 2022 The IDEA Authors. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from .dino_transformer import ( | ||
DINOTransformer, | ||
DINOTransformerEncoder, | ||
DINOTransformerDecoder, | ||
) | ||
from .dino import DINO | ||
from .dino_criterion import DINOCriterion |
Oops, something went wrong.