Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions src/deepforest/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,17 @@ def get_available_augmentations() -> list[str]:

def get_transform(
augmentations: str | list[str] | dict[str, Any] | None = None,
task: str = "box",
) -> A.Compose:
"""Create Albumentations transform for bounding boxes.
"""Create Albumentations transform for boxes or keypoints.

Args:
augmentations: Augmentation configuration:
- str: Single augmentation name
- list: List of augmentation names
- dict: Dict with names as keys and params as values
- None: No augmentations
task: Task type - "box" for bounding boxes or "keypoint" for keypoints

Returns:
Composed albumentations transform
Expand All @@ -79,9 +81,13 @@ def get_transform(
... "HorizontalFlip": {"p": 0.5},
... "Downscale": {"scale_min": 0.25, "scale_max": 0.75}
... })

>>> # Keypoint augmentations
>>> transform = get_transform(augmentations=["HorizontalFlip"], task="keypoint")
"""
transforms_list = []
bbox_params = None
keypoint_params = None

if augmentations is not None:
augment_configs = _parse_augmentations(augmentations)
Expand All @@ -90,12 +96,19 @@ def get_transform(
aug_transform = _create_augmentation(aug_name, aug_params)
transforms_list.append(aug_transform)

bbox_params = A.BboxParams(format="pascal_voc", label_fields=["category_ids"])
if task == "box":
bbox_params = A.BboxParams(format="pascal_voc", label_fields=["labels"])
elif task == "keypoint":
keypoint_params = A.KeypointParams(format="xy", label_fields=["labels"])
else:
raise ValueError(f"Unsupported task: {task}. Must be 'box' or 'keypoint'.")

# Always add ToTensorV2 at the end
transforms_list.append(ToTensorV2())

return A.Compose(transforms_list, bbox_params=bbox_params)
return A.Compose(
transforms_list, bbox_params=bbox_params, keypoint_params=keypoint_params
)


def _parse_augmentations(
Expand Down
12 changes: 6 additions & 6 deletions src/deepforest/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,12 @@ def _log_last_predictions(self, trainer, pl_module):
else:
selected_images = df.image_path.unique()[: self.prediction_samples]

# Ensure color is correctly assigned
if self.color is None:
num_classes = len(df["label"].unique())
results_color = sv.ColorPalette.from_matplotlib("viridis", num_classes)
else:
results_color = self.color
# Ensure color is correctly assigned
if self.color is None:
num_classes = len(df["label"].unique())
results_color = sv.ColorPalette.from_matplotlib("viridis", num_classes)
else:
results_color = self.color

for image_name in selected_images:
pred_df = df[df.image_path == image_name]
Expand Down
10 changes: 10 additions & 0 deletions src/deepforest/conf/bird.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Ensure we inherit from default config + overlay these overrides.
defaults:
- config
- _self_

task: 'box'

model:
name: 'weecology/deepforest-bird'
revision: 'main'
5 changes: 5 additions & 0 deletions src/deepforest/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ accelerator: auto
batch_size: 1

# Model Architecture
task: 'box'
architecture: 'retinanet'
num_classes: 1
nms_thresh: 0.05
Expand Down Expand Up @@ -66,6 +67,10 @@ train:
min_lr: 0
eps: 0.00000001

# Currently sgd and adamw are supported. If you use Adam,
# make sure your learning rate is lowered sufficiently.
optimizer: sgd

# How many epochs to run for
epochs: 1
# Useful debugging flag in pytorch lightning, set to True to get a single batch of training to test settings.
Expand Down
33 changes: 33 additions & 0 deletions src/deepforest/conf/keypoint.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Config file for DeepForest keypoint detection tasks

# Ensure we inherit from default config + overlay these overrides.
defaults:
- config
- _self_

# Task and Model Architecture
task: 'keypoint'
architecture: 'DeformableDetr'

# Keypoint-specific parameters for Deformable DETR
# point_cost: Relative weight of point distance in matching cost (default: 5.0)
# point_loss_coefficient: Weight of point loss in total loss (default: 5.0)
# point_loss_type: Type of loss for coordinates - "l1" (default) or "mse"
point_cost: 5.0
point_loss_coefficient: 5.0
point_loss_type: 'l1'

# For keypoint detection, start from pretrained Deformable DETR backbone
# Override with our DETR backbone once trained.
model:
name: 'SenseTime/deformable-detr'
revision: 'main'

# Transformer-based models often prefer lower learning rates
train:
lr: 0.0001

# Pixel distance threshold for keypoint matching (instead of IoU for boxes)
# A prediction is considered correct if within this many pixels of ground truth
validation:
pixel_distance_threshold: 10.0
10 changes: 10 additions & 0 deletions src/deepforest/conf/livestock.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Ensure we inherit from default config + overlay these overrides.
defaults:
- config
- _self_

task: 'box'

model:
name: 'weecology/deepforest-livestock'
revision: 'main'
33 changes: 23 additions & 10 deletions src/deepforest/conf/schema.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from dataclasses import dataclass, field

from omegaconf import MISSING


@dataclass
class ModelConfig:
Expand Down Expand Up @@ -53,21 +51,28 @@ class TrainConfig:
architectures, such as transformers-based models which sometimes
prefer a lower learning rate.

The optimizer can be "sgd" (with momentum=0.9) or "adamw". SGD is
the default and works well for RetinaNet. AdamW is recommended for
transformer-based models like DeformableDetr, typically with a lower
learning rate (e.g., 1e-4 to 5e-4).

The number of epochs should be user-specified and depends on the
size of the dataset (e.g. how many iterations the model will train
for and how diverse the imagery is). DeepForest uses Lightning to
manage the training loop and you can set fast_dev_run to True for
sanity checking.
"""

csv_file: str | None = MISSING
root_dir: str | None = MISSING
csv_file: str | None = None
root_dir: str | None = None
lr: float = 0.001
optimizer: str = "sgd"
scheduler: SchedulerConfig = field(default_factory=SchedulerConfig)
epochs: int = 1
fast_dev_run: bool = False
preload_images: bool = False
augmentations: list[str] | None = field(default_factory=lambda: ["HorizontalFlip"])
log_root: str = "logs"


@dataclass
Expand All @@ -79,14 +84,16 @@ class ValidationConfig:
converged or is overfitting.
"""

csv_file: str | None = MISSING
root_dir: str | None = MISSING
csv_file: str | None = None
root_dir: str | None = None
preload_images: bool = False
size: int | None = None
iou_threshold: float = 0.4
val_accuracy_interval: int = 20
lr_plateau_target: str = "val_loss"
augmentations: list[str] | None = field(default_factory=lambda: [])
# Keypoint-specific validation (used when task="keypoint")
pixel_distance_threshold: float = 10.0


@dataclass
Expand Down Expand Up @@ -130,21 +137,27 @@ class Config:
accelerator: str = "auto"
batch_size: int = 1

task: str = "box"
architecture: str = "retinanet"
num_classes: int = 1
label_dict: dict[str, int] = field(default_factory=lambda: {"Tree": 0})

# Keypoint-specific parameters (used when task="keypoint")
point_cost: float = 5.0
point_loss_coefficient: float = 5.0
point_loss_type: str = "l1"

nms_thresh: float = 0.05
score_thresh: float = 0.1
model: ModelConfig = field(default_factory=ModelConfig)

# Preprocessing
path_to_raster: str | None = MISSING
path_to_raster: str | None = None
patch_size: int = 400
patch_overlap: float = 0.05
annotations_xml: str | None = MISSING
rgb_dir: str | None = MISSING
path_to_rgb: str | None = MISSING
annotations_xml: str | None = None
rgb_dir: str | None = None
path_to_rgb: str | None = None

train: TrainConfig = field(default_factory=TrainConfig)
validation: ValidationConfig = field(default_factory=ValidationConfig)
Expand Down
10 changes: 10 additions & 0 deletions src/deepforest/conf/tree.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Ensure we inherit from default config + overlay these overrides.
defaults:
- config
- _self_

task: 'box'

model:
name: 'weecology/deepforest-tree'
revision: 'main'
Loading