Skip to content

Commit d8f1105

Browse files
attempt to fix barriers
1 parent fffcc06 commit d8f1105

File tree

1 file changed

+166
-150
lines changed

1 file changed

+166
-150
lines changed

src/deepforest/callbacks/evaluation.py

Lines changed: 166 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -1,115 +1,94 @@
1-
"""Evaluation callback for prediction saving and evaluation during training."""
2-
31
import gzip
42
import json
53
import os
4+
import shutil
65
import tempfile
76
import warnings
87
from glob import glob
8+
from pathlib import Path
99

10-
import pandas as pd
10+
import torch
1111
from pytorch_lightning import Callback, Trainer
1212
from pytorch_lightning.core import LightningModule
1313

1414

1515
class EvaluationCallback(Callback):
16-
"""Accumulate validation predictions and save to disk during training.
17-
18-
This callback accumulates predictions during validation and writes them
19-
incrementally to a CSV file. At the end of validation, it saves metadata
20-
and optionally runs evaluation. The saved predictions are in the format
21-
expected by DeepForest's evaluation functions.
22-
23-
The saved files follow this naming convention:
24-
- Predictions: {save_dir}/predictions_epoch_{epoch}.csv (or .csv.gz if compressed)
25-
- Metadata: {save_dir}/predictions_epoch_{epoch}_metadata.json
26-
27-
Args:
28-
save_dir (str): Directory to save prediction files. Will be created if it doesn't exist.
29-
every_n_epochs (int, optional): Run interval in epochs. Set to -1 to disable callback.
30-
Defaults to 5.
31-
iou_threshold (float, optional): IoU threshold for evaluation when run_evaluation=True.
32-
Defaults to 0.4.
33-
run_evaluation (bool, optional): Whether to run evaluate_boxes at epoch end and log metrics.
34-
Defaults to False.
35-
compress (bool, optional): Whether to compress CSV files using gzip. When True, saves as
36-
.csv.gz files for better storage efficiency. When False (default), saves as plain .csv
37-
Defaults to False.
38-
39-
Attributes:
40-
save_dir (str): Directory where files are saved
41-
every_n_epochs (int): Epoch interval for running callback
42-
iou_threshold (float): IoU threshold used for evaluation
43-
run_evaluation (bool): Whether evaluation is run at epoch end
44-
predictions_written (int): Number of predictions written in current epoch
45-
46-
Note:
47-
This callback should be used with `val_accuracy_interval = -1` in the model config
48-
to disable the built-in evaluation and avoid duplicate processing.
16+
"""Accumulate validation predictions per batch, write one shard per rank,
17+
optionally merge shards on rank 0, and optionally run evaluation.
18+
19+
File names:
20+
- Shards: predictions_epoch_{E}_rank{R}.csv[.gz]
21+
- Merged: predictions_epoch_{E}.csv[.gz]
22+
- Meta: predictions_epoch_{E}_metadata.json
4923
"""
5024

5125
def __init__(
5226
self,
53-
save_dir: str = None,
27+
save_dir: str | None = None,
5428
every_n_epochs: int = 5,
5529
iou_threshold: float = 0.4,
5630
run_evaluation: bool = False,
5731
compress: bool = False,
5832
) -> None:
5933
super().__init__()
60-
61-
self.temp_dir_obj = None
62-
if not save_dir:
63-
self.temp_dir_obj = tempfile.TemporaryDirectory()
64-
save_dir = self.temp_dir_obj.name
65-
66-
self.save_dir = save_dir
34+
self._user_save_dir = save_dir
35+
self.compress = compress
6736
self.every_n_epochs = every_n_epochs
6837
self.iou_threshold = iou_threshold
6938
self.run_evaluation = run_evaluation
70-
self.compress = compress
71-
self.predictions_written = 0
7239

73-
def _should_skip(self, trainer: Trainer) -> bool:
74-
"""Check if callback should be skipped for the current trainer
75-
state."""
40+
self.save_dir: Path | None = None
41+
self._is_temp = save_dir is None
42+
self._rank_base: Path | None = None
43+
self.csv_file = None
44+
self.csv_path: Path | None = None
45+
self.predictions_written = 0 # rows written by *this rank* this epoch
7646

77-
return (
47+
def _active_epoch(self, trainer: Trainer) -> bool:
48+
e = trainer.current_epoch + 1
49+
return not (
7850
trainer.sanity_checking
7951
or trainer.fast_dev_run
8052
or self.every_n_epochs == -1
81-
or (trainer.current_epoch + 1) % self.every_n_epochs != 0
53+
or (e % self.every_n_epochs != 0)
8254
)
8355

56+
def _open_writer(self, path: Path):
57+
if self.compress:
58+
return gzip.open(path, "wt", encoding="utf-8")
59+
return open(path, "w", encoding="utf-8")
60+
61+
def setup(
62+
self, trainer: Trainer, pl_module: LightningModule, stage: str | None = None
63+
):
64+
if self._is_temp:
65+
# independent temp base per rank, then a rank subdir for clarity
66+
base = Path(tempfile.mkdtemp(prefix="preds_"))
67+
self._rank_base = base
68+
self.save_dir = base / f"rank{trainer.global_rank}"
69+
else:
70+
self.save_dir = Path(self._user_save_dir) # type: ignore[arg-type]
71+
self.save_dir.mkdir(parents=True, exist_ok=True)
72+
8473
def on_validation_epoch_start(
8574
self, trainer: Trainer, pl_module: LightningModule
8675
) -> None:
87-
if self._should_skip(trainer):
88-
return
89-
90-
# Create once to avoid races
91-
if trainer.is_global_zero:
92-
os.makedirs(self.save_dir, exist_ok=True)
93-
trainer.strategy.barrier()
94-
95-
# Per-rank shard filename
96-
rank = trainer.global_rank
97-
csv_filename = f"predictions_epoch_{trainer.current_epoch + 1}_rank{rank}.csv"
98-
if self.compress:
99-
csv_filename += ".gz"
100-
csv_path = os.path.join(self.save_dir, csv_filename)
101-
102-
self.csv_path = csv_path # path to this rank's shard
103-
self.predictions_written = 0
104-
105-
if self.compress:
106-
self.csv_file = gzip.open(csv_path, "wt", encoding="utf-8")
76+
if self._active_epoch(trainer):
77+
epoch = trainer.current_epoch + 1
78+
rank = trainer.global_rank
79+
suffix = ".csv.gz" if self.compress else ".csv"
80+
self.csv_path = (
81+
self.save_dir / f"predictions_epoch_{epoch}_rank{rank}{suffix}"
82+
)
83+
self.csv_file = self._open_writer(self.csv_path)
10784
else:
108-
self.csv_file = open(csv_path, "w")
85+
self.csv_path = None
86+
self.csv_file = None
10987

110-
self.csv_path = csv_path
11188
self.predictions_written = 0
11289

90+
trainer.strategy.barrier()
91+
11392
def on_validation_batch_end(
11493
self,
11594
trainer: Trainer,
@@ -119,104 +98,141 @@ def on_validation_batch_end(
11998
batch_idx: int,
12099
dataloader_idx: int = 0,
121100
) -> None:
122-
"""Write predictions from current validation batch to CSV file."""
123-
if self._should_skip(trainer) or self.csv_file is None:
101+
if not self._active_epoch(trainer) or self.csv_file is None:
124102
return
125-
126-
# Get predictions from this batch
127-
batch_preds = pl_module.last_preds
128-
129-
for pred in batch_preds:
130-
if pred is not None and not pred.empty:
131-
pred.to_csv(
132-
self.csv_file, index=False, header=(self.predictions_written == 0)
133-
)
134-
self.predictions_written += len(pred)
103+
# expected: pl_module.last_preds is list[pd.DataFrame]
104+
batch_preds = getattr(pl_module, "last_preds", None)
105+
if not batch_preds:
106+
return
107+
for df in batch_preds:
108+
if df is None or df.empty:
109+
continue
110+
df.to_csv(self.csv_file, index=False, header=(self.predictions_written == 0))
111+
self.predictions_written += len(df)
135112

136113
def on_validation_epoch_end(
137114
self, trainer: Trainer, pl_module: LightningModule
138115
) -> None:
139-
"""Clean up at end of epoch.
140-
141-
Handles DDP sync and merging of output shards.
142-
"""
143-
if self._should_skip(trainer):
144-
return
116+
strategy = trainer.strategy
117+
world_size = strategy.world_size
145118

146119
if self.csv_file is not None:
147120
self.csv_file.close()
148121
self.csv_file = None
149122

150-
# All ranks finished writing
151-
trainer.strategy.barrier()
123+
strategy.barrier() # all ranks finished writing
124+
125+
# Collect each rank's save_dir and row count
126+
if (
127+
world_size > 1
128+
and torch.distributed.is_available()
129+
and torch.distributed.is_initialized()
130+
):
131+
rank_dirs: list[str | None] = [None] * world_size
132+
rank_counts: list[int] = [0] * world_size
133+
torch.distributed.all_gather_object(rank_dirs, str(self.save_dir))
134+
torch.distributed.all_gather_object(
135+
rank_counts, int(self.predictions_written)
136+
)
137+
else:
138+
rank_dirs = [str(self.save_dir)]
139+
rank_counts = [int(self.predictions_written)]
152140

153-
# Merge on global rank 0
154-
if trainer.is_global_zero:
155-
epoch = trainer.current_epoch + 1
156-
pattern = os.path.join(self.save_dir, f"predictions_epoch_{epoch}_rank*.csv")
141+
if self._active_epoch(trainer) and trainer.is_global_zero:
142+
self._reduce_and_evaluate(
143+
trainer, pl_module, [Path(d) for d in rank_dirs if d], sum(rank_counts)
144+
)
145+
146+
strategy.barrier() # allow rank 0 to finish
147+
148+
def teardown(
149+
self, trainer: Trainer, pl_module: LightningModule, stage: str | None = None
150+
):
151+
if self._is_temp and self._rank_base is not None:
152+
shutil.rmtree(self._rank_base, ignore_errors=True)
153+
154+
def _reduce_and_evaluate(
155+
self,
156+
trainer: Trainer,
157+
pl_module: LightningModule,
158+
rank_dirs: list[Path],
159+
total_written: int,
160+
) -> None:
161+
epoch = trainer.current_epoch + 1
162+
suffix = ".csv.gz" if self.compress else ".csv"
163+
164+
# discover shards
165+
shard_paths: list[Path] = []
166+
for d in rank_dirs:
167+
pattern = str(d / f"predictions_epoch_{epoch}_rank*.csv")
157168
if self.compress:
158169
pattern += ".gz"
170+
shard_paths.extend(sorted(Path(p) for p in glob(pattern)))
159171

160-
shard_paths = sorted(glob(pattern))
161-
merged_filename = f"predictions_epoch_{epoch}.csv"
162-
if self.compress:
163-
merged_filename += ".gz"
164-
merged_path = os.path.join(self.save_dir, merged_filename)
165-
166-
# Concatenate shards
167-
dfs = []
168-
for p in shard_paths:
169-
if p.endswith(".gz"):
170-
dfs.append(pd.read_csv(p, compression="gzip"))
171-
else:
172-
dfs.append(pd.read_csv(p))
173-
if dfs:
174-
merged = pd.concat(dfs, ignore_index=True)
175-
if self.compress:
176-
merged.to_csv(merged_path, index=False, compression="gzip")
177-
else:
178-
merged.to_csv(merged_path, index=False)
179-
total_written = len(merged)
180-
else:
181-
total_written = 0
182-
merged_path = None
183-
184-
# Save metadata
185-
metadata = {
186-
"epoch": epoch,
187-
"current_step": trainer.global_step,
188-
"predictions_count": total_written,
189-
"target_csv_file": getattr(pl_module.config.validation, "csv_file", None),
190-
"target_root_dir": getattr(pl_module.config.validation, "root_dir", None),
191-
"shards": shard_paths,
192-
"merged_predictions": merged_path,
193-
"world_size": trainer.world_size,
194-
}
195-
with open(
196-
os.path.join(self.save_dir, f"predictions_epoch_{epoch}_metadata.json"),
197-
"w",
198-
) as f:
199-
json.dump(metadata, f, indent=2)
200-
201-
# Optional: cleanup shards after merge
202-
for p in shard_paths:
172+
merged_path = (
173+
(self.save_dir / f"predictions_epoch_{epoch}{suffix}")
174+
if shard_paths
175+
else None
176+
)
177+
178+
# stream-merge shards into a single file without repeating headers
179+
if merged_path is not None:
180+
merged_path.parent.mkdir(parents=True, exist_ok=True)
181+
open_out = gzip.open if self.compress else open
182+
with open_out(merged_path, "wt", encoding="utf-8") as out_f:
183+
wrote_header = False
184+
for shard in shard_paths:
185+
open_in = (
186+
gzip.open
187+
if shard.suffix == ".gz" or shard.suffixes[-2:] == [".csv", ".gz"]
188+
else open
189+
)
190+
with open_in(shard, "rt", encoding="utf-8") as in_f:
191+
for i, line in enumerate(in_f):
192+
if i == 0 and wrote_header:
193+
continue
194+
out_f.write(line)
195+
wrote_header = True
196+
197+
# metadata
198+
cfg = getattr(pl_module, "config", None)
199+
val = getattr(cfg, "validation", None)
200+
meta = {
201+
"epoch": epoch,
202+
"current_step": trainer.global_step,
203+
"predictions_count": int(total_written),
204+
"target_csv_file": getattr(val, "csv_file", None),
205+
"target_root_dir": getattr(val, "root_dir", None),
206+
"shards": [str(p) for p in shard_paths],
207+
"merged_predictions": str(merged_path) if merged_path else None,
208+
"world_size": trainer.strategy.world_size,
209+
}
210+
with open(
211+
self.save_dir / f"predictions_epoch_{epoch}_metadata.json",
212+
"w",
213+
encoding="utf-8",
214+
) as f:
215+
json.dump(meta, f, indent=2)
216+
217+
# optional shard cleanup
218+
for p in shard_paths:
219+
try:
203220
os.remove(p)
221+
except OSError:
222+
pass
204223

205-
# Run evaluation only if we have predictions and user requested it
206-
if self.run_evaluation and total_written > 0 and merged_path is not None:
224+
# optional evaluation
225+
if self.run_evaluation:
226+
if merged_path and total_written > 0:
207227
try:
208228
pl_module.evaluate(
209-
predictions=merged_path, csv_file=metadata["target_csv_file"]
229+
predictions=str(merged_path),
230+
csv_file=meta["target_csv_file"],
231+
iou_threshold=self.iou_threshold,
210232
)
211233
except Exception as e:
212234
warnings.warn(f"Evaluation failed: {e}", stacklevel=2)
213-
elif self.run_evaluation:
235+
else:
214236
warnings.warn(
215237
"No predictions written to disk, skipping evaluate.", stacklevel=2
216238
)
217-
218-
# Ensure rank 0 finished before next stage
219-
trainer.strategy.barrier()
220-
221-
if self.temp_dir_obj is not None and trainer.is_global_zero:
222-
self.temp_dir_obj.cleanup()

0 commit comments

Comments
 (0)