|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved |
| 2 | +import copy |
| 3 | +import logging |
| 4 | +import numpy as np |
| 5 | +import torch |
| 6 | +from fvcore.common.file_io import PathManager |
| 7 | +from fvcore.transforms.transform import CropTransform |
| 8 | +from PIL import Image |
| 9 | + |
| 10 | +from detectron2.data import detection_utils as utils |
| 11 | +from detectron2.data import transforms as T |
| 12 | + |
| 13 | +""" |
| 14 | +This file contains the mapping that's applied to "dataset dicts" for semantic segmentation models. |
| 15 | +Unlike the default DatasetMapper this mapper uses cropping as the last transformation. |
| 16 | +""" |
| 17 | + |
| 18 | +__all__ = ["SemSegDatasetMapper"] |
| 19 | + |
| 20 | + |
| 21 | +class SemSegDatasetMapper: |
| 22 | + """ |
| 23 | + A callable which takes a dataset dict in Detectron2 Dataset format, |
| 24 | + and map it into a format used by semantic segmentation models. |
| 25 | +
|
| 26 | + The callable currently does the following: |
| 27 | +
|
| 28 | + 1. Read the image from "file_name" |
| 29 | + 2. Applies geometric transforms to the image and annotation |
| 30 | + 3. Find and applies suitable cropping to the image and annotation |
| 31 | + 4. Prepare image and annotation to Tensors |
| 32 | + """ |
| 33 | + |
| 34 | + def __init__(self, cfg, is_train=True): |
| 35 | + if cfg.INPUT.CROP.ENABLED and is_train: |
| 36 | + self.crop_gen = T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE) |
| 37 | + logging.getLogger(__name__).info("CropGen used in training: " + str(self.crop_gen)) |
| 38 | + else: |
| 39 | + self.crop_gen = None |
| 40 | + |
| 41 | + self.tfm_gens = utils.build_transform_gen(cfg, is_train) |
| 42 | + |
| 43 | + # fmt: off |
| 44 | + self.img_format = cfg.INPUT.FORMAT |
| 45 | + self.single_category_max_area = cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA |
| 46 | + self.ignore_value = cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE |
| 47 | + # fmt: on |
| 48 | + |
| 49 | + self.is_train = is_train |
| 50 | + |
| 51 | + def __call__(self, dataset_dict): |
| 52 | + """ |
| 53 | + Args: |
| 54 | + dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. |
| 55 | +
|
| 56 | + Returns: |
| 57 | + dict: a format that builtin models in detectron2 accept |
| 58 | + """ |
| 59 | + dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below |
| 60 | + image = utils.read_image(dataset_dict["file_name"], format=self.img_format) |
| 61 | + utils.check_image_size(dataset_dict, image) |
| 62 | + assert "sem_seg_file_name" in dataset_dict |
| 63 | + |
| 64 | + image, transforms = T.apply_transform_gens(self.tfm_gens, image) |
| 65 | + if self.is_train: |
| 66 | + with PathManager.open(dataset_dict.pop("sem_seg_file_name"), "rb") as f: |
| 67 | + sem_seg_gt = Image.open(f) |
| 68 | + sem_seg_gt = np.asarray(sem_seg_gt, dtype="uint8") |
| 69 | + sem_seg_gt = transforms.apply_segmentation(sem_seg_gt) |
| 70 | + if self.crop_gen: |
| 71 | + image, sem_seg_gt = crop_transform( |
| 72 | + image, |
| 73 | + sem_seg_gt, |
| 74 | + self.crop_gen, |
| 75 | + self.single_category_max_area, |
| 76 | + self.ignore_value, |
| 77 | + ) |
| 78 | + dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long")) |
| 79 | + |
| 80 | + # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, |
| 81 | + # but not efficient on large generic data structures due to the use of pickle & mp.Queue. |
| 82 | + # Therefore it's important to use torch.Tensor. |
| 83 | + dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) |
| 84 | + |
| 85 | + if not self.is_train: |
| 86 | + dataset_dict.pop("sem_seg_file_name", None) |
| 87 | + return dataset_dict |
| 88 | + |
| 89 | + return dataset_dict |
| 90 | + |
| 91 | + |
| 92 | +def crop_transform(image, sem_seg, crop_gen, single_category_max_area, ignore_value): |
| 93 | + """ |
| 94 | + Find a cropping window such that no single category occupies more than |
| 95 | + `single_category_max_area` in `sem_seg`. The function retries random cropping 10 times max. |
| 96 | + """ |
| 97 | + if single_category_max_area >= 1.0: |
| 98 | + crop_tfm = crop_gen.get_transform(image) |
| 99 | + sem_seg_temp = crop_tfm.apply_segmentation(sem_seg) |
| 100 | + else: |
| 101 | + h, w = sem_seg.shape |
| 102 | + crop_size = crop_gen.get_crop_size((h, w)) |
| 103 | + for _ in range(10): |
| 104 | + y0 = np.random.randint(h - crop_size[0] + 1) |
| 105 | + x0 = np.random.randint(w - crop_size[1] + 1) |
| 106 | + sem_seg_temp = sem_seg[y0 : y0 + crop_size[0], x0 : x0 + crop_size[1]] |
| 107 | + labels, cnt = np.unique(sem_seg_temp, return_counts=True) |
| 108 | + cnt = cnt[labels != ignore_value] |
| 109 | + if len(cnt) > 1 and np.max(cnt) / np.sum(cnt) < single_category_max_area: |
| 110 | + break |
| 111 | + crop_tfm = CropTransform(x0, y0, crop_size[1], crop_size[0]) |
| 112 | + image = crop_tfm.apply_image(image) |
| 113 | + return image, sem_seg_temp |
0 commit comments