Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions src/supervision/annotators/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
PENDING_TRACK_ID,
ColorLookup,
Trace,
_resolve_detection_color,
_validate_labels,
calculate_dynamic_kernel_size,
calculate_dynamic_pixel_size,
get_labels_text,
hex_to_rgba,
resolve_color,
resolve_text_background_xyxy,
snap_boxes,
wrap_text,
Expand Down Expand Up @@ -256,7 +256,7 @@ def annotate(
return scene
for detection_idx in range(len(detections)):
x1, y1, x2, y2 = detections.xyxy[detection_idx].astype(int)
color = resolve_color(
color = _resolve_detection_color(
color=self.color,
detections=detections,
detection_idx=detection_idx,
Expand Down Expand Up @@ -347,7 +347,7 @@ def annotate(

for detection_idx in range(len(detections)):
obb = obb_boxes[detection_idx]
color = resolve_color(
color = _resolve_detection_color(
color=self.color,
detections=detections,
detection_idx=detection_idx,
Expand Down Expand Up @@ -398,7 +398,7 @@ def _paint_masks_by_area(
)
compact_mask = masks if isinstance(masks, CompactMask) else None
for detection_idx in np.flip(np.argsort(detections.area)):
color_bgr = resolve_color(
color_bgr = _resolve_detection_color(
color=color,
detections=detections,
detection_idx=detection_idx,
Expand Down Expand Up @@ -584,7 +584,7 @@ def annotate(

for detection_idx in range(len(detections)):
mask = detections.mask[detection_idx]
color = resolve_color(
color = _resolve_detection_color(
color=self.color,
detections=detections,
detection_idx=detection_idx,
Expand Down Expand Up @@ -673,7 +673,7 @@ def annotate(
scene_with_boxes = scene.copy()
for detection_idx in range(len(detections)):
x1, y1, x2, y2 = detections.xyxy[detection_idx].astype(int)
color = resolve_color(
color = _resolve_detection_color(
color=self.color,
detections=detections,
detection_idx=detection_idx,
Expand Down Expand Up @@ -872,7 +872,7 @@ def annotate(
return scene
for detection_idx in range(len(detections)):
x1, _y1, x2, y2 = detections.xyxy[detection_idx].astype(int)
color = resolve_color(
color = _resolve_detection_color(
color=self.color,
detections=detections,
detection_idx=detection_idx,
Expand Down Expand Up @@ -968,7 +968,7 @@ def annotate(
return scene
for detection_idx in range(len(detections)):
x1, y1, x2, y2 = detections.xyxy[detection_idx].astype(int)
color = resolve_color(
color = _resolve_detection_color(
color=self.color,
detections=detections,
detection_idx=detection_idx,
Expand Down Expand Up @@ -1064,7 +1064,7 @@ def annotate(
x1, y1, x2, y2 = detections.xyxy[detection_idx].astype(int)
center = ((x1 + x2) // 2, (y1 + y2) // 2)
distance = sqrt((x1 - center[0]) ** 2 + (y1 - center[1]) ** 2)
color = resolve_color(
color = _resolve_detection_color(
color=self.color,
detections=detections,
detection_idx=detection_idx,
Expand Down Expand Up @@ -1164,7 +1164,7 @@ def annotate(
return scene
xy = detections.get_anchors_coordinates(anchor=self.position)
for detection_idx in range(len(detections)):
color = resolve_color(
color = _resolve_detection_color(
color=self.color,
detections=detections,
detection_idx=detection_idx,
Expand All @@ -1176,7 +1176,7 @@ def annotate(

cv2.circle(scene, center, self.radius, color.as_bgr(), -1)
if self.outline_thickness:
outline_color = resolve_color(
outline_color = _resolve_detection_color(
color=self.outline_color,
detections=detections,
detection_idx=detection_idx,
Expand Down Expand Up @@ -1405,13 +1405,13 @@ def _draw_labels(
)

for idx, label_property in enumerate(label_properties):
background_color = resolve_color(
background_color = _resolve_detection_color(
color=self.color,
detections=detections,
detection_idx=idx,
color_lookup=color_lookup,
)
text_color = resolve_color(
text_color = _resolve_detection_color(
color=self.text_color,
detections=detections,
detection_idx=idx,
Expand Down Expand Up @@ -1718,13 +1718,13 @@ def _draw_labels(
)

for idx, label_property in enumerate(label_properties):
background_color = resolve_color(
background_color = _resolve_detection_color(
color=self.color,
detections=detections,
detection_idx=idx,
color_lookup=color_lookup,
)
text_color = resolve_color(
text_color = _resolve_detection_color(
color=self.text_color,
detections=detections,
detection_idx=idx,
Expand Down Expand Up @@ -2079,7 +2079,7 @@ def annotate(
if tracker_id_val is None:
continue
tracker_id = int(tracker_id_val)
color = resolve_color(
color = _resolve_detection_color(
color=self.color,
detections=filtered_detections,
detection_idx=detection_idx,
Expand Down Expand Up @@ -2406,7 +2406,7 @@ def annotate(
return scene
xy = detections.get_anchors_coordinates(anchor=self.position)
for detection_idx in range(len(detections)):
color = resolve_color(
color = _resolve_detection_color(
color=self.color,
detections=detections,
detection_idx=detection_idx,
Expand All @@ -2426,7 +2426,7 @@ def annotate(

cv2.fillPoly(scene, [vertices], color.as_bgr())
if self.outline_thickness:
outline_color = resolve_color(
outline_color = _resolve_detection_color(
color=self.outline_color,
detections=detections,
detection_idx=detection_idx,
Expand Down Expand Up @@ -2523,7 +2523,7 @@ def annotate(
return scene
for detection_idx in range(len(detections)):
x1, y1, x2, y2 = detections.xyxy[detection_idx].astype(int)
color = resolve_color(
color = _resolve_detection_color(
color=self.color,
detections=detections,
detection_idx=detection_idx,
Expand Down Expand Up @@ -2691,7 +2691,7 @@ def annotate(
assert detections.confidence is not None # MyPy type hint
value = detections.confidence[detection_idx]

color = resolve_color(
color = _resolve_detection_color(
color=self.color,
detections=detections,
detection_idx=detection_idx,
Expand Down Expand Up @@ -2882,7 +2882,7 @@ def annotate(
anchor=anchor, crop_wh=crop_wh, position=self.position
)
scene = overlay_image(image=scene, overlay=resized_crop, anchor=(x1, y1))
color = resolve_color(
color = _resolve_detection_color(
color=self.border_color,
detections=detections,
detection_idx=idx,
Expand Down
153 changes: 127 additions & 26 deletions src/supervision/annotators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,96 @@ class ColorLookup(Enum):
"""
Enumeration class to define strategies for mapping colors to annotations.

This enum supports three different lookup strategies:
This enum supports four different lookup strategies:
- `INDEX`: Colors are determined by the index of the detection within the scene.
- `CLASS`: Colors are determined by the class label of the detected object.
- `TRACK`: Colors are determined by the tracking identifier of the object.
- `KEYPOINT`: Colors are determined by the keypoint index within each skeleton.
Only valid for keypoint annotators.
"""

INDEX = "index"
CLASS = "class"
TRACK = "track"
KEYPOINT = "keypoint"

@classmethod
def list(cls) -> list[str]:
return list(map(lambda c: c.value, cls))


def resolve_color_idx(
def _resolve_color_idx(
instance_idx: int,
color_lookup: ColorLookup,
count: int,
class_id: npt.NDArray[np.int_] | None = None,
tracker_id: npt.NDArray[np.int_] | None = None,
keypoint_idx: int | None = None,
) -> int:
"""Resolve a palette index from raw field arrays.

Low-level helper used by both detection and keypoint annotators.

Args:
instance_idx: Index of the current detection or skeleton.
color_lookup: Strategy for mapping colors.
count: Total number of detections or skeletons.
class_id: Per-instance class IDs, required for ``CLASS``.
tracker_id: Per-instance tracker IDs, required for ``TRACK``.
keypoint_idx: Index of a keypoint within a skeleton, required for
``KEYPOINT``.

Returns:
An integer index suitable for ``ColorPalette.by_idx()``.
"""
if instance_idx >= count:
raise ValueError(
f"Instance index {instance_idx} is out of bounds for length {count}"
)

if color_lookup == ColorLookup.INDEX:
return instance_idx
elif color_lookup == ColorLookup.CLASS:
if class_id is None:
raise ValueError(
"Could not resolve color by class because "
"class_id is not available. Try setting "
"color_lookup to sv.ColorLookup.INDEX."
)
return int(class_id[instance_idx])
elif color_lookup == ColorLookup.TRACK:
if tracker_id is None:
raise ValueError(
"Could not resolve color by track because "
"tracker_id is not available. Make sure that the "
"Detections object contains tracker_id data."
)
Comment on lines +83 to +88
return int(tracker_id[instance_idx])
elif color_lookup == ColorLookup.KEYPOINT:
if keypoint_idx is None:
raise ValueError(
"ColorLookup.KEYPOINT is only valid for keypoint annotators."
)
return keypoint_idx
raise ValueError(f"Unsupported color lookup strategy: {color_lookup}")


def _resolve_detection_color_idx(
detections: Detections,
detection_idx: int,
color_lookup: ColorLookup | npt.NDArray[np.int_] = ColorLookup.CLASS,
) -> int:
"""Resolve a palette index for a single detection.

Args:
detections: The detections object.
detection_idx: Index of the current detection.
color_lookup: Strategy for mapping colors. Also accepts a custom
``np.ndarray`` of integer indices.

Returns:
An integer index suitable for ``ColorPalette.by_idx()``.
"""
if detection_idx >= len(detections):
raise ValueError(
f"Detection index {detection_idx} "
Expand All @@ -55,26 +125,27 @@ def resolve_color_idx(
f"does not match length of detections {len(detections)}"
)
return int(color_lookup[detection_idx])
elif color_lookup == ColorLookup.INDEX:
return detection_idx
elif color_lookup == ColorLookup.CLASS:
if detections.class_id is None:
raise ValueError(
"Could not resolve color by class because "
"Detections do not have class_id. If using an annotator, "
"try setting color_lookup to sv.ColorLookup.INDEX or "
"sv.ColorLookup.TRACK."
)
return int(detections.class_id[detection_idx])
elif color_lookup == ColorLookup.TRACK:
if detections.tracker_id is None:
raise ValueError(
"Could not resolve color by track because "
"Detections do not have tracker_id. Did you call "
"tracker.update_with_detections(...) before annotating?"
)
return int(detections.tracker_id[detection_idx])
raise ValueError(f"Unsupported color lookup strategy: {color_lookup}")

return _resolve_color_idx(
instance_idx=detection_idx,
color_lookup=color_lookup,
count=len(detections),
class_id=detections.class_id,
tracker_id=detections.tracker_id,
)


@deprecated( # type: ignore[untyped-decorator]
target=_resolve_detection_color_idx,
deprecated_in="0.30.0",
remove_in="0.33.0",
)
def resolve_color_idx( # type: ignore[return]
detections: Detections,
detection_idx: int,
color_lookup: ColorLookup | npt.NDArray[np.int_] = ColorLookup.CLASS,
) -> int:
void(detections, detection_idx, color_lookup)


def resolve_text_background_xyxy(
Expand Down Expand Up @@ -130,19 +201,35 @@ def resolve_text_background_xyxy(
)


def get_color_by_index(color: Color | ColorPalette, idx: int) -> Color:
def _get_color_by_index(color: Color | ColorPalette, idx: int) -> Color:
if isinstance(color, ColorPalette):
return color.by_idx(idx)
return color


def resolve_color(
def get_color_by_index(color: Color | ColorPalette, idx: int) -> Color:
return _get_color_by_index(color=color, idx=idx)


def _resolve_detection_color(
color: Color | ColorPalette,
detections: Detections,
detection_idx: int,
color_lookup: ColorLookup | npt.NDArray[np.int_] = ColorLookup.CLASS,
) -> Color:
idx = resolve_color_idx(
"""Resolve the color for a single detection.

Args:
color: A single color or a palette to pick from.
detections: The detections object.
detection_idx: Index of the current detection.
color_lookup: Strategy for mapping colors. Also accepts a custom
``np.ndarray`` of integer indices.

Returns:
The resolved ``Color``.
"""
idx = _resolve_detection_color_idx(
detections=detections,
detection_idx=detection_idx,
color_lookup=color_lookup,
Expand All @@ -153,7 +240,21 @@ def resolve_color(
and idx == PENDING_TRACK_ID
):
return PENDING_TRACK_COLOR
return get_color_by_index(color=color, idx=idx)
return _get_color_by_index(color=color, idx=idx)


@deprecated( # type: ignore[untyped-decorator]
target=_resolve_detection_color,
deprecated_in="0.30.0",
remove_in="0.33.0",
)
def resolve_color( # type: ignore[return]
color: Color | ColorPalette,
detections: Detections,
detection_idx: int,
color_lookup: ColorLookup | npt.NDArray[np.int_] = ColorLookup.CLASS,
) -> Color:
void(color, detections, detection_idx, color_lookup)


def wrap_text(text: Any, max_line_length: int | None = None) -> list[str]:
Expand Down
Loading
Loading