Skip to content

Commit 6bf5ecb

Browse files
Transformers/DETR integration (#1078)
Add Transformers-based Deformable DETR and AutoModelForObjectDetection integration for object detection tasks
1 parent bfe200f commit 6bf5ecb

File tree

12 files changed

+271
-49
lines changed

12 files changed

+271
-49
lines changed

docs/user_guide/13_annotation.md

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ An incomplete list of annotation tools DeepForest users have reported success wi
2222
- AWS Ground Truth
2323
- LabelBox
2424
- Roboflow
25-
- and many more
25+
- and many more
2626

2727
We intentionally do not create our own annotation tools, but rather focus on supporting community-created tools. Look for exports in `.xml`, `.json`, or `.csv` formats, which are all common in the above tools.
2828

@@ -142,9 +142,7 @@ for path in files:
142142
if boxes is None:
143143
continue
144144

145-
image = np.rollaxis(image, 0, 3)
146-
fig = plot_predictions(df=boxes, image=image)
147-
plt.imshow(fig)
145+
plot_results(results=boxes, image=image)
148146

149147
basename = os.path.splitext(os.path.basename(path))[0]
150148
shp = boxes_to_shapefile(boxes, root_dir=PATH_TO_DIR, projected=False)
@@ -183,4 +181,4 @@ Avoid collecting all annotations before model testing. Start with a small number
183181

184182
# Please Make Your Annotations Open-Source!
185183

186-
DeepForest's models are not perfect. Please consider sharing your annotations with the community to make the models stronger. You can post your annotations on **Zenodo** or open an [issue](https://github.com/weecology/DeepForest/issues) to share your data with the maintainers.
184+
DeepForest's models are not perfect. Please consider sharing your annotations with the community to make the models stronger. You can post your annotations on **Zenodo** or open an [issue](https://github.com/weecology/DeepForest/issues) to share your data with the maintainers.

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ dependencies = [
7171
"torchvision>=0.13",
7272
"tqdm",
7373
"xmltodict",
74+
"transformers>=4.46.3",
75+
"timm>=1.0.15",
7476
]
7577

7678
[project.optional-dependencies]

src/deepforest/conf/config.yaml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ batch_size: 1
1111
architecture: 'retinanet'
1212
num_classes: 1
1313
nms_thresh: 0.05
14+
score_thresh: 0.1
1415

1516
model:
1617
name: 'weecology/deepforest-tree'
@@ -24,11 +25,6 @@ annotations_xml:
2425
rgb_dir:
2526
path_to_rgb:
2627

27-
# Architecture specific params
28-
retinanet:
29-
# Non-max suppression of overlapping predictions
30-
score_thresh: 0.1
31-
3228
train:
3329
csv_file:
3430
root_dir:

src/deepforest/main.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def __init__(
6363
# If not provided, load default config via hydra.
6464
if config is None:
6565
config = utilities.load_config(overrides=config_args)
66-
elif 'config_file' in config:
66+
# Hub overrides
67+
elif 'config_file' in config or 'config_args' in config:
6768
config = utilities.load_config(overrides=config['config_args'])
6869
elif config_args is not None:
6970
warnings.warn(
@@ -118,7 +119,7 @@ def __init__(
118119

119120
self.save_hyperparameters()
120121

121-
def load_model(self, model_name="weecology/deepforest-tree", revision='main'):
122+
def load_model(self, model_name=None, revision=None):
122123
"""Loads a model that has already been pretrained for a specific task,
123124
like tree crown detection.
124125
@@ -136,16 +137,22 @@ def load_model(self, model_name="weecology/deepforest-tree", revision='main'):
136137
Returns:
137138
None
138139
"""
140+
141+
if model_name is None:
142+
model_name = self.config.model.name
143+
144+
if revision is None:
145+
revision = self.config.model.revision
146+
139147
# Load the model using from_pretrained
140-
self.create_model()
141148
loaded_model = self.from_pretrained(model_name, revision=revision)
142149
self.label_dict = loaded_model.label_dict
143150
self.model = loaded_model.model
144151
self.numeric_to_label_dict = loaded_model.numeric_to_label_dict
145152

146153
# Set bird-specific settings if loading the bird model
147154
if model_name == "weecology/deepforest-bird":
148-
self.config.retinanet.score_thresh = 0.3
155+
self.config.score_thresh = 0.3
149156
self.label_dict = {"Bird": 0}
150157
self.numeric_to_label_dict = {v: k for k, v in self.label_dict.items()}
151158

src/deepforest/model.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,14 @@ class Model():
2121
statement below.
2222
2323
Args:
24-
num_classes (int): number of classes in the model
25-
nms_thresh (float): non-max suppression threshold for intersection-over-union [0,1]
26-
score_thresh (float): minimum prediction score to keep during prediction [0,1]
27-
Returns:
28-
model: a pytorch nn module
24+
config (DictConfig): DeepForest config settings object
2925
"""
3026

3127
def __init__(self, config):
3228

3329
# Check for required properties and formats
3430
self.config = config
31+
self.nms_thresh = None # Required for some models but not all
3532

3633
def create_model(self):
3734
"""This function converts a deepforest config file into a model.
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import warnings
2+
from transformers import DeformableDetrForObjectDetection, DeformableDetrImageProcessor, logging
3+
from deepforest.model import Model
4+
from torch import nn
5+
6+
# Suppress huge amounts of unnecessary warnings from transformers.
7+
logging.set_verbosity_error()
8+
9+
10+
class DeformableDetrWrapper(nn.Module):
11+
"""This class wraps a transformers DeformableDetrForObjectDetection model
12+
so that input pre- and post-processing happens transparently."""
13+
14+
def __init__(self, config, name, revision):
15+
"""Initialize a DeformableDetrForObjectDetection model.
16+
17+
We assume that the provided name applies to both model and
18+
processor. By default this function creates a model with MS-COCO
19+
initialized weights, but can be overridden if needed.
20+
"""
21+
super().__init__()
22+
self.config = config
23+
24+
# This suppresses a bunch of messages which are specific to DETR,
25+
# but do not impact model function.
26+
with warnings.catch_warnings():
27+
warnings.simplefilter("ignore", category=UserWarning)
28+
29+
self.net = DeformableDetrForObjectDetection.from_pretrained(
30+
name,
31+
revision=revision,
32+
num_labels=self.config.num_classes,
33+
ignore_mismatched_sizes=True)
34+
self.processor = DeformableDetrImageProcessor.from_pretrained(
35+
name, revision=revision)
36+
37+
def _prepare_targets(self, targets):
38+
39+
if not isinstance(targets, list):
40+
targets = [targets]
41+
42+
coco_targets = []
43+
44+
for target in targets:
45+
coco_targets.append({
46+
"image_id":
47+
0,
48+
"annotations": [{
49+
"id": i,
50+
"image_id": i,
51+
"category_id": label,
52+
"bbox": box.tolist(),
53+
"area": (box[3] - box[1]) * (box[2] - box[0]),
54+
"iscrowd": 0,
55+
} for i, (label, box) in enumerate(zip(target["labels"], target["boxes"]))
56+
]
57+
})
58+
59+
return coco_targets
60+
61+
def forward(self, images, targets=None, prepare_targets=True):
62+
"""AutoModelForObjectDetection forward pass. If targets are provided
63+
the function returns a loss dictionary, otherwise it returns processed
64+
predictions. For details, see the transformers documentation for
65+
"post_process_object_detection".
66+
67+
Returns:
68+
predictions: list of dictionaries with "score", "boxes" and "labels", or
69+
a loss dict for training.
70+
"""
71+
72+
if targets and prepare_targets:
73+
targets = self._prepare_targets(targets)
74+
75+
encoded_inputs = self.processor.preprocess(images=images,
76+
annotations=targets,
77+
return_tensors="pt",
78+
do_rescale=False)
79+
80+
preds = self.net(**encoded_inputs)
81+
82+
if targets is None:
83+
return self.processor.post_process_object_detection(
84+
preds,
85+
threshold=self.config.score_thresh,
86+
target_sizes=[i.shape[-2:] for i in images]
87+
if isinstance(images, list) else [images.shape[-2:]])
88+
else:
89+
return preds.loss_dict
90+
91+
92+
class Model(Model):
93+
94+
def __init__(self, config, **kwargs):
95+
"""
96+
Args:
97+
"""
98+
super().__init__(config)
99+
100+
def create_model(self, name="SenseTime/deformable-detr", revision="main"):
101+
"""Create a Deformable DETR model from pretrained weights.
102+
103+
The number of classes set via config and will override the
104+
downloaded checkpoint. The default weights will load a model
105+
trained on MS-COCO that should fine-tune well on other tasks.
106+
"""
107+
return DeformableDetrWrapper(self.config, name, revision)

src/deepforest/models/retinanet.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,8 @@ def create_anchor_generator(self,
4141
return anchor_generator
4242

4343
def create_model(self):
44-
"""Create a retinanet model
45-
Args:
46-
num_classes (int): number of classes in the model
47-
nms_thresh (float): non-max suppression threshold for intersection-over-union [0,1]
48-
score_thresh (float): minimum prediction score to keep during prediction [0,1]
44+
"""Create a retinanet model.
45+
4946
Returns:
5047
model: a pytorch nn module
5148
"""
@@ -54,7 +51,7 @@ def create_model(self):
5451

5552
model = RetinaNet(backbone=backbone, num_classes=self.config.num_classes)
5653
model.nms_thresh = self.config.nms_thresh
57-
model.score_thresh = self.config.retinanet.score_thresh
54+
model.score_thresh = self.config.score_thresh
5855

5956
# Optionally allow anchor generator parameters to be created here
6057
# https://pytorch.org/vision/stable/_modules/torchvision/models/detection/retinanet.html

tests/deepforest_config_test.yml

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,12 @@ batch_size: 1
1111
architecture: 'retinanet'
1212
num_classes: 1
1313
nms_thresh: 0.05
14-
15-
# Architecture specific params
16-
retinanet:
17-
# Non-max suppression of overlapping predictions
18-
score_thresh: 0.1
14+
score_thresh: 0.1
1915

2016
train:
2117
csv_file:
2218
root_dir:
23-
19+
2420
# Optimizer initial learning rate
2521
lr: 0.001
2622
scheduler:
@@ -50,10 +46,10 @@ train:
5046
fast_dev_run: False
5147
# pin images to GPU memory for fast training. This depends on GPU size and number of images.
5248
preload_images: False
53-
49+
5450
validation:
5551
# callback args
56-
csv_file:
52+
csv_file:
5753
root_dir:
5854
# Intersection over union evaluation
5955
iou_threshold: 0.4

tests/test_FasterRCNN.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def _make_empty_sample():
2626
return images, targets
2727

2828

29-
def test_retinanet(config):
29+
def test_faster_rcnn(config):
3030
r = FasterRCNN.Model(config)
3131
assert r
3232

@@ -48,10 +48,10 @@ def test_check_model(config):
4848
@pytest.mark.parametrize("num_classes", [1, 2, 10])
4949
def test_create_model(config, num_classes):
5050
config.num_classes = num_classes
51-
retinanet_model = FasterRCNN.Model(config).create_model()
52-
retinanet_model.eval()
51+
faster_rcnn_model = FasterRCNN.Model(config).create_model()
52+
faster_rcnn_model.eval()
5353
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
54-
predictions = retinanet_model(x)
54+
predictions = faster_rcnn_model(x)
5555

5656

5757
def test_forward_empty(config):

0 commit comments

Comments
 (0)