Skip to content

Commit c8a48c2

Browse files
Add validation checks for bounding boxes
Checks to see if bounding boxes occur outside of image boundaries and clearly communicates to user when they do. This is a version of #1015 that accounts for changes in the codebase over the last year. Closes #1014 Co-authored-by: Keerthi Reddy <[email protected]>
1 parent 682e326 commit c8a48c2

File tree

2 files changed

+140
-22
lines changed

2 files changed

+140
-22
lines changed

src/deepforest/datasets/training.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(
7171
self.preload_images = preload_images
7272

7373
self._validate_labels()
74+
self._validate_coordinates()
7475

7576
# Pin data to memory if desired
7677
if self.preload_images:
@@ -94,6 +95,49 @@ def _validate_labels(self):
9495
f"Please ensure all labels in the annotations exist as keys in label_dict."
9596
)
9697

98+
def _validate_coordinates(self):
99+
"""Validate that all bounding box coordinates occur within the image.
100+
101+
Raises:
102+
ValueError: If any bounding box coordinate occurs outside the image
103+
"""
104+
errors = []
105+
for idx, row in self.annotations.iterrows():
106+
img_path = os.path.join(self.root_dir, row["image_path"])
107+
try:
108+
with Image.open(img_path) as img:
109+
width, height = img.size
110+
except Exception as e:
111+
errors.append(f"Failed to open image {img_path}: {e}")
112+
continue
113+
114+
# Extract bounding box
115+
try:
116+
geom = row["geometry"]
117+
xmin, ymin, xmax, ymax = geom.bounds
118+
except Exception as e:
119+
errors.append(f"Invalid box format at index {idx}: {e}")
120+
continue
121+
122+
# Check if box is valid
123+
oob_issues = []
124+
if xmin < 0:
125+
oob_issues.append(f"xmin ({xmin}) < 0")
126+
if xmax > width:
127+
oob_issues.append(f"xmax ({xmax}) > image width ({width})")
128+
if ymin < 0:
129+
oob_issues.append(f"ymin ({ymin}) < 0")
130+
if ymax > height:
131+
oob_issues.append(f"ymax ({ymax}) > image height ({height})")
132+
133+
if oob_issues:
134+
errors.append(
135+
f"Box, ({xmin}, {ymin}, {xmax}, {ymax}) exceeds image dimensions, ({width}, {height}). Issues: {', '.join(oob_issues)}."
136+
)
137+
138+
if errors:
139+
raise ValueError("\n".join(errors))
140+
97141
def __len__(self):
98142
return len(self.image_names)
99143

tests/test_datasets_training.py

Lines changed: 96 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pandas as pd
77
import pytest
88
import torch
9+
from PIL import Image
910

1011
from deepforest import get_data, main, utilities
1112
from deepforest.datasets.training import BoxDataset
@@ -25,10 +26,13 @@ def multi_class():
2526

2627
@pytest.fixture()
2728
def raster_path():
28-
return get_data(path='OSBS_029.tif')
29+
return get_data(path="OSBS_029.tif")
2930

3031

31-
@pytest.mark.parametrize("csv_file,label_dict", [(single_class(), {"Tree": 0}), (multi_class(), {"Alive": 0, "Dead": 1})])
32+
@pytest.mark.parametrize(
33+
"csv_file,label_dict",
34+
[(single_class(), {"Tree": 0}), (multi_class(), {"Alive": 0, "Dead": 1})],
35+
)
3236
def test_BoxDataset(csv_file, label_dict):
3337
root_dir = os.path.dirname(get_data("OSBS_029.png"))
3438
ds = BoxDataset(csv_file=csv_file, root_dir=root_dir, label_dict=label_dict)
@@ -48,7 +52,7 @@ def test_BoxDataset(csv_file, label_dict):
4852

4953

5054
def test_single_class_with_empty(tmpdir):
51-
"""Add fake empty annotations to test parsing """
55+
"""Add fake empty annotations to test parsing"""
5256
csv_file1 = get_data("example.csv")
5357
csv_file2 = get_data("OSBS_029.csv")
5458

@@ -64,9 +68,9 @@ def test_single_class_with_empty(tmpdir):
6468
df.to_csv(f"{tmpdir}_test_empty.csv")
6569

6670
root_dir = os.path.dirname(get_data("OSBS_029.png"))
67-
ds = BoxDataset(csv_file=f"{tmpdir}_test_empty.csv",
68-
root_dir=root_dir,
69-
label_dict={"Tree": 0})
71+
ds = BoxDataset(
72+
csv_file=f"{tmpdir}_test_empty.csv", root_dir=root_dir, label_dict={"Tree": 0}
73+
)
7074
assert len(ds) == 2
7175
# First image has annotations
7276
assert not torch.sum(ds[0][1]["boxes"]) == 0
@@ -78,9 +82,11 @@ def test_single_class_with_empty(tmpdir):
7882
def test_BoxDataset_transform(augment):
7983
csv_file = get_data("example.csv")
8084
root_dir = os.path.dirname(csv_file)
81-
ds = BoxDataset(csv_file=csv_file,
82-
root_dir=root_dir,
83-
augmentations=["HorizontalFlip"] if augment else None)
85+
ds = BoxDataset(
86+
csv_file=csv_file,
87+
root_dir=root_dir,
88+
augmentations=["HorizontalFlip"] if augment else None,
89+
)
8490

8591
for i in range(len(ds)):
8692
# Between 0 and 1
@@ -100,8 +106,7 @@ def test_collate():
100106
"""Due to data augmentations the dataset class may yield empty bounding box annotations"""
101107
csv_file = get_data("example.csv")
102108
root_dir = os.path.dirname(csv_file)
103-
ds = BoxDataset(csv_file=csv_file,
104-
root_dir=root_dir)
109+
ds = BoxDataset(csv_file=csv_file, root_dir=root_dir)
105110

106111
for i in range(len(ds)):
107112
# Between 0 and 1
@@ -114,8 +119,7 @@ def test_empty_collate():
114119
"""Due to data augmentations the dataset class may yield empty bounding box annotations"""
115120
csv_file = get_data("example.csv")
116121
root_dir = os.path.dirname(csv_file)
117-
ds = BoxDataset(csv_file=csv_file,
118-
root_dir=root_dir)
122+
ds = BoxDataset(csv_file=csv_file, root_dir=root_dir)
119123

120124
for i in range(len(ds)):
121125
# Between 0 and 1
@@ -145,8 +149,7 @@ def test_multi_image_warning():
145149
df.to_csv(csv_file)
146150

147151
root_dir = os.path.dirname(csv_file1)
148-
ds = BoxDataset(csv_file=csv_file,
149-
root_dir=root_dir)
152+
ds = BoxDataset(csv_file=csv_file, root_dir=root_dir)
150153

151154
for i in range(len(ds)):
152155
# Between 0 and 1
@@ -162,7 +165,9 @@ def test_label_validation__training_csv():
162165
m.config.train.root_dir = os.path.dirname(get_data("example.csv"))
163166
m.create_trainer()
164167

165-
with pytest.raises(ValueError, match="Labels \\['Tree'\\] are missing from label_dict"):
168+
with pytest.raises(
169+
ValueError, match="Labels \\['Tree'\\] are missing from label_dict"
170+
):
166171
m.trainer.fit(m)
167172

168173

@@ -171,11 +176,15 @@ def test_csv_label_validation__validation_csv(m):
171176
m = main.deepforest(config_args={"num_classes": 1, "label_dict": {"Tree": 0}})
172177
m.config.train.csv_file = get_data("example.csv") # contains 'Tree' label
173178
m.config.train.root_dir = os.path.dirname(get_data("example.csv"))
174-
m.config.validation.csv_file = get_data("testfile_multi.csv") # contains 'Dead', 'Alive' labels
179+
m.config.validation.csv_file = get_data(
180+
"testfile_multi.csv"
181+
) # contains 'Dead', 'Alive' labels
175182
m.config.validation.root_dir = os.path.dirname(get_data("testfile_multi.csv"))
176183
m.create_trainer()
177184

178-
with pytest.raises(ValueError, match="Labels \\['Dead', 'Alive'\\] are missing from label_dict"):
185+
with pytest.raises(
186+
ValueError, match="Labels \\['Dead', 'Alive'\\] are missing from label_dict"
187+
):
179188
m.trainer.fit(m)
180189

181190

@@ -191,17 +200,73 @@ def test_BoxDataset_validate_labels():
191200
# Should not raise an error
192201

193202
# Invalid case: CSV labels are not in label_dict
194-
with pytest.raises(ValueError, match="Labels \\['Tree'\\] are missing from label_dict"):
203+
with pytest.raises(
204+
ValueError, match="Labels \\['Tree'\\] are missing from label_dict"
205+
):
195206
BoxDataset(csv_file=csv_file, root_dir=root_dir, label_dict={"Bird": 0})
196207

197208

209+
def test_validate_BoxDataset_missing_image(tmpdir, raster_path):
210+
csv_path = os.path.join(tmpdir, "test.csv")
211+
df = pd.DataFrame(
212+
{
213+
"image_path": ["missing.tif"],
214+
"xmin": 0,
215+
"ymin": 0,
216+
"xmax": 10,
217+
"ymax": 10,
218+
"label": ["Tree"],
219+
}
220+
)
221+
df.to_csv(csv_path, index=False)
222+
root_dir = os.path.dirname(raster_path)
223+
with pytest.raises(ValueError, match="Failed to open image"):
224+
_ = BoxDataset(csv_file=csv_path, root_dir=root_dir)
225+
226+
227+
def test_BoxDataset_validate_coordinates(tmpdir, raster_path):
228+
# Valid case: uses example.csv with all valid boxes
229+
csv_path = get_data("example.csv")
230+
root_dir = os.path.dirname(csv_path)
231+
_ = BoxDataset(csv_file=csv_path, root_dir=root_dir)
232+
233+
# Test various invalid box coordinates
234+
with Image.open(raster_path) as image:
235+
width, height = image.size
236+
237+
invalid_boxes = [
238+
(width - 5, 0, width + 10, 10), # xmax exceeds width
239+
(0, height - 5, 10, height + 10), # ymax exceeds height
240+
(-5, 0, 10, 10), # negative xmin
241+
(0, -5, 10, 10), # negative ymin
242+
]
243+
244+
for box in invalid_boxes:
245+
csv_path = os.path.join(tmpdir, "test.csv")
246+
df = pd.DataFrame(
247+
{
248+
"image_path": ["OSBS_029.tif"],
249+
"xmin": [box[0]],
250+
"ymin": [box[1]],
251+
"xmax": [box[2]],
252+
"ymax": [box[3]],
253+
"label": ["Tree"],
254+
}
255+
)
256+
df.to_csv(csv_path, index=False)
257+
258+
with pytest.raises(ValueError, match="exceeds image dimensions"):
259+
BoxDataset(csv_file=csv_path, root_dir=root_dir)
260+
261+
198262
def test_BoxDataset_with_projected_shapefile(tmpdir, raster_path):
199263
"""Test that BoxDataset can load a shapefile with projected coordinates and converts to pixel coordinates"""
200264
import geopandas as gpd
201265

202266
# Get the raster to extract CRS and bounds
203267
import rasterio
204268
from shapely import geometry
269+
205270
with rasterio.open(raster_path) as src:
206271
raster_crs = src.crs
207272
bounds = src.bounds
@@ -216,12 +281,19 @@ def test_BoxDataset_with_projected_shapefile(tmpdir, raster_path):
216281

217282
sample_geometry = [
218283
geometry.box(sample_x, sample_y, sample_x + box_size, sample_y + box_size),
219-
geometry.box(sample_x + box_size * 2, sample_y + box_size * 2, sample_x + box_size * 3, sample_y + box_size * 3)
284+
geometry.box(
285+
sample_x + box_size * 2,
286+
sample_y + box_size * 2,
287+
sample_x + box_size * 3,
288+
sample_y + box_size * 3,
289+
),
220290
]
221291
labels = ["Tree", "Tree"]
222292
image_path = os.path.basename(raster_path)
223293

224-
df = pd.DataFrame({"geometry": sample_geometry, "label": labels, "image_path": image_path})
294+
df = pd.DataFrame(
295+
{"geometry": sample_geometry, "label": labels, "image_path": image_path}
296+
)
225297
gdf = gpd.GeoDataFrame(df, geometry="geometry", crs=raster_crs)
226298

227299
# Save as shapefile
@@ -241,6 +313,8 @@ def test_BoxDataset_with_projected_shapefile(tmpdir, raster_path):
241313
# Verify boxes are in pixel coordinates (should be positive and reasonable)
242314
# After geo_to_image_coordinates conversion, values should be in pixel space
243315
boxes = targets["boxes"]
244-
assert torch.all(boxes >= 0), "Boxes should have non-negative coordinates in pixel space"
316+
assert torch.all(boxes >= 0), (
317+
"Boxes should have non-negative coordinates in pixel space"
318+
)
245319
assert torch.all(boxes[:, 2] > boxes[:, 0]), "xmax should be greater than xmin"
246320
assert torch.all(boxes[:, 3] > boxes[:, 1]), "ymax should be greater than ymin"

0 commit comments

Comments
 (0)