Skip to content

Commit 76e543e

Browse files
RaymondCMfacebook-github-bot
authored andcommitted
Training visualisation for RetinaNet (#890)
Summary: Purely additive PR, added a training visualization option to the RetinaNet class, similar to the rcnn that allows logging of best scoring predictions every cfg.VIS_PERIOD storage iters. Pull Request resolved: #890 Reviewed By: rbgirshick Differential Revision: D20065554 Pulled By: ppwwyyxx fbshipit-source-id: 800631d38faf7cc93eb43cf62cf74187a52bb76d
1 parent d362da6 commit 76e543e

File tree

1 file changed

+51
-1
lines changed

1 file changed

+51
-1
lines changed

detectron2/modeling/meta_arch/retinanet.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
22
import logging
33
import math
4+
import numpy as np
45
from typing import List
56
import torch
67
from fvcore.nn import sigmoid_focal_loss_jit, smooth_l1_loss
@@ -77,6 +78,9 @@ def __init__(self, cfg):
7778
self.topk_candidates = cfg.MODEL.RETINANET.TOPK_CANDIDATES_TEST
7879
self.nms_threshold = cfg.MODEL.RETINANET.NMS_THRESH_TEST
7980
self.max_detections_per_image = cfg.TEST.DETECTIONS_PER_IMAGE
81+
# Vis parameters
82+
self.vis_period = cfg.VIS_PERIOD
83+
self.input_format = cfg.INPUT.FORMAT
8084
# fmt: on
8185

8286
self.backbone = build_backbone(cfg)
@@ -108,6 +112,44 @@ def __init__(self, cfg):
108112
self.loss_normalizer = 100 # initialize with any reasonable #fg that's not too small
109113
self.loss_normalizer_momentum = 0.9
110114

115+
def visualize_training(self, batched_inputs, results):
116+
"""
117+
A function used to visualize ground truth images and final network predictions.
118+
It shows ground truth bounding boxes on the original image and up to 20
119+
predicted object bounding boxes on the original image.
120+
121+
Args:
122+
batched_inputs (list): a list that contains input to the model.
123+
results (List[Instances]): a list of #images elements.
124+
"""
125+
from detectron2.utils.visualizer import Visualizer
126+
127+
assert len(batched_inputs) == len(
128+
results
129+
), "Cannot visualize inputs and results of different sizes"
130+
storage = get_event_storage()
131+
max_boxes = 20
132+
133+
image_index = 0 # only visualize a single image
134+
img = batched_inputs[image_index]["image"].cpu().numpy()
135+
assert img.shape[0] == 3, "Images should have 3 channels."
136+
if self.input_format == "BGR":
137+
img = img[::-1, :, :]
138+
img = img.transpose(1, 2, 0)
139+
v_gt = Visualizer(img, None)
140+
v_gt = v_gt.overlay_instances(boxes=batched_inputs[image_index]["instances"].gt_boxes)
141+
anno_img = v_gt.get_image()
142+
processed_results = detector_postprocess(results[image_index], img.shape[0], img.shape[1])
143+
predicted_boxes = processed_results.pred_boxes.tensor.detach().cpu().numpy()
144+
145+
v_pred = Visualizer(img, None)
146+
v_pred = v_pred.overlay_instances(boxes=predicted_boxes[0:max_boxes])
147+
prop_img = v_pred.get_image()
148+
vis_img = np.vstack((anno_img, prop_img))
149+
vis_img = vis_img.transpose(2, 0, 1)
150+
vis_name = f"Top: GT bounding boxes; Bottom: {max_boxes} Highest Scoring Results"
151+
storage.put_image(vis_name, vis_img)
152+
111153
def forward(self, batched_inputs):
112154
"""
113155
Args:
@@ -144,7 +186,15 @@ def forward(self, batched_inputs):
144186

145187
if self.training:
146188
gt_classes, gt_anchors_reg_deltas = self.get_ground_truth(anchors, gt_instances)
147-
return self.losses(gt_classes, gt_anchors_reg_deltas, box_cls, box_delta)
189+
losses = self.losses(gt_classes, gt_anchors_reg_deltas, box_cls, box_delta)
190+
191+
if self.vis_period > 0:
192+
storage = get_event_storage()
193+
if storage.iter % self.vis_period == 0:
194+
results = self.inference(box_cls, box_delta, anchors, images.image_sizes)
195+
self.visualize_training(batched_inputs, results)
196+
197+
return losses
148198
else:
149199
results = self.inference(box_cls, box_delta, anchors, images.image_sizes)
150200
processed_results = []

0 commit comments

Comments
 (0)