44
55import torch
66import torchvision
7- from torchvision .models .detection .retinanet import RetinaNet
7+ from torchvision .models .detection .retinanet import RetinaNet , RetinaNet_ResNet50_FPN_Weights
8+ from torchvision .models .resnet import ResNet50_Weights
89from torchvision .models .detection .retinanet import AnchorGenerator
910from deepforest .model import BaseModel
1011from deepforest .models .dinov3 import Dinov3Model
@@ -15,6 +16,7 @@ class RetinaNetHub(RetinaNet, PyTorchModelHubMixin):
1516 """RetinaNet extension that allows the use of the HF Hub API."""
1617
1718 def __init__ (self ,
19+ weights : str | None = None ,
1820 backbone_weights : str | None = None ,
1921 num_classes : int = 1 ,
2022 nms_thresh : float = 0.05 ,
@@ -25,8 +27,9 @@ def __init__(self,
2527 freeze_backbone : bool = True ,
2628 ** kwargs ):
2729
28- if backbone_weights == "dinov3" :
30+ if isinstance ( weights , str ) and "dinov3" in weights :
2931 backbone = Dinov3Model (
32+ repo_id = weights ,
3033 use_conv_pyramid = use_conv_pyramid ,
3134 fpn_out_channels = fpn_out_channels ,
3235 frozen = freeze_backbone ,
@@ -45,7 +48,7 @@ def __init__(self,
4548 "Frozen backbone is currently not enabled for ResNet, but you can set the learning rate to zero."
4649 )
4750 backbone = torchvision .models .detection .retinanet_resnet50_fpn (
48- weights = backbone_weights ).backbone
51+ weights = weights , backbone_weights = backbone_weights ).backbone
4952 anchor_generator = None # Use default
5053
5154 # Explicitly use ImageNet
@@ -199,19 +202,32 @@ def create_model(self,
199202 model: a pytorch nn module
200203 """
201204
202- if pretrained == "resnet50" :
205+ if pretrained is "resnet50-imagenet" :
206+ if revision is not None :
207+ warnings .warn (
208+ "Ignoring revision and using an un-initialized RetinaNet head, ImageNet backbone."
209+ )
210+ model = RetinaNetHub (weights = None ,
211+ backbone_weights = ResNet50_Weights .IMAGENET1K_V2 ,
212+ num_classes = self .config .num_classes ,
213+ nms_thresh = self .config .nms_thresh ,
214+ score_thresh = self .config .score_thresh ,
215+ label_dict = self .config .label_dict )
216+ elif pretrained == "resnet50-mscoco" :
203217 if revision is not None :
204218 warnings .warn (
205219 "Ignoring revision and fine-tuning from ResNet50 MS-COCO checkpoint." )
206- model = RetinaNetHub (backbone_weights = " COCO_V1" ,
220+ model = RetinaNetHub (weights = RetinaNet_ResNet50_FPN_Weights . COCO_V1 ,
207221 num_classes = self .config .num_classes ,
208222 nms_thresh = self .config .nms_thresh ,
209223 score_thresh = self .config .score_thresh ,
210224 label_dict = self .config .label_dict )
211- elif pretrained == "dinov3" :
212- warnings .warn (
213- "Ignoring revision and fine-tuning from DinoV3 Sat-493M checkpoint." )
214- model = RetinaNetHub (backbone_weights = "dinov3" ,
225+ elif "dinov3" in pretrained :
226+ if revision is not None :
227+ warnings .warn (
228+ f"Ignoring revision and fine-tuning from DinoV3 { pretrained } checkpoint."
229+ )
230+ model = RetinaNetHub (weights = pretrained ,
215231 num_classes = self .config .num_classes ,
216232 nms_thresh = self .config .nms_thresh ,
217233 score_thresh = self .config .score_thresh ,
0 commit comments