Skip to content

Commit 8cb6ede

Browse files
Koen HelwegenlgeigerAdamHillier
authored
MeliusNet22 implementation (#142)
* MeliusNet22 implementation * Apply suggestions from code review Co-Authored-By: Lukas Geiger <[email protected]> * review remarks * Update larq_zoo/literature/meliusnet.py Co-Authored-By: Lukas Geiger <[email protected]> * wip * correct hashes, docstr etc * isort * Apply suggestions from code review Co-Authored-By: Adam Hillier <[email protected]> Co-Authored-By: Lukas Geiger <[email protected]> * import tuple * linting * explicit naming * add path attributes to base * better TFOpLayer, update hashes * inline if/else Co-authored-by: Lukas Geiger <[email protected]> Co-authored-by: Adam Hillier <[email protected]>
1 parent 6fdbc07 commit 8cb6ede

File tree

4 files changed

+305
-9
lines changed

4 files changed

+305
-9
lines changed

larq_zoo/core/utils.py

+24-9
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,7 @@ def slash_join(*args):
2525

2626

2727
def download_pretrained_model(
28-
model: str,
29-
version: str,
30-
file: str,
31-
file_hash: str,
32-
cache_dir: Optional[str] = None,
28+
model: str, version: str, file: str, file_hash: str, cache_dir: Optional[str] = None
3329
) -> str:
3430
root_url = "https://github.com/larq/zoo/releases/download/"
3531

@@ -144,10 +140,12 @@ def global_pool(
144140
pool_size = (
145141
input_shape[1:3] if data_format == "channels_last" else input_shape[2:4]
146142
)
147-
x = keras.layers.AveragePooling2D(pool_size=pool_size, data_format=data_format)(
148-
x
149-
)
150-
x = keras.layers.Flatten()(x)
143+
x = keras.layers.AveragePooling2D(
144+
pool_size=pool_size,
145+
data_format=data_format,
146+
name=f"{name}_pool" if name else None,
147+
)(x)
148+
x = keras.layers.Flatten(name=f"{name}_flatten" if name else None)(x)
151149
except ValueError:
152150
x = keras.layers.GlobalAveragePooling2D(data_format=data_format, name=name)(x)
153151

@@ -170,3 +168,20 @@ def decode_predictions(preds, top=5, **kwargs):
170168
ValueError: In case of invalid shape of the `pred` array (must be 2D).
171169
"""
172170
return keras_decode_predictions(preds, top=top, **kwargs)
171+
172+
173+
def TFOpLayer(tf_op: tf.Operation, *args, **kwargs) -> tf.keras.layers.Layer:
174+
"""Wrap a tensorflow op using a Lambda layer. This facilitates naming the op as a
175+
proper keras layer.
176+
177+
Example: `TFOpLayer(tf.split, groups, axis=-1, name="split")(x)`.
178+
179+
# Arguments
180+
tf_op: tensorflow that needs to be wrapped.
181+
182+
# Returns
183+
A keras layer wrapping `tf_op`.
184+
185+
"""
186+
name = kwargs.pop("name", None)
187+
return tf.keras.layers.Lambda(lambda x_: tf_op(x_, *args, **kwargs), name=name)

larq_zoo/literature/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
BinaryDenseNet45,
88
)
99
from larq_zoo.literature.dorefanet import DoReFaNet
10+
from larq_zoo.literature.meliusnet import MeliusNet22
1011
from larq_zoo.literature.real_to_bin_nets import RealToBinaryNet
1112
from larq_zoo.literature.resnet_e import BinaryResNetE18
1213
from larq_zoo.literature.xnornet import XNORNet
@@ -20,6 +21,7 @@
2021
"BinaryDenseNet37Dilated",
2122
"BinaryDenseNet45",
2223
"DoReFaNet",
24+
"MeliusNet22",
2325
"RealToBinaryNet",
2426
"XNORNet",
2527
]

larq_zoo/literature/meliusnet.py

+278
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
from typing import Optional, Sequence, Tuple, Union
2+
3+
import larq as lq
4+
import tensorflow as tf
5+
from zookeeper import Field, factory
6+
7+
from larq_zoo.core import utils
8+
from larq_zoo.core.model_factory import ModelFactory
9+
10+
################
11+
# Base factory #
12+
################
13+
14+
15+
class MeliusNetFactory(ModelFactory):
16+
# Overall architecture configuration. These are not `Fields`, as they should
17+
# not be configurable, but set in the various concrete subclasses.
18+
num_blocks: Sequence[int]
19+
transition_features: Sequence[int]
20+
name: str = None
21+
imagenet_weights_path: str
22+
imagenet_no_top_weights_path: str
23+
24+
# Some default layer arguments.
25+
batch_norm_momentum: float = Field(0.9)
26+
kernel_initializer: Optional[Union[str, tf.keras.initializers.Initializer]] = Field(
27+
"glorot_normal"
28+
)
29+
input_quantizer = Field(lambda: lq.quantizers.SteSign(1.3))
30+
kernel_quantizer = Field(lambda: lq.quantizers.SteSign(1.3))
31+
kernel_constraint = Field(lambda: lq.constraints.WeightClip(1.3))
32+
33+
def pool(self, x: tf.Tensor, name: str = None) -> tf.Tensor:
34+
return tf.keras.layers.MaxPool2D(2, strides=2, padding="same", name=name)(x)
35+
36+
def norm(self, x: tf.Tensor, name: str = None) -> tf.Tensor:
37+
return tf.keras.layers.BatchNormalization(
38+
momentum=self.batch_norm_momentum, epsilon=1e-5, name=name
39+
)(x)
40+
41+
def act(self, x: tf.Tensor, name: str = None) -> tf.Tensor:
42+
return tf.keras.layers.Activation("relu", name=name)(x)
43+
44+
def quant_conv(
45+
self,
46+
x: tf.Tensor,
47+
filters: int,
48+
kernel: Union[int, Tuple[int, int]],
49+
strides: Union[int, Tuple[int, int]] = 1,
50+
name: str = None,
51+
) -> tf.Tensor:
52+
return lq.layers.QuantConv2D(
53+
filters,
54+
kernel,
55+
strides=strides,
56+
padding="same",
57+
use_bias=False,
58+
input_quantizer=self.input_quantizer,
59+
kernel_quantizer=self.kernel_quantizer,
60+
kernel_constraint=self.kernel_constraint,
61+
kernel_initializer=self.kernel_initializer,
62+
name=name,
63+
)(x)
64+
65+
def group_conv(
66+
self,
67+
x: tf.Tensor,
68+
filters: int,
69+
kernel: Union[int, Tuple[int, int]],
70+
groups: int,
71+
name: str = None,
72+
) -> tf.Tensor:
73+
assert filters % groups == 0
74+
assert x.shape.as_list()[-1] % groups == 0
75+
76+
x_split = utils.TFOpLayer(tf.split, groups, axis=-1, name=f"{name}_split")(x)
77+
78+
y_split = [
79+
tf.keras.layers.Conv2D(
80+
filters // groups,
81+
kernel,
82+
padding="same",
83+
use_bias=False,
84+
kernel_initializer=self.kernel_initializer,
85+
name=f"{name}_conv{i}",
86+
)(split)
87+
for i, split in enumerate(x_split)
88+
]
89+
90+
return utils.TFOpLayer(tf.concat, axis=-1, name=f"{name}_concat")(y_split)
91+
92+
def group_stem(self, x: tf.Tensor, name: str = None) -> tf.Tensor:
93+
x = tf.keras.layers.Conv2D(
94+
32,
95+
3,
96+
strides=2,
97+
padding="same",
98+
use_bias=False,
99+
kernel_initializer=self.kernel_initializer,
100+
name=f"{name}_s0_conv",
101+
)(x)
102+
x = self.norm(x, name=f"{name}_s0_bn")
103+
x = self.act(x, name=f"{name}_s0_relu")
104+
105+
x = self.group_conv(x, 32, 3, 4, name=f"{name}_s1_groupconv")
106+
x = self.norm(x, name=f"{name}_s1_bn")
107+
x = self.act(x, name=f"{name}_s1_relu")
108+
109+
x = self.group_conv(x, 64, 3, 8, name=f"{name}_s2_groupconv")
110+
x = self.norm(x, name=f"{name}_s2_bn")
111+
x = self.act(x, name=f"{name}_s2_relu")
112+
113+
return self.pool(x, name=f"{name}_pool")
114+
115+
def dense_block(self, x: tf.Tensor, name: str = None) -> tf.Tensor:
116+
w = x
117+
w = self.norm(w, name=f"{name}_bn")
118+
w = self.quant_conv(w, 64, 3, name=f"{name}_binconv")
119+
return utils.TFOpLayer(tf.concat, axis=-1, name=f"{name}_concat")([x, w])
120+
121+
def improvement_block(self, x: tf.Tensor, name: str = None) -> tf.Tensor:
122+
w = x
123+
w = self.norm(w, name=f"{name}_bn")
124+
w = self.quant_conv(w, 64, 3, name=f"{name}_binconv")
125+
f_in = int(x.shape[-1])
126+
return tf.keras.layers.Lambda(
127+
lambda x_: x_[0] + tf.pad(x_[1], [[0, 0], [0, 0], [0, 0], [f_in - 64, 0]]),
128+
name=f"{name}_merge",
129+
)([x, w])
130+
131+
def transition_block(
132+
self, x: tf.Tensor, filters: int, name: str = None
133+
) -> tf.Tensor:
134+
x = self.norm(x, name=f"{name}_bn")
135+
x = self.pool(x, name=f"{name}_maxpool")
136+
x = self.act(x, name=f"{name}_relu")
137+
return tf.keras.layers.Conv2D(
138+
filters,
139+
1,
140+
use_bias=False,
141+
kernel_initializer=self.kernel_initializer,
142+
name=f"{name}_pw",
143+
)(x)
144+
145+
def block(self, x: tf.Tensor, name: str = None) -> tf.Tensor:
146+
x = self.dense_block(x, name=f"{name}_dense")
147+
return self.improvement_block(x, name=f"{name}_improve")
148+
149+
def build(self) -> tf.keras.models.Model:
150+
x = self.image_input
151+
x = self.group_stem(x, name="stem")
152+
for i, (n, f) in enumerate(zip(self.num_blocks, self.transition_features)):
153+
for j in range(n):
154+
x = self.block(x, f"section_{i}_block_{j}")
155+
if f:
156+
x = self.transition_block(x, f, f"section_{i}_transition")
157+
158+
x = self.norm(x, "head_bn")
159+
x = self.act(x, "head_relu")
160+
161+
if self.include_top:
162+
x = utils.global_pool(x, name="head_globalpool")
163+
x = tf.keras.layers.Dense(
164+
self.num_classes,
165+
kernel_initializer=self.kernel_initializer,
166+
name="head_dense",
167+
)(x)
168+
x = tf.keras.layers.Activation(
169+
"softmax", dtype="float32", name="head_softmax"
170+
)(x)
171+
172+
model = tf.keras.models.Model(
173+
inputs=self.image_input, outputs=x, name=self.name
174+
)
175+
176+
if self.weights == "imagenet":
177+
model.load_weights(
178+
self.imagenet_weights_path
179+
if self.include_top
180+
else self.imagenet_no_top_weights_path
181+
)
182+
elif self.weights is not None:
183+
model.load_weights(self.weights)
184+
185+
return model
186+
187+
188+
######################
189+
# Concrete factories #
190+
######################
191+
192+
193+
@factory
194+
class MeliusNet22Factory(MeliusNetFactory):
195+
num_blocks = (4, 5, 4, 4)
196+
transition_features = (160, 224, 256, None)
197+
name = "meliusnet22"
198+
199+
@property
200+
def imagenet_weights_path(self):
201+
return utils.download_pretrained_model(
202+
model="meliusnet22",
203+
version="v0.1.0",
204+
file="meliusnet22_weights.h5",
205+
file_hash="c1ba85e8389ae326009665ec13331e49fc3df4d0f925fa8553e224f7362c18ed",
206+
)
207+
208+
@property
209+
def imagenet_no_top_weights_path(self):
210+
return utils.download_pretrained_model(
211+
model="meliusnet22",
212+
version="v0.1.0",
213+
file="meliusnet22_weights_notop.h5",
214+
file_hash="b64c8296a3d07ce2799846caf0ad6d390f6cd9bbf21ea3390fafbab87bb79aa5",
215+
)
216+
217+
218+
#########################
219+
# Functional interfaces #
220+
#########################
221+
222+
223+
def MeliusNet22(
224+
*, # Keyword arguments only
225+
input_shape: Optional[Sequence[Optional[int]]] = None,
226+
input_tensor: Optional[tf.Tensor] = None,
227+
weights: Optional[str] = "imagenet",
228+
include_top: bool = True,
229+
num_classes: int = 1000,
230+
) -> tf.keras.models.Model:
231+
"""Instantiates the MeliusNet22 architecture.
232+
233+
Optionally loads weights pre-trained on ImageNet.
234+
235+
```netron
236+
meliusnet22-v0.1.0/meliusnet22.json
237+
```
238+
```summary
239+
literature.MeliusNet22
240+
```
241+
```plot-altair
242+
/plots/meliusnet22.vg.json
243+
```
244+
245+
# ImageNet Metrics
246+
247+
| Top-1 Accuracy | Top-5 Accuracy | Parameters | Memory |
248+
| -------------- | -------------- | ---------- | -------- |
249+
| 62.4 % | 83.9 % | 6 944 584 | 3.88 MiB |
250+
251+
# Arguments
252+
input_shape: Optional shape tuple, to be specified if you would like to use a model
253+
with an input image resolution that is not (224, 224, 3).
254+
It should have exactly 3 inputs channels.
255+
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) to use as
256+
image input for the model.
257+
weights: one of `None` (random initialization), "imagenet" (pre-training on
258+
ImageNet), or the path to the weights file to be loaded.
259+
include_top: whether to include the fully-connected layer at the top of the network.
260+
num_classes: optional number of classes to classify images into, only to be
261+
specified if `include_top` is True, and if no `weights` argument is specified.
262+
263+
# Returns
264+
A Keras model instance.
265+
266+
# Raises
267+
ValueError: in case of invalid argument for `weights`, or invalid input shape.
268+
269+
# References
270+
- [MeliusNet: Can Binary Neural Networks Achieve MobileNet-level Accuracy?](https://arxiv.org/abs/2001.05936)
271+
"""
272+
return MeliusNet22Factory(
273+
include_top=include_top,
274+
weights=weights,
275+
input_tensor=input_tensor,
276+
input_shape=input_shape,
277+
num_classes=num_classes,
278+
).build()

tests/models_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def parametrize(func):
3737
(lqz.literature.BinaryDenseNet37, 640),
3838
(lqz.literature.BinaryDenseNet37Dilated, 640),
3939
(lqz.literature.BinaryDenseNet45, 800),
40+
(lqz.literature.MeliusNet22, 512),
4041
(lqz.literature.XNORNet, 4096),
4142
(lqz.literature.DoReFaNet, 256),
4243
(lqz.literature.RealToBinaryNet, 512),

0 commit comments

Comments
 (0)