|
1 | 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved |
2 | 2 | import logging |
3 | 3 | import math |
| 4 | +import numpy as np |
4 | 5 | from typing import List |
5 | 6 | import torch |
6 | 7 | from fvcore.nn import sigmoid_focal_loss_jit, smooth_l1_loss |
@@ -77,6 +78,9 @@ def __init__(self, cfg): |
77 | 78 | self.topk_candidates = cfg.MODEL.RETINANET.TOPK_CANDIDATES_TEST |
78 | 79 | self.nms_threshold = cfg.MODEL.RETINANET.NMS_THRESH_TEST |
79 | 80 | self.max_detections_per_image = cfg.TEST.DETECTIONS_PER_IMAGE |
| 81 | + # Vis parameters |
| 82 | + self.vis_period = cfg.VIS_PERIOD |
| 83 | + self.input_format = cfg.INPUT.FORMAT |
80 | 84 | # fmt: on |
81 | 85 |
|
82 | 86 | self.backbone = build_backbone(cfg) |
@@ -108,6 +112,44 @@ def __init__(self, cfg): |
108 | 112 | self.loss_normalizer = 100 # initialize with any reasonable #fg that's not too small |
109 | 113 | self.loss_normalizer_momentum = 0.9 |
110 | 114 |
|
| 115 | + def visualize_training(self, batched_inputs, results): |
| 116 | + """ |
| 117 | + A function used to visualize ground truth images and final network predictions. |
| 118 | + It shows ground truth bounding boxes on the original image and up to 20 |
| 119 | + predicted object bounding boxes on the original image. |
| 120 | +
|
| 121 | + Args: |
| 122 | + batched_inputs (list): a list that contains input to the model. |
| 123 | + results (List[Instances]): a list of #images elements. |
| 124 | + """ |
| 125 | + from detectron2.utils.visualizer import Visualizer |
| 126 | + |
| 127 | + assert len(batched_inputs) == len( |
| 128 | + results |
| 129 | + ), "Cannot visualize inputs and results of different sizes" |
| 130 | + storage = get_event_storage() |
| 131 | + max_boxes = 20 |
| 132 | + |
| 133 | + image_index = 0 # only visualize a single image |
| 134 | + img = batched_inputs[image_index]["image"].cpu().numpy() |
| 135 | + assert img.shape[0] == 3, "Images should have 3 channels." |
| 136 | + if self.input_format == "BGR": |
| 137 | + img = img[::-1, :, :] |
| 138 | + img = img.transpose(1, 2, 0) |
| 139 | + v_gt = Visualizer(img, None) |
| 140 | + v_gt = v_gt.overlay_instances(boxes=batched_inputs[image_index]["instances"].gt_boxes) |
| 141 | + anno_img = v_gt.get_image() |
| 142 | + processed_results = detector_postprocess(results[image_index], img.shape[0], img.shape[1]) |
| 143 | + predicted_boxes = processed_results.pred_boxes.tensor.detach().cpu().numpy() |
| 144 | + |
| 145 | + v_pred = Visualizer(img, None) |
| 146 | + v_pred = v_pred.overlay_instances(boxes=predicted_boxes[0:max_boxes]) |
| 147 | + prop_img = v_pred.get_image() |
| 148 | + vis_img = np.vstack((anno_img, prop_img)) |
| 149 | + vis_img = vis_img.transpose(2, 0, 1) |
| 150 | + vis_name = f"Top: GT bounding boxes; Bottom: {max_boxes} Highest Scoring Results" |
| 151 | + storage.put_image(vis_name, vis_img) |
| 152 | + |
111 | 153 | def forward(self, batched_inputs): |
112 | 154 | """ |
113 | 155 | Args: |
@@ -144,7 +186,15 @@ def forward(self, batched_inputs): |
144 | 186 |
|
145 | 187 | if self.training: |
146 | 188 | gt_classes, gt_anchors_reg_deltas = self.get_ground_truth(anchors, gt_instances) |
147 | | - return self.losses(gt_classes, gt_anchors_reg_deltas, box_cls, box_delta) |
| 189 | + losses = self.losses(gt_classes, gt_anchors_reg_deltas, box_cls, box_delta) |
| 190 | + |
| 191 | + if self.vis_period > 0: |
| 192 | + storage = get_event_storage() |
| 193 | + if storage.iter % self.vis_period == 0: |
| 194 | + results = self.inference(box_cls, box_delta, anchors, images.image_sizes) |
| 195 | + self.visualize_training(batched_inputs, results) |
| 196 | + |
| 197 | + return losses |
148 | 198 | else: |
149 | 199 | results = self.inference(box_cls, box_delta, anchors, images.image_sizes) |
150 | 200 | processed_results = [] |
|
0 commit comments