Skip to content

Commit 0ed2528

Browse files
add dinov3
1 parent b03194b commit 0ed2528

File tree

16 files changed

+770
-202
lines changed

16 files changed

+770
-202
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,5 @@ tests/__pycache__
2020
tests/data/*
2121
.vscode/
2222
*ipynb_checkpoints/
23-
docs/user_guide/deepforestr.md
23+
docs/user_guide/deepforestr.md
24+
.env

pyproject.toml

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ dependencies = [
3838
"h5py",
3939
"huggingface_hub>=0.25.0",
4040
"hydra-core",
41+
"geopandas>=1.0.0",
4142
"matplotlib",
4243
"numpy<2.0",
4344
"omegaconf",
@@ -56,11 +57,28 @@ dependencies = [
5657
"supervision",
5758
"tensorboard",
5859
"timm",
59-
"torch>=2.2.0,<2.3.0",
60-
"torchvision>=0.17.0,<0.18.0",
60+
"torch>=2.7.0",
61+
"torchvision>=0.17.0",
6162
"tqdm",
62-
"transformers",
63+
"transformers>=4.56",
6364
"xmltodict",
65+
"transformers",
66+
"timm>=1.0.15",
67+
"faster-coco-eval>=1.6.7",
68+
"comet-ml>=3.51.0",
69+
]
70+
71+
[[tool.uv.index]]
72+
name = "pytorch-cu128"
73+
url = "https://download.pytorch.org/whl/cu128"
74+
explicit = true
75+
76+
[tool.uv.sources]
77+
torch = [
78+
{ index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
79+
]
80+
torchvision = [
81+
{ index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
6482
]
6583

6684
[project.urls]

src/deepforest/callbacks.py

Lines changed: 179 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,26 @@
55
"""
66

77
import glob
8+
import os
9+
import warnings
10+
from pathlib import Path
811

12+
import matplotlib.pyplot as plt
913
import numpy as np
1014
import supervision as sv
15+
import torch
16+
from PIL import Image
1117
from pytorch_lightning import Callback
1218

13-
from deepforest import visualize
19+
from deepforest import utilities, visualize
20+
from deepforest.datasets.training import BoxDataset
1421

1522

16-
class images_callback(Callback):
23+
class ImagesCallback(Callback):
1724
"""Log evaluation images during training.
1825
1926
Args:
20-
savedir: Directory to save predicted images
27+
save_dir: Directory to save predicted images
2128
n: Number of images to process
2229
every_n_epochs: Run interval in epochs
2330
select_random: Whether to select random images
@@ -26,61 +33,190 @@ class images_callback(Callback):
2633
"""
2734

2835
def __init__(
29-
self, savedir, n=2, every_n_epochs=5, select_random=False, color=None, thickness=1
36+
self,
37+
save_dir,
38+
prediction_samples=2,
39+
dataset_samples=5,
40+
every_n_epochs=5,
41+
select_random=False,
42+
color=None,
43+
thickness=1,
3044
):
31-
self.savedir = savedir
32-
self.n = n
45+
self.savedir = save_dir
46+
self.prediction_samples = prediction_samples
47+
self.dataset_samples = dataset_samples
3348
self.color = color
3449
self.thickness = thickness
3550
self.select_random = select_random
3651
self.every_n_epochs = every_n_epochs
3752

38-
def log_images(self, pl_module):
39-
"""Log images to the logger."""
53+
def on_train_start(self, trainer, pl_module):
54+
"""Log sample images from training and validation datasets at training
55+
start."""
56+
57+
if trainer.fast_dev_run:
58+
return
59+
60+
self.trainer = trainer
61+
self.pl_module = pl_module
62+
63+
# Training samples
64+
self.pl_module.print("Logging training dataset samples.")
65+
train_ds = self.trainer.train_dataloader.dataset
66+
self._log_dataset_sample(train_ds, split="train")
67+
68+
# Validation samples
69+
if self.trainer.val_dataloaders:
70+
self.pl_module.print("Logging validation dataset samples.")
71+
val_ds = self.trainer.val_dataloaders.dataset
72+
self._log_dataset_sample(val_ds, split="validation")
73+
74+
def on_validation_end(self, trainer, pl_module):
75+
"""Run callback at validation end."""
76+
if trainer.sanity_checking or trainer.fast_dev_run:
77+
return
78+
79+
if trainer.current_epoch % self.every_n_epochs == 0:
80+
pl_module.print("Running image callback")
81+
self._log_last_predictions(trainer, pl_module)
82+
83+
def _log_dataset_sample(self, dataset: BoxDataset, split: str):
84+
"""Log random samples from a DeepForest BoxDataset."""
85+
86+
if self.dataset_samples == 0:
87+
return
88+
89+
out_dir = os.path.join(self.savedir, split + "_sample")
90+
os.makedirs(out_dir, exist_ok=True)
91+
n_samples = min(self.dataset_samples, len(dataset))
92+
sample_indices = torch.randperm(len(dataset))[:n_samples]
93+
94+
sample_data = [dataset[idx] for idx in sample_indices]
95+
sample_images = [data[0] for data in sample_data]
96+
sample_targets = [data[1] for data in sample_data]
97+
sample_paths = [data[2] for data in sample_data]
98+
99+
for image, target, path in zip(
100+
sample_images, sample_targets, sample_paths, strict=False
101+
):
102+
image_annotations = target.copy()
103+
image_annotations = utilities.format_geometry(image_annotations, scores=False)
104+
image_annotations.root_dir = dataset.root_dir
105+
image_annotations["image_path"] = path
106+
107+
# Plot transformed image
108+
basename = Path(path).stem
109+
image = (255 * image.cpu().numpy().transpose((1, 2, 0))).astype(np.uint8)
110+
fig = visualize.plot_annotations(
111+
image=image,
112+
annotations=image_annotations,
113+
savedir=out_dir,
114+
basename=basename,
115+
thickness=self.thickness,
116+
show=False,
117+
)
118+
plt.close(fig)
119+
120+
self._log_to_all(
121+
image=os.path.join(out_dir, basename + ".png"),
122+
trainer=self.trainer,
123+
tag=f"{split} dataset sample",
124+
)
125+
126+
def _log_last_predictions(self, trainer, pl_module):
127+
"""Log sample of predictions + targets from last validation."""
128+
if self.prediction_samples == 0:
129+
return
130+
131+
out_dir = os.path.join(self.savedir, "predictions")
132+
os.makedirs(out_dir, exist_ok=True)
40133
df = pl_module.predictions
41134

135+
# Add root_dir to the dataframe
136+
if "root_dir" not in df.columns:
137+
df["root_dir"] = trainer.val_dataloaders.dataset.root_dir
138+
42139
# Limit to n images, potentially randomly selected
43140
if self.select_random:
44-
selected_images = np.random.choice(df.image_path.unique(), self.n)
141+
selected_images = np.random.choice(
142+
df.image_path.unique(), self.prediction_samples
143+
)
45144
else:
46-
selected_images = df.image_path.unique()[: self.n]
47-
df = df[df.image_path.isin(selected_images)]
145+
selected_images = df.image_path.unique()[: self.prediction_samples]
146+
147+
# Ensure color is correctly assigned
148+
if self.color is None:
149+
num_classes = len(df["label"].unique())
150+
results_color = sv.ColorPalette.from_matplotlib("viridis", num_classes)
151+
else:
152+
results_color = self.color
153+
154+
for image_name in selected_images:
155+
pred_df = df[df.image_path == image_name]
156+
targets = utilities.format_geometry(
157+
pl_module.targets[image_name], scores=False
158+
)
48159

49-
# Add root_dir to the dataframe
50-
if "root_dir" not in df.columns:
51-
df["root_dir"] = pl_module.config.validation.root_dir
160+
# Assume that validation images are un-augmented
161+
fig = visualize.plot_results(
162+
results=pred_df,
163+
ground_truth=targets,
164+
savedir=out_dir,
165+
results_color=results_color,
166+
thickness=self.thickness,
167+
show=False,
168+
)
169+
plt.close(fig)
52170

53-
# Ensure color is correctly assigned
54-
if self.color is None:
55-
num_classes = len(df["label"].unique())
56-
results_color = sv.ColorPalette.from_matplotlib("viridis", num_classes)
57-
else:
58-
results_color = self.color
59-
60-
# Plot results
61-
visualize.plot_results(
62-
results=df,
63-
savedir=self.savedir,
64-
results_color=results_color,
65-
thickness=self.thickness,
66-
)
171+
saved_plots = glob.glob(f"{out_dir}/*.png")
172+
for saved_plot in saved_plots:
173+
self._log_to_all(image=saved_plot, trainer=trainer, tag="prediction sample")
174+
175+
def _log_to_all(self, image: str, trainer, tag):
176+
"""Log to all connected loggers.
67177
178+
Since Comet will pickup image logs to Tensorboard by default, we
179+
add a check to log images preferentially to Tensorboard if both
180+
are enabled.
181+
"""
68182
try:
69-
saved_plots = glob.glob(f"{self.savedir}/*.png")
70-
for x in saved_plots:
71-
pl_module.logger.experiment.log_image(x)
72-
except Exception as e:
73-
print(
74-
"Could not find comet logger in lightning module, "
75-
f"skipping upload, images were saved to {self.savedir}, "
76-
f"error was raised {e}"
183+
img = np.array(Image.open(image).convert("RGB"))
184+
185+
loggers = [lg for lg in trainer.loggers if hasattr(lg, "experiment")]
186+
187+
tb = next((lg for lg in loggers if hasattr(lg.experiment, "add_image")), None)
188+
if tb is not None:
189+
tb.experiment.add_image(
190+
tag=f"{tag}/{os.path.basename(image)}",
191+
img_tensor=img,
192+
global_step=trainer.global_step,
193+
dataformats="HWC",
194+
)
195+
return
196+
197+
comet = next(
198+
(lg.experiment for lg in loggers if hasattr(lg.experiment, "log_image")),
199+
None,
77200
)
201+
if comet is not None:
202+
comet.experiment.log_image(
203+
img,
204+
name=tag,
205+
step=trainer.global_step,
206+
metadata={
207+
"image_name": os.path.basename(image),
208+
"context": tag,
209+
"step": trainer.global_step,
210+
},
211+
)
78212

79-
def on_validation_end(self, trainer, pl_module):
80-
"""Run callback at validation end."""
81-
if trainer.sanity_checking:
82-
return
213+
except Exception as e:
214+
warnings.warn(f"Tried to log {image} exception raised: {e}", stacklevel=2)
83215

84-
if trainer.current_epoch % self.every_n_epochs == 0:
85-
print("Running image callback")
86-
self.log_images(pl_module)
216+
217+
class images_callback(ImagesCallback):
218+
def __init__(self, savedir, **kwargs):
219+
warnings.warn(
220+
"Please use ImagesCallback instead.", DeprecationWarning, stacklevel=2
221+
)
222+
super().__init__(save_dir=savedir, **kwargs)

src/deepforest/conf/config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ rgb_dir:
3030
path_to_rgb:
3131

3232
train:
33+
# Sanity check annotations on dataset load
34+
check_annotations: False
35+
log_root: logs
3336
csv_file:
3437
root_dir:
3538

src/deepforest/conf/dinov3.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# RetinaNet Base Configuration - Shared parameters for all folds
2+
defaults:
3+
- config
4+
- _self_
5+
6+
model:
7+
name: "facebook/dinov3-vitl16-pretrain-sat493m"
8+
revision: 'main'
9+
10+
train:
11+
epochs: 75
12+
lr: 0.01
13+
scheduler:
14+
type: cosine
15+
params:
16+
T_max: 75
17+
eta_min: 0.0001

src/deepforest/conf/schema.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,14 @@ class TrainConfig:
6262

6363
csv_file: str | None = MISSING
6464
root_dir: str | None = MISSING
65+
log_root: str = "logs"
6566
lr: float = 0.001
6667
scheduler: SchedulerConfig = field(default_factory=SchedulerConfig)
6768
epochs: int = 1
6869
fast_dev_run: bool = False
6970
preload_images: bool = False
7071
augmentations: list[str] | None = field(default_factory=lambda: ["HorizontalFlip"])
72+
check_annotations: bool = False
7173

7274

7375
@dataclass

0 commit comments

Comments
 (0)