Skip to content

Commit 9763402

Browse files
Alexander Kirillovfacebook-github-bot
authored andcommitted
Cropping for semantic segmentation models
Summary: Cropping for semantic segmentation Reviewed By: ppwwyyxx Differential Revision: D19453097 fbshipit-source-id: a64d462e068d22b7e3be3703c16df3839474ea71
1 parent e7aa570 commit 9763402

File tree

5 files changed

+129
-3
lines changed

5 files changed

+129
-3
lines changed

detectron2/data/transforms/transform_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def get_crop_size(self, image_size):
342342
ch, cw = crop_size + np.random.rand(2) * (1 - crop_size)
343343
return int(h * ch + 0.5), int(w * cw + 0.5)
344344
elif self.crop_type == "absolute":
345-
return self.crop_size
345+
return (min(self.crop_size[0], h), min(self.crop_size[1], w))
346346
else:
347347
NotImplementedError("Unknown crop type {}".format(self.crop_type))
348348

projects/PointRend/point_rend/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from .config import add_pointrend_config
33
from .coarse_mask_head import CoarseMaskHead
44
from .roi_heads import PointRendROIHeads
5+
from .dataset_mapper import SemSegDatasetMapper

projects/PointRend/point_rend/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ def add_pointrend_config(cfg):
88
"""
99
Add config for PointRend.
1010
"""
11+
# We retry random cropping until no single category in semantic segmentation GT occupies more
12+
# than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
13+
cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
14+
1115
# Names of the input feature maps to be used by a coarse mask head.
1216
cfg.MODEL.ROI_MASK_HEAD.IN_FEATURES = ("p2",)
1317
cfg.MODEL.ROI_MASK_HEAD.FC_DIM = 1024
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2+
import copy
3+
import logging
4+
import numpy as np
5+
import torch
6+
from fvcore.common.file_io import PathManager
7+
from fvcore.transforms.transform import CropTransform
8+
from PIL import Image
9+
10+
from detectron2.data import detection_utils as utils
11+
from detectron2.data import transforms as T
12+
13+
"""
14+
This file contains the mapping that's applied to "dataset dicts" for semantic segmentation models.
15+
Unlike the default DatasetMapper this mapper uses cropping as the last transformation.
16+
"""
17+
18+
__all__ = ["SemSegDatasetMapper"]
19+
20+
21+
class SemSegDatasetMapper:
22+
"""
23+
A callable which takes a dataset dict in Detectron2 Dataset format,
24+
and map it into a format used by semantic segmentation models.
25+
26+
The callable currently does the following:
27+
28+
1. Read the image from "file_name"
29+
2. Applies geometric transforms to the image and annotation
30+
3. Find and applies suitable cropping to the image and annotation
31+
4. Prepare image and annotation to Tensors
32+
"""
33+
34+
def __init__(self, cfg, is_train=True):
35+
if cfg.INPUT.CROP.ENABLED and is_train:
36+
self.crop_gen = T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)
37+
logging.getLogger(__name__).info("CropGen used in training: " + str(self.crop_gen))
38+
else:
39+
self.crop_gen = None
40+
41+
self.tfm_gens = utils.build_transform_gen(cfg, is_train)
42+
43+
# fmt: off
44+
self.img_format = cfg.INPUT.FORMAT
45+
self.single_category_max_area = cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA
46+
self.ignore_value = cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE
47+
# fmt: on
48+
49+
self.is_train = is_train
50+
51+
def __call__(self, dataset_dict):
52+
"""
53+
Args:
54+
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
55+
56+
Returns:
57+
dict: a format that builtin models in detectron2 accept
58+
"""
59+
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
60+
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
61+
utils.check_image_size(dataset_dict, image)
62+
assert "sem_seg_file_name" in dataset_dict
63+
64+
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
65+
if self.is_train:
66+
with PathManager.open(dataset_dict.pop("sem_seg_file_name"), "rb") as f:
67+
sem_seg_gt = Image.open(f)
68+
sem_seg_gt = np.asarray(sem_seg_gt, dtype="uint8")
69+
sem_seg_gt = transforms.apply_segmentation(sem_seg_gt)
70+
if self.crop_gen:
71+
image, sem_seg_gt = crop_transform(
72+
image,
73+
sem_seg_gt,
74+
self.crop_gen,
75+
self.single_category_max_area,
76+
self.ignore_value,
77+
)
78+
dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))
79+
80+
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
81+
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
82+
# Therefore it's important to use torch.Tensor.
83+
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
84+
85+
if not self.is_train:
86+
dataset_dict.pop("sem_seg_file_name", None)
87+
return dataset_dict
88+
89+
return dataset_dict
90+
91+
92+
def crop_transform(image, sem_seg, crop_gen, single_category_max_area, ignore_value):
93+
"""
94+
Find a cropping window such that no single category occupies more than
95+
`single_category_max_area` in `sem_seg`. The function retries random cropping 10 times max.
96+
"""
97+
if single_category_max_area >= 1.0:
98+
crop_tfm = crop_gen.get_transform(image)
99+
sem_seg_temp = crop_tfm.apply_segmentation(sem_seg)
100+
else:
101+
h, w = sem_seg.shape
102+
crop_size = crop_gen.get_crop_size((h, w))
103+
for _ in range(10):
104+
y0 = np.random.randint(h - crop_size[0] + 1)
105+
x0 = np.random.randint(w - crop_size[1] + 1)
106+
sem_seg_temp = sem_seg[y0 : y0 + crop_size[0], x0 : x0 + crop_size[1]]
107+
labels, cnt = np.unique(sem_seg_temp, return_counts=True)
108+
cnt = cnt[labels != ignore_value]
109+
if len(cnt) > 1 and np.max(cnt) / np.sum(cnt) < single_category_max_area:
110+
break
111+
crop_tfm = CropTransform(x0, y0, crop_size[1], crop_size[0])
112+
image = crop_tfm.apply_image(image)
113+
return image, sem_seg_temp

projects/PointRend/train_net.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import detectron2.utils.comm as comm
1414
from detectron2.checkpoint import DetectionCheckpointer
1515
from detectron2.config import get_cfg
16-
from detectron2.data import MetadataCatalog
16+
from detectron2.data import MetadataCatalog, build_detection_train_loader
1717
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch
1818
from detectron2.evaluation import (
1919
CityscapesInstanceEvaluator,
@@ -24,7 +24,7 @@
2424
verify_results,
2525
)
2626

27-
from point_rend import add_pointrend_config
27+
from point_rend import SemSegDatasetMapper, add_pointrend_config
2828

2929

3030
class Trainer(DefaultTrainer):
@@ -71,6 +71,14 @@ def build_evaluator(cls, cfg, dataset_name, output_folder=None):
7171
return evaluator_list[0]
7272
return DatasetEvaluators(evaluator_list)
7373

74+
@classmethod
75+
def build_train_loader(cls, cfg):
76+
if "SemanticSegmentor" in cfg.MODEL.META_ARCHITECTURE:
77+
mapper = SemSegDatasetMapper(cfg, True)
78+
else:
79+
mapper = None
80+
return build_detection_train_loader(cfg, mapper=mapper)
81+
7482

7583
def setup(args):
7684
"""

0 commit comments

Comments
 (0)