Skip to content

Commit d362da6

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
use EMA normalizer in RetinaNet
Summary: fix #868 Reviewed By: rbgirshick, alexander-kirillov Differential Revision: D20062149 fbshipit-source-id: 5fcf0537730f3b3f7217fde42f8b97e506eab316
1 parent 037823e commit d362da6

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

detectron2/modeling/meta_arch/retinanet.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from detectron2.layers import ShapeSpec, batched_nms, cat
1010
from detectron2.structures import Boxes, ImageList, Instances, pairwise_iou
11+
from detectron2.utils.events import get_event_storage
1112
from detectron2.utils.logger import log_first_n
1213

1314
from ..anchor_generator import build_anchor_generator
@@ -98,6 +99,15 @@ def __init__(self, cfg):
9899
self.normalizer = lambda x: (x - pixel_mean) / pixel_std
99100
self.to(self.device)
100101

102+
"""
103+
In Detectron1, loss is normalized by number of foreground samples in the batch.
104+
When batch size is 1 per GPU, #foreground has a large variance and
105+
using it lead to lower performance. Here we maintain an EMA of #foreground to
106+
stabilize the normalizer.
107+
"""
108+
self.loss_normalizer = 100 # initialize with any reasonable #fg that's not too small
109+
self.loss_normalizer_momentum = 0.9
110+
101111
def forward(self, batched_inputs):
102112
"""
103113
Args:
@@ -172,7 +182,12 @@ def losses(self, gt_classes, gt_anchors_deltas, pred_class_logits, pred_anchor_d
172182

173183
valid_idxs = gt_classes >= 0
174184
foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
175-
num_foreground = foreground_idxs.sum()
185+
num_foreground = foreground_idxs.sum().item()
186+
get_event_storage().put_scalar("num_foreground", num_foreground)
187+
self.loss_normalizer = (
188+
self.loss_normalizer_momentum * self.loss_normalizer
189+
+ (1 - self.loss_normalizer_momentum) * num_foreground
190+
)
176191

177192
gt_classes_target = torch.zeros_like(pred_class_logits)
178193
gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1
@@ -184,15 +199,15 @@ def losses(self, gt_classes, gt_anchors_deltas, pred_class_logits, pred_anchor_d
184199
alpha=self.focal_loss_alpha,
185200
gamma=self.focal_loss_gamma,
186201
reduction="sum",
187-
) / max(1, num_foreground)
202+
) / max(1, self.loss_normalizer)
188203

189204
# regression loss
190205
loss_box_reg = smooth_l1_loss(
191206
pred_anchor_deltas[foreground_idxs],
192207
gt_anchors_deltas[foreground_idxs],
193208
beta=self.smooth_l1_loss_beta,
194209
reduction="sum",
195-
) / max(1, num_foreground)
210+
) / max(1, self.loss_normalizer)
196211

197212
return {"loss_cls": loss_cls, "loss_box_reg": loss_box_reg}
198213

0 commit comments

Comments
 (0)