11from dataclasses import dataclass , field
22
3- from omegaconf import MISSING
4-
53
64@dataclass
75class ModelConfig :
@@ -53,21 +51,28 @@ class TrainConfig:
5351 architectures, such as transformers-based models which sometimes
5452 prefer a lower learning rate.
5553
54+ The optimizer can be "sgd" (with momentum=0.9) or "adamw". SGD is
55+ the default and works well for RetinaNet. AdamW is recommended for
56+ transformer-based models like DeformableDetr, typically with a lower
57+ learning rate (e.g., 1e-4 to 5e-4).
58+
5659 The number of epochs should be user-specified and depends on the
5760 size of the dataset (e.g. how many iterations the model will train
5861 for and how diverse the imagery is). DeepForest uses Lightning to
5962 manage the training loop and you can set fast_dev_run to True for
6063 sanity checking.
6164 """
6265
63- csv_file : str | None = MISSING
64- root_dir : str | None = MISSING
66+ csv_file : str | None = None
67+ root_dir : str | None = None
6568 lr : float = 0.001
69+ optimizer : str = "sgd"
6670 scheduler : SchedulerConfig = field (default_factory = SchedulerConfig )
6771 epochs : int = 1
6872 fast_dev_run : bool = False
6973 preload_images : bool = False
7074 augmentations : list [str ] | None = field (default_factory = lambda : ["HorizontalFlip" ])
75+ log_root : str = "logs"
7176
7277
7378@dataclass
@@ -79,14 +84,16 @@ class ValidationConfig:
7984 converged or is overfitting.
8085 """
8186
82- csv_file : str | None = MISSING
83- root_dir : str | None = MISSING
87+ csv_file : str | None = None
88+ root_dir : str | None = None
8489 preload_images : bool = False
8590 size : int | None = None
8691 iou_threshold : float = 0.4
8792 val_accuracy_interval : int = 20
8893 lr_plateau_target : str = "val_loss"
8994 augmentations : list [str ] | None = field (default_factory = lambda : [])
95+ # Keypoint-specific validation (used when task="keypoint")
96+ pixel_distance_threshold : float = 10.0
9097
9198
9299@dataclass
@@ -130,21 +137,27 @@ class Config:
130137 accelerator : str = "auto"
131138 batch_size : int = 1
132139
140+ task : str = "box"
133141 architecture : str = "retinanet"
134142 num_classes : int = 1
135143 label_dict : dict [str , int ] = field (default_factory = lambda : {"Tree" : 0 })
136144
145+ # Keypoint-specific parameters (used when task="keypoint")
146+ point_cost : float = 5.0
147+ point_loss_coefficient : float = 5.0
148+ point_loss_type : str = "l1"
149+
137150 nms_thresh : float = 0.05
138151 score_thresh : float = 0.1
139152 model : ModelConfig = field (default_factory = ModelConfig )
140153
141154 # Preprocessing
142- path_to_raster : str | None = MISSING
155+ path_to_raster : str | None = None
143156 patch_size : int = 400
144157 patch_overlap : float = 0.05
145- annotations_xml : str | None = MISSING
146- rgb_dir : str | None = MISSING
147- path_to_rgb : str | None = MISSING
158+ annotations_xml : str | None = None
159+ rgb_dir : str | None = None
160+ path_to_rgb : str | None = None
148161
149162 train : TrainConfig = field (default_factory = TrainConfig )
150163 validation : ValidationConfig = field (default_factory = ValidationConfig )
0 commit comments