Skip to content

Commit f0d4a0d

Browse files
authored
validate labels in BoxDataset (#1093)
1 parent c227169 commit f0d4a0d

File tree

2 files changed

+59
-3
lines changed

2 files changed

+59
-3
lines changed

src/deepforest/datasets/training.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,30 @@ def __init__(self,
5656
self.label_dict = label_dict
5757
self.preload_images = preload_images
5858

59+
self._validate_labels()
60+
5961
# Pin data to memory if desired
6062
if self.preload_images:
6163
print("Pinning dataset to GPU memory")
6264
self.image_dict = {}
6365
for idx, x in enumerate(self.image_names):
6466
self.image_dict[idx] = self.load_image(idx)
6567

68+
def _validate_labels(self):
69+
"""Validate that all labels in annotations exist in label_dict.
70+
71+
Raises:
72+
ValueError: If any label in annotations is missing from label_dict
73+
"""
74+
csv_labels = self.annotations['label'].unique()
75+
missing_labels = [label for label in csv_labels if label not in self.label_dict]
76+
77+
if missing_labels:
78+
raise ValueError(
79+
f"Labels {missing_labels} are missing from label_dict. "
80+
f"Please ensure all labels in the annotations exist as keys in label_dict."
81+
)
82+
6683
def __len__(self):
6784
return len(self.image_names)
6885

tests/test_datasets_training.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# test dataset model
2-
from deepforest import get_data
2+
from deepforest import get_data, main
33
from deepforest import utilities
44
import os
55
import pytest
@@ -125,7 +125,7 @@ def test_BoxDataset_format():
125125
root_dir = os.path.dirname(csv_file)
126126
ds = BoxDataset(csv_file=csv_file, root_dir=root_dir)
127127
image, targets, path = next(iter(ds))
128-
128+
129129
# Assert image is channels first format
130130
assert image.shape[0] == 3
131131

@@ -147,4 +147,43 @@ def test_multi_image_warning():
147147
# Between 0 and 1
148148
batch = ds[i]
149149
collated_batch = utilities.collate_fn([None, batch, batch])
150-
len(collated_batch[0]) == 2
150+
len(collated_batch[0]) == 2
151+
152+
def test_label_validation__training_csv():
153+
"""Test training CSV labels are validated against label_dict"""
154+
m = main.deepforest(config_args={"num_classes": 1}, label_dict={"Bird": 0})
155+
m.config.train.csv_file = get_data("example.csv") # contains 'Tree' label
156+
m.config.train.root_dir = os.path.dirname(get_data("example.csv"))
157+
m.create_trainer()
158+
159+
with pytest.raises(ValueError, match="Labels \\['Tree'\\] are missing from label_dict"):
160+
m.trainer.fit(m)
161+
162+
163+
def test_csv_label_validation__validation_csv(m):
164+
"""Test validation CSV labels are validated against label_dict"""
165+
m = main.deepforest(config_args={"num_classes": 1}, label_dict={"Tree": 0})
166+
m.config.train.csv_file = get_data("example.csv") # contains 'Tree' label
167+
m.config.train.root_dir = os.path.dirname(get_data("example.csv"))
168+
m.config.validation.csv_file = get_data("testfile_multi.csv") # contains 'Dead', 'Alive' labels
169+
m.config.validation.root_dir = os.path.dirname(get_data("testfile_multi.csv"))
170+
m.create_trainer()
171+
172+
with pytest.raises(ValueError, match="Labels \\['Dead', 'Alive'\\] are missing from label_dict"):
173+
m.trainer.fit(m)
174+
175+
176+
def test_BoxDataset_validate_labels():
177+
"""Test that BoxDataset validates labels correctly"""
178+
from deepforest.datasets.training import BoxDataset
179+
180+
csv_file = get_data("example.csv") # contains 'Tree' label
181+
root_dir = os.path.dirname(csv_file)
182+
183+
# Valid case: CSV labels are in label_dict
184+
ds = BoxDataset(csv_file=csv_file, root_dir=root_dir, label_dict={"Tree": 0})
185+
# Should not raise an error
186+
187+
# Invalid case: CSV labels are not in label_dict
188+
with pytest.raises(ValueError, match="Labels \\['Tree'\\] are missing from label_dict"):
189+
BoxDataset(csv_file=csv_file, root_dir=root_dir, label_dict={"Bird": 0})

0 commit comments

Comments
 (0)