11# test dataset model
2- from deepforest import get_data
2+ from deepforest import get_data , main
33from deepforest import utilities
44import os
55import pytest
@@ -125,7 +125,7 @@ def test_BoxDataset_format():
125125 root_dir = os .path .dirname (csv_file )
126126 ds = BoxDataset (csv_file = csv_file , root_dir = root_dir )
127127 image , targets , path = next (iter (ds ))
128-
128+
129129 # Assert image is channels first format
130130 assert image .shape [0 ] == 3
131131
@@ -147,4 +147,43 @@ def test_multi_image_warning():
147147 # Between 0 and 1
148148 batch = ds [i ]
149149 collated_batch = utilities .collate_fn ([None , batch , batch ])
150- len (collated_batch [0 ]) == 2
150+ len (collated_batch [0 ]) == 2
151+
152+ def test_label_validation__training_csv ():
153+ """Test training CSV labels are validated against label_dict"""
154+ m = main .deepforest (config_args = {"num_classes" : 1 }, label_dict = {"Bird" : 0 })
155+ m .config .train .csv_file = get_data ("example.csv" ) # contains 'Tree' label
156+ m .config .train .root_dir = os .path .dirname (get_data ("example.csv" ))
157+ m .create_trainer ()
158+
159+ with pytest .raises (ValueError , match = "Labels \\ ['Tree'\\ ] are missing from label_dict" ):
160+ m .trainer .fit (m )
161+
162+
163+ def test_csv_label_validation__validation_csv (m ):
164+ """Test validation CSV labels are validated against label_dict"""
165+ m = main .deepforest (config_args = {"num_classes" : 1 }, label_dict = {"Tree" : 0 })
166+ m .config .train .csv_file = get_data ("example.csv" ) # contains 'Tree' label
167+ m .config .train .root_dir = os .path .dirname (get_data ("example.csv" ))
168+ m .config .validation .csv_file = get_data ("testfile_multi.csv" ) # contains 'Dead', 'Alive' labels
169+ m .config .validation .root_dir = os .path .dirname (get_data ("testfile_multi.csv" ))
170+ m .create_trainer ()
171+
172+ with pytest .raises (ValueError , match = "Labels \\ ['Dead', 'Alive'\\ ] are missing from label_dict" ):
173+ m .trainer .fit (m )
174+
175+
176+ def test_BoxDataset_validate_labels ():
177+ """Test that BoxDataset validates labels correctly"""
178+ from deepforest .datasets .training import BoxDataset
179+
180+ csv_file = get_data ("example.csv" ) # contains 'Tree' label
181+ root_dir = os .path .dirname (csv_file )
182+
183+ # Valid case: CSV labels are in label_dict
184+ ds = BoxDataset (csv_file = csv_file , root_dir = root_dir , label_dict = {"Tree" : 0 })
185+ # Should not raise an error
186+
187+ # Invalid case: CSV labels are not in label_dict
188+ with pytest .raises (ValueError , match = "Labels \\ ['Tree'\\ ] are missing from label_dict" ):
189+ BoxDataset (csv_file = csv_file , root_dir = root_dir , label_dict = {"Bird" : 0 })
0 commit comments