Skip to content

Commit 9725ea1

Browse files
Alexander Kirillovfacebook-github-bot
authored andcommitted
PointRend semantic segmentation
Summary: semantic segmentation PointRend Reviewed By: ppwwyyxx Differential Revision: D19350389 fbshipit-source-id: cec04422ae5b76d730257de336d871cf281b625a
1 parent 9763402 commit 9725ea1

File tree

7 files changed

+199
-0
lines changed

7 files changed

+199
-0
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
_BASE_: "../../../../configs/Base-RCNN-FPN.yaml"
2+
MODEL:
3+
META_ARCHITECTURE: "SemanticSegmentor"
4+
BACKBONE:
5+
FREEZE_AT: 0
6+
SEM_SEG_HEAD:
7+
NAME: "PointRendSemSegHead"
8+
POINT_HEAD:
9+
NUM_CLASSES: 54
10+
FC_DIM: 256
11+
NUM_FC: 3
12+
IN_FEATURES: ["p2"]
13+
TRAIN_NUM_POINTS: 1024
14+
SUBDIVISION_STEPS: 2
15+
SUBDIVISION_NUM_POINTS: 8192
16+
COARSE_SEM_SEG_HEAD_NAME: "SemSegFPNHead"
17+
DATASETS:
18+
TRAIN: ("coco_2017_train_panoptic_stuffonly",)
19+
TEST: ("coco_2017_val_panoptic_stuffonly",)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
_BASE_: Base-PointRend-Semantic-FPN.yaml
2+
MODEL:
3+
WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-101.pkl
4+
RESNETS:
5+
DEPTH: 101
6+
SEM_SEG_HEAD:
7+
NUM_CLASSES: 19
8+
POINT_HEAD:
9+
NUM_CLASSES: 19
10+
TRAIN_NUM_POINTS: 2048
11+
SUBDIVISION_NUM_POINTS: 8192
12+
DATASETS:
13+
TRAIN: ("cityscapes_fine_sem_seg_train",)
14+
TEST: ("cityscapes_fine_sem_seg_val",)
15+
SOLVER:
16+
BASE_LR: 0.01
17+
STEPS: (40000, 55000)
18+
MAX_ITER: 65000
19+
IMS_PER_BATCH: 32
20+
INPUT:
21+
MIN_SIZE_TRAIN: (512, 768, 1024, 1280, 1536, 1792, 2048)
22+
MIN_SIZE_TRAIN_SAMPLING: "choice"
23+
MIN_SIZE_TEST: 1024
24+
MAX_SIZE_TRAIN: 4096
25+
MAX_SIZE_TEST: 2048
26+
CROP:
27+
ENABLED: True
28+
TYPE: "absolute"
29+
SIZE: (512, 1024)
30+
SINGLE_CATEGORY_MAX_AREA: 0.75
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
_BASE_: Base-PointRend-Semantic-FPN.yaml
2+
MODEL:
3+
WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl
4+
RESNETS:
5+
DEPTH: 50

projects/PointRend/point_rend/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .coarse_mask_head import CoarseMaskHead
44
from .roi_heads import PointRendROIHeads
55
from .dataset_mapper import SemSegDatasetMapper
6+
from .semantic_seg import PointRendSemSegHead

projects/PointRend/point_rend/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,4 @@ def add_pointrend_config(cfg):
4343
cfg.MODEL.POINT_HEAD.CLS_AGNOSTIC_MASK = False
4444
# If True, then coarse prediction features are used as inout for each layer in PointRend's MLP.
4545
cfg.MODEL.POINT_HEAD.COARSE_PRED_EACH_LAYER = True
46+
cfg.MODEL.POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME = "SemSegFPNHead"
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2+
import numpy as np
3+
from typing import Dict
4+
import torch
5+
from torch import nn
6+
from torch.nn import functional as F
7+
8+
from detectron2.layers import ShapeSpec, cat
9+
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
10+
11+
from .point_features import (
12+
get_uncertain_point_coords_on_grid,
13+
get_uncertain_point_coords_with_randomness,
14+
point_sample,
15+
)
16+
from .point_head import build_point_head
17+
18+
19+
def calculate_uncertainty(sem_seg_logits):
20+
"""
21+
For each location of the prediction `sem_seg_logits` we estimate uncerainty as the
22+
difference between top first and top second predicted logits.
23+
24+
Args:
25+
mask_logits (Tensor): A tensor of shape (N, C, ...), where N is the minibatch size and
26+
C is the number of foreground classes. The values are logits.
27+
28+
Returns:
29+
scores (Tensor): A tensor of shape (N, 1, ...) that contains uncertainty scores with
30+
the most uncertain locations having the highest uncertainty score.
31+
"""
32+
top2_scores = torch.topk(sem_seg_logits, k=2, dim=1)[0]
33+
return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)
34+
35+
36+
@SEM_SEG_HEADS_REGISTRY.register()
37+
class PointRendSemSegHead(nn.Module):
38+
"""
39+
A semantic segmentation head that combines a head set in `POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME`
40+
and a point head set in `MODEL.POINT_HEAD.NAME`.
41+
"""
42+
43+
def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]):
44+
super().__init__()
45+
46+
self.ignore_value = cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE
47+
48+
self.coarse_sem_seg_head = SEM_SEG_HEADS_REGISTRY.get(
49+
cfg.MODEL.POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME
50+
)(cfg, input_shape)
51+
self._init_point_head(cfg, input_shape)
52+
53+
def _init_point_head(self, cfg, input_shape: Dict[str, ShapeSpec]):
54+
# fmt: off
55+
assert cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES == cfg.MODEL.POINT_HEAD.NUM_CLASSES
56+
feature_channels = {k: v.channels for k, v in input_shape.items()}
57+
self.in_features = cfg.MODEL.POINT_HEAD.IN_FEATURES
58+
self.train_num_points = cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS
59+
self.oversample_ratio = cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO
60+
self.importance_sample_ratio = cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO
61+
self.subdivision_steps = cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS
62+
self.subdivision_num_points = cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS
63+
# fmt: on
64+
65+
in_channels = np.sum([feature_channels[f] for f in self.in_features])
66+
self.point_head = build_point_head(cfg, ShapeSpec(channels=in_channels, width=1, height=1))
67+
68+
def forward(self, features, targets=None):
69+
coarse_sem_seg_logits = self.coarse_sem_seg_head.layers(features)
70+
71+
if self.training:
72+
losses = self.coarse_sem_seg_head.losses(coarse_sem_seg_logits, targets)
73+
74+
with torch.no_grad():
75+
point_coords = get_uncertain_point_coords_with_randomness(
76+
coarse_sem_seg_logits,
77+
calculate_uncertainty,
78+
self.train_num_points,
79+
self.oversample_ratio,
80+
self.importance_sample_ratio,
81+
)
82+
coarse_features = point_sample(coarse_sem_seg_logits, point_coords, align_corners=False)
83+
84+
fine_grained_features = cat(
85+
[
86+
point_sample(features[in_feature], point_coords, align_corners=False)
87+
for in_feature in self.in_features
88+
]
89+
)
90+
point_logits = self.point_head(fine_grained_features, coarse_features)
91+
point_targets = (
92+
point_sample(
93+
targets.unsqueeze(1).to(torch.float),
94+
point_coords,
95+
mode="nearest",
96+
align_corners=False,
97+
)
98+
.squeeze(1)
99+
.to(torch.long)
100+
)
101+
losses["loss_sem_seg_point"] = F.cross_entropy(
102+
point_logits, point_targets, reduction="mean", ignore_index=self.ignore_value
103+
)
104+
return None, losses
105+
else:
106+
sem_seg_logits = coarse_sem_seg_logits.clone()
107+
for _ in range(self.subdivision_steps):
108+
sem_seg_logits = F.interpolate(
109+
sem_seg_logits, scale_factor=2, mode="bilinear", align_corners=False
110+
)
111+
uncertainty_map = calculate_uncertainty(sem_seg_logits)
112+
point_indices, point_coords = get_uncertain_point_coords_on_grid(
113+
uncertainty_map, self.subdivision_num_points
114+
)
115+
fine_grained_features = cat(
116+
[
117+
point_sample(features[in_feature], point_coords, align_corners=False)
118+
for in_feature in self.in_features
119+
]
120+
)
121+
coarse_features = point_sample(
122+
coarse_sem_seg_logits, point_coords, align_corners=False
123+
)
124+
point_logits = self.point_head(fine_grained_features, coarse_features)
125+
126+
# put sem seg point predictions to the right places on the upsampled grid.
127+
N, C, H, W = sem_seg_logits.shape
128+
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
129+
sem_seg_logits = (
130+
sem_seg_logits.reshape(N, C, H * W)
131+
.scatter_(2, point_indices, point_logits)
132+
.view(N, C, H, W)
133+
)
134+
return sem_seg_logits, {}

projects/PointRend/train_net.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
COCOEvaluator,
2222
DatasetEvaluators,
2323
LVISEvaluator,
24+
SemSegEvaluator,
2425
verify_results,
2526
)
2627

@@ -51,6 +52,14 @@ def build_evaluator(cls, cfg, dataset_name, output_folder=None):
5152
return LVISEvaluator(dataset_name, cfg, True, output_folder)
5253
if evaluator_type == "coco":
5354
return COCOEvaluator(dataset_name, cfg, True, output_folder)
55+
if evaluator_type == "sem_seg":
56+
return SemSegEvaluator(
57+
dataset_name,
58+
distributed=True,
59+
num_classes=cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
60+
ignore_label=cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
61+
output_dir=output_folder,
62+
)
5463
if evaluator_type == "cityscapes_instance":
5564
assert (
5665
torch.cuda.device_count() >= comm.get_rank()

0 commit comments

Comments
 (0)