Skip to content

Commit a60b8ff

Browse files
Add validation checks for bounding boxes
Checks to see if bounding boxes occur outside of image boundaries and clearly communicates to user when they do. This is a version of weecology#1015 that accounts for changes in the codebase over the last year. Closes weecology#1014 Co-authored-by: Keerthi Reddy <[email protected]>
1 parent 682e326 commit a60b8ff

File tree

2 files changed

+93
-3
lines changed

2 files changed

+93
-3
lines changed

src/deepforest/datasets/training.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(
7171
self.preload_images = preload_images
7272

7373
self._validate_labels()
74+
self._validate_coordinates()
7475

7576
# Pin data to memory if desired
7677
if self.preload_images:
@@ -94,6 +95,49 @@ def _validate_labels(self):
9495
f"Please ensure all labels in the annotations exist as keys in label_dict."
9596
)
9697

98+
def _validate_coordinates(self):
99+
"""Validate that all bounding box coordinates occur within the image.
100+
101+
Raises:
102+
ValueError: If any bounding box coordinate occurs outside the image
103+
"""
104+
errors = []
105+
for idx, row in self.annotations.iterrows():
106+
img_path = os.path.join(self.root_dir, row["image_path"])
107+
try:
108+
with Image.open(img_path) as img:
109+
width, height = img.size
110+
except Exception as e:
111+
errors.append(f"Failed to open image {img_path}: {e}")
112+
continue
113+
114+
# Extract bounding box
115+
try:
116+
geom = row["geometry"]
117+
xmin, ymin, xmax, ymax = geom.bounds
118+
except Exception as e:
119+
errors.append(f"Invalid box format at index {idx}: {e}")
120+
continue
121+
122+
# Check if box is valid
123+
oob_issues = []
124+
if xmin < 0:
125+
oob_issues.append(f"xmin ({xmin}) < 0")
126+
if xmax > width:
127+
oob_issues.append(f"xmax ({xmax}) > image width ({width})")
128+
if ymin < 0:
129+
oob_issues.append(f"ymin ({ymin}) < 0")
130+
if ymax > height:
131+
oob_issues.append(f"ymax ({ymax}) > image height ({height})")
132+
133+
if oob_issues:
134+
errors.append(
135+
f"Box, ({xmin}, {ymin}, {xmax}, {ymax}) exceeds image dimensions, ({width}, {height}). Issues: {', '.join(oob_issues)}."
136+
)
137+
138+
if errors:
139+
raise ValueError("\n".join(errors))
140+
97141
def __len__(self):
98142
return len(self.image_names)
99143

tests/test_datasets_training.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pandas as pd
77
import pytest
88
import torch
9+
from PIL import Image
910

1011
from deepforest import get_data, main, utilities
1112
from deepforest.datasets.training import BoxDataset
@@ -195,13 +196,49 @@ def test_BoxDataset_validate_labels():
195196
BoxDataset(csv_file=csv_file, root_dir=root_dir, label_dict={"Bird": 0})
196197

197198

199+
def test_BoxDataset_validate_coordinates(tmpdir, raster_path):
200+
# Valid case: uses example.csv with all valid boxes
201+
csv_path = get_data("example.csv")
202+
root_dir = os.path.dirname(csv_path)
203+
ds = BoxDataset(csv_file=csv_path, root_dir=root_dir)
204+
205+
# Test various invalid box coordinates
206+
with Image.open(raster_path) as image:
207+
width, height = image.size
208+
209+
invalid_boxes = [
210+
(width - 5, 0, width + 10, 10), # xmax exceeds width
211+
(0, height - 5, 10, height + 10), # ymax exceeds height
212+
(-5, 0, 10, 10), # negative xmin
213+
(0, -5, 10, 10), # negative ymin
214+
]
215+
216+
for box in invalid_boxes:
217+
csv_path = os.path.join(tmpdir, "test.csv")
218+
df = pd.DataFrame(
219+
{
220+
"image_path": ["OSBS_029.tif"],
221+
"xmin": [box[0]],
222+
"ymin": [box[1]],
223+
"xmax": [box[2]],
224+
"ymax": [box[3]],
225+
"label": ["Tree"],
226+
}
227+
)
228+
df.to_csv(csv_path, index=False)
229+
230+
with pytest.raises(ValueError, match="exceeds image dimensions"):
231+
print(BoxDataset(csv_file=csv_path, root_dir=root_dir))
232+
233+
198234
def test_BoxDataset_with_projected_shapefile(tmpdir, raster_path):
199235
"""Test that BoxDataset can load a shapefile with projected coordinates and converts to pixel coordinates"""
200236
import geopandas as gpd
201237

202238
# Get the raster to extract CRS and bounds
203239
import rasterio
204240
from shapely import geometry
241+
205242
with rasterio.open(raster_path) as src:
206243
raster_crs = src.crs
207244
bounds = src.bounds
@@ -216,12 +253,19 @@ def test_BoxDataset_with_projected_shapefile(tmpdir, raster_path):
216253

217254
sample_geometry = [
218255
geometry.box(sample_x, sample_y, sample_x + box_size, sample_y + box_size),
219-
geometry.box(sample_x + box_size * 2, sample_y + box_size * 2, sample_x + box_size * 3, sample_y + box_size * 3)
256+
geometry.box(
257+
sample_x + box_size * 2,
258+
sample_y + box_size * 2,
259+
sample_x + box_size * 3,
260+
sample_y + box_size * 3,
261+
),
220262
]
221263
labels = ["Tree", "Tree"]
222264
image_path = os.path.basename(raster_path)
223265

224-
df = pd.DataFrame({"geometry": sample_geometry, "label": labels, "image_path": image_path})
266+
df = pd.DataFrame(
267+
{"geometry": sample_geometry, "label": labels, "image_path": image_path}
268+
)
225269
gdf = gpd.GeoDataFrame(df, geometry="geometry", crs=raster_crs)
226270

227271
# Save as shapefile
@@ -241,6 +285,8 @@ def test_BoxDataset_with_projected_shapefile(tmpdir, raster_path):
241285
# Verify boxes are in pixel coordinates (should be positive and reasonable)
242286
# After geo_to_image_coordinates conversion, values should be in pixel space
243287
boxes = targets["boxes"]
244-
assert torch.all(boxes >= 0), "Boxes should have non-negative coordinates in pixel space"
288+
assert torch.all(boxes >= 0), (
289+
"Boxes should have non-negative coordinates in pixel space"
290+
)
245291
assert torch.all(boxes[:, 2] > boxes[:, 0]), "xmax should be greater than xmin"
246292
assert torch.all(boxes[:, 3] > boxes[:, 1]), "ymax should be greater than ymin"

0 commit comments

Comments
 (0)