@@ -66,37 +66,45 @@ def compute_class_recall(results):
6666 # Per class recall and precision
6767 class_recall_dict = {}
6868 class_precision_dict = {}
69- class_size = {}
69+ class_size_dict = {}
7070
7171 box_results = results [results .predicted_label .notna ()]
7272 if box_results .empty :
7373 print ("No predictions made" )
7474 class_recall = None
7575 return class_recall
7676
77- for name , group in box_results .groupby ("true_label" ):
78- class_recall_dict [name ] = (
79- sum (group .true_label == group .predicted_label ) / group .shape [0 ]
80- )
81- number_of_predictions = box_results [box_results .predicted_label == name ].shape [
82- 0
83- ]
84- if number_of_predictions == 0 :
85- class_precision_dict [name ] = 0
86- else :
87- class_precision_dict [name ] = (
88- sum (group .true_label == group .predicted_label ) / number_of_predictions
77+ labels = set (box_results ["predicted_label" ].unique ()).union (
78+ box_results ["true_label" ].unique ()
79+ )
80+ for label in labels :
81+ ground_df = box_results [box_results ["true_label" ] == label ]
82+ n_ground_boxes = ground_df .shape [0 ]
83+ if n_ground_boxes > 0 :
84+ class_recall_dict [label ] = (
85+ sum (ground_df .true_label == ground_df .predicted_label ) / n_ground_boxes
8986 )
90- class_size [name ] = group .shape [0 ]
91-
92- class_recall = pd .DataFrame (
93- {
94- "label" : class_recall_dict .keys (),
95- "recall" : pd .Series (class_recall_dict ),
96- "precision" : pd .Series (class_precision_dict ),
97- "size" : pd .Series (class_size ),
98- }
99- ).reset_index (drop = True )
87+ pred_df = box_results [box_results ["true_label" ] == label ]
88+ n_pred_boxes = pred_df .shape [0 ]
89+ if n_pred_boxes > 0 :
90+ class_precision_dict [label ] = (
91+ sum (pred_df .true_label == pred_df .predicted_label ) / pred_df .shape [0 ]
92+ )
93+ class_size_dict [label ] = n_ground_boxes
94+
95+ # the fillna is needed for the missing labels with 0 ground truths or 0 predictions
96+ class_recall = (
97+ pd .DataFrame (
98+ {
99+ # "label": class_recall_dict.keys(),
100+ "recall" : pd .Series (class_recall_dict ),
101+ "precision" : pd .Series (class_precision_dict ),
102+ "size" : pd .Series (class_size_dict ),
103+ }
104+ )
105+ .reset_index (names = "label" )
106+ .fillna (0 )
107+ )
100108
101109 return class_recall
102110
@@ -160,9 +168,16 @@ def __evaluate_wrapper__(predictions, ground_df, iou_threshold, label_dict):
160168 results ["results" ]["true_label" ] = results ["results" ]["true_label" ].apply (
161169 lambda x : label_dict [x ]
162170 )
171+ # TODO: do we need to return the "predictions" in the results?
172+ # TODO: DRY getting the proper predicted label column with `evaluate_boxes`
173+ # set the score and predicted label
174+ if "cropmodel_label" in predictions .columns :
175+ pred_label_col = "cropmodel_label"
176+ else :
177+ pred_label_col = "label"
163178 # avoid modifying a view
164179 results ["predictions" ] = predictions .copy ()
165- results ["predictions" ]["label" ] = results ["predictions" ]["label" ].apply (
180+ results ["predictions" ]["label" ] = results ["predictions" ][pred_label_col ].apply (
166181 lambda x : label_dict [x ]
167182 )
168183
@@ -309,12 +324,16 @@ def evaluate_boxes(predictions, ground_df, iou_threshold=0.4):
309324 )
310325 results_df = results_df .rename (columns = {"label" : "true_label" })
311326 # set the score and predicted label
327+ if "cropmodel_label" in predictions .columns :
328+ pred_label_col = "cropmodel_label"
329+ else :
330+ pred_label_col = "label"
312331 results_df = results_df .merge (
313- predictions [["score" , "label" ]],
332+ predictions [["score" , pred_label_col ]],
314333 left_on = "prediction_id" ,
315334 right_index = True ,
316335 )
317- results_df = results_df .rename (columns = {"label" : "predicted_label" })
336+ results_df = results_df .rename (columns = {pred_label_col : "predicted_label" })
318337 # set whether it is a match
319338 results_df ["match" ] = results_df ["IoU" ] >= iou_threshold
320339
0 commit comments