Skip to content

Commit 9280cd3

Browse files
allow arbitrary dinov3 checkpoints, resnet-imagenet
1 parent a4c7e57 commit 9280cd3

File tree

2 files changed

+26
-10
lines changed

2 files changed

+26
-10
lines changed

src/deepforest/conf/dinov3.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ defaults:
44
- _self_
55

66
model:
7-
name: "dinov3"
7+
name: "facebook/dinov3-vitl16-pretrain-sat493m"
88
revision: 'main'
99

1010
train:

src/deepforest/models/retinanet.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
import torch
66
import torchvision
7-
from torchvision.models.detection.retinanet import RetinaNet
7+
from torchvision.models.detection.retinanet import RetinaNet, RetinaNet_ResNet50_FPN_Weights
8+
from torchvision.models.resnet import ResNet50_Weights
89
from torchvision.models.detection.retinanet import AnchorGenerator
910
from deepforest.model import BaseModel
1011
from deepforest.models.dinov3 import Dinov3Model
@@ -15,6 +16,7 @@ class RetinaNetHub(RetinaNet, PyTorchModelHubMixin):
1516
"""RetinaNet extension that allows the use of the HF Hub API."""
1617

1718
def __init__(self,
19+
weights: str | None = None,
1820
backbone_weights: str | None = None,
1921
num_classes: int = 1,
2022
nms_thresh: float = 0.05,
@@ -25,8 +27,9 @@ def __init__(self,
2527
freeze_backbone: bool = True,
2628
**kwargs):
2729

28-
if backbone_weights == "dinov3":
30+
if isinstance(weights, str) and "dinov3" in weights:
2931
backbone = Dinov3Model(
32+
repo_id=weights,
3033
use_conv_pyramid=use_conv_pyramid,
3134
fpn_out_channels=fpn_out_channels,
3235
frozen=freeze_backbone,
@@ -45,7 +48,7 @@ def __init__(self,
4548
"Frozen backbone is currently not enabled for ResNet, but you can set the learning rate to zero."
4649
)
4750
backbone = torchvision.models.detection.retinanet_resnet50_fpn(
48-
weights=backbone_weights).backbone
51+
weights=weights, backbone_weights=backbone_weights).backbone
4952
anchor_generator = None # Use default
5053

5154
# Explicitly use ImageNet
@@ -199,19 +202,32 @@ def create_model(self,
199202
model: a pytorch nn module
200203
"""
201204

202-
if pretrained == "resnet50":
205+
if pretrained is "resnet50-imagenet":
206+
if revision is not None:
207+
warnings.warn(
208+
"Ignoring revision and using an un-initialized RetinaNet head, ImageNet backbone."
209+
)
210+
model = RetinaNetHub(weights=None,
211+
backbone_weights=ResNet50_Weights.IMAGENET1K_V2,
212+
num_classes=self.config.num_classes,
213+
nms_thresh=self.config.nms_thresh,
214+
score_thresh=self.config.score_thresh,
215+
label_dict=self.config.label_dict)
216+
elif pretrained == "resnet50-mscoco":
203217
if revision is not None:
204218
warnings.warn(
205219
"Ignoring revision and fine-tuning from ResNet50 MS-COCO checkpoint.")
206-
model = RetinaNetHub(backbone_weights="COCO_V1",
220+
model = RetinaNetHub(weights=RetinaNet_ResNet50_FPN_Weights.COCO_V1,
207221
num_classes=self.config.num_classes,
208222
nms_thresh=self.config.nms_thresh,
209223
score_thresh=self.config.score_thresh,
210224
label_dict=self.config.label_dict)
211-
elif pretrained == "dinov3":
212-
warnings.warn(
213-
"Ignoring revision and fine-tuning from DinoV3 Sat-493M checkpoint.")
214-
model = RetinaNetHub(backbone_weights="dinov3",
225+
elif "dinov3" in pretrained:
226+
if revision is not None:
227+
warnings.warn(
228+
f"Ignoring revision and fine-tuning from DinoV3 {pretrained} checkpoint."
229+
)
230+
model = RetinaNetHub(weights=pretrained,
215231
num_classes=self.config.num_classes,
216232
nms_thresh=self.config.nms_thresh,
217233
score_thresh=self.config.score_thresh,

0 commit comments

Comments
 (0)