55"""
66
77import glob
8+ import os
9+ import warnings
10+ from pathlib import Path
811
12+ import matplotlib .pyplot as plt
913import numpy as np
1014import supervision as sv
15+ import torch
16+ from PIL import Image
1117from pytorch_lightning import Callback
1218
13- from deepforest import visualize
19+ from deepforest import utilities , visualize
20+ from deepforest .datasets .training import BoxDataset
1421
1522
16- class images_callback (Callback ):
23+ class ImagesCallback (Callback ):
1724 """Log evaluation images during training.
1825
1926 Args:
20- savedir : Directory to save predicted images
27+ save_dir : Directory to save predicted images
2128 n: Number of images to process
2229 every_n_epochs: Run interval in epochs
2330 select_random: Whether to select random images
@@ -26,61 +33,190 @@ class images_callback(Callback):
2633 """
2734
2835 def __init__ (
29- self , savedir , n = 2 , every_n_epochs = 5 , select_random = False , color = None , thickness = 1
36+ self ,
37+ save_dir ,
38+ prediction_samples = 2 ,
39+ dataset_samples = 5 ,
40+ every_n_epochs = 5 ,
41+ select_random = False ,
42+ color = None ,
43+ thickness = 1 ,
3044 ):
31- self .savedir = savedir
32- self .n = n
45+ self .savedir = save_dir
46+ self .prediction_samples = prediction_samples
47+ self .dataset_samples = dataset_samples
3348 self .color = color
3449 self .thickness = thickness
3550 self .select_random = select_random
3651 self .every_n_epochs = every_n_epochs
3752
38- def log_images (self , pl_module ):
39- """Log images to the logger."""
53+ def on_train_start (self , trainer , pl_module ):
54+ """Log sample images from training and validation datasets at training
55+ start."""
56+
57+ if trainer .fast_dev_run :
58+ return
59+
60+ self .trainer = trainer
61+ self .pl_module = pl_module
62+
63+ # Training samples
64+ self .pl_module .print ("Logging training dataset samples." )
65+ train_ds = self .trainer .train_dataloader .dataset
66+ self ._log_dataset_sample (train_ds , split = "train" )
67+
68+ # Validation samples
69+ if self .trainer .val_dataloaders :
70+ self .pl_module .print ("Logging validation dataset samples." )
71+ val_ds = self .trainer .val_dataloaders .dataset
72+ self ._log_dataset_sample (val_ds , split = "validation" )
73+
74+ def on_validation_end (self , trainer , pl_module ):
75+ """Run callback at validation end."""
76+ if trainer .sanity_checking or trainer .fast_dev_run :
77+ return
78+
79+ if trainer .current_epoch % self .every_n_epochs == 0 :
80+ pl_module .print ("Running image callback" )
81+ self ._log_last_predictions (trainer , pl_module )
82+
83+ def _log_dataset_sample (self , dataset : BoxDataset , split : str ):
84+ """Log random samples from a DeepForest BoxDataset."""
85+
86+ if self .dataset_samples == 0 :
87+ return
88+
89+ out_dir = os .path .join (self .savedir , split + "_sample" )
90+ os .makedirs (out_dir , exist_ok = True )
91+ n_samples = min (self .dataset_samples , len (dataset ))
92+ sample_indices = torch .randperm (len (dataset ))[:n_samples ]
93+
94+ sample_data = [dataset [idx ] for idx in sample_indices ]
95+ sample_images = [data [0 ] for data in sample_data ]
96+ sample_targets = [data [1 ] for data in sample_data ]
97+ sample_paths = [data [2 ] for data in sample_data ]
98+
99+ for image , target , path in zip (
100+ sample_images , sample_targets , sample_paths , strict = False
101+ ):
102+ image_annotations = target .copy ()
103+ image_annotations = utilities .format_geometry (image_annotations , scores = False )
104+ image_annotations .root_dir = dataset .root_dir
105+ image_annotations ["image_path" ] = path
106+
107+ # Plot transformed image
108+ basename = Path (path ).stem
109+ image = (255 * image .cpu ().numpy ().transpose ((1 , 2 , 0 ))).astype (np .uint8 )
110+ fig = visualize .plot_annotations (
111+ image = image ,
112+ annotations = image_annotations ,
113+ savedir = out_dir ,
114+ basename = basename ,
115+ thickness = self .thickness ,
116+ show = False ,
117+ )
118+ plt .close (fig )
119+
120+ self ._log_to_all (
121+ image = os .path .join (out_dir , basename + ".png" ),
122+ trainer = self .trainer ,
123+ tag = f"{ split } dataset sample" ,
124+ )
125+
126+ def _log_last_predictions (self , trainer , pl_module ):
127+ """Log sample of predictions + targets from last validation."""
128+ if self .prediction_samples == 0 :
129+ return
130+
131+ out_dir = os .path .join (self .savedir , "predictions" )
132+ os .makedirs (out_dir , exist_ok = True )
40133 df = pl_module .predictions
41134
135+ # Add root_dir to the dataframe
136+ if "root_dir" not in df .columns :
137+ df ["root_dir" ] = trainer .val_dataloaders .dataset .root_dir
138+
42139 # Limit to n images, potentially randomly selected
43140 if self .select_random :
44- selected_images = np .random .choice (df .image_path .unique (), self .n )
141+ selected_images = np .random .choice (
142+ df .image_path .unique (), self .prediction_samples
143+ )
45144 else :
46- selected_images = df .image_path .unique ()[: self .n ]
47- df = df [df .image_path .isin (selected_images )]
145+ selected_images = df .image_path .unique ()[: self .prediction_samples ]
146+
147+ # Ensure color is correctly assigned
148+ if self .color is None :
149+ num_classes = len (df ["label" ].unique ())
150+ results_color = sv .ColorPalette .from_matplotlib ("viridis" , num_classes )
151+ else :
152+ results_color = self .color
153+
154+ for image_name in selected_images :
155+ pred_df = df [df .image_path == image_name ]
156+ targets = utilities .format_geometry (
157+ pl_module .targets [image_name ], scores = False
158+ )
48159
49- # Add root_dir to the dataframe
50- if "root_dir" not in df .columns :
51- df ["root_dir" ] = pl_module .config .validation .root_dir
160+ # Assume that validation images are un-augmented
161+ fig = visualize .plot_results (
162+ results = pred_df ,
163+ ground_truth = targets ,
164+ savedir = out_dir ,
165+ results_color = results_color ,
166+ thickness = self .thickness ,
167+ show = False ,
168+ )
169+ plt .close (fig )
52170
53- # Ensure color is correctly assigned
54- if self .color is None :
55- num_classes = len (df ["label" ].unique ())
56- results_color = sv .ColorPalette .from_matplotlib ("viridis" , num_classes )
57- else :
58- results_color = self .color
59-
60- # Plot results
61- visualize .plot_results (
62- results = df ,
63- savedir = self .savedir ,
64- results_color = results_color ,
65- thickness = self .thickness ,
66- )
171+ saved_plots = glob .glob (f"{ out_dir } /*.png" )
172+ for saved_plot in saved_plots :
173+ self ._log_to_all (image = saved_plot , trainer = trainer , tag = "prediction sample" )
174+
175+ def _log_to_all (self , image : str , trainer , tag ):
176+ """Log to all connected loggers.
67177
178+ Since Comet will pickup image logs to Tensorboard by default, we
179+ add a check to log images preferentially to Tensorboard if both
180+ are enabled.
181+ """
68182 try :
69- saved_plots = glob .glob (f"{ self .savedir } /*.png" )
70- for x in saved_plots :
71- pl_module .logger .experiment .log_image (x )
72- except Exception as e :
73- print (
74- "Could not find comet logger in lightning module, "
75- f"skipping upload, images were saved to { self .savedir } , "
76- f"error was raised { e } "
183+ img = np .array (Image .open (image ).convert ("RGB" ))
184+
185+ loggers = [lg for lg in trainer .loggers if hasattr (lg , "experiment" )]
186+
187+ tb = next ((lg for lg in loggers if hasattr (lg .experiment , "add_image" )), None )
188+ if tb is not None :
189+ tb .experiment .add_image (
190+ tag = f"{ tag } /{ os .path .basename (image )} " ,
191+ img_tensor = img ,
192+ global_step = trainer .global_step ,
193+ dataformats = "HWC" ,
194+ )
195+ return
196+
197+ comet = next (
198+ (lg .experiment for lg in loggers if hasattr (lg .experiment , "log_image" )),
199+ None ,
77200 )
201+ if comet is not None :
202+ comet .experiment .log_image (
203+ img ,
204+ name = tag ,
205+ step = trainer .global_step ,
206+ metadata = {
207+ "image_name" : os .path .basename (image ),
208+ "context" : tag ,
209+ "step" : trainer .global_step ,
210+ },
211+ )
78212
79- def on_validation_end (self , trainer , pl_module ):
80- """Run callback at validation end."""
81- if trainer .sanity_checking :
82- return
213+ except Exception as e :
214+ warnings .warn (f"Tried to log { image } exception raised: { e } " , stacklevel = 2 )
83215
84- if trainer .current_epoch % self .every_n_epochs == 0 :
85- print ("Running image callback" )
86- self .log_images (pl_module )
216+
217+ class images_callback (ImagesCallback ):
218+ def __init__ (self , savedir , ** kwargs ):
219+ warnings .warn (
220+ "Please use ImagesCallback instead." , DeprecationWarning , stacklevel = 2
221+ )
222+ super ().__init__ (save_dir = savedir , ** kwargs )
0 commit comments