diff --git a/src/supervision/annotators/core.py b/src/supervision/annotators/core.py index 8336e11ab..d3a9322b6 100644 --- a/src/supervision/annotators/core.py +++ b/src/supervision/annotators/core.py @@ -16,12 +16,12 @@ PENDING_TRACK_ID, ColorLookup, Trace, + _resolve_detection_color, _validate_labels, calculate_dynamic_kernel_size, calculate_dynamic_pixel_size, get_labels_text, hex_to_rgba, - resolve_color, resolve_text_background_xyxy, snap_boxes, wrap_text, @@ -256,7 +256,7 @@ def annotate( return scene for detection_idx in range(len(detections)): x1, y1, x2, y2 = detections.xyxy[detection_idx].astype(int) - color = resolve_color( + color = _resolve_detection_color( color=self.color, detections=detections, detection_idx=detection_idx, @@ -347,7 +347,7 @@ def annotate( for detection_idx in range(len(detections)): obb = obb_boxes[detection_idx] - color = resolve_color( + color = _resolve_detection_color( color=self.color, detections=detections, detection_idx=detection_idx, @@ -398,7 +398,7 @@ def _paint_masks_by_area( ) compact_mask = masks if isinstance(masks, CompactMask) else None for detection_idx in np.flip(np.argsort(detections.area)): - color_bgr = resolve_color( + color_bgr = _resolve_detection_color( color=color, detections=detections, detection_idx=detection_idx, @@ -584,7 +584,7 @@ def annotate( for detection_idx in range(len(detections)): mask = detections.mask[detection_idx] - color = resolve_color( + color = _resolve_detection_color( color=self.color, detections=detections, detection_idx=detection_idx, @@ -673,7 +673,7 @@ def annotate( scene_with_boxes = scene.copy() for detection_idx in range(len(detections)): x1, y1, x2, y2 = detections.xyxy[detection_idx].astype(int) - color = resolve_color( + color = _resolve_detection_color( color=self.color, detections=detections, detection_idx=detection_idx, @@ -872,7 +872,7 @@ def annotate( return scene for detection_idx in range(len(detections)): x1, _y1, x2, y2 = detections.xyxy[detection_idx].astype(int) - color = resolve_color( + color = _resolve_detection_color( color=self.color, detections=detections, detection_idx=detection_idx, @@ -968,7 +968,7 @@ def annotate( return scene for detection_idx in range(len(detections)): x1, y1, x2, y2 = detections.xyxy[detection_idx].astype(int) - color = resolve_color( + color = _resolve_detection_color( color=self.color, detections=detections, detection_idx=detection_idx, @@ -1064,7 +1064,7 @@ def annotate( x1, y1, x2, y2 = detections.xyxy[detection_idx].astype(int) center = ((x1 + x2) // 2, (y1 + y2) // 2) distance = sqrt((x1 - center[0]) ** 2 + (y1 - center[1]) ** 2) - color = resolve_color( + color = _resolve_detection_color( color=self.color, detections=detections, detection_idx=detection_idx, @@ -1164,7 +1164,7 @@ def annotate( return scene xy = detections.get_anchors_coordinates(anchor=self.position) for detection_idx in range(len(detections)): - color = resolve_color( + color = _resolve_detection_color( color=self.color, detections=detections, detection_idx=detection_idx, @@ -1176,7 +1176,7 @@ def annotate( cv2.circle(scene, center, self.radius, color.as_bgr(), -1) if self.outline_thickness: - outline_color = resolve_color( + outline_color = _resolve_detection_color( color=self.outline_color, detections=detections, detection_idx=detection_idx, @@ -1405,13 +1405,13 @@ def _draw_labels( ) for idx, label_property in enumerate(label_properties): - background_color = resolve_color( + background_color = _resolve_detection_color( color=self.color, detections=detections, detection_idx=idx, color_lookup=color_lookup, ) - text_color = resolve_color( + text_color = _resolve_detection_color( color=self.text_color, detections=detections, detection_idx=idx, @@ -1718,13 +1718,13 @@ def _draw_labels( ) for idx, label_property in enumerate(label_properties): - background_color = resolve_color( + background_color = _resolve_detection_color( color=self.color, detections=detections, detection_idx=idx, color_lookup=color_lookup, ) - text_color = resolve_color( + text_color = _resolve_detection_color( color=self.text_color, detections=detections, detection_idx=idx, @@ -2079,7 +2079,7 @@ def annotate( if tracker_id_val is None: continue tracker_id = int(tracker_id_val) - color = resolve_color( + color = _resolve_detection_color( color=self.color, detections=filtered_detections, detection_idx=detection_idx, @@ -2406,7 +2406,7 @@ def annotate( return scene xy = detections.get_anchors_coordinates(anchor=self.position) for detection_idx in range(len(detections)): - color = resolve_color( + color = _resolve_detection_color( color=self.color, detections=detections, detection_idx=detection_idx, @@ -2426,7 +2426,7 @@ def annotate( cv2.fillPoly(scene, [vertices], color.as_bgr()) if self.outline_thickness: - outline_color = resolve_color( + outline_color = _resolve_detection_color( color=self.outline_color, detections=detections, detection_idx=detection_idx, @@ -2523,7 +2523,7 @@ def annotate( return scene for detection_idx in range(len(detections)): x1, y1, x2, y2 = detections.xyxy[detection_idx].astype(int) - color = resolve_color( + color = _resolve_detection_color( color=self.color, detections=detections, detection_idx=detection_idx, @@ -2691,7 +2691,7 @@ def annotate( assert detections.confidence is not None # MyPy type hint value = detections.confidence[detection_idx] - color = resolve_color( + color = _resolve_detection_color( color=self.color, detections=detections, detection_idx=detection_idx, @@ -2882,7 +2882,7 @@ def annotate( anchor=anchor, crop_wh=crop_wh, position=self.position ) scene = overlay_image(image=scene, overlay=resized_crop, anchor=(x1, y1)) - color = resolve_color( + color = _resolve_detection_color( color=self.border_color, detections=detections, detection_idx=idx, diff --git a/src/supervision/annotators/utils.py b/src/supervision/annotators/utils.py index caa53757f..dc6046523 100644 --- a/src/supervision/annotators/utils.py +++ b/src/supervision/annotators/utils.py @@ -3,7 +3,7 @@ import re import textwrap from enum import Enum -from typing import Any +from typing import Any, cast import numpy as np import numpy.typing as npt @@ -22,26 +22,102 @@ class ColorLookup(Enum): """ Enumeration class to define strategies for mapping colors to annotations. - This enum supports three different lookup strategies: + This enum supports four different lookup strategies: - `INDEX`: Colors are determined by the index of the detection within the scene. - `CLASS`: Colors are determined by the class label of the detected object. - `TRACK`: Colors are determined by the tracking identifier of the object. + - `KEYPOINT`: Colors are determined by the keypoint index within each skeleton. + Only valid for keypoint annotators. """ INDEX = "index" CLASS = "class" TRACK = "track" + KEYPOINT = "keypoint" @classmethod def list(cls) -> list[str]: return list(map(lambda c: c.value, cls)) -def resolve_color_idx( +def _resolve_color_idx( + instance_idx: int, + color_lookup: ColorLookup, + count: int, + class_id: npt.NDArray[np.generic] | None = None, + tracker_id: npt.NDArray[np.generic] | None = None, + keypoint_idx: int | None = None, +) -> int: + """Resolve a palette index from raw field arrays. + + Low-level helper used by both detection and keypoint annotators. + + Args: + instance_idx: Index of the current detection or skeleton. + color_lookup: Strategy for mapping colors. + count: Total number of detections or skeletons. + class_id: Per-instance class IDs, required for ``CLASS``. + tracker_id: Per-instance tracker IDs, required for ``TRACK``. + keypoint_idx: Index of a keypoint within a skeleton, required for + ``KEYPOINT``. + + Returns: + An integer index suitable for ``ColorPalette.by_idx()``. + + Raises: + ValueError: If ``instance_idx`` is out of bounds for the given ``count``. + ValueError: If ``color_lookup`` is ``CLASS`` and ``class_id`` is ``None``. + ValueError: If ``color_lookup`` is ``TRACK`` and ``tracker_id`` is ``None``. + ValueError: If ``color_lookup`` is ``KEYPOINT`` and ``keypoint_idx`` is + ``None``. + ValueError: If ``color_lookup`` is an unsupported strategy. + """ + if instance_idx >= count: + raise ValueError( + f"Instance index {instance_idx} is out of bounds for length {count}" + ) + + if color_lookup == ColorLookup.INDEX: + return instance_idx + elif color_lookup == ColorLookup.CLASS: + if class_id is None: + raise ValueError( + "Could not resolve color by class because class_id is not available. " + "Try setting color_lookup to sv.ColorLookup.INDEX." + ) + return int(class_id[instance_idx]) + elif color_lookup == ColorLookup.TRACK: + if tracker_id is None: + raise ValueError( + "Could not resolve color by track because tracker_id is not available. " + "Make sure tracker_id is set on the input object." + ) + return int(tracker_id[instance_idx]) + elif color_lookup == ColorLookup.KEYPOINT: + if keypoint_idx is None: + raise ValueError( + "ColorLookup.KEYPOINT is only valid for keypoint annotators." + ) + return keypoint_idx + raise ValueError(f"Unsupported color lookup strategy: {color_lookup}") + + +def _resolve_detection_color_idx( detections: Detections, detection_idx: int, color_lookup: ColorLookup | npt.NDArray[np.int_] = ColorLookup.CLASS, ) -> int: + """Resolve a palette index for a single detection. + + Args: + detections: The detections object. + detection_idx: Index of the current detection. + color_lookup: Strategy for mapping colors. Also accepts a custom + ``np.ndarray`` of integer indices. + + Returns: + An integer index suitable for ``ColorPalette.by_idx()``. + """ if detection_idx >= len(detections): raise ValueError( f"Detection index {detection_idx} " @@ -55,26 +131,27 @@ def resolve_color_idx( f"does not match length of detections {len(detections)}" ) return int(color_lookup[detection_idx]) - elif color_lookup == ColorLookup.INDEX: - return detection_idx - elif color_lookup == ColorLookup.CLASS: - if detections.class_id is None: - raise ValueError( - "Could not resolve color by class because " - "Detections do not have class_id. If using an annotator, " - "try setting color_lookup to sv.ColorLookup.INDEX or " - "sv.ColorLookup.TRACK." - ) - return int(detections.class_id[detection_idx]) - elif color_lookup == ColorLookup.TRACK: - if detections.tracker_id is None: - raise ValueError( - "Could not resolve color by track because " - "Detections do not have tracker_id. Did you call " - "tracker.update_with_detections(...) before annotating?" - ) - return int(detections.tracker_id[detection_idx]) - raise ValueError(f"Unsupported color lookup strategy: {color_lookup}") + + return _resolve_color_idx( + instance_idx=detection_idx, + color_lookup=color_lookup, + count=len(detections), + class_id=detections.class_id, + tracker_id=detections.tracker_id, + ) + + +@deprecated( # type: ignore[untyped-decorator] + target=_resolve_detection_color_idx, + deprecated_in="0.30.0", + remove_in="0.33.0", +) +def resolve_color_idx( # type: ignore[return] + detections: Detections, + detection_idx: int, + color_lookup: ColorLookup | npt.NDArray[np.int_] = ColorLookup.CLASS, +) -> int: + void(detections, detection_idx, color_lookup) def resolve_text_background_xyxy( @@ -130,19 +207,42 @@ def resolve_text_background_xyxy( ) -def get_color_by_index(color: Color | ColorPalette, idx: int) -> Color: +def _get_color_by_index(color: Color | ColorPalette, idx: int) -> Color: if isinstance(color, ColorPalette): return color.by_idx(idx) return color -def resolve_color( +@deprecated( # type: ignore[untyped-decorator] + target=_get_color_by_index, + deprecated_in="0.30.0", + remove_in="0.33.0", +) +def get_color_by_index( # type: ignore[return] + color: Color | ColorPalette, idx: int +) -> Color: + void(color, idx) + + +def _resolve_detection_color( color: Color | ColorPalette, detections: Detections, detection_idx: int, color_lookup: ColorLookup | npt.NDArray[np.int_] = ColorLookup.CLASS, ) -> Color: - idx = resolve_color_idx( + """Resolve the color for a single detection. + + Args: + color: A single color or a palette to pick from. + detections: The detections object. + detection_idx: Index of the current detection. + color_lookup: Strategy for mapping colors. Also accepts a custom + ``np.ndarray`` of integer indices. + + Returns: + The resolved ``Color``. + """ + idx = _resolve_detection_color_idx( detections=detections, detection_idx=detection_idx, color_lookup=color_lookup, @@ -153,7 +253,21 @@ def resolve_color( and idx == PENDING_TRACK_ID ): return PENDING_TRACK_COLOR - return get_color_by_index(color=color, idx=idx) + return _get_color_by_index(color=color, idx=idx) + + +@deprecated( # type: ignore[untyped-decorator] + target=_resolve_detection_color, + deprecated_in="0.30.0", + remove_in="0.33.0", +) +def resolve_color( # type: ignore[return] + color: Color | ColorPalette, + detections: Detections, + detection_idx: int, + color_lookup: ColorLookup | npt.NDArray[np.int_] = ColorLookup.CLASS, +) -> Color: + void(color, detections, detection_idx, color_lookup) def wrap_text(text: Any, max_line_length: int | None = None) -> list[str]: @@ -326,7 +440,7 @@ def snap_boxes( bottom_shift = height - result[bottom_overflow, 3] result[bottom_overflow, 1:4:2] += bottom_shift[:, np.newaxis] - return result.astype(np.float32) # type: ignore + return cast(np.ndarray[Any, np.dtype[np.float32]], result.astype(np.float32)) class Trace: diff --git a/src/supervision/key_points/annotators.py b/src/supervision/key_points/annotators.py index 63a2b7d1a..9d05899b2 100644 --- a/src/supervision/key_points/annotators.py +++ b/src/supervision/key_points/annotators.py @@ -8,19 +8,49 @@ import numpy as np import numpy.typing as npt +from supervision.annotators.utils import ( + ColorLookup, + _get_color_by_index, + _resolve_color_idx, +) from supervision.detection.utils.boxes import pad_boxes, spread_out_boxes from supervision.draw.base import ImageType -from supervision.draw.color import Color +from supervision.draw.color import Color, ColorPalette from supervision.draw.utils import draw_rounded_rectangle from supervision.geometry.core import Rect from supervision.key_points.core import KeyPoints from supervision.key_points.skeletons import SKELETONS_BY_VERTEX_COUNT from supervision.utils.conversion import ensure_cv2_image_for_class_method +from supervision.utils.internal import warn_deprecated from supervision.utils.logger import _get_logger logger = _get_logger(__name__) +def _resolve_keypoint_color( + color: Color | ColorPalette, + color_lookup: ColorLookup, + key_points: KeyPoints, + instance_idx: int, + keypoint_idx: int = 0, +) -> Color: + """Resolve a single color for a keypoint annotation. + + Fast path: when *color* is a plain ``Color``, returns it directly. + """ + if isinstance(color, Color): + return color + + idx = _resolve_color_idx( + instance_idx=instance_idx, + color_lookup=color_lookup, + count=len(key_points), + class_id=key_points.class_id, + keypoint_idx=keypoint_idx, + ) + return _get_color_by_index(color=color, idx=idx) + + class BaseKeyPointAnnotator(ABC): @abstractmethod def annotate(self, scene: ImageType, key_points: KeyPoints) -> ImageType: @@ -36,16 +66,24 @@ class VertexAnnotator(BaseKeyPointAnnotator): def __init__( self, - color: Color = Color.ROBOFLOW, + color: Color | ColorPalette = Color.ROBOFLOW, radius: int = 4, + color_lookup: ColorLookup = ColorLookup.CLASS, ) -> None: """ Args: - color: The color to use for annotating key points. - radius: The radius of the circles used to represent the key points. + color: The color or color palette to use for + annotating key points. + radius: Radius of the drawn key point circles. + color_lookup: Strategy for mapping colors to annotations. + Options are `INDEX` (per-skeleton index), `CLASS` + (per class_id), and `KEYPOINT` (per keypoint index within + each skeleton). Note: ``TRACK`` is not supported for + keypoint annotators. """ self.color = color self.radius = radius + self.color_lookup = color_lookup @ensure_cv2_image_for_class_method def annotate(self, scene: ImageType, key_points: KeyPoints) -> ImageType: @@ -97,11 +135,18 @@ def annotate(self, scene: ImageType, key_points: KeyPoints) -> ImageType: and not key_points.visible[detection_index, point_index] ): continue + color = _resolve_keypoint_color( + color=self.color, + color_lookup=self.color_lookup, + key_points=key_points, + instance_idx=detection_index, + keypoint_idx=point_index, + ) cv2.circle( img=scene, center=(int(x), int(y)), radius=self.radius, - color=self.color.as_bgr(), + color=color.as_bgr(), thickness=-1, ) @@ -116,26 +161,34 @@ class EdgeAnnotator(BaseKeyPointAnnotator): def __init__( self, - color: Color = Color.ROBOFLOW, + color: Color | ColorPalette = Color.ROBOFLOW, thickness: int = 2, edges: ( Sequence[tuple[int, int]] | dict[int, Sequence[tuple[int, int]]] | None ) = None, + color_lookup: ColorLookup = ColorLookup.CLASS, ) -> None: """ Args: - color: The color to use for the edges. - thickness: The thickness of the edges. + color: The color or color palette to use for + annotating edges. + thickness: Thickness of the edge lines. edges: The edges to draw. If set to ``None``, will attempt to auto-detect the skeleton by vertex count. A ``Sequence[tuple[int, int]]`` applies a single skeleton to every instance. A ``dict[int, Sequence[tuple[int, int]]]`` maps ``class_id`` to skeleton edges, enabling correct rendering for datasets with multiple skeleton types. + color_lookup: Strategy for mapping colors to annotations. + Options are `INDEX` (per-skeleton index), `CLASS` + (per class_id), and `KEYPOINT` (per keypoint index — + edge inherits the color of its first endpoint). Note: + ``TRACK`` is not supported for keypoint annotators. """ self.color = color self.thickness = thickness self.edges = edges + self.color_lookup = color_lookup @ensure_cv2_image_for_class_method def annotate(self, scene: ImageType, key_points: KeyPoints) -> ImageType: @@ -247,11 +300,18 @@ def annotate(self, scene: ImageType, key_points: KeyPoints) -> ImageType: ): continue + color = _resolve_keypoint_color( + color=self.color, + color_lookup=self.color_lookup, + key_points=key_points, + instance_idx=detection_index, + keypoint_idx=idx_a, + ) cv2.line( img=scene, pt1=(int(xy_a[0]), int(xy_a[1])), pt2=(int(xy_b[0]), int(xy_b[1])), - color=self.color.as_bgr(), + color=color.as_bgr(), thickness=self.thickness, ) @@ -708,34 +768,61 @@ class VertexLabelAnnotator: def __init__( self, - color: Color | list[Color] = Color.ROBOFLOW, - text_color: Color | list[Color] = Color.WHITE, + color: Color | list[Color] | ColorPalette = Color.ROBOFLOW, + text_color: Color | list[Color] | ColorPalette = Color.WHITE, text_scale: float = 0.5, text_thickness: int = 1, text_padding: int = 10, border_radius: int = 0, smart_position: bool = False, - ): + color_lookup: ColorLookup = ColorLookup.CLASS, + ) -> None: """ Args: - color: The color to use for each keypoint label. If a list is - provided, the colors will be used in order for each keypoint. - text_color: The color to use for the labels. If a list is - provided, the colors will be used in order for each keypoint. - text_scale: The scale of the text. - text_thickness: The thickness of the text. - text_padding: The padding around the text. - border_radius: The radius of the rounded corners of the boxes. - Set to a high value to produce circles. - smart_position: Spread out the labels to avoid overlap. + color: The color to use for each keypoint label. A single + ``Color`` applies uniformly. A ``ColorPalette`` selects + colors via the ``color_lookup`` strategy. Passing a + ``list[Color]`` is **deprecated since 0.30.0** (removed in + 0.33.0) — use ``ColorPalette`` with + ``ColorLookup.KEYPOINT`` instead. + text_color: The color to use for the label text. Accepts the + same types as ``color``. Passing a ``list[Color]`` is + **deprecated since 0.30.0** (removed in 0.33.0). + text_scale: Font scale for the text. + text_thickness: Thickness of the text characters. + text_padding: Padding around the text within its + background box. + border_radius: The radius to apply round edges. If the + selected value is higher than the lower dimension, + width or height, is clipped. + smart_position: Spread out the labels to avoid overlapping. + color_lookup: Strategy for mapping colors to annotations. + Options are `INDEX` (per-skeleton index), `CLASS` + (per class_id), and `KEYPOINT` (per keypoint index within + each skeleton). Note: ``TRACK`` is not supported for + keypoint annotators. """ + if isinstance(color, list): + warn_deprecated( + "Passing a list[Color] for 'color' in VertexLabelAnnotator " + "is deprecated since 0.30.0 and will be removed in 0.33.0. " + "Use ColorPalette with ColorLookup.KEYPOINT instead." + ) + if isinstance(text_color, list): + warn_deprecated( + "Passing a list[Color] for 'text_color' in " + "VertexLabelAnnotator is deprecated since 0.30.0 and will be " + "removed in 0.33.0. Use ColorPalette with " + "ColorLookup.KEYPOINT instead." + ) self.border_radius: int = border_radius - self.color: Color | list[Color] = color - self.text_color: Color | list[Color] = text_color + self.color: Color | list[Color] | ColorPalette = color + self.text_color: Color | list[Color] | ColorPalette = text_color self.text_scale: float = text_scale self.text_thickness: int = text_thickness self.text_padding: int = text_padding self.smart_position = smart_position + self.color_lookup = color_lookup def annotate( self, @@ -843,10 +930,6 @@ def annotate( int(key_points.class_id[i]) if key_points.class_id is not None else None ) instance_labels = self._resolve_labels(labels, points_count, class_id) - instance_colors = self._resolve_color_list(self.color, points_count) - instance_text_colors = self._resolve_color_list( - self.text_color, points_count - ) for j in range(points_count): if key_points.visible is not None: @@ -858,8 +941,16 @@ def annotate( anchor = (int(xy[j][0]), int(xy[j][1])) all_anchors.append(anchor) all_labels.append(instance_labels[j]) - all_colors.append(instance_colors[j]) - all_text_colors.append(instance_text_colors[j]) + all_colors.append( + self._resolve_label_color_legacy( + self.color, key_points, i, j, points_count + ) + ) + all_text_colors.append( + self._resolve_label_color_legacy( + self.text_color, key_points, i, j, points_count + ) + ) if not all_anchors: return scene @@ -956,17 +1047,48 @@ def _resolve_labels( ) return resolved - @staticmethod - def _resolve_color_list( - colors: Color | list[Color], + def _resolve_label_color_legacy( + self, + color_input: Color | list[Color] | ColorPalette, + key_points: KeyPoints, + instance_idx: int, + keypoint_idx: int, points_count: int, - ) -> list[Color]: - """Return a per-keypoint color list for a single instance.""" - if isinstance(colors, list): - if len(colors) != points_count: + ) -> Color: + """Backward-compatibility shim for resolving a label color. + + Handles the deprecated ``list[Color]`` input type by indexing + directly into the list. For ``Color`` and ``ColorPalette`` inputs, + delegates to ``_resolve_keypoint_color``. + + .. deprecated:: 0.30.0 + When ``list[Color]`` support is removed in 0.33.0, delete this + method and call ``_resolve_keypoint_color`` directly. The call + sites in ``annotate()`` should change from:: + + self._resolve_label_color_legacy(self.color, ...) + + to:: + + _resolve_keypoint_color( + color=self.color, + color_lookup=self.color_lookup, + key_points=key_points, + instance_idx=i, + keypoint_idx=keypoint_idx, + ) + """ + if isinstance(color_input, list): + if len(color_input) != points_count: raise ValueError( - f"Number of colors ({len(colors)}) must match " + f"Number of colors ({len(color_input)}) must match " f"number of key points ({points_count})." ) - return colors - return [colors] * points_count + return color_input[keypoint_idx] + return _resolve_keypoint_color( + color=color_input, + color_lookup=self.color_lookup, + key_points=key_points, + instance_idx=instance_idx, + keypoint_idx=keypoint_idx, + ) diff --git a/tests/annotators/test_utils.py b/tests/annotators/test_utils.py index 51642cfe5..b15151d0a 100644 --- a/tests/annotators/test_utils.py +++ b/tests/annotators/test_utils.py @@ -7,6 +7,7 @@ from supervision.annotators.utils import ( ColorLookup, + _resolve_color_idx, hex_to_rgba, is_valid_hex, resolve_color_idx, @@ -117,6 +118,167 @@ def test_resolve_color_idx( assert result == expected_result +_CLASS_IDS = np.array([5, 3, 7]) +_TRACKER_IDS = np.array([2, 6, 4]) + + +@pytest.mark.parametrize( + ( + "instance_idx", + "color_lookup", + "count", + "class_id", + "tracker_id", + "keypoint_idx", + "expected_result", + "exception", + ), + [ + pytest.param( + 0, + ColorLookup.INDEX, + 3, + _CLASS_IDS, + _TRACKER_IDS, + None, + 0, + DoesNotRaise(), + id="index-first", + ), + pytest.param( + 1, + ColorLookup.INDEX, + 3, + _CLASS_IDS, + _TRACKER_IDS, + None, + 1, + DoesNotRaise(), + id="index-second", + ), + pytest.param( + 0, + ColorLookup.CLASS, + 3, + _CLASS_IDS, + _TRACKER_IDS, + None, + 5, + DoesNotRaise(), + id="class-first", + ), + pytest.param( + 1, + ColorLookup.CLASS, + 3, + _CLASS_IDS, + _TRACKER_IDS, + None, + 3, + DoesNotRaise(), + id="class-second", + ), + pytest.param( + 0, + ColorLookup.TRACK, + 3, + _CLASS_IDS, + _TRACKER_IDS, + None, + 2, + DoesNotRaise(), + id="track-first", + ), + pytest.param( + 1, + ColorLookup.TRACK, + 3, + _CLASS_IDS, + _TRACKER_IDS, + None, + 6, + DoesNotRaise(), + id="track-second", + ), + pytest.param( + 0, + ColorLookup.KEYPOINT, + 3, + _CLASS_IDS, + _TRACKER_IDS, + 4, + 4, + DoesNotRaise(), + id="keypoint", + ), + pytest.param( + 3, + ColorLookup.INDEX, + 3, + _CLASS_IDS, + _TRACKER_IDS, + None, + None, + pytest.raises(ValueError, match="out of bounds"), + id="out-of-bounds", + ), + pytest.param( + 0, + ColorLookup.CLASS, + 3, + None, + _TRACKER_IDS, + None, + None, + pytest.raises(ValueError, match="class_id"), + id="class-no-class-id", + ), + pytest.param( + 0, + ColorLookup.TRACK, + 3, + _CLASS_IDS, + None, + None, + None, + pytest.raises(ValueError, match="tracker_id"), + id="track-no-tracker-id", + ), + pytest.param( + 0, + ColorLookup.KEYPOINT, + 3, + _CLASS_IDS, + _TRACKER_IDS, + None, + None, + pytest.raises(ValueError, match="KEYPOINT"), + id="keypoint-no-keypoint-idx", + ), + ], +) +def test_resolve_color_idx_from_fields( + instance_idx: int, + color_lookup: ColorLookup, + count: int, + class_id: np.ndarray | None, + tracker_id: np.ndarray | None, + keypoint_idx: int | None, + expected_result: int | None, + exception: Exception, +) -> None: + with exception: + result = _resolve_color_idx( + instance_idx=instance_idx, + color_lookup=color_lookup, + count=count, + class_id=class_id, + tracker_id=tracker_id, + keypoint_idx=keypoint_idx, + ) + assert result == expected_result + + @pytest.mark.parametrize( ("text", "max_line_length", "expected_result", "exception"), [ diff --git a/tests/key_points/test_annotators.py b/tests/key_points/test_annotators.py index 8121c027c..38271953f 100644 --- a/tests/key_points/test_annotators.py +++ b/tests/key_points/test_annotators.py @@ -468,37 +468,163 @@ def test_resolve_labels_raises(self, labels, points_count, class_id, match): with pytest.raises(ValueError, match=match): sv.VertexLabelAnnotator._resolve_labels(labels, points_count, class_id) + +class TestVertexAnnotatorColorLookup: + """Verify VertexAnnotator respects each ColorLookup strategy with a ColorPalette.""" + + @pytest.fixture + def key_points_with_class(self) -> sv.KeyPoints: + """Two-instance, three-keypoint set with class_id set.""" + return sv.KeyPoints( + xy=np.array( + [ + [[20.0, 20.0], [40.0, 40.0], [60.0, 60.0]], + [[25.0, 25.0], [45.0, 45.0], [65.0, 65.0]], + ], + dtype=np.float32, + ), + class_id=np.array([0, 1], dtype=int), + ) + @pytest.mark.parametrize( - ("colors", "points_count", "expected"), + "color_lookup", [ - pytest.param( - sv.Color.RED, - 3, - [sv.Color.RED, sv.Color.RED, sv.Color.RED], - id="single-color-expands", - ), - pytest.param( - [sv.Color.RED, sv.Color.GREEN, sv.Color.BLUE], - 3, - [sv.Color.RED, sv.Color.GREEN, sv.Color.BLUE], - id="list-returns-as-is", - ), + pytest.param(sv.ColorLookup.INDEX, id="index"), + pytest.param(sv.ColorLookup.CLASS, id="class"), + pytest.param(sv.ColorLookup.KEYPOINT, id="keypoint"), ], ) - def test_resolve_color_list_returns_expected(self, colors, points_count, expected): - result = sv.VertexLabelAnnotator._resolve_color_list(colors, points_count) - assert result == expected + def test_annotate_with_color_palette_returns_ndarray( + self, scene, key_points_with_class, color_lookup + ): + """ColorPalette + each ColorLookup produces a modified ndarray output.""" + annotator = sv.VertexAnnotator( + color=sv.ColorPalette.DEFAULT, + radius=5, + color_lookup=color_lookup, + ) + result = annotator.annotate( + scene=scene.copy(), key_points=key_points_with_class + ) + + assert isinstance(result, np.ndarray) + assert result.shape == scene.shape + assert not np.array_equal(result, scene) + + def test_annotate_class_lookup_raises_when_class_id_none(self, scene): + """CLASS strategy raises ValueError when key_points.class_id is None.""" + key_points = sv.KeyPoints( + xy=np.array([[[30.0, 30.0], [50.0, 50.0]]], dtype=np.float32), + ) + annotator = sv.VertexAnnotator( + color=sv.ColorPalette.DEFAULT, + color_lookup=sv.ColorLookup.CLASS, + ) + + with pytest.raises(ValueError, match="class_id"): + annotator.annotate(scene=scene.copy(), key_points=key_points) + + +class TestEdgeAnnotatorColorLookup: + """Verify EdgeAnnotator respects each ColorLookup strategy with a ColorPalette.""" + + @pytest.fixture + def key_points_triangle(self) -> sv.KeyPoints: + """Single-instance, three-vertex triangle useful with explicit edges.""" + return sv.KeyPoints( + xy=np.array( + [[[10.0, 10.0], [80.0, 10.0], [45.0, 80.0]]], + dtype=np.float32, + ), + class_id=np.array([0], dtype=int), + ) @pytest.mark.parametrize( - ("colors", "points_count"), + "color_lookup", [ - pytest.param( - [sv.Color.RED, sv.Color.GREEN], - 3, - id="list-wrong-length", + pytest.param(sv.ColorLookup.INDEX, id="index"), + pytest.param(sv.ColorLookup.CLASS, id="class"), + pytest.param(sv.ColorLookup.KEYPOINT, id="keypoint"), + ], + ) + def test_annotate_with_color_palette_returns_ndarray( + self, scene, key_points_triangle, color_lookup + ): + """ColorPalette + each ColorLookup produces a modified ndarray output.""" + annotator = sv.EdgeAnnotator( + color=sv.ColorPalette.DEFAULT, + thickness=2, + edges=[(1, 2), (2, 3), (1, 3)], + color_lookup=color_lookup, + ) + result = annotator.annotate(scene=scene.copy(), key_points=key_points_triangle) + + assert isinstance(result, np.ndarray) + assert result.shape == scene.shape + assert not np.array_equal(result, scene) + + def test_annotate_class_lookup_raises_when_class_id_none(self, scene): + """CLASS strategy raises ValueError when key_points.class_id is None.""" + key_points = sv.KeyPoints( + xy=np.array([[[10.0, 10.0], [80.0, 10.0]]], dtype=np.float32), + ) + annotator = sv.EdgeAnnotator( + color=sv.ColorPalette.DEFAULT, + edges=[(1, 2)], + color_lookup=sv.ColorLookup.CLASS, + ) + + with pytest.raises(ValueError, match="class_id"): + annotator.annotate(scene=scene.copy(), key_points=key_points) + + +class TestVertexLabelAnnotatorColorLookup: + """Verify VertexLabelAnnotator respects each ColorLookup strategy.""" + + @pytest.fixture + def key_points_with_class(self) -> sv.KeyPoints: + """Two-instance, two-keypoint set with class_id set.""" + return sv.KeyPoints( + xy=np.array( + [[[20.0, 20.0], [60.0, 60.0]], [[25.0, 25.0], [65.0, 65.0]]], + dtype=np.float32, ), + class_id=np.array([0, 1], dtype=int), + ) + + @pytest.mark.parametrize( + "color_lookup", + [ + pytest.param(sv.ColorLookup.INDEX, id="index"), + pytest.param(sv.ColorLookup.CLASS, id="class"), + pytest.param(sv.ColorLookup.KEYPOINT, id="keypoint"), ], ) - def test_resolve_color_list_wrong_length_raises(self, colors, points_count): - with pytest.raises(ValueError, match="Number of colors"): - sv.VertexLabelAnnotator._resolve_color_list(colors, points_count) + def test_annotate_with_color_palette_returns_ndarray( + self, scene, key_points_with_class, color_lookup + ): + """ColorPalette + each ColorLookup produces a modified ndarray output.""" + annotator = sv.VertexLabelAnnotator( + color=sv.ColorPalette.DEFAULT, + color_lookup=color_lookup, + ) + result = annotator.annotate( + scene=scene.copy(), key_points=key_points_with_class + ) + + assert isinstance(result, np.ndarray) + assert result.shape == scene.shape + assert not np.array_equal(result, scene) + + def test_annotate_class_lookup_raises_when_class_id_none(self, scene): + """CLASS strategy raises ValueError when key_points.class_id is None.""" + key_points = sv.KeyPoints( + xy=np.array([[[30.0, 30.0], [50.0, 50.0]]], dtype=np.float32), + ) + annotator = sv.VertexLabelAnnotator( + color=sv.ColorPalette.DEFAULT, + color_lookup=sv.ColorLookup.CLASS, + ) + + with pytest.raises(ValueError, match="class_id"): + annotator.annotate(scene=scene.copy(), key_points=key_points)