Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix OBB prediction and update Ultralytics demo notebook #1126

Merged
merged 5 commits into from
Mar 6, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix OBB prediction and update Ultralytics demo notebook
- Add OBB (Oriented Bounding Box) prediction example to inference notebook
- Enhance visualization for OBB predictions in cv utils
- Update AutoDetectionModel and prediction methods to support OBB models
- Bump package version to 0.11.22
- Improve demo notebook with additional test image and simplified imports
fcakyon committed Mar 6, 2025
commit e78e66c4ea96b16b9d7871fb9b79705aafb1b9ff
311 changes: 206 additions & 105 deletions demo/inference_for_ultralytics.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "sahi"
version = "0.11.21"
version = "0.11.22"
readme = "README.md"
description = "A vision library for performing sliced inference on large images/small objects"
requires-python = ">=3.8"
4 changes: 3 additions & 1 deletion sahi/auto_model.py
Original file line number Diff line number Diff line change
@@ -39,7 +39,7 @@ def from_pretrained(

Args:
model_type: str
Name of the detection framework (example: "yolov5", "mmdet", "detectron2")
Name of the detection framework (example: "ultralytics", "huggingface", "torchvision")
model_path: str
Path of the detection model (ex. 'model.pt')
config_path: str
@@ -58,8 +58,10 @@ def from_pretrained(
If True, automatically loads the model at initialization
image_size: int
Inference input size.

Returns:
Returns an instance of a DetectionModel

Raises:
ImportError: If given {model_type} framework is not installed
"""
6 changes: 3 additions & 3 deletions sahi/models/torchvision.py
Original file line number Diff line number Diff line change
@@ -178,13 +178,13 @@ def _create_object_prediction_list_from_original_predictions(

for ind in range(len(boxes)):
if masks is not None:
mask = get_coco_segmentation_from_bool_mask(np.array(masks[ind]))
segmentation = get_coco_segmentation_from_bool_mask(np.array(masks[ind]))
else:
mask = None
segmentation = None

object_prediction = ObjectPrediction(
bbox=boxes[ind],
segmentation=mask,
segmentation=segmentation,
category_id=int(category_ids[ind]),
category_name=self.category_mapping[str(int(category_ids[ind]))],
shift_amount=shift_amount,
4 changes: 2 additions & 2 deletions sahi/models/ultralytics.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
from sahi.models.base import DetectionModel
from sahi.prediction import ObjectPrediction
from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list
from sahi.utils.cv import get_coco_segmentation_from_bool_mask, get_coco_segmentation_from_obb_points
from sahi.utils.cv import get_coco_segmentation_from_bool_mask
from sahi.utils.import_utils import check_requirements

logger = logging.getLogger(__name__)
@@ -207,7 +207,7 @@ def _create_object_prediction_list_from_original_predictions(
segmentation = get_coco_segmentation_from_bool_mask(bool_mask)
else: # is_obb
obb_points = masks_or_points[pred_ind] # Get OBB points for this prediction
segmentation = get_coco_segmentation_from_obb_points(obb_points)
segmentation = [obb_points.reshape(-1).tolist()]

if len(segmentation) == 0:
continue
11 changes: 8 additions & 3 deletions sahi/predict.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# OBSS SAHI Tool

Check failure on line 1 in sahi/predict.py

GitHub Actions / ruff-format

Would reformat
# Code written by Fatih C Akyon, 2020.

import logging
@@ -113,6 +113,9 @@
time_end = time.time() - time_start
durations_in_seconds["prediction"] = time_end

if full_shape is None:
full_shape = [image_as_pil.height, image_as_pil.width]

# process prediction
time_start = time.time()
# works only with 1 batch
@@ -239,19 +242,21 @@
overlap_width_ratio=overlap_width_ratio,
auto_slice_resolution=auto_slice_resolution,
)
from sahi.models.ultralytics import UltralyticsDetectionModel

num_slices = len(slice_image_result)
time_end = time.time() - time_start
durations_in_seconds["slice"] = time_end

if isinstance(detection_model, UltralyticsDetectionModel) and detection_model.is_obb:
# Only NMS is supported for OBB model outputs
postprocess_type = "NMS"

# init match postprocess instance
if postprocess_type not in POSTPROCESS_NAME_TO_CLASS.keys():
raise ValueError(
f"postprocess_type should be one of {list(POSTPROCESS_NAME_TO_CLASS.keys())} but given as {postprocess_type}"
)
elif postprocess_type == "UNIONMERGE":
# deprecated in v0.9.3
raise ValueError("'UNIONMERGE' postprocess_type is deprecated, use 'GREEDYNMM' instead.")
postprocess_constructor = POSTPROCESS_NAME_TO_CLASS[postprocess_type]
postprocess = postprocess_constructor(
match_threshold=postprocess_match_threshold,
128 changes: 72 additions & 56 deletions sahi/utils/cv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# OBSS SAHI Tool

Check failure on line 1 in sahi/utils/cv.py

GitHub Actions / ruff-format

Would reformat
# Code written by Fatih C Akyon, 2020.

import copy
@@ -540,68 +540,83 @@
# set text_size for category names
text_size = text_size or rect_th / 3

# add masks to image if present
# add masks or obb polygons to image if present
for object_prediction in object_prediction_list:
# deepcopy object_prediction_list so that original is not altered
object_prediction = object_prediction.deepcopy()
# visualize masks if present
if object_prediction.mask is not None:
# deepcopy mask so that original is not altered
mask = object_prediction.mask.bool_mask
# set color
if colors is not None:
color = colors(object_prediction.category.id)
# draw mask
rgb_mask = apply_color_mask(mask, color or (0, 0, 0))
image = cv2.addWeighted(image, 1, rgb_mask, 0.6, 0)

# add bboxes to image if present
for object_prediction in object_prediction_list:
# deepcopy object_prediction_list so that original is not altered
object_prediction = object_prediction.deepcopy()

bbox = object_prediction.bbox.to_xyxy()
category_name = object_prediction.category.name
score = object_prediction.score.value

# arange label to be displayed
label = f"{object_prediction.category.name}"
if not hide_conf:
label += f" {object_prediction.score.value:.2f}"
# set color
if colors is not None:
color = colors(object_prediction.category.id)
# set bbox points
point1, point2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))
# visualize boxes
cv2.rectangle(
image,
point1,
point2,
color=color or (0, 0, 0),
thickness=rect_th,
)

if not hide_labels:
# arange bounding box text location
label = f"{category_name}"

if not hide_conf:
label += f" {score:.2f}"

box_width, box_height = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[
0
] # label width, height
outside = point1[1] - box_height - 3 >= 0 # label fits outside box
point2 = point1[0] + box_width, point1[1] - box_height - 3 if outside else point1[1] + box_height + 3
# add bounding box text
cv2.rectangle(image, point1, point2, color or (0, 0, 0), -1, cv2.LINE_AA) # filled
cv2.putText(
# visualize masks or obb polygons if present
has_mask = object_prediction.mask is not None
is_obb_pred = False
if has_mask:
segmentation = object_prediction.mask.segmentation
if len(segmentation) == 1 and len(segmentation[0]) == 8:
is_obb_pred = True

if is_obb_pred:
points = np.array(segmentation).reshape((-1, 1, 2)).astype(np.int32)
cv2.polylines(image, [points], isClosed=True, color=color or (0, 0, 0), thickness=rect_th)

if not hide_labels:
lowest_point = points[points[:, :, 1].argmax()][0]
box_width, box_height = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[0]
outside = lowest_point[1] - box_height - 3 >= 0
text_bg_point1 = (lowest_point[0], lowest_point[1] - box_height - 3 if outside else lowest_point[1] + 3)
text_bg_point2 = (lowest_point[0] + box_width, lowest_point[1])
cv2.rectangle(image, text_bg_point1, text_bg_point2, color or (0, 0, 0), thickness=-1, lineType=cv2.LINE_AA)
cv2.putText(
image,
label,
(lowest_point[0], lowest_point[1] - 2 if outside else lowest_point[1] + box_height + 2),
0,
text_size,
(255, 255, 255),
thickness=text_th,
)
else:
# draw mask
rgb_mask = apply_color_mask(object_prediction.mask.bool_mask, color or (0, 0, 0))
image = cv2.addWeighted(image, 1, rgb_mask, 0.6, 0)

# add bboxes to image if is_obb_pred=False
if not is_obb_pred:
bbox = object_prediction.bbox.to_xyxy()

# set bbox points
point1, point2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))
# visualize boxes
cv2.rectangle(
image,
label,
(point1[0], point1[1] - 2 if outside else point1[1] + box_height + 2),
0,
text_size,
(255, 255, 255),
thickness=text_th,
point1,
point2,
color=color or (0, 0, 0),
thickness=rect_th,
)

if not hide_labels:
box_width, box_height = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[
0
] # label width, height
outside = point1[1] - box_height - 3 >= 0 # label fits outside box
point2 = point1[0] + box_width, point1[1] - box_height - 3 if outside else point1[1] + box_height + 3
# add bounding box text
cv2.rectangle(image, point1, point2, color or (0, 0, 0), -1, cv2.LINE_AA) # filled
cv2.putText(
image,
label,
(point1[0], point1[1] - 2 if outside else point1[1] + box_height + 2),
0,
text_size,
(255, 255, 255),
thickness=text_th,
)

# export if output_dir is present
if output_dir is not None:
# export image with predictions
@@ -614,7 +629,7 @@
return {"image": image, "elapsed_time": elapsed_time}


def get_coco_segmentation_from_bool_mask(bool_mask):
def get_coco_segmentation_from_bool_mask(bool_mask: np.ndarray) -> List[List[float]]:
"""
Convert boolean mask to coco segmentation format
[
@@ -712,12 +727,13 @@
obb_points: np.ndarray
OBB points tensor from ultralytics.engine.results.OBB
Shape: (4, 2) containing 4 points with (x,y) coordinates each

Returns:
List[List[float]]: Polygon points in COCO format
[[x1, y1, x2, y2, x3, y3, x4, y4, x1, y1], [...], ...]
[[x1, y1, x2, y2, x3, y3, x4, y4], [...], ...]
"""
# Convert from (4,2) to [x1,y1,x2,y2,x3,y3,x4,y4] format
points = obb_points.reshape(-1).tolist()
points = obb_points.reshape(-1).tolist() #

# Create polygon from points and close it by repeating first point
polygons = []