Skip to content

Commit fdb8fed

Browse files
committed
fix: bug in omitting pred labels to compute class recall/precision
1 parent dd4dc18 commit fdb8fed

File tree

1 file changed

+45
-26
lines changed

1 file changed

+45
-26
lines changed

src/deepforest/evaluate.py

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -66,37 +66,45 @@ def compute_class_recall(results):
6666
# Per class recall and precision
6767
class_recall_dict = {}
6868
class_precision_dict = {}
69-
class_size = {}
69+
class_size_dict = {}
7070

7171
box_results = results[results.predicted_label.notna()]
7272
if box_results.empty:
7373
print("No predictions made")
7474
class_recall = None
7575
return class_recall
7676

77-
for name, group in box_results.groupby("true_label"):
78-
class_recall_dict[name] = (
79-
sum(group.true_label == group.predicted_label) / group.shape[0]
80-
)
81-
number_of_predictions = box_results[box_results.predicted_label == name].shape[
82-
0
83-
]
84-
if number_of_predictions == 0:
85-
class_precision_dict[name] = 0
86-
else:
87-
class_precision_dict[name] = (
88-
sum(group.true_label == group.predicted_label) / number_of_predictions
77+
labels = set(box_results["predicted_label"].unique()).union(
78+
box_results["true_label"].unique()
79+
)
80+
for label in labels:
81+
ground_df = box_results[box_results["true_label"] == label]
82+
n_ground_boxes = ground_df.shape[0]
83+
if n_ground_boxes > 0:
84+
class_recall_dict[label] = (
85+
sum(ground_df.true_label == ground_df.predicted_label) / n_ground_boxes
8986
)
90-
class_size[name] = group.shape[0]
91-
92-
class_recall = pd.DataFrame(
93-
{
94-
"label": class_recall_dict.keys(),
95-
"recall": pd.Series(class_recall_dict),
96-
"precision": pd.Series(class_precision_dict),
97-
"size": pd.Series(class_size),
98-
}
99-
).reset_index(drop=True)
87+
pred_df = box_results[box_results["true_label"] == label]
88+
n_pred_boxes = pred_df.shape[0]
89+
if n_pred_boxes > 0:
90+
class_precision_dict[label] = (
91+
sum(pred_df.true_label == pred_df.predicted_label) / pred_df.shape[0]
92+
)
93+
class_size_dict[label] = n_ground_boxes
94+
95+
# the fillna is needed for the missing labels with 0 ground truths or 0 predictions
96+
class_recall = (
97+
pd.DataFrame(
98+
{
99+
# "label": class_recall_dict.keys(),
100+
"recall": pd.Series(class_recall_dict),
101+
"precision": pd.Series(class_precision_dict),
102+
"size": pd.Series(class_size_dict),
103+
}
104+
)
105+
.reset_index(names="label")
106+
.fillna(0)
107+
)
100108

101109
return class_recall
102110

@@ -160,9 +168,16 @@ def __evaluate_wrapper__(predictions, ground_df, iou_threshold, label_dict):
160168
results["results"]["true_label"] = results["results"]["true_label"].apply(
161169
lambda x: label_dict[x]
162170
)
171+
# TODO: do we need to return the "predictions" in the results?
172+
# TODO: DRY getting the proper predicted label column with `evaluate_boxes`
173+
# set the score and predicted label
174+
if "cropmodel_label" in predictions.columns:
175+
pred_label_col = "cropmodel_label"
176+
else:
177+
pred_label_col = "label"
163178
# avoid modifying a view
164179
results["predictions"] = predictions.copy()
165-
results["predictions"]["label"] = results["predictions"]["label"].apply(
180+
results["predictions"]["label"] = results["predictions"][pred_label_col].apply(
166181
lambda x: label_dict[x]
167182
)
168183

@@ -309,12 +324,16 @@ def evaluate_boxes(predictions, ground_df, iou_threshold=0.4):
309324
)
310325
results_df = results_df.rename(columns={"label": "true_label"})
311326
# set the score and predicted label
327+
if "cropmodel_label" in predictions.columns:
328+
pred_label_col = "cropmodel_label"
329+
else:
330+
pred_label_col = "label"
312331
results_df = results_df.merge(
313-
predictions[["score", "label"]],
332+
predictions[["score", pred_label_col]],
314333
left_on="prediction_id",
315334
right_index=True,
316335
)
317-
results_df = results_df.rename(columns={"label": "predicted_label"})
336+
results_df = results_df.rename(columns={pred_label_col: "predicted_label"})
318337
# set whether it is a match
319338
results_df["match"] = results_df["IoU"] >= iou_threshold
320339

0 commit comments

Comments
 (0)