diff --git a/sahi/models/ultralytics.py b/sahi/models/ultralytics.py index e9df51f6..dc215a21 100644 --- a/sahi/models/ultralytics.py +++ b/sahi/models/ultralytics.py @@ -67,9 +67,11 @@ def perform_inference(self, image: np.ndarray): if self.image_size is not None: kwargs = {"imgsz": self.image_size, **kwargs} + if type(image) is list: - prediction_result = self.model(image[:, :, ::-1], **kwargs) # YOLOv8 expects numpy arrays to have BGR - + prediction_result = self.model(image, **kwargs) # YOLOv8 expects numpy arrays to have BGR + else : + prediction_result = self.model(image[:, :, ::-1], **kwargs) if self.has_mask: if not prediction_result[0].masks: prediction_result[0].masks = Masks( @@ -109,7 +111,10 @@ def perform_inference(self, image: np.ndarray): prediction_result = [result.boxes.data for result in prediction_result] self._original_predictions = prediction_result - self._original_shape = image.shape + if type(image) == list: + self._original_shape = image[0].shape + else: + self._original_shape = image.shape @property def category_names(self): diff --git a/sahi/predict.py b/sahi/predict.py index 60d834c6..82f92a91 100644 --- a/sahi/predict.py +++ b/sahi/predict.py @@ -5,6 +5,7 @@ import os import time from typing import Generator, List, Optional, Union +import math from PIL import Image @@ -106,10 +107,11 @@ def get_prediction( durations_in_seconds = dict() # read image as pil - image_as_pil = read_image_as_pil(image) + # image_as_pil = read_image_as_pil(image) # get prediction time_start = time.time() - detection_model.perform_inference(np.ascontiguousarray(image_as_pil)) + # detection_model.perform_inference(np.ascontiguousarray(image_as_pil)) + detection_model.perform_inference(image) time_end = time.time() - time_start durations_in_seconds["prediction"] = time_end @@ -126,7 +128,6 @@ def get_prediction( # postprocess matching predictions if postprocess is not None: object_prediction_list = postprocess(object_prediction_list) - time_end = time.time() - time_start durations_in_seconds["postprocess"] = time_end @@ -159,6 +160,7 @@ def get_sliced_prediction( auto_slice_resolution: bool = True, slice_export_prefix: Optional[str] = None, slice_dir: Optional[str] = None, + num_batch: int = 1 exclude_classes_by_name: Optional[List[str]] = None, exclude_classes_by_id: Optional[List[int]] = None, ) -> PredictionResult: @@ -225,8 +227,8 @@ def get_sliced_prediction( # for profiling durations_in_seconds = dict() - # currently only 1 batch supported - num_batch = 1 + # # currently only 1 batch supported + # num_batch = 1 # create slices from full image time_start = time.time() slice_image_result = slice_image( @@ -260,7 +262,8 @@ def get_sliced_prediction( ) # create prediction input - num_group = int(num_slices / num_batch) + # num_group = int(num_slices / num_batch) + num_group = math.ceil(num_slices / num_batch) if verbose == 1 or verbose == 2: tqdm.write(f"Performing prediction on {num_slices} slices.") object_prediction_list = [] @@ -270,24 +273,33 @@ def get_sliced_prediction( image_list = [] shift_amount_list = [] for image_ind in range(num_batch): - image_list.append(slice_image_result.images[group_ind * num_batch + image_ind]) + if (group_ind * num_batch + image_ind) >= num_slices: + break + # image_list.append(slice_image_result.images[group_ind * num_batch + image_ind]) + img_slice = slice_image_result.images[group_ind * num_batch + image_ind] + img_slice = img_slice[:,:,::-1] + image_list.append(img_slice) shift_amount_list.append(slice_image_result.starting_pixels[group_ind * num_batch + image_ind]) # perform batch prediction + num_full = len(image_list) prediction_result = get_prediction( - image=image_list[0], + image=image_list, detection_model=detection_model, - shift_amount=shift_amount_list[0], - full_shape=[ + shift_amount=shift_amount_list, + full_shape=[[ slice_image_result.original_image_height, slice_image_result.original_image_width, - ], + ]] * num_full, exclude_classes_by_name=exclude_classes_by_name, exclude_classes_by_id=exclude_classes_by_id, ) + # convert sliced predictions to full predictions - for object_prediction in prediction_result.object_prediction_list: - if object_prediction: # if not empty - object_prediction_list.append(object_prediction.get_shifted_object_prediction()) + for object_prediction_per in prediction_result.object_prediction_list: + + if len(object_prediction_per) != 0: # if not empty + for object_prediction in object_prediction_per: + object_prediction_list.append(object_prediction.get_shifted_object_prediction()) # merge matching predictions during sliced prediction if merge_buffer_length is not None and len(object_prediction_list) > merge_buffer_length: @@ -296,7 +308,7 @@ def get_sliced_prediction( # perform standard prediction if num_slices > 1 and perform_standard_pred: prediction_result = get_prediction( - image=image, + image=[np.array(image)], detection_model=detection_model, shift_amount=[0, 0], full_shape=[ @@ -307,7 +319,9 @@ def get_sliced_prediction( exclude_classes_by_name=exclude_classes_by_name, exclude_classes_by_id=exclude_classes_by_id, ) - object_prediction_list.extend(prediction_result.object_prediction_list) + if len(prediction_result.object_prediction_list) != 0: + for _predicion_result in prediction_result.object_prediction_list: + object_prediction_list.extend(_predicion_result) # merge matching predictions if len(object_prediction_list) > 1: @@ -408,6 +422,7 @@ def predict( verbose: int = 1, return_dict: bool = False, force_postprocess_type: bool = False, + num_batch: int = 1, exclude_classes_by_name: Optional[List[str]] = None, exclude_classes_by_id: Optional[List[int]] = None, **kwargs, @@ -610,6 +625,7 @@ def predict( postprocess_match_threshold=postprocess_match_threshold, postprocess_class_agnostic=postprocess_class_agnostic, verbose=1 if verbose else 0, + num_batch = num_batch, exclude_classes_by_name=exclude_classes_by_name, exclude_classes_by_id=exclude_classes_by_id, ) diff --git a/sahi/prediction.py b/sahi/prediction.py index a73b0cce..daf18a88 100644 --- a/sahi/prediction.py +++ b/sahi/prediction.py @@ -164,8 +164,13 @@ def __init__( image: Union[Image.Image, str, np.ndarray], durations_in_seconds: Dict[str, Any] = dict(), ): - self.image: Image.Image = read_image_as_pil(image) - self.image_width, self.image_height = self.image.size + + if type(image) is list: + self.image = image + self.image_width, self.image_height = self.image[0].shape[:2] + else : + self.image: Image.Image = read_image_as_pil(image) + self.image_width, self.image_height = self.image.size self.object_prediction_list: List[ObjectPrediction] = object_prediction_list self.durations_in_seconds = durations_in_seconds