Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions src/deepforest/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(self,
"""
self.annotations = pd.read_csv(csv_file)
self.root_dir = root_dir
self.validate_annotations()
if transforms is None:
self.transform = get_transform(augment=train)
else:
Expand All @@ -88,6 +89,50 @@ def __init__(self,
image = np.array(Image.open(img_name).convert("RGB")) / 255
self.image_dict[idx] = image.astype("float32")

def validate_annotations(self):
errors = []
for idx, row in self.annotations.iterrows():
img_path = os.path.join(self.root_dir, row['image_path'])
try:
with Image.open(img_path) as img:
width, height = img.size
except Exception as e:
errors.append(f"Failed to open image {img_path}: {e}")
continue

# Extract bounding box
try:
if 'geometry' in self.annotations.columns:
geom = shapely.wkt.loads(row['geometry'])
xmin, ymin, xmax, ymax = geom.bounds
else:
xmin = row['xmin']
ymin = row['ymin']
xmax = row['xmax']
ymax = row['ymax']
except Exception as e:
errors.append(f"Invalid box format at index {idx}: {e}")
continue

#Check if box is valid
oob_issues = []
if xmin < 0:
oob_issues.append(f"xmin ({xmin}) < 0")
if xmax > width:
oob_issues.append(f"xmax ({xmax}) > image width ({width})")
if ymin < 0:
oob_issues.append(f"ymin ({ymin}) < 0")
if ymax > height:
oob_issues.append(f"ymax ({ymax}) > image height ({height})")

if oob_issues:
errors.append(
f"Box, ({xmin}, {ymin}, {xmax}, {ymax}) exceeds image dimensions, ({width}, {height}). Issues: {', '.join(oob_issues)}."
)

if errors:
raise ValueError("\n".join(errors))

def __len__(self):
return len(self.image_names)

Expand Down
50 changes: 48 additions & 2 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import numpy as np
import tempfile
import rasterio as rio
from deepforest.dataset import BoundingBoxDataset
from deepforest.dataset import BoundingBoxDataset, TreeDataset
from deepforest.dataset import RasterDataset
from torch.utils.data import DataLoader

from PIL import Image


def single_class():
Expand Down Expand Up @@ -211,3 +211,49 @@ def test_raster_dataset():
batch = next(iter(dataloader))
assert batch.shape[0] == 2 # Batch size
assert batch.shape[1] == 3 # Channels first

def test_validate_annotations_invalid_boxes(tmpdir):
img_path = get_data("OSBS_029.tif")
image = Image.open(img_path)
width, height = image.size

test_cases = [
{
"test_box": (width - 5, 0, width + 10, 10),
"reason": "xmax exceeds image width"
},
{
"test_box": (0, height - 5, 10, height + 10),
"reason": "ymax exceeds image height"
},
{
"test_box": (-5, 0, 10, 10),
"reason": "xmin is negative"
},
{
"test_box": (0, -5, 10, 10),
"reason": "ymin is negative"
},
]

for case in test_cases:
test_box = case["test_box"]
error_msg = "exceeds image dimensions"
reason = case["reason"]

csv_path = os.path.join(tmpdir, "test.csv")
df = pd.DataFrame({
"image_path": ["OSBS_029.tif"],
"xmin": [test_box[0]],
"ymin": [test_box[1]],
"xmax": [test_box[2]],
"ymax": [test_box[3]],
"label": ["Tree"]
})
df.to_csv(csv_path, index=False)
root_dir = os.path.dirname(img_path)

with pytest.raises(ValueError) as excinfo:
TreeDataset(csv_file=csv_path, root_dir=root_dir)

assert error_msg in str(excinfo.value), f"Test failed for case: {reason}"