1- """Evaluation callback for prediction saving and evaluation during training."""
2-
31import gzip
42import json
53import os
4+ import shutil
65import tempfile
76import warnings
87from glob import glob
8+ from pathlib import Path
99
10- import pandas as pd
10+ import torch
1111from pytorch_lightning import Callback , Trainer
1212from pytorch_lightning .core import LightningModule
1313
1414
1515class EvaluationCallback (Callback ):
16- """Accumulate validation predictions and save to disk during training.
17-
18- This callback accumulates predictions during validation and writes them
19- incrementally to a CSV file. At the end of validation, it saves metadata
20- and optionally runs evaluation. The saved predictions are in the format
21- expected by DeepForest's evaluation functions.
22-
23- The saved files follow this naming convention:
24- - Predictions: {save_dir}/predictions_epoch_{epoch}.csv (or .csv.gz if compressed)
25- - Metadata: {save_dir}/predictions_epoch_{epoch}_metadata.json
26-
27- Args:
28- save_dir (str): Directory to save prediction files. Will be created if it doesn't exist.
29- every_n_epochs (int, optional): Run interval in epochs. Set to -1 to disable callback.
30- Defaults to 5.
31- iou_threshold (float, optional): IoU threshold for evaluation when run_evaluation=True.
32- Defaults to 0.4.
33- run_evaluation (bool, optional): Whether to run evaluate_boxes at epoch end and log metrics.
34- Defaults to False.
35- compress (bool, optional): Whether to compress CSV files using gzip. When True, saves as
36- .csv.gz files for better storage efficiency. When False (default), saves as plain .csv
37- Defaults to False.
38-
39- Attributes:
40- save_dir (str): Directory where files are saved
41- every_n_epochs (int): Epoch interval for running callback
42- iou_threshold (float): IoU threshold used for evaluation
43- run_evaluation (bool): Whether evaluation is run at epoch end
44- predictions_written (int): Number of predictions written in current epoch
45-
46- Note:
47- This callback should be used with `val_accuracy_interval = -1` in the model config
48- to disable the built-in evaluation and avoid duplicate processing.
16+ """Accumulate validation predictions per batch, write one shard per rank,
17+ optionally merge shards on rank 0, and optionally run evaluation.
18+
19+ File names:
20+ - Shards: predictions_epoch_{E}_rank{R}.csv[.gz]
21+ - Merged: predictions_epoch_{E}.csv[.gz]
22+ - Meta: predictions_epoch_{E}_metadata.json
4923 """
5024
5125 def __init__ (
5226 self ,
53- save_dir : str = None ,
27+ save_dir : str | None = None ,
5428 every_n_epochs : int = 5 ,
5529 iou_threshold : float = 0.4 ,
5630 run_evaluation : bool = False ,
5731 compress : bool = False ,
5832 ) -> None :
5933 super ().__init__ ()
60-
61- self .temp_dir_obj = None
62- if not save_dir :
63- self .temp_dir_obj = tempfile .TemporaryDirectory ()
64- save_dir = self .temp_dir_obj .name
65-
66- self .save_dir = save_dir
34+ self ._user_save_dir = save_dir
35+ self .compress = compress
6736 self .every_n_epochs = every_n_epochs
6837 self .iou_threshold = iou_threshold
6938 self .run_evaluation = run_evaluation
70- self .compress = compress
71- self .predictions_written = 0
7239
73- def _should_skip (self , trainer : Trainer ) -> bool :
74- """Check if callback should be skipped for the current trainer
75- state."""
40+ self .save_dir : Path | None = None
41+ self ._is_temp = save_dir is None
42+ self ._rank_base : Path | None = None
43+ self .csv_file = None
44+ self .csv_path : Path | None = None
45+ self .predictions_written = 0 # rows written by *this rank* this epoch
7646
77- return (
47+ def _active_epoch (self , trainer : Trainer ) -> bool :
48+ e = trainer .current_epoch + 1
49+ return not (
7850 trainer .sanity_checking
7951 or trainer .fast_dev_run
8052 or self .every_n_epochs == - 1
81- or (trainer . current_epoch + 1 ) % self .every_n_epochs != 0
53+ or (e % self .every_n_epochs != 0 )
8254 )
8355
56+ def _open_writer (self , path : Path ):
57+ if self .compress :
58+ return gzip .open (path , "wt" , encoding = "utf-8" )
59+ return open (path , "w" , encoding = "utf-8" )
60+
61+ def setup (
62+ self , trainer : Trainer , pl_module : LightningModule , stage : str | None = None
63+ ):
64+ if self ._is_temp :
65+ # independent temp base per rank, then a rank subdir for clarity
66+ base = Path (tempfile .mkdtemp (prefix = "preds_" ))
67+ self ._rank_base = base
68+ self .save_dir = base / f"rank{ trainer .global_rank } "
69+ else :
70+ self .save_dir = Path (self ._user_save_dir ) # type: ignore[arg-type]
71+ self .save_dir .mkdir (parents = True , exist_ok = True )
72+
8473 def on_validation_epoch_start (
8574 self , trainer : Trainer , pl_module : LightningModule
8675 ) -> None :
87- if self ._should_skip (trainer ):
88- return
89-
90- # Create once to avoid races
91- if trainer .is_global_zero :
92- os .makedirs (self .save_dir , exist_ok = True )
93- trainer .strategy .barrier ()
94-
95- # Per-rank shard filename
96- rank = trainer .global_rank
97- csv_filename = f"predictions_epoch_{ trainer .current_epoch + 1 } _rank{ rank } .csv"
98- if self .compress :
99- csv_filename += ".gz"
100- csv_path = os .path .join (self .save_dir , csv_filename )
101-
102- self .csv_path = csv_path # path to this rank's shard
103- self .predictions_written = 0
104-
105- if self .compress :
106- self .csv_file = gzip .open (csv_path , "wt" , encoding = "utf-8" )
76+ if self ._active_epoch (trainer ):
77+ epoch = trainer .current_epoch + 1
78+ rank = trainer .global_rank
79+ suffix = ".csv.gz" if self .compress else ".csv"
80+ self .csv_path = (
81+ self .save_dir / f"predictions_epoch_{ epoch } _rank{ rank } { suffix } "
82+ )
83+ self .csv_file = self ._open_writer (self .csv_path )
10784 else :
108- self .csv_file = open (csv_path , "w" )
85+ self .csv_path = None
86+ self .csv_file = None
10987
110- self .csv_path = csv_path
11188 self .predictions_written = 0
11289
90+ trainer .strategy .barrier ()
91+
11392 def on_validation_batch_end (
11493 self ,
11594 trainer : Trainer ,
@@ -119,104 +98,141 @@ def on_validation_batch_end(
11998 batch_idx : int ,
12099 dataloader_idx : int = 0 ,
121100 ) -> None :
122- """Write predictions from current validation batch to CSV file."""
123- if self ._should_skip (trainer ) or self .csv_file is None :
101+ if not self ._active_epoch (trainer ) or self .csv_file is None :
124102 return
125-
126- # Get predictions from this batch
127- batch_preds = pl_module .last_preds
128-
129- for pred in batch_preds :
130- if pred is not None and not pred .empty :
131- pred .to_csv (
132- self .csv_file , index = False , header = (self .predictions_written == 0 )
133- )
134- self .predictions_written += len (pred )
103+ # expected: pl_module.last_preds is list[pd.DataFrame]
104+ batch_preds = getattr (pl_module , "last_preds" , None )
105+ if not batch_preds :
106+ return
107+ for df in batch_preds :
108+ if df is None or df .empty :
109+ continue
110+ df .to_csv (self .csv_file , index = False , header = (self .predictions_written == 0 ))
111+ self .predictions_written += len (df )
135112
136113 def on_validation_epoch_end (
137114 self , trainer : Trainer , pl_module : LightningModule
138115 ) -> None :
139- """Clean up at end of epoch.
140-
141- Handles DDP sync and merging of output shards.
142- """
143- if self ._should_skip (trainer ):
144- return
116+ strategy = trainer .strategy
117+ world_size = strategy .world_size
145118
146119 if self .csv_file is not None :
147120 self .csv_file .close ()
148121 self .csv_file = None
149122
150- # All ranks finished writing
151- trainer .strategy .barrier ()
123+ strategy .barrier () # all ranks finished writing
124+
125+ # Collect each rank's save_dir and row count
126+ if (
127+ world_size > 1
128+ and torch .distributed .is_available ()
129+ and torch .distributed .is_initialized ()
130+ ):
131+ rank_dirs : list [str | None ] = [None ] * world_size
132+ rank_counts : list [int ] = [0 ] * world_size
133+ torch .distributed .all_gather_object (rank_dirs , str (self .save_dir ))
134+ torch .distributed .all_gather_object (
135+ rank_counts , int (self .predictions_written )
136+ )
137+ else :
138+ rank_dirs = [str (self .save_dir )]
139+ rank_counts = [int (self .predictions_written )]
152140
153- # Merge on global rank 0
154- if trainer .is_global_zero :
155- epoch = trainer .current_epoch + 1
156- pattern = os .path .join (self .save_dir , f"predictions_epoch_{ epoch } _rank*.csv" )
141+ if self ._active_epoch (trainer ) and trainer .is_global_zero :
142+ self ._reduce_and_evaluate (
143+ trainer , pl_module , [Path (d ) for d in rank_dirs if d ], sum (rank_counts )
144+ )
145+
146+ strategy .barrier () # allow rank 0 to finish
147+
148+ def teardown (
149+ self , trainer : Trainer , pl_module : LightningModule , stage : str | None = None
150+ ):
151+ if self ._is_temp and self ._rank_base is not None :
152+ shutil .rmtree (self ._rank_base , ignore_errors = True )
153+
154+ def _reduce_and_evaluate (
155+ self ,
156+ trainer : Trainer ,
157+ pl_module : LightningModule ,
158+ rank_dirs : list [Path ],
159+ total_written : int ,
160+ ) -> None :
161+ epoch = trainer .current_epoch + 1
162+ suffix = ".csv.gz" if self .compress else ".csv"
163+
164+ # discover shards
165+ shard_paths : list [Path ] = []
166+ for d in rank_dirs :
167+ pattern = str (d / f"predictions_epoch_{ epoch } _rank*.csv" )
157168 if self .compress :
158169 pattern += ".gz"
170+ shard_paths .extend (sorted (Path (p ) for p in glob (pattern )))
159171
160- shard_paths = sorted (glob (pattern ))
161- merged_filename = f"predictions_epoch_{ epoch } .csv"
162- if self .compress :
163- merged_filename += ".gz"
164- merged_path = os .path .join (self .save_dir , merged_filename )
165-
166- # Concatenate shards
167- dfs = []
168- for p in shard_paths :
169- if p .endswith (".gz" ):
170- dfs .append (pd .read_csv (p , compression = "gzip" ))
171- else :
172- dfs .append (pd .read_csv (p ))
173- if dfs :
174- merged = pd .concat (dfs , ignore_index = True )
175- if self .compress :
176- merged .to_csv (merged_path , index = False , compression = "gzip" )
177- else :
178- merged .to_csv (merged_path , index = False )
179- total_written = len (merged )
180- else :
181- total_written = 0
182- merged_path = None
183-
184- # Save metadata
185- metadata = {
186- "epoch" : epoch ,
187- "current_step" : trainer .global_step ,
188- "predictions_count" : total_written ,
189- "target_csv_file" : getattr (pl_module .config .validation , "csv_file" , None ),
190- "target_root_dir" : getattr (pl_module .config .validation , "root_dir" , None ),
191- "shards" : shard_paths ,
192- "merged_predictions" : merged_path ,
193- "world_size" : trainer .world_size ,
194- }
195- with open (
196- os .path .join (self .save_dir , f"predictions_epoch_{ epoch } _metadata.json" ),
197- "w" ,
198- ) as f :
199- json .dump (metadata , f , indent = 2 )
200-
201- # Optional: cleanup shards after merge
202- for p in shard_paths :
172+ merged_path = (
173+ (self .save_dir / f"predictions_epoch_{ epoch } { suffix } " )
174+ if shard_paths
175+ else None
176+ )
177+
178+ # stream-merge shards into a single file without repeating headers
179+ if merged_path is not None :
180+ merged_path .parent .mkdir (parents = True , exist_ok = True )
181+ open_out = gzip .open if self .compress else open
182+ with open_out (merged_path , "wt" , encoding = "utf-8" ) as out_f :
183+ wrote_header = False
184+ for shard in shard_paths :
185+ open_in = (
186+ gzip .open
187+ if shard .suffix == ".gz" or shard .suffixes [- 2 :] == [".csv" , ".gz" ]
188+ else open
189+ )
190+ with open_in (shard , "rt" , encoding = "utf-8" ) as in_f :
191+ for i , line in enumerate (in_f ):
192+ if i == 0 and wrote_header :
193+ continue
194+ out_f .write (line )
195+ wrote_header = True
196+
197+ # metadata
198+ cfg = getattr (pl_module , "config" , None )
199+ val = getattr (cfg , "validation" , None )
200+ meta = {
201+ "epoch" : epoch ,
202+ "current_step" : trainer .global_step ,
203+ "predictions_count" : int (total_written ),
204+ "target_csv_file" : getattr (val , "csv_file" , None ),
205+ "target_root_dir" : getattr (val , "root_dir" , None ),
206+ "shards" : [str (p ) for p in shard_paths ],
207+ "merged_predictions" : str (merged_path ) if merged_path else None ,
208+ "world_size" : trainer .strategy .world_size ,
209+ }
210+ with open (
211+ self .save_dir / f"predictions_epoch_{ epoch } _metadata.json" ,
212+ "w" ,
213+ encoding = "utf-8" ,
214+ ) as f :
215+ json .dump (meta , f , indent = 2 )
216+
217+ # optional shard cleanup
218+ for p in shard_paths :
219+ try :
203220 os .remove (p )
221+ except OSError :
222+ pass
204223
205- # Run evaluation only if we have predictions and user requested it
206- if self .run_evaluation and total_written > 0 and merged_path is not None :
224+ # optional evaluation
225+ if self .run_evaluation :
226+ if merged_path and total_written > 0 :
207227 try :
208228 pl_module .evaluate (
209- predictions = merged_path , csv_file = metadata ["target_csv_file" ]
229+ predictions = str (merged_path ),
230+ csv_file = meta ["target_csv_file" ],
231+ iou_threshold = self .iou_threshold ,
210232 )
211233 except Exception as e :
212234 warnings .warn (f"Evaluation failed: { e } " , stacklevel = 2 )
213- elif self . run_evaluation :
235+ else :
214236 warnings .warn (
215237 "No predictions written to disk, skipping evaluate." , stacklevel = 2
216238 )
217-
218- # Ensure rank 0 finished before next stage
219- trainer .strategy .barrier ()
220-
221- if self .temp_dir_obj is not None and trainer .is_global_zero :
222- self .temp_dir_obj .cleanup ()
0 commit comments