Skip to content

Commit d6f90e9

Browse files
integrate keypoints
1 parent 5d3e5f5 commit d6f90e9

File tree

16 files changed

+1181
-112
lines changed

16 files changed

+1181
-112
lines changed

src/deepforest/augmentations.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,17 @@ def get_available_augmentations() -> list[str]:
5151

5252
def get_transform(
5353
augmentations: str | list[str] | dict[str, Any] | None = None,
54+
task: str = "box",
5455
) -> A.Compose:
55-
"""Create Albumentations transform for bounding boxes.
56+
"""Create Albumentations transform for boxes or keypoints.
5657
5758
Args:
5859
augmentations: Augmentation configuration:
5960
- str: Single augmentation name
6061
- list: List of augmentation names
6162
- dict: Dict with names as keys and params as values
6263
- None: No augmentations
64+
task: Task type - "box" for bounding boxes or "keypoint" for keypoints
6365
6466
Returns:
6567
Composed albumentations transform
@@ -79,9 +81,13 @@ def get_transform(
7981
... "HorizontalFlip": {"p": 0.5},
8082
... "Downscale": {"scale_min": 0.25, "scale_max": 0.75}
8183
... })
84+
85+
>>> # Keypoint augmentations
86+
>>> transform = get_transform(augmentations=["HorizontalFlip"], task="keypoint")
8287
"""
8388
transforms_list = []
8489
bbox_params = None
90+
keypoint_params = None
8591

8692
if augmentations is not None:
8793
augment_configs = _parse_augmentations(augmentations)
@@ -90,12 +96,19 @@ def get_transform(
9096
aug_transform = _create_augmentation(aug_name, aug_params)
9197
transforms_list.append(aug_transform)
9298

93-
bbox_params = A.BboxParams(format="pascal_voc", label_fields=["category_ids"])
99+
if task == "box":
100+
bbox_params = A.BboxParams(format="pascal_voc", label_fields=["labels"])
101+
elif task == "keypoint":
102+
keypoint_params = A.KeypointParams(format="xy", label_fields=["labels"])
103+
else:
104+
raise ValueError(f"Unsupported task: {task}. Must be 'box' or 'keypoint'.")
94105

95106
# Always add ToTensorV2 at the end
96107
transforms_list.append(ToTensorV2())
97108

98-
return A.Compose(transforms_list, bbox_params=bbox_params)
109+
return A.Compose(
110+
transforms_list, bbox_params=bbox_params, keypoint_params=keypoint_params
111+
)
99112

100113

101114
def _parse_augmentations(

src/deepforest/callbacks.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,12 @@ def _log_last_predictions(self, trainer, pl_module):
151151
else:
152152
selected_images = df.image_path.unique()[: self.prediction_samples]
153153

154-
# Ensure color is correctly assigned
155-
if self.color is None:
156-
num_classes = len(df["label"].unique())
157-
results_color = sv.ColorPalette.from_matplotlib("viridis", num_classes)
158-
else:
159-
results_color = self.color
154+
# Ensure color is correctly assigned
155+
if self.color is None:
156+
num_classes = len(df["label"].unique())
157+
results_color = sv.ColorPalette.from_matplotlib("viridis", num_classes)
158+
else:
159+
results_color = self.color
160160

161161
for image_name in selected_images:
162162
pred_df = df[df.image_path == image_name]

src/deepforest/conf/bird.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Ensure we inherit from default config + overlay these overrides.
2+
defaults:
3+
- config
4+
- _self_
5+
6+
task: 'box'
7+
8+
model:
9+
name: 'weecology/deepforest-bird'
10+
revision: 'main'

src/deepforest/conf/config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ accelerator: auto
88
batch_size: 1
99

1010
# Model Architecture
11+
task: 'box'
1112
architecture: 'retinanet'
1213
num_classes: 1
1314
nms_thresh: 0.05

src/deepforest/conf/keypoint.yaml

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Config file for DeepForest keypoint detection tasks
2+
3+
# Ensure we inherit from default config + overlay these overrides.
4+
defaults:
5+
- config
6+
- _self_
7+
8+
# Task and Model Architecture
9+
task: 'keypoint'
10+
architecture: 'DeformableDetr'
11+
12+
# Keypoint-specific parameters for Deformable DETR
13+
# point_cost: Relative weight of point distance in matching cost (default: 5.0)
14+
# point_loss_coefficient: Weight of point loss in total loss (default: 5.0)
15+
# point_loss_type: Type of loss for coordinates - "l1" (default) or "mse"
16+
point_cost: 5.0
17+
point_loss_coefficient: 5.0
18+
point_loss_type: 'l1'
19+
20+
# For keypoint detection, start from pretrained Deformable DETR backbone
21+
model:
22+
name: 'SenseTime/deformable-detr'
23+
revision: 'main'
24+
25+
# Transformer-based models often prefer lower learning rates
26+
train:
27+
lr: 0.0001
28+
29+
# Pixel distance threshold for keypoint matching (instead of IoU for boxes)
30+
# A prediction is considered correct if within this many pixels of ground truth
31+
validation:
32+
pixel_distance_threshold: 10.0

src/deepforest/conf/livestock.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Ensure we inherit from default config + overlay these overrides.
2+
defaults:
3+
- box
4+
- _self_
5+
6+
task: 'box'
7+
8+
model:
9+
name: 'weecology/deepforest-livestock'
10+
revision: 'main'

src/deepforest/conf/schema.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class TrainConfig:
6868
fast_dev_run: bool = False
6969
preload_images: bool = False
7070
augmentations: list[str] | None = field(default_factory=lambda: ["HorizontalFlip"])
71+
log_root: str = "logs"
7172

7273

7374
@dataclass
@@ -87,6 +88,8 @@ class ValidationConfig:
8788
val_accuracy_interval: int = 20
8889
lr_plateau_target: str = "val_loss"
8990
augmentations: list[str] | None = field(default_factory=lambda: [])
91+
# Keypoint-specific validation (used when task="keypoint")
92+
pixel_distance_threshold: float = 10.0
9093

9194

9295
@dataclass
@@ -130,10 +133,16 @@ class Config:
130133
accelerator: str = "auto"
131134
batch_size: int = 1
132135

136+
task: str = "box"
133137
architecture: str = "retinanet"
134138
num_classes: int = 1
135139
label_dict: dict[str, int] = field(default_factory=lambda: {"Tree": 0})
136140

141+
# Keypoint-specific parameters (used when task="keypoint")
142+
point_cost: float = 5.0
143+
point_loss_coefficient: float = 5.0
144+
point_loss_type: str = "l1"
145+
137146
nms_thresh: float = 0.05
138147
score_thresh: float = 0.1
139148
model: ModelConfig = field(default_factory=ModelConfig)

src/deepforest/conf/tree.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Ensure we inherit from default config + overlay these overrides.
2+
defaults:
3+
- box
4+
- _self_
5+
6+
task: 'box'
7+
8+
model:
9+
name: 'weecology/deepforest-tree'
10+
revision: 'main'

0 commit comments

Comments
 (0)