|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved |
| 2 | +import numpy as np |
| 3 | +from typing import Dict |
| 4 | +import torch |
| 5 | +from torch import nn |
| 6 | +from torch.nn import functional as F |
| 7 | + |
| 8 | +from detectron2.layers import ShapeSpec, cat |
| 9 | +from detectron2.modeling import SEM_SEG_HEADS_REGISTRY |
| 10 | + |
| 11 | +from .point_features import ( |
| 12 | + get_uncertain_point_coords_on_grid, |
| 13 | + get_uncertain_point_coords_with_randomness, |
| 14 | + point_sample, |
| 15 | +) |
| 16 | +from .point_head import build_point_head |
| 17 | + |
| 18 | + |
| 19 | +def calculate_uncertainty(sem_seg_logits): |
| 20 | + """ |
| 21 | + For each location of the prediction `sem_seg_logits` we estimate uncerainty as the |
| 22 | + difference between top first and top second predicted logits. |
| 23 | +
|
| 24 | + Args: |
| 25 | + mask_logits (Tensor): A tensor of shape (N, C, ...), where N is the minibatch size and |
| 26 | + C is the number of foreground classes. The values are logits. |
| 27 | +
|
| 28 | + Returns: |
| 29 | + scores (Tensor): A tensor of shape (N, 1, ...) that contains uncertainty scores with |
| 30 | + the most uncertain locations having the highest uncertainty score. |
| 31 | + """ |
| 32 | + top2_scores = torch.topk(sem_seg_logits, k=2, dim=1)[0] |
| 33 | + return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1) |
| 34 | + |
| 35 | + |
| 36 | +@SEM_SEG_HEADS_REGISTRY.register() |
| 37 | +class PointRendSemSegHead(nn.Module): |
| 38 | + """ |
| 39 | + A semantic segmentation head that combines a head set in `POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME` |
| 40 | + and a point head set in `MODEL.POINT_HEAD.NAME`. |
| 41 | + """ |
| 42 | + |
| 43 | + def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]): |
| 44 | + super().__init__() |
| 45 | + |
| 46 | + self.ignore_value = cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE |
| 47 | + |
| 48 | + self.coarse_sem_seg_head = SEM_SEG_HEADS_REGISTRY.get( |
| 49 | + cfg.MODEL.POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME |
| 50 | + )(cfg, input_shape) |
| 51 | + self._init_point_head(cfg, input_shape) |
| 52 | + |
| 53 | + def _init_point_head(self, cfg, input_shape: Dict[str, ShapeSpec]): |
| 54 | + # fmt: off |
| 55 | + assert cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES == cfg.MODEL.POINT_HEAD.NUM_CLASSES |
| 56 | + feature_channels = {k: v.channels for k, v in input_shape.items()} |
| 57 | + self.in_features = cfg.MODEL.POINT_HEAD.IN_FEATURES |
| 58 | + self.train_num_points = cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS |
| 59 | + self.oversample_ratio = cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO |
| 60 | + self.importance_sample_ratio = cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO |
| 61 | + self.subdivision_steps = cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS |
| 62 | + self.subdivision_num_points = cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS |
| 63 | + # fmt: on |
| 64 | + |
| 65 | + in_channels = np.sum([feature_channels[f] for f in self.in_features]) |
| 66 | + self.point_head = build_point_head(cfg, ShapeSpec(channels=in_channels, width=1, height=1)) |
| 67 | + |
| 68 | + def forward(self, features, targets=None): |
| 69 | + coarse_sem_seg_logits = self.coarse_sem_seg_head.layers(features) |
| 70 | + |
| 71 | + if self.training: |
| 72 | + losses = self.coarse_sem_seg_head.losses(coarse_sem_seg_logits, targets) |
| 73 | + |
| 74 | + with torch.no_grad(): |
| 75 | + point_coords = get_uncertain_point_coords_with_randomness( |
| 76 | + coarse_sem_seg_logits, |
| 77 | + calculate_uncertainty, |
| 78 | + self.train_num_points, |
| 79 | + self.oversample_ratio, |
| 80 | + self.importance_sample_ratio, |
| 81 | + ) |
| 82 | + coarse_features = point_sample(coarse_sem_seg_logits, point_coords, align_corners=False) |
| 83 | + |
| 84 | + fine_grained_features = cat( |
| 85 | + [ |
| 86 | + point_sample(features[in_feature], point_coords, align_corners=False) |
| 87 | + for in_feature in self.in_features |
| 88 | + ] |
| 89 | + ) |
| 90 | + point_logits = self.point_head(fine_grained_features, coarse_features) |
| 91 | + point_targets = ( |
| 92 | + point_sample( |
| 93 | + targets.unsqueeze(1).to(torch.float), |
| 94 | + point_coords, |
| 95 | + mode="nearest", |
| 96 | + align_corners=False, |
| 97 | + ) |
| 98 | + .squeeze(1) |
| 99 | + .to(torch.long) |
| 100 | + ) |
| 101 | + losses["loss_sem_seg_point"] = F.cross_entropy( |
| 102 | + point_logits, point_targets, reduction="mean", ignore_index=self.ignore_value |
| 103 | + ) |
| 104 | + return None, losses |
| 105 | + else: |
| 106 | + sem_seg_logits = coarse_sem_seg_logits.clone() |
| 107 | + for _ in range(self.subdivision_steps): |
| 108 | + sem_seg_logits = F.interpolate( |
| 109 | + sem_seg_logits, scale_factor=2, mode="bilinear", align_corners=False |
| 110 | + ) |
| 111 | + uncertainty_map = calculate_uncertainty(sem_seg_logits) |
| 112 | + point_indices, point_coords = get_uncertain_point_coords_on_grid( |
| 113 | + uncertainty_map, self.subdivision_num_points |
| 114 | + ) |
| 115 | + fine_grained_features = cat( |
| 116 | + [ |
| 117 | + point_sample(features[in_feature], point_coords, align_corners=False) |
| 118 | + for in_feature in self.in_features |
| 119 | + ] |
| 120 | + ) |
| 121 | + coarse_features = point_sample( |
| 122 | + coarse_sem_seg_logits, point_coords, align_corners=False |
| 123 | + ) |
| 124 | + point_logits = self.point_head(fine_grained_features, coarse_features) |
| 125 | + |
| 126 | + # put sem seg point predictions to the right places on the upsampled grid. |
| 127 | + N, C, H, W = sem_seg_logits.shape |
| 128 | + point_indices = point_indices.unsqueeze(1).expand(-1, C, -1) |
| 129 | + sem_seg_logits = ( |
| 130 | + sem_seg_logits.reshape(N, C, H * W) |
| 131 | + .scatter_(2, point_indices, point_logits) |
| 132 | + .view(N, C, H, W) |
| 133 | + ) |
| 134 | + return sem_seg_logits, {} |
0 commit comments