Skip to content

Commit

Permalink
Reduce code duplication
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiltsov-max committed Jan 31, 2025
1 parent bcbe859 commit 879d05e
Showing 1 changed file with 3 additions and 26 deletions.
29 changes: 3 additions & 26 deletions cvat/apps/consensus/intersect_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import itertools
from abc import ABCMeta, abstractmethod
from collections.abc import Collection
from typing import Iterable, Optional, Sequence
from typing import Iterable, Sequence

import attrs

Expand All @@ -17,7 +17,7 @@
import datumaro.components.merge.intersect_merge
import numpy as np
from datumaro.components.errors import FailedLabelVotingError
from datumaro.util.annotation_util import find_instances, max_bbox, mean_bbox
from datumaro.util.annotation_util import mean_bbox
from datumaro.util.attrs_util import ensure_cls

from cvat.apps.engine.models import Label
Expand Down Expand Up @@ -121,27 +121,8 @@ def _for_type(t: dm.AnnotationType, **kwargs) -> AnnotationMatcher:
else:
raise AssertionError(f"Annotation type {t} is not supported")

instance_map = {}
for s in sources:
s_instances = find_instances(s)
for inst in s_instances:
inst_bbox = max_bbox(
[
a
for a in inst
if a.type
in {
dm.AnnotationType.polygon,
dm.AnnotationType.mask,
dm.AnnotationType.bbox,
}
]
)
for ann in inst:
instance_map[id(ann)] = [inst, inst_bbox]

self._mergers = {
t: _for_type(t, instance_map=instance_map, categories=self._categories)
t: _for_type(t, categories=self._categories)
for t in self.conf.included_annotation_types
}

Expand Down Expand Up @@ -357,17 +338,13 @@ class MaskMatcher(PolygonMatcher):

@attrs.define(kw_only=True, slots=False)
class PointsMatcher(ShapeMatcher):
sigma: Optional[list] = attrs.field(default=None)

def match_annotations_between_two_items(self, item_a, item_b):
matches, _, _, _, distances = self._comparator.match_points(item_a, item_b)
return matches, distances


@attrs.define(kw_only=True, slots=False)
class SkeletonMatcher(ShapeMatcher):
sigma: float = 0.1

def match_annotations_between_two_items(self, item_a, item_b):
matches, _, _, _, distances = self._comparator.match_skeletons(item_a, item_b)
return matches, distances
Expand Down

0 comments on commit 879d05e

Please sign in to comment.