diff --git a/segmentation/model/post_processing/instance_post_processing.py b/segmentation/model/post_processing/instance_post_processing.py index 915144d..cf02f60 100755 --- a/segmentation/model/post_processing/instance_post_processing.py +++ b/segmentation/model/post_processing/instance_post_processing.py @@ -46,7 +46,7 @@ def find_instance_center(ctr_hmp, threshold=0.1, nms_kernel=3, top_k=None): return ctr_all else: # find top k centers. - top_k_scores, _ = torch.topk(torch.flatten(ctr_all), top_k) + top_k_scores, _ = torch.topk(torch.flatten(ctr_hmp), top_k) return torch.nonzero(ctr_hmp > top_k_scores[-1])