Skip to content

Commit a94af37

Browse files
committed
Make Backbone parent classes to allow handling all subclass presets
1 parent c60112e commit a94af37

40 files changed

+330
-895
lines changed

keras_cv/models/backbones/backbone.py

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515

1616
from keras_cv.api_export import keras_cv_export
1717
from keras_cv.backend import keras
18-
from keras_cv.utils.preset_utils import check_preset_class
18+
from keras_cv.utils.preset_utils import check_config_class
19+
from keras_cv.utils.preset_utils import list_presets
20+
from keras_cv.utils.preset_utils import list_subclasses
1921
from keras_cv.utils.preset_utils import load_from_preset
2022
from keras_cv.utils.python_utils import classproperty
21-
from keras_cv.utils.python_utils import format_docstring
2223

2324

2425
@keras_cv_export("keras_cv.models.Backbone")
@@ -64,12 +65,18 @@ def from_config(cls, config):
6465
@classproperty
6566
def presets(cls):
6667
"""Dictionary of preset names and configs."""
67-
return {}
68+
presets = list_presets(cls)
69+
for subclass in list_subclasses(cls):
70+
presets.update(subclass.presets)
71+
return presets
6872

6973
@classproperty
7074
def presets_with_weights(cls):
7175
"""Dictionary of preset names and configs that include weights."""
72-
return {}
76+
presets = list_presets(cls, with_weights=True)
77+
for subclass in list_subclasses(cls):
78+
presets.update(subclass.presets)
79+
return presets
7380

7481
@classproperty
7582
def presets_without_weights(cls):
@@ -109,47 +116,19 @@ def from_preset(
109116
load_weights=False,
110117
```
111118
"""
112-
# We support short IDs for official presets, e.g. `"bert_base_en"`.
113-
# Map these to a Kaggle Models handle.
114-
if preset in cls.presets:
115-
preset = cls.presets[preset]["kaggle_handle"]
116-
117-
check_preset_class(preset, cls)
119+
preset_cls = check_config_class(preset)
120+
if not issubclass(preset_cls, cls):
121+
raise ValueError(
122+
f"Preset has type `{preset_cls.__name__}` which is not a "
123+
f"a subclass of calling class `{cls.__name__}`. Call "
124+
f"`from_preset` directly on `{preset_cls.__name__}` instead."
125+
)
118126
return load_from_preset(
119127
preset,
120128
load_weights=load_weights,
121129
config_overrides=kwargs,
122130
)
123131

124-
def __init_subclass__(cls, **kwargs):
125-
# Use __init_subclass__ to set up a correct docstring for from_preset.
126-
super().__init_subclass__(**kwargs)
127-
128-
# If the subclass does not define from_preset, assign a wrapper so that
129-
# each class can have a distinct docstring.
130-
if "from_preset" not in cls.__dict__:
131-
132-
def from_preset(calling_cls, *args, **kwargs):
133-
return super(cls, calling_cls).from_preset(*args, **kwargs)
134-
135-
cls.from_preset = classmethod(from_preset)
136-
137-
if not cls.presets:
138-
cls.from_preset.__func__.__doc__ = """Not implemented.
139-
140-
No presets available for this class.
141-
"""
142-
143-
# Format and assign the docstring unless the subclass has overridden it.
144-
if cls.from_preset.__doc__ is None:
145-
cls.from_preset.__func__.__doc__ = Backbone.from_preset.__doc__
146-
format_docstring(
147-
model_name=cls.__name__,
148-
example_preset_name=next(iter(cls.presets_with_weights), ""),
149-
preset_names='", "'.join(cls.presets),
150-
preset_with_weights_names='", "'.join(cls.presets_with_weights),
151-
)(cls.from_preset.__func__)
152-
153132
@property
154133
def pyramid_level_inputs(self):
155134
"""Intermediate model outputs for feature extraction.

keras_cv/models/backbones/backbone_presets.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@
3131
from keras_cv.models.backbones.video_swin import video_swin_backbone_presets
3232
from keras_cv.models.backbones.vit_det import vit_det_backbone_presets
3333
from keras_cv.models.object_detection.yolo_v8 import yolo_v8_backbone_presets
34+
from keras_cv.models.object_detection_3d import center_pillar_backbone_presets
3435

3536
backbone_presets_no_weights = {
37+
**center_pillar_backbone_presets.backbone_presets_no_weights,
3638
**resnet_v1_backbone_presets.backbone_presets_no_weights,
3739
**resnet_v2_backbone_presets.backbone_presets_no_weights,
3840
**mobilenet_v3_backbone_presets.backbone_presets_no_weights,
@@ -47,6 +49,7 @@
4749
}
4850

4951
backbone_presets_with_weights = {
52+
**center_pillar_backbone_presets.backbone_presets_with_weights,
5053
**resnet_v1_backbone_presets.backbone_presets_with_weights,
5154
**resnet_v2_backbone_presets.backbone_presets_with_weights,
5255
**mobilenet_v3_backbone_presets.backbone_presets_with_weights,

keras_cv/models/backbones/csp_darknet/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,22 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
from keras_cv.models.backbones.csp_darknet.csp_darknet_backbone_presets import (
16+
backbone_presets_no_weights, backbone_presets_with_weights,
17+
)
18+
from keras_cv.models.backbones.csp_darknet.csp_darknet_backbone import (
19+
CSPDarkNetBackbone,
20+
)
21+
from keras_cv.models.backbones.csp_darknet.csp_darknet_aliases import (
22+
CSPDarkNetTinyBackbone, CSPDarkNetLBackbone,
23+
)
24+
from keras_cv.utils.preset_utils import register_presets, register_preset
25+
26+
register_presets(backbone_presets_no_weights, (CSPDarkNetBackbone, ), with_weights=False)
27+
register_presets(backbone_presets_with_weights, (CSPDarkNetBackbone, ), with_weights=True)
28+
register_presets(backbone_presets_with_weights, (CSPDarkNetBackbone, ), with_weights=True)
29+
register_preset("csp_darknet_tiny_imagenet", backbone_presets_with_weights["csp_darknet_tiny_imagenet"],
30+
(CSPDarkNetTinyBackbone,), with_weights=True)
31+
register_preset("csp_darknet_l_imagenet", backbone_presets_with_weights["csp_darknet_l_imagenet"],
32+
(CSPDarkNetLBackbone,), with_weights=True)

keras_cv/models/backbones/csp_darknet/csp_darknet_aliases.py

Lines changed: 0 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import copy
1615

1716
from keras_cv.api_export import keras_cv_export
1817
from keras_cv.models.backbones.csp_darknet.csp_darknet_backbone import (
1918
CSPDarkNetBackbone,
2019
)
21-
from keras_cv.models.backbones.csp_darknet.csp_darknet_backbone_presets import (
22-
backbone_presets,
23-
)
24-
from keras_cv.utils.python_utils import classproperty
2520

2621
ALIAS_DOCSTRING = """CSPDarkNetBackbone model with {stackwise_channels} channels
2722
and {stackwise_depth} depths.
@@ -71,21 +66,6 @@ def __new__(
7166
)
7267
return CSPDarkNetBackbone.from_preset("csp_darknet_tiny", **kwargs)
7368

74-
@classproperty
75-
def presets(cls):
76-
"""Dictionary of preset names and configurations."""
77-
return {
78-
"csp_darknet_tiny_imagenet": copy.deepcopy(
79-
backbone_presets["csp_darknet_tiny_imagenet"]
80-
)
81-
}
82-
83-
@classproperty
84-
def presets_with_weights(cls):
85-
"""Dictionary of preset names and configurations that include
86-
weights."""
87-
return cls.presets
88-
8969

9070
@keras_cv_export("keras_cv.models.CSPDarkNetSBackbone")
9171
class CSPDarkNetSBackbone(CSPDarkNetBackbone):
@@ -106,17 +86,6 @@ def __new__(
10686
)
10787
return CSPDarkNetBackbone.from_preset("csp_darknet_s", **kwargs)
10888

109-
@classproperty
110-
def presets(cls):
111-
"""Dictionary of preset names and configurations."""
112-
return {}
113-
114-
@classproperty
115-
def presets_with_weights(cls):
116-
"""Dictionary of preset names and configurations that include
117-
weights."""
118-
return {}
119-
12089

12190
@keras_cv_export("keras_cv.models.CSPDarkNetMBackbone")
12291
class CSPDarkNetMBackbone(CSPDarkNetBackbone):
@@ -137,17 +106,6 @@ def __new__(
137106
)
138107
return CSPDarkNetBackbone.from_preset("csp_darknet_m", **kwargs)
139108

140-
@classproperty
141-
def presets(cls):
142-
"""Dictionary of preset names and configurations."""
143-
return {}
144-
145-
@classproperty
146-
def presets_with_weights(cls):
147-
"""Dictionary of preset names and configurations that include
148-
weights."""
149-
return {}
150-
151109

152110
@keras_cv_export("keras_cv.models.CSPDarkNetLBackbone")
153111
class CSPDarkNetLBackbone(CSPDarkNetBackbone):
@@ -168,21 +126,6 @@ def __new__(
168126
)
169127
return CSPDarkNetBackbone.from_preset("csp_darknet_l", **kwargs)
170128

171-
@classproperty
172-
def presets(cls):
173-
"""Dictionary of preset names and configurations."""
174-
return {
175-
"csp_darknet_l_imagenet": copy.deepcopy(
176-
backbone_presets["csp_darknet_l_imagenet"]
177-
)
178-
}
179-
180-
@classproperty
181-
def presets_with_weights(cls):
182-
"""Dictionary of preset names and configurations that include
183-
weights."""
184-
return cls.presets
185-
186129

187130
@keras_cv_export("keras_cv.models.CSPDarkNetXLBackbone")
188131
class CSPDarkNetXLBackbone(CSPDarkNetBackbone):
@@ -203,16 +146,6 @@ def __new__(
203146
)
204147
return CSPDarkNetBackbone.from_preset("csp_darknet_xl", **kwargs)
205148

206-
@classproperty
207-
def presets(cls):
208-
"""Dictionary of preset names and configurations."""
209-
return {}
210-
211-
@classproperty
212-
def presets_with_weights(cls):
213-
"""Dictionary of preset names and configurations that include
214-
weights."""
215-
return {}
216149

217150

218151
setattr(

keras_cv/models/backbones/csp_darknet/csp_darknet_backbone.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,11 @@
1313
# limitations under the License.
1414

1515
"""CSPDarkNet backbone model. """
16-
import copy
1716

1817
from keras_cv.api_export import keras_cv_export
1918
from keras_cv.backend import keras
2019
from keras_cv.models import utils
2120
from keras_cv.models.backbones.backbone import Backbone
22-
from keras_cv.models.backbones.csp_darknet.csp_darknet_backbone_presets import (
23-
backbone_presets,
24-
)
25-
from keras_cv.models.backbones.csp_darknet.csp_darknet_backbone_presets import (
26-
backbone_presets_with_weights,
27-
)
2821
from keras_cv.models.backbones.csp_darknet.csp_darknet_utils import (
2922
CrossStagePartial,
3023
)
@@ -38,8 +31,6 @@
3831
from keras_cv.models.backbones.csp_darknet.csp_darknet_utils import (
3932
SpatialPyramidPoolingBottleneck,
4033
)
41-
from keras_cv.utils.python_utils import classproperty
42-
4334

4435
@keras_cv_export("keras_cv.models.CSPDarkNetBackbone")
4536
class CSPDarkNetBackbone(Backbone):
@@ -169,14 +160,3 @@ def get_config(self):
169160
}
170161
)
171162
return config
172-
173-
@classproperty
174-
def presets(cls):
175-
"""Dictionary of preset names and configurations."""
176-
return copy.deepcopy(backbone_presets)
177-
178-
@classproperty
179-
def presets_with_weights(cls):
180-
"""Dictionary of preset names and configurations that include
181-
weights."""
182-
return copy.deepcopy(backbone_presets_with_weights)

keras_cv/models/backbones/densenet/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,26 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
from keras_cv.models.backbones.densenet.densenet_backbone_presets import (
16+
backbone_presets_no_weights,
17+
)
18+
from keras_cv.models.backbones.densenet.densenet_backbone_presets import (
19+
backbone_presets_with_weights,
20+
)
21+
from keras_cv.models.backbones.densenet.densenet_backbone import (
22+
DenseNetBackbone,
23+
)
24+
from keras_cv.models.backbones.densenet.densenet_aliases import (
25+
DenseNet121Backbone, DenseNet169Backbone, DenseNet201Backbone
26+
)
27+
from keras_cv.utils.preset_utils import register_presets, register_preset
28+
29+
register_presets(backbone_presets_no_weights, (DenseNetBackbone, ), with_weights=False)
30+
register_presets(backbone_presets_with_weights, (DenseNetBackbone, ), with_weights=True)
31+
register_preset("densenet121_imagenet", backbone_presets_with_weights["densenet121_imagenet"],
32+
(DenseNet121Backbone,), with_weights=True)
33+
register_preset("densenet169_imagenet", backbone_presets_with_weights["densenet169_imagenet"],
34+
(DenseNet169Backbone,), with_weights=True)
35+
register_preset("densenet201_imagenet", backbone_presets_with_weights["densenet201_imagenet"],
36+
(DenseNet201Backbone,), with_weights=True)

keras_cv/models/backbones/densenet/densenet_aliases.py

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import copy
16-
1715
from keras_cv.api_export import keras_cv_export
1816
from keras_cv.models.backbones.densenet.densenet_backbone import (
1917
DenseNetBackbone,
2018
)
21-
from keras_cv.models.backbones.densenet.densenet_backbone_presets import (
22-
backbone_presets,
23-
)
24-
from keras_cv.utils.python_utils import classproperty
2519

2620
ALIAS_DOCSTRING = """DenseNetBackbone model with {num_layers} layers.
2721
@@ -69,21 +63,6 @@ def __new__(
6963
)
7064
return DenseNetBackbone.from_preset("densenet121", **kwargs)
7165

72-
@classproperty
73-
def presets(cls):
74-
"""Dictionary of preset names and configurations."""
75-
return {
76-
"densenet121_imagenet": copy.deepcopy(
77-
backbone_presets["densenet121_imagenet"]
78-
),
79-
}
80-
81-
@classproperty
82-
def presets_with_weights(cls):
83-
"""Dictionary of preset names and configurations that include weights.""" # noqa: E501
84-
return cls.presets
85-
86-
8766
@keras_cv_export("keras_cv.models.DenseNet169Backbone")
8867
class DenseNet169Backbone(DenseNetBackbone):
8968
def __new__(
@@ -103,20 +82,6 @@ def __new__(
10382
)
10483
return DenseNetBackbone.from_preset("densenet169", **kwargs)
10584

106-
@classproperty
107-
def presets(cls):
108-
"""Dictionary of preset names and configurations."""
109-
return {
110-
"densenet169_imagenet": copy.deepcopy(
111-
backbone_presets["densenet169_imagenet"]
112-
),
113-
}
114-
115-
@classproperty
116-
def presets_with_weights(cls):
117-
"""Dictionary of preset names and configurations that include weights.""" # noqa: E501
118-
return cls.presets
119-
12085

12186
@keras_cv_export("keras_cv.models.DenseNet201Backbone")
12287
class DenseNet201Backbone(DenseNetBackbone):
@@ -137,20 +102,6 @@ def __new__(
137102
)
138103
return DenseNetBackbone.from_preset("densenet201", **kwargs)
139104

140-
@classproperty
141-
def presets(cls):
142-
"""Dictionary of preset names and configurations."""
143-
return {
144-
"densenet201_imagenet": copy.deepcopy(
145-
backbone_presets["densenet201_imagenet"]
146-
),
147-
}
148-
149-
@classproperty
150-
def presets_with_weights(cls):
151-
"""Dictionary of preset names and configurations that include weights.""" # noqa: E501
152-
return cls.presets
153-
154105

155106
setattr(DenseNet121Backbone, "__doc__", ALIAS_DOCSTRING.format(num_layers=121))
156107
setattr(DenseNet169Backbone, "__doc__", ALIAS_DOCSTRING.format(num_layers=169))

0 commit comments

Comments
 (0)