Skip to content

Commit 53820cd

Browse files
integrate keypoints
1 parent 078b545 commit 53820cd

21 files changed

+1640
-163
lines changed

src/deepforest/augmentations.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,17 @@ def get_available_augmentations() -> list[str]:
5151

5252
def get_transform(
5353
augmentations: str | list[str] | dict[str, Any] | None = None,
54+
task: str = "box",
5455
) -> A.Compose:
55-
"""Create Albumentations transform for bounding boxes.
56+
"""Create Albumentations transform for boxes or keypoints.
5657
5758
Args:
5859
augmentations: Augmentation configuration:
5960
- str: Single augmentation name
6061
- list: List of augmentation names
6162
- dict: Dict with names as keys and params as values
6263
- None: No augmentations
64+
task: Task type - "box" for bounding boxes or "keypoint" for keypoints
6365
6466
Returns:
6567
Composed albumentations transform
@@ -79,9 +81,13 @@ def get_transform(
7981
... "HorizontalFlip": {"p": 0.5},
8082
... "Downscale": {"scale_min": 0.25, "scale_max": 0.75}
8183
... })
84+
85+
>>> # Keypoint augmentations
86+
>>> transform = get_transform(augmentations=["HorizontalFlip"], task="keypoint")
8287
"""
8388
transforms_list = []
8489
bbox_params = None
90+
keypoint_params = None
8591

8692
if augmentations is not None:
8793
augment_configs = _parse_augmentations(augmentations)
@@ -90,12 +96,19 @@ def get_transform(
9096
aug_transform = _create_augmentation(aug_name, aug_params)
9197
transforms_list.append(aug_transform)
9298

93-
bbox_params = A.BboxParams(format="pascal_voc", label_fields=["category_ids"])
99+
if task == "box":
100+
bbox_params = A.BboxParams(format="pascal_voc", label_fields=["labels"])
101+
elif task == "keypoint":
102+
keypoint_params = A.KeypointParams(format="xy", label_fields=["labels"])
103+
else:
104+
raise ValueError(f"Unsupported task: {task}. Must be 'box' or 'keypoint'.")
94105

95106
# Always add ToTensorV2 at the end
96107
transforms_list.append(ToTensorV2())
97108

98-
return A.Compose(transforms_list, bbox_params=bbox_params)
109+
return A.Compose(
110+
transforms_list, bbox_params=bbox_params, keypoint_params=keypoint_params
111+
)
99112

100113

101114
def _parse_augmentations(

src/deepforest/callbacks.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,12 @@ def _log_last_predictions(self, trainer, pl_module):
151151
else:
152152
selected_images = df.image_path.unique()[: self.prediction_samples]
153153

154-
# Ensure color is correctly assigned
155-
if self.color is None:
156-
num_classes = len(df["label"].unique())
157-
results_color = sv.ColorPalette.from_matplotlib("viridis", num_classes)
158-
else:
159-
results_color = self.color
154+
# Ensure color is correctly assigned
155+
if self.color is None:
156+
num_classes = len(df["label"].unique())
157+
results_color = sv.ColorPalette.from_matplotlib("viridis", num_classes)
158+
else:
159+
results_color = self.color
160160

161161
for image_name in selected_images:
162162
pred_df = df[df.image_path == image_name]

src/deepforest/conf/bird.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Ensure we inherit from default config + overlay these overrides.
2+
defaults:
3+
- config
4+
- _self_
5+
6+
task: 'box'
7+
8+
model:
9+
name: 'weecology/deepforest-bird'
10+
revision: 'main'

src/deepforest/conf/config.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ accelerator: auto
88
batch_size: 1
99

1010
# Model Architecture
11+
task: 'box'
1112
architecture: 'retinanet'
1213
num_classes: 1
1314
nms_thresh: 0.05
@@ -66,6 +67,10 @@ train:
6667
min_lr: 0
6768
eps: 0.00000001
6869

70+
# Currently sgd and adamw are supported. If you use Adam,
71+
# make sure your learning rate is lowered sufficiently.
72+
optimizer: sgd
73+
6974
# How many epochs to run for
7075
epochs: 1
7176
# Useful debugging flag in pytorch lightning, set to True to get a single batch of training to test settings.

src/deepforest/conf/keypoint.yaml

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Config file for DeepForest keypoint detection tasks
2+
3+
# Ensure we inherit from default config + overlay these overrides.
4+
defaults:
5+
- config
6+
- _self_
7+
8+
# Task and Model Architecture
9+
task: 'keypoint'
10+
architecture: 'DeformableDetr'
11+
12+
# Keypoint-specific parameters for Deformable DETR
13+
# point_cost: Relative weight of point distance in matching cost (default: 5.0)
14+
# point_loss_coefficient: Weight of point loss in total loss (default: 5.0)
15+
# point_loss_type: Type of loss for coordinates - "l1" (default) or "mse"
16+
point_cost: 5.0
17+
point_loss_coefficient: 5.0
18+
point_loss_type: 'l1'
19+
20+
# For keypoint detection, start from pretrained Deformable DETR backbone
21+
# Override with our DETR backbone once trained.
22+
model:
23+
name: 'SenseTime/deformable-detr'
24+
revision: 'main'
25+
26+
# Transformer-based models often prefer lower learning rates
27+
train:
28+
lr: 0.0001
29+
30+
# Pixel distance threshold for keypoint matching (instead of IoU for boxes)
31+
# A prediction is considered correct if within this many pixels of ground truth
32+
validation:
33+
pixel_distance_threshold: 10.0

src/deepforest/conf/livestock.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Ensure we inherit from default config + overlay these overrides.
2+
defaults:
3+
- config
4+
- _self_
5+
6+
task: 'box'
7+
8+
model:
9+
name: 'weecology/deepforest-livestock'
10+
revision: 'main'

src/deepforest/conf/schema.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from dataclasses import dataclass, field
22

3-
from omegaconf import MISSING
4-
53

64
@dataclass
75
class ModelConfig:
@@ -53,21 +51,28 @@ class TrainConfig:
5351
architectures, such as transformers-based models which sometimes
5452
prefer a lower learning rate.
5553
54+
The optimizer can be "sgd" (with momentum=0.9) or "adamw". SGD is
55+
the default and works well for RetinaNet. AdamW is recommended for
56+
transformer-based models like DeformableDetr, typically with a lower
57+
learning rate (e.g., 1e-4 to 5e-4).
58+
5659
The number of epochs should be user-specified and depends on the
5760
size of the dataset (e.g. how many iterations the model will train
5861
for and how diverse the imagery is). DeepForest uses Lightning to
5962
manage the training loop and you can set fast_dev_run to True for
6063
sanity checking.
6164
"""
6265

63-
csv_file: str | None = MISSING
64-
root_dir: str | None = MISSING
66+
csv_file: str | None = None
67+
root_dir: str | None = None
6568
lr: float = 0.001
69+
optimizer: str = "sgd"
6670
scheduler: SchedulerConfig = field(default_factory=SchedulerConfig)
6771
epochs: int = 1
6872
fast_dev_run: bool = False
6973
preload_images: bool = False
7074
augmentations: list[str] | None = field(default_factory=lambda: ["HorizontalFlip"])
75+
log_root: str = "logs"
7176

7277

7378
@dataclass
@@ -79,14 +84,16 @@ class ValidationConfig:
7984
converged or is overfitting.
8085
"""
8186

82-
csv_file: str | None = MISSING
83-
root_dir: str | None = MISSING
87+
csv_file: str | None = None
88+
root_dir: str | None = None
8489
preload_images: bool = False
8590
size: int | None = None
8691
iou_threshold: float = 0.4
8792
val_accuracy_interval: int = 20
8893
lr_plateau_target: str = "val_loss"
8994
augmentations: list[str] | None = field(default_factory=lambda: [])
95+
# Keypoint-specific validation (used when task="keypoint")
96+
pixel_distance_threshold: float = 10.0
9097

9198

9299
@dataclass
@@ -130,21 +137,27 @@ class Config:
130137
accelerator: str = "auto"
131138
batch_size: int = 1
132139

140+
task: str = "box"
133141
architecture: str = "retinanet"
134142
num_classes: int = 1
135143
label_dict: dict[str, int] = field(default_factory=lambda: {"Tree": 0})
136144

145+
# Keypoint-specific parameters (used when task="keypoint")
146+
point_cost: float = 5.0
147+
point_loss_coefficient: float = 5.0
148+
point_loss_type: str = "l1"
149+
137150
nms_thresh: float = 0.05
138151
score_thresh: float = 0.1
139152
model: ModelConfig = field(default_factory=ModelConfig)
140153

141154
# Preprocessing
142-
path_to_raster: str | None = MISSING
155+
path_to_raster: str | None = None
143156
patch_size: int = 400
144157
patch_overlap: float = 0.05
145-
annotations_xml: str | None = MISSING
146-
rgb_dir: str | None = MISSING
147-
path_to_rgb: str | None = MISSING
158+
annotations_xml: str | None = None
159+
rgb_dir: str | None = None
160+
path_to_rgb: str | None = None
148161

149162
train: TrainConfig = field(default_factory=TrainConfig)
150163
validation: ValidationConfig = field(default_factory=ValidationConfig)

src/deepforest/conf/tree.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Ensure we inherit from default config + overlay these overrides.
2+
defaults:
3+
- config
4+
- _self_
5+
6+
task: 'box'
7+
8+
model:
9+
name: 'weecology/deepforest-tree'
10+
revision: 'main'

0 commit comments

Comments
 (0)