diff --git a/mmdet/models/utils/misc.py b/mmdet/models/utils/misc.py index 2cf429153ba..743ad013f08 100644 --- a/mmdet/models/utils/misc.py +++ b/mmdet/models/utils/misc.py @@ -333,7 +333,7 @@ def filter_scores_and_topk(scores, score_thr, topk, results=None): scores = scores[valid_mask] valid_idxs = torch.nonzero(valid_mask) - num_topk = min(topk, valid_idxs.size(0)) + num_topk = min(topk if topk != -1 else valid_idxs.size(0), valid_idxs.size(0)) # torch.sort is actually faster than .topk (at least on GPUs) scores, idxs = scores.sort(descending=True) scores = scores[:num_topk]