Skip to content

Commit

Permalink
Support DINO (#51)
Browse files Browse the repository at this point in the history
* add two stage dab deformable detr

* update two stage criterion

* dino

* Add two-stage dab-deformable-detr (#46)

* add two stage

* update two stage with warmup

* update warmup

* update model init

* refine dab-deformable-two-stage model config

* refine dino project

* delete redundant files

* add readme for dino

* refine dino config

Co-authored-by: SlongLiu <[email protected]>
Co-authored-by: hao zhang <[email protected]>
Co-authored-by: Shilong Liu <[email protected]>
Co-authored-by: ntianhe ren <[email protected]>
  • Loading branch information
5 people authored Sep 7, 2022
1 parent 32856ea commit 48cba8d
Show file tree
Hide file tree
Showing 12 changed files with 1,302 additions and 63 deletions.
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
import copy
from detrex.config import get_config

from detectron2.config import LazyCall as L

from detrex.config import get_config
from detrex.modeling.matcher import HungarianMatcher

from projects.dab_deformable_detr.modeling import (
DabDeformableDETR,
DabDeformableDetrTransformerEncoder,
DabDeformableDetrTransformerDecoder,
DabDeformableDetrTransformer,
TwoStageCriterion
)

from projects.dab_deformable_detr.modeling import TwoStageCriterion

from .dab_deformable_detr_r50_50ep import (
model,
train,
dataloader,
optimizer,
lr_multiplier,
)

from .models.dab_deformable_detr_r50 import model


# modify training config
train.init_checkpoint = "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
train.output_dir = "./output/dab_deformable_detr_r50_two_stage_50ep"


# set model
# modify model config
model.as_two_stage = True
model.criterion = L(TwoStageCriterion)(
num_classes=80,
Expand Down Expand Up @@ -48,7 +48,7 @@
)


# set aux loss weight dict
# set aux loss weight dict for two stage deformable model
base_weight_dict = copy.deepcopy(model.criterion.weight_dict)
if model.aux_loss:
weight_dict = model.criterion.weight_dict
Expand All @@ -63,26 +63,3 @@
aux_weight_dict.update({k + f"_enc": v for k, v in base_weight_dict.items()})
weight_dict.update(aux_weight_dict)
model.criterion.weight_dict = weight_dict



dataloader = get_config("common/data/coco_detr.py").dataloader
optimizer = get_config("common/optim.py").AdamW
lr_multiplier = get_config("common/coco_schedule.py").lr_multiplier_50ep
train = get_config("common/train.py").train

# modify training config
train.init_checkpoint = "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
train.output_dir = "./output/dab_deformable_detr_r50_50ep"
train.max_iter = 375000
train.clip_grad.enabled = True
train.clip_grad.params.max_norm = 0.1
train.clip_grad.params.norm_type = 2
train.seed = 42

# modify optimizer config
optimizer.weight_decay = 1e-4
optimizer.params.lr_factor_func = lambda module_name: 0.1 if "backbone" in module_name else 1

# modify dataloader config
dataloader.train.num_workers = 16
29 changes: 4 additions & 25 deletions projects/dab_deformable_detr/modeling/dab_deformable_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,36 +94,14 @@ def __init__(
# hack implementation for two-stage
if self.as_two_stage:
self.transformer.decoder.class_embed = self.class_embed

# hack implementation for iterative bounding box refinement
self.transformer.decoder.bbox_embed = self.bbox_embed

# self.init_weights()

if self.as_two_stage:
for bbox_embed_layer in self.bbox_embed:
nn.init.constant_(bbox_embed_layer.layers[-1].bias.data[2:], 0.0)

# def init_weights(self):
# prior_prob = 0.01
# bias_value = -math.log((1 - prior_prob) / prior_prob)

# for class_embed_layer in self.class_embed:
# class_embed_layer.bias.data = torch.ones(self.num_classes) * bias_value

# for bbox_embed_layer in self.bbox_embed:
# nn.init.constant_(bbox_embed_layer.layers[-1].weight.data, 0)
# nn.init.constant_(bbox_embed_layer.layers[-1].bias.data, 0)

# for _, neck_layer in self.neck.named_modules():
# if isinstance(neck_layer, nn.Conv2d):
# nn.init.xavier_uniform_(neck_layer.weight, gain=1)
# nn.init.constant_(neck_layer.bias, 0)

# nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0)

# if self.as_two_stage:
# for bbox_embed_layer in self.bbox_embed:
# nn.init.constant_(bbox_embed_layer.layers[-1].bias.data[2:], 0.0)

def forward(self, batched_inputs):

Expand Down Expand Up @@ -151,7 +129,7 @@ def forward(self, batched_inputs):
F.interpolate(img_masks[None], size=feat.shape[-2:]).to(torch.bool).squeeze(0)
)
multi_level_position_embeddings.append(self.position_embedding(multi_level_masks[-1]))

if self.as_two_stage:
query_embeds = None
else:
Expand Down Expand Up @@ -199,7 +177,8 @@ def forward(self, batched_inputs):
interm_coord = enc_reference
interm_class = self.class_embed[-1](enc_state)
output['enc_outputs'] = {'pred_logits': interm_class, 'pred_boxes': interm_coord}



if self.training:
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
targets = self.prepare_targets(gt_instances)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ def __init__(
if self.as_two_stage:
self.enc_output = nn.Linear(self.embed_dim, self.embed_dim)
self.enc_output_norm = nn.LayerNorm(self.embed_dim)
self.pos_trans = nn.Linear(self.embed_dim * 2, self.embed_dim * 2)
self.pos_trans_norm = nn.LayerNorm(self.embed_dim)

self.init_weights()

Expand Down Expand Up @@ -407,6 +409,7 @@ def forward(
)
# output_memory: bs, num_tokens, c
# output_proposals: bs, num_tokens, 4. unsigmoided.
# output_proposals: bs, num_tokens, 4

enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory)
enc_outputs_coord_unact = (
Expand All @@ -428,7 +431,6 @@ def forward(
output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1])
)
target = target_unact.detach()

else:
reference_points = query_embed[..., self.embed_dim :].sigmoid()
target = query_embed[..., : self.embed_dim]
Expand Down
40 changes: 40 additions & 0 deletions projects/dino/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
## DINO: DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection

Hao Zhang, Feng Li, Shilong Liu, Lei Zhang, Hang Su, Jun Zhu, Lionel M. Ni, Heung-Yeung Shum

[[`arXiv`](https://arxiv.org/abs/2203.03605)] [[`BibTeX`](#citing-dino)]

<div align="center">
<img src="./assets/dino_arch.png"/>
</div><br/>


## Training
All configs can be trained with:
```bash
cd detrex
python tools/train_net.py --config-file projects/dino/configs/path/to/config.py --num-gpus 8
```
By default, we use 8 GPUs with total batch size as 16 for training.

## Evaluation
Model evaluation can be done as follows:
```bash
cd detrex
python tools/train_net.py --config-file projects/dino/configs/path/to/config.py --eval-only train.init_checkpoint=/path/to/model_checkpoint
```


## Citing DINO
If you find our work helpful for your research, please consider citing the following BibTeX entry.

```BibTex
@misc{zhang2022dino,
title={DINO: DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection},
author={Hao Zhang and Feng Li and Shilong Liu and Lei Zhang and Hang Su and Jun Zhu and Lionel M. Ni and Heung-Yeung Shum},
year={2022},
eprint={2203.03605},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
Binary file added projects/dino/assets/dino_arch.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
22 changes: 22 additions & 0 deletions projects/dino/configs/dino_r50_50ep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from detrex.config import get_config
from .models.dino_r50 import model

dataloader = get_config("common/data/coco_detr.py").dataloader
optimizer = get_config("common/optim.py").AdamW
lr_multiplier = get_config("common/coco_schedule.py").lr_multiplier_50ep
train = get_config("common/train.py").train

# modify training config
train.init_checkpoint = "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
train.output_dir = "./output/dino_r50_50ep"
train.max_iter = 375000
train.clip_grad.enabled = True
train.clip_grad.params.max_norm = 0.1
train.clip_grad.params.norm_type = 2

# modify optimizer config
optimizer.weight_decay = 1e-4
optimizer.params.lr_factor_func = lambda module_name: 0.1 if "backbone" in module_name else 1

# modify dataloader config
dataloader.train.num_workers = 16
130 changes: 130 additions & 0 deletions projects/dino/configs/models/dino_r50.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import copy
import torch.nn as nn

from detectron2.modeling.backbone import ResNet, BasicStem
from detectron2.layers import ShapeSpec
from detectron2.config import LazyCall as L

from detrex.modeling.matcher import HungarianMatcher
from detrex.modeling.neck import ChannelMapper
from detrex.layers import PositionEmbeddingSine

from projects.dino.modeling import (
DINO,
DINOTransformer,
DINOTransformerEncoder,
DINOTransformerDecoder,
DINOCriterion,
)

num_feature_levels = 4

model = L(DINO)(
backbone=L(ResNet)(
stem=L(BasicStem)(in_channels=3, out_channels=64, norm="FrozenBN"),
stages=L(ResNet.make_default_stages)(
depth=50,
stride_in_1x1=False,
norm="FrozenBN",
),
out_features=["res3", "res4", "res5"],
freeze_at=1,
),
position_embedding=L(PositionEmbeddingSine)(
num_pos_feats=128,
temperature=10000,
normalize=True,
offset=-0.5,
),
neck=L(ChannelMapper)(
input_shapes={
"res3": ShapeSpec(channels=512),
"res4": ShapeSpec(channels=1024),
"res5": ShapeSpec(channels=2048),
},
in_features=["res3", "res4", "res5"],
out_channels=256,
num_outs=4,
norm_layer=L(nn.GroupNorm)(num_groups=32, num_channels=256),
),
transformer=L(DINOTransformer)(
encoder=L(DINOTransformerEncoder)(
embed_dim=256,
num_heads=8,
feedforward_dim=2048,
attn_dropout=0.0,
ffn_dropout=0.0,
num_layers=6,
post_norm=False,
num_feature_levels=num_feature_levels,
),
decoder=L(DINOTransformerDecoder)(
embed_dim=256,
num_heads=8,
feedforward_dim=2048,
attn_dropout=0.0,
ffn_dropout=0.0,
num_layers=6,
return_intermediate=True,
use_dab=True,
num_feature_levels=num_feature_levels,
),
as_two_stage="${..as_two_stage}",
num_feature_levels=num_feature_levels,
two_stage_num_proposals=900,
),
num_classes=80,
num_queries=900,
aux_loss=True,
as_two_stage=True,
criterion = L(DINOCriterion)(
num_classes=80,
matcher=L(HungarianMatcher)(
cost_class=2.0,
cost_bbox=5.0,
cost_giou=2.0,
cost_class_type="focal_loss_cost",
alpha=0.25,
gamma=2.0,
),
weight_dict={
"loss_class": 1,
"loss_bbox": 5.0,
"loss_giou": 2.0,
"loss_class_dn":1,
'loss_bbox_dn':5.0,
'loss_giou_dn':2.0
},

losses=[
"class",
"boxes",
],
loss_class_type="focal_loss",
alpha=0.25,
gamma=2.0,
two_stage_binary_cls=False,
),
pixel_mean=[123.675, 116.280, 103.530],
pixel_std=[58.395, 57.120, 57.375],
dn_number=100,
label_noise_ratio=0.2,
box_noise_scale=1.0,
device="cuda",
)

# set aux loss weight dict
base_weight_dict = copy.deepcopy(model.criterion.weight_dict)
if model.aux_loss:
weight_dict = model.criterion.weight_dict
aux_weight_dict = {}
for i in range(model.transformer.decoder.num_layers - 1):
aux_weight_dict.update({k + f"_{i}": v for k, v in base_weight_dict.items()})
weight_dict.update(aux_weight_dict)
model.criterion.weight_dict = weight_dict
if model.as_two_stage:
weight_dict = model.criterion.weight_dict
aux_weight_dict = {}
aux_weight_dict.update({k + f"_enc": v for k, v in base_weight_dict.items()})
weight_dict.update(aux_weight_dict)
model.criterion.weight_dict = weight_dict
23 changes: 23 additions & 0 deletions projects/dino/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# coding=utf-8
# Copyright 2022 The IDEA Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from .dino_transformer import (
DINOTransformer,
DINOTransformerEncoder,
DINOTransformerDecoder,
)
from .dino import DINO
from .dino_criterion import DINOCriterion
Loading

0 comments on commit 48cba8d

Please sign in to comment.