66import pandas as pd
77import pytest
88import torch
9+ from PIL import Image
910
1011from deepforest import get_data , main , utilities
1112from deepforest .datasets .training import BoxDataset
@@ -25,10 +26,13 @@ def multi_class():
2526
2627@pytest .fixture ()
2728def raster_path ():
28- return get_data (path = ' OSBS_029.tif' )
29+ return get_data (path = " OSBS_029.tif" )
2930
3031
31- @pytest .mark .parametrize ("csv_file,label_dict" , [(single_class (), {"Tree" : 0 }), (multi_class (), {"Alive" : 0 , "Dead" : 1 })])
32+ @pytest .mark .parametrize (
33+ "csv_file,label_dict" ,
34+ [(single_class (), {"Tree" : 0 }), (multi_class (), {"Alive" : 0 , "Dead" : 1 })],
35+ )
3236def test_BoxDataset (csv_file , label_dict ):
3337 root_dir = os .path .dirname (get_data ("OSBS_029.png" ))
3438 ds = BoxDataset (csv_file = csv_file , root_dir = root_dir , label_dict = label_dict )
@@ -48,7 +52,7 @@ def test_BoxDataset(csv_file, label_dict):
4852
4953
5054def test_single_class_with_empty (tmpdir ):
51- """Add fake empty annotations to test parsing """
55+ """Add fake empty annotations to test parsing"""
5256 csv_file1 = get_data ("example.csv" )
5357 csv_file2 = get_data ("OSBS_029.csv" )
5458
@@ -64,9 +68,9 @@ def test_single_class_with_empty(tmpdir):
6468 df .to_csv (f"{ tmpdir } _test_empty.csv" )
6569
6670 root_dir = os .path .dirname (get_data ("OSBS_029.png" ))
67- ds = BoxDataset (csv_file = f" { tmpdir } _test_empty.csv" ,
68- root_dir = root_dir ,
69- label_dict = { "Tree" : 0 } )
71+ ds = BoxDataset (
72+ csv_file = f" { tmpdir } _test_empty.csv" , root_dir = root_dir , label_dict = { "Tree" : 0 }
73+ )
7074 assert len (ds ) == 2
7175 # First image has annotations
7276 assert not torch .sum (ds [0 ][1 ]["boxes" ]) == 0
@@ -78,9 +82,11 @@ def test_single_class_with_empty(tmpdir):
7882def test_BoxDataset_transform (augment ):
7983 csv_file = get_data ("example.csv" )
8084 root_dir = os .path .dirname (csv_file )
81- ds = BoxDataset (csv_file = csv_file ,
82- root_dir = root_dir ,
83- augmentations = ["HorizontalFlip" ] if augment else None )
85+ ds = BoxDataset (
86+ csv_file = csv_file ,
87+ root_dir = root_dir ,
88+ augmentations = ["HorizontalFlip" ] if augment else None ,
89+ )
8490
8591 for i in range (len (ds )):
8692 # Between 0 and 1
@@ -100,8 +106,7 @@ def test_collate():
100106 """Due to data augmentations the dataset class may yield empty bounding box annotations"""
101107 csv_file = get_data ("example.csv" )
102108 root_dir = os .path .dirname (csv_file )
103- ds = BoxDataset (csv_file = csv_file ,
104- root_dir = root_dir )
109+ ds = BoxDataset (csv_file = csv_file , root_dir = root_dir )
105110
106111 for i in range (len (ds )):
107112 # Between 0 and 1
@@ -114,8 +119,7 @@ def test_empty_collate():
114119 """Due to data augmentations the dataset class may yield empty bounding box annotations"""
115120 csv_file = get_data ("example.csv" )
116121 root_dir = os .path .dirname (csv_file )
117- ds = BoxDataset (csv_file = csv_file ,
118- root_dir = root_dir )
122+ ds = BoxDataset (csv_file = csv_file , root_dir = root_dir )
119123
120124 for i in range (len (ds )):
121125 # Between 0 and 1
@@ -145,8 +149,7 @@ def test_multi_image_warning():
145149 df .to_csv (csv_file )
146150
147151 root_dir = os .path .dirname (csv_file1 )
148- ds = BoxDataset (csv_file = csv_file ,
149- root_dir = root_dir )
152+ ds = BoxDataset (csv_file = csv_file , root_dir = root_dir )
150153
151154 for i in range (len (ds )):
152155 # Between 0 and 1
@@ -162,7 +165,9 @@ def test_label_validation__training_csv():
162165 m .config .train .root_dir = os .path .dirname (get_data ("example.csv" ))
163166 m .create_trainer ()
164167
165- with pytest .raises (ValueError , match = "Labels \\ ['Tree'\\ ] are missing from label_dict" ):
168+ with pytest .raises (
169+ ValueError , match = "Labels \\ ['Tree'\\ ] are missing from label_dict"
170+ ):
166171 m .trainer .fit (m )
167172
168173
@@ -171,11 +176,15 @@ def test_csv_label_validation__validation_csv(m):
171176 m = main .deepforest (config_args = {"num_classes" : 1 , "label_dict" : {"Tree" : 0 }})
172177 m .config .train .csv_file = get_data ("example.csv" ) # contains 'Tree' label
173178 m .config .train .root_dir = os .path .dirname (get_data ("example.csv" ))
174- m .config .validation .csv_file = get_data ("testfile_multi.csv" ) # contains 'Dead', 'Alive' labels
179+ m .config .validation .csv_file = get_data (
180+ "testfile_multi.csv"
181+ ) # contains 'Dead', 'Alive' labels
175182 m .config .validation .root_dir = os .path .dirname (get_data ("testfile_multi.csv" ))
176183 m .create_trainer ()
177184
178- with pytest .raises (ValueError , match = "Labels \\ ['Dead', 'Alive'\\ ] are missing from label_dict" ):
185+ with pytest .raises (
186+ ValueError , match = "Labels \\ ['Dead', 'Alive'\\ ] are missing from label_dict"
187+ ):
179188 m .trainer .fit (m )
180189
181190
@@ -191,17 +200,73 @@ def test_BoxDataset_validate_labels():
191200 # Should not raise an error
192201
193202 # Invalid case: CSV labels are not in label_dict
194- with pytest .raises (ValueError , match = "Labels \\ ['Tree'\\ ] are missing from label_dict" ):
203+ with pytest .raises (
204+ ValueError , match = "Labels \\ ['Tree'\\ ] are missing from label_dict"
205+ ):
195206 BoxDataset (csv_file = csv_file , root_dir = root_dir , label_dict = {"Bird" : 0 })
196207
197208
209+ def test_validate_BoxDataset_missing_image (tmpdir , raster_path ):
210+ csv_path = os .path .join (tmpdir , "test.csv" )
211+ df = pd .DataFrame (
212+ {
213+ "image_path" : ["missing.tif" ],
214+ "xmin" : 0 ,
215+ "ymin" : 0 ,
216+ "xmax" : 10 ,
217+ "ymax" : 10 ,
218+ "label" : ["Tree" ],
219+ }
220+ )
221+ df .to_csv (csv_path , index = False )
222+ root_dir = os .path .dirname (raster_path )
223+ with pytest .raises (ValueError , match = "Failed to open image" ):
224+ _ = BoxDataset (csv_file = csv_path , root_dir = root_dir )
225+
226+
227+ def test_BoxDataset_validate_coordinates (tmpdir , raster_path ):
228+ # Valid case: uses example.csv with all valid boxes
229+ csv_path = get_data ("example.csv" )
230+ root_dir = os .path .dirname (csv_path )
231+ _ = BoxDataset (csv_file = csv_path , root_dir = root_dir )
232+
233+ # Test various invalid box coordinates
234+ with Image .open (raster_path ) as image :
235+ width , height = image .size
236+
237+ invalid_boxes = [
238+ (width - 5 , 0 , width + 10 , 10 ), # xmax exceeds width
239+ (0 , height - 5 , 10 , height + 10 ), # ymax exceeds height
240+ (- 5 , 0 , 10 , 10 ), # negative xmin
241+ (0 , - 5 , 10 , 10 ), # negative ymin
242+ ]
243+
244+ for box in invalid_boxes :
245+ csv_path = os .path .join (tmpdir , "test.csv" )
246+ df = pd .DataFrame (
247+ {
248+ "image_path" : ["OSBS_029.tif" ],
249+ "xmin" : [box [0 ]],
250+ "ymin" : [box [1 ]],
251+ "xmax" : [box [2 ]],
252+ "ymax" : [box [3 ]],
253+ "label" : ["Tree" ],
254+ }
255+ )
256+ df .to_csv (csv_path , index = False )
257+
258+ with pytest .raises (ValueError , match = "exceeds image dimensions" ):
259+ BoxDataset (csv_file = csv_path , root_dir = root_dir )
260+
261+
198262def test_BoxDataset_with_projected_shapefile (tmpdir , raster_path ):
199263 """Test that BoxDataset can load a shapefile with projected coordinates and converts to pixel coordinates"""
200264 import geopandas as gpd
201265
202266 # Get the raster to extract CRS and bounds
203267 import rasterio
204268 from shapely import geometry
269+
205270 with rasterio .open (raster_path ) as src :
206271 raster_crs = src .crs
207272 bounds = src .bounds
@@ -216,12 +281,19 @@ def test_BoxDataset_with_projected_shapefile(tmpdir, raster_path):
216281
217282 sample_geometry = [
218283 geometry .box (sample_x , sample_y , sample_x + box_size , sample_y + box_size ),
219- geometry .box (sample_x + box_size * 2 , sample_y + box_size * 2 , sample_x + box_size * 3 , sample_y + box_size * 3 )
284+ geometry .box (
285+ sample_x + box_size * 2 ,
286+ sample_y + box_size * 2 ,
287+ sample_x + box_size * 3 ,
288+ sample_y + box_size * 3 ,
289+ ),
220290 ]
221291 labels = ["Tree" , "Tree" ]
222292 image_path = os .path .basename (raster_path )
223293
224- df = pd .DataFrame ({"geometry" : sample_geometry , "label" : labels , "image_path" : image_path })
294+ df = pd .DataFrame (
295+ {"geometry" : sample_geometry , "label" : labels , "image_path" : image_path }
296+ )
225297 gdf = gpd .GeoDataFrame (df , geometry = "geometry" , crs = raster_crs )
226298
227299 # Save as shapefile
@@ -241,6 +313,8 @@ def test_BoxDataset_with_projected_shapefile(tmpdir, raster_path):
241313 # Verify boxes are in pixel coordinates (should be positive and reasonable)
242314 # After geo_to_image_coordinates conversion, values should be in pixel space
243315 boxes = targets ["boxes" ]
244- assert torch .all (boxes >= 0 ), "Boxes should have non-negative coordinates in pixel space"
316+ assert torch .all (boxes >= 0 ), (
317+ "Boxes should have non-negative coordinates in pixel space"
318+ )
245319 assert torch .all (boxes [:, 2 ] > boxes [:, 0 ]), "xmax should be greater than xmin"
246320 assert torch .all (boxes [:, 3 ] > boxes [:, 1 ]), "ymax should be greater than ymin"
0 commit comments