Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster postprocessing #67

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 37 additions & 36 deletions segmentation/model/post_processing/instance_post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch
import torch.nn.functional as F

from collections import Counter
from .semantic_post_processing import get_semantic_segmentation

__all__ = ['find_instance_center', 'get_instance_segmentation', 'get_panoptic_segmentation']
Expand Down Expand Up @@ -39,15 +39,17 @@ def find_instance_center(ctr_hmp, threshold=0.1, nms_kernel=3, top_k=None):
assert len(ctr_hmp.size()) == 2, 'Something is wrong with center heatmap dimension.'

# find non-zero elements
ctr_all = torch.nonzero(ctr_hmp > 0)
ctr_all = torch.nonzero(ctr_hmp > 0, as_tuple=True)
centers = torch.stack(ctr_all, 1)
if top_k is None:
return ctr_all
elif ctr_all.size(0) < top_k:
return ctr_all
return centers
elif len(centers) < top_k:
return centers
else:
# find top k centers.
top_k_scores, _ = torch.topk(torch.flatten(ctr_hmp), top_k)
return torch.nonzero(ctr_hmp > top_k_scores[-1])
scores = ctr_hmp[ctr_all]
_, indices = torch.topk(scores, top_k)
return centers[indices]


def group_pixels(ctr, offsets):
Expand Down Expand Up @@ -138,43 +140,42 @@ def merge_semantic_and_instance(sem_seg, ins_seg, label_divisor, thing_list, stu
"""
# In case thing mask does not align with semantic prediction
pan_seg = torch.zeros_like(sem_seg) + void_label
tl = torch.tensor(thing_list).view(-1, 1, 1).cuda()

thing_seg = ins_seg > 0
semantic_thing_seg = torch.zeros_like(sem_seg)
for thing_class in thing_list:
semantic_thing_seg[sem_seg == thing_class] = 1
semantic_thing_seg = (sem_seg == tl).sum(0, keepdim=True)

# keep track of instance id for each class
class_id_tracker = {}
class_id_counter = Counter()

# paste thing by majority voting
instance_ids = torch.unique(ins_seg)
for ins_id in instance_ids:
if ins_id == 0:
continue
# Make sure only do majority voting within semantic_thing_seg
thing_mask = (ins_seg == ins_id) & (semantic_thing_seg == 1)
if torch.nonzero(thing_mask).size(0) == 0:
continue
class_id, _ = torch.mode(sem_seg[thing_mask].view(-1, ))
if class_id.item() in class_id_tracker:
new_ins_id = class_id_tracker[class_id.item()]
else:
class_id_tracker[class_id.item()] = 1
new_ins_id = 1
class_id_tracker[class_id.item()] += 1
pan_seg[thing_mask] = class_id * label_divisor + new_ins_id
instance_ids = instance_ids[instance_ids != 0]
if len(instance_ids) > 0:
instance_masks = (ins_seg == instance_ids.view(-1, 1, 1)) * semantic_thing_seg
sem_seg_oh = F.one_hot(sem_seg.squeeze()).float()
instance_masks = instance_masks.float()
instance_classes = torch.matmul(instance_masks.view(instance_masks.shape[0], -1),
sem_seg_oh.view(-1, sem_seg_oh.shape[-1])).argmax(1)
instance_ids = []
for c in instance_classes:
c = c.item()
class_id_counter[c] += 1
instance_ids += [class_id_counter[c]]
instance_ids = torch.tensor(instance_ids).cuda()
instance_masks.mul_((instance_classes * label_divisor + instance_ids).view(-1, 1, 1))
instance_masks = instance_masks.sum(0, keepdim=True)
instance_masks = instance_masks.long()
pan_seg[instance_masks != 0] = instance_masks[instance_masks != 0]

# paste stuff to unoccupied area
class_ids = torch.unique(sem_seg)
for class_id in class_ids:
if class_id.item() in thing_list:
# thing class
continue
# calculate stuff area
stuff_mask = (sem_seg == class_id) & (~thing_seg)
area = torch.nonzero(stuff_mask).size(0)
if area >= stuff_area:
pan_seg[stuff_mask] = class_id * label_divisor
class_ids = torch.tensor([x for x in class_ids if x.item() not in thing_list]).cuda()
stuff_masks = (sem_seg == class_ids.view(-1, 1, 1)) & (~thing_seg)
areas = stuff_masks.view(stuff_masks.shape[0], -1).sum(1)
stuff_masks = stuff_masks[areas >= stuff_area]
class_ids = class_ids[areas >= stuff_area]
stuff_seg = (stuff_masks * class_ids.view(-1, 1, 1) * label_divisor).sum(0, keepdim=True)
pan_seg[stuff_seg != 0] = stuff_seg[stuff_seg != 0]

return pan_seg

Expand Down