66import matplotlib
77import matplotlib .pyplot as plt
88import numpy as np
9+ import numpy .typing as npt
910
1011from supervision .dataset .core import DetectionDataset
1112from supervision .detection .core import Detections
1516
1617def detections_to_tensor (
1718 detections : Detections , with_confidence : bool = False
18- ) -> np .ndarray :
19+ ) -> npt . NDArray [ np .float32 ] :
1920 """
2021 Convert Supervision Detections to numpy tensors for further computation
2122
@@ -40,12 +41,13 @@ def detections_to_tensor(
4041 )
4142 arrays_to_concat .append (np .expand_dims (detections .confidence , 1 ))
4243
43- result : np .ndarray = np .concatenate (arrays_to_concat , axis = 1 )
44+ result : npt . NDArray [ np .float32 ] = np .concatenate (arrays_to_concat , axis = 1 )
4445 return result
4546
4647
4748def validate_input_tensors (
48- predictions : list [np .ndarray ], targets : list [np .ndarray ]
49+ predictions : list [npt .NDArray [np .float32 ]],
50+ targets : list [npt .NDArray [np .float32 ]],
4951) -> None :
5052 """
5153 Checks for shape consistency of input tensors.
@@ -89,7 +91,7 @@ class ConfusionMatrix:
8991 Detections with lower IoU will be classified as `FP`.
9092 """
9193
92- matrix : np .ndarray
94+ matrix : npt . NDArray [ np .int32 ]
9395 classes : list [str ]
9496 conf_threshold : float
9597 iou_threshold : float
@@ -162,8 +164,8 @@ def from_detections(
162164 @classmethod
163165 def from_tensors (
164166 cls ,
165- predictions : list [np .ndarray ],
166- targets : list [np .ndarray ],
167+ predictions : list [npt . NDArray [ np .float32 ] ],
168+ targets : list [npt . NDArray [ np .float32 ] ],
167169 classes : list [str ],
168170 conf_threshold : float = 0.3 ,
169171 iou_threshold : float = 0.5 ,
@@ -237,12 +239,12 @@ def from_tensors(
237239
238240 @staticmethod
239241 def evaluate_detection_batch (
240- predictions : np .ndarray ,
241- targets : np .ndarray ,
242+ predictions : npt . NDArray [ np .float32 ] ,
243+ targets : npt . NDArray [ np .float32 ] ,
242244 num_classes : int ,
243245 conf_threshold : float ,
244246 iou_threshold : float ,
245- ) -> np .ndarray :
247+ ) -> npt . NDArray [ np .int32 ] :
246248 """
247249 Calculate confusion matrix for a batch of detections for a single image.
248250
@@ -307,11 +309,13 @@ def evaluate_detection_batch(
307309 for i , detection_class_value in enumerate (detection_classes ):
308310 if not any (matched_detection_idx == i ):
309311 result_matrix [num_classes , detection_class_value ] += 1 # FP
310- final_result_matrix : np .ndarray = result_matrix
312+ final_result_matrix : npt . NDArray [ np .int32 ] = result_matrix
311313 return final_result_matrix
312314
313315 @staticmethod
314- def _drop_extra_matches (matches : np .ndarray ) -> np .ndarray :
316+ def _drop_extra_matches (
317+ matches : npt .NDArray [np .float32 ],
318+ ) -> npt .NDArray [np .float32 ]:
315319 """
316320 Deduplicate matches. If there are multiple matches for the same true or
317321 predicted box, only the one with the highest IoU is kept.
@@ -321,13 +325,14 @@ def _drop_extra_matches(matches: np.ndarray) -> np.ndarray:
321325 matches = matches [np .unique (matches [:, 1 ], return_index = True )[1 ]]
322326 matches = matches [matches [:, 2 ].argsort ()[::- 1 ]]
323327 matches = matches [np .unique (matches [:, 0 ], return_index = True )[1 ]]
324- return matches
328+ result : npt .NDArray [np .float32 ] = matches
329+ return result
325330
326331 @classmethod
327332 def benchmark (
328333 cls ,
329334 dataset : DetectionDataset ,
330- callback : Callable [[np .ndarray ], Detections ],
335+ callback : Callable [[npt . NDArray [ np .uint8 ] ], Detections ],
331336 conf_threshold : float = 0.3 ,
332337 iou_threshold : float = 0.5 ,
333338 ) -> ConfusionMatrix :
@@ -510,7 +515,7 @@ class MeanAveragePrecision:
510515 map50_95 : float
511516 map50 : float
512517 map75 : float
513- per_class_ap50_95 : np .ndarray
518+ per_class_ap50_95 : npt . NDArray [ np .float64 ]
514519
515520 @classmethod
516521 def from_detections (
@@ -566,7 +571,7 @@ def from_detections(
566571 def benchmark (
567572 cls ,
568573 dataset : DetectionDataset ,
569- callback : Callable [[np .ndarray ], Detections ],
574+ callback : Callable [[npt . NDArray [ np .uint8 ] ], Detections ],
570575 ) -> MeanAveragePrecision :
571576 """
572577 Calculate mean average precision from dataset and callback function.
@@ -612,8 +617,8 @@ def callback(image: np.ndarray) -> sv.Detections:
612617 @classmethod
613618 def from_tensors (
614619 cls ,
615- predictions : list [np .ndarray ],
616- targets : list [np .ndarray ],
620+ predictions : list [npt . NDArray [ np .float32 ] ],
621+ targets : list [npt . NDArray [ np .float32 ] ],
617622 ) -> MeanAveragePrecision :
618623 """
619624 Calculate Mean Average Precision based on predicted and ground-truth
@@ -704,7 +709,10 @@ def from_tensors(
704709 )
705710
706711 @staticmethod
707- def compute_average_precision (recall : np .ndarray , precision : np .ndarray ) -> float :
712+ def compute_average_precision (
713+ recall : npt .NDArray [np .float64 ],
714+ precision : npt .NDArray [np .float64 ],
715+ ) -> float :
708716 """
709717 Compute the average precision using 101-point interpolation (COCO), given
710718 the recall and precision curves.
@@ -732,16 +740,18 @@ def compute_average_precision(recall: np.ndarray, precision: np.ndarray) -> floa
732740 interpolated_precision , interpolated_recall_levels
733741 )
734742 else :
735- average_precision = np . trapz ( # type: ignore[attr-defined]
743+ average_precision = getattr ( np , " trapz" )(
736744 interpolated_precision , interpolated_recall_levels
737745 )
738746
739747 return float (average_precision )
740748
741749 @staticmethod
742750 def _match_detection_batch (
743- predictions : np .ndarray , targets : np .ndarray , iou_thresholds : np .ndarray
744- ) -> np .ndarray :
751+ predictions : npt .NDArray [np .float32 ],
752+ targets : npt .NDArray [np .float32 ],
753+ iou_thresholds : npt .NDArray [np .float32 ],
754+ ) -> npt .NDArray [np .bool_ ]:
745755 """
746756 Match predictions with target labels based on IoU levels.
747757
@@ -778,17 +788,17 @@ def _match_detection_batch(
778788 matches = matches [np .unique (matches [:, 0 ], return_index = True )[1 ]]
779789
780790 correct [matches [:, 1 ].astype (int ), i ] = True
781- result : np .ndarray = correct
791+ result : npt . NDArray [ np .bool_ ] = correct
782792 return result
783793
784794 @staticmethod
785795 def _average_precisions_per_class (
786- matches : np .ndarray ,
787- prediction_confidence : np .ndarray ,
788- prediction_class_ids : np .ndarray ,
789- true_class_ids : np .ndarray ,
796+ matches : npt . NDArray [ np .bool_ ] ,
797+ prediction_confidence : npt . NDArray [ np .float32 ] ,
798+ prediction_class_ids : npt . NDArray [ np .int32 ] ,
799+ true_class_ids : npt . NDArray [ np .int32 ] ,
790800 eps : float = 1e-16 ,
791- ) -> np .ndarray :
801+ ) -> npt . NDArray [ np .float64 ] :
792802 """
793803 Compute the average precision, given the recall and precision curves.
794804 Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
@@ -810,7 +820,9 @@ def _average_precisions_per_class(
810820 unique_classes , class_counts = np .unique (true_class_ids , return_counts = True )
811821 num_classes = unique_classes .shape [0 ]
812822
813- average_precisions = np .zeros ((num_classes , matches .shape [1 ]))
823+ average_precisions : npt .NDArray [np .float64 ] = np .zeros (
824+ (num_classes , matches .shape [1 ]), dtype = np .float64
825+ )
814826
815827 for class_idx , class_id in enumerate (unique_classes ):
816828 is_class = prediction_class_ids == class_id
@@ -832,5 +844,5 @@ def _average_precisions_per_class(
832844 )
833845 )
834846
835- result : np .ndarray = average_precisions
847+ result : npt . NDArray [ np .float64 ] = average_precisions
836848 return result
0 commit comments