diff --git a/sahi/postprocess/utils.py b/sahi/postprocess/utils.py index 1e10341b..c71075e5 100644 --- a/sahi/postprocess/utils.py +++ b/sahi/postprocess/utils.py @@ -4,6 +4,7 @@ import numpy as np import torch from shapely.geometry import MultiPolygon, Polygon +from shapely.geometry.collection import GeometryCollection from sahi.annotation import BoundingBox, Category, Mask from sahi.prediction import ObjectPrediction @@ -65,6 +66,71 @@ def tolist(self): return self.list +def repair_polygon(shapely_polygon: Polygon) -> Polygon: + """ + Fix polygons + :param shapely_polygon: Shapely polygon object + :return: + """ + if not shapely_polygon.is_valid: + fixed_polygon = shapely_polygon.buffer(0) + if fixed_polygon.is_valid: + if isinstance(fixed_polygon, Polygon): + return fixed_polygon + elif isinstance(fixed_polygon, MultiPolygon): + return max(fixed_polygon.geoms, key=lambda p: p.area) + elif isinstance(fixed_polygon, GeometryCollection): + polygons = [geom for geom in fixed_polygon.geoms if isinstance(geom, Polygon)] + return max(polygons, key=lambda p: p.area) if polygons else shapely_polygon + + return shapely_polygon + + +def repair_multipolygon(shapely_multipolygon: MultiPolygon) -> MultiPolygon: + """ + Fix invalid MultiPolygon objects + :param shapely_multipolygon: Imported shapely MultiPolygon object + :return: + """ + if not shapely_multipolygon.is_valid: + fixed_geometry = shapely_multipolygon.buffer(0) + + if fixed_geometry.is_valid: + if isinstance(fixed_geometry, MultiPolygon): + return fixed_geometry + elif isinstance(fixed_geometry, Polygon): + return MultiPolygon([fixed_geometry]) + elif isinstance(fixed_geometry, GeometryCollection): + polygons = [geom for geom in fixed_geometry.geoms if isinstance(geom, Polygon)] + return MultiPolygon(polygons) if polygons else shapely_multipolygon + + return shapely_multipolygon + + +def coco_segmentation_to_shapely(segmentation: Union[List, List[List]]): + """ + Fix segment data in COCO format + :param segmentation: segment data in COCO format + :return: + """ + if isinstance(segmentation, List) and all([not isinstance(seg, List) for seg in segmentation]): + segmentation = [segmentation] + elif isinstance(segmentation, List) and all([isinstance(seg, List) for seg in segmentation]): + pass + else: + raise ValueError("segmentation must be List or List[List]") + + polygon_list = [] + + for coco_polygon in segmentation: + point_list = list(zip(coco_polygon[::2], coco_polygon[1::2])) + shapely_polygon = Polygon(point_list) + polygon_list.append(repair_polygon(shapely_polygon)) + + shapely_multipolygon = repair_multipolygon(MultiPolygon(polygon_list)) + return shapely_multipolygon + + def object_prediction_list_to_torch(object_prediction_list: ObjectPredictionList) -> torch.tensor: """ Returns: @@ -166,6 +232,12 @@ def get_merged_mask(pred1: ObjectPrediction, pred2: ObjectPrediction) -> Mask: # buffer(0) is a quickhack to fix invalid polygons most of the time poly1 = get_shapely_multipolygon(mask1.segmentation).buffer(0) poly2 = get_shapely_multipolygon(mask2.segmentation).buffer(0) + + if poly1.is_empty: + poly1 = coco_segmentation_to_shapely(mask1.segmentation) + if poly2.is_empty: + poly2 = coco_segmentation_to_shapely(mask2.segmentation) + union_poly = poly1.union(poly2) if not hasattr(union_poly, "geoms"): union_poly = MultiPolygon([union_poly])