Skip to content

Commit 229fae7

Browse files
committed
Don't set quantizers as fields
1 parent 01579c3 commit 229fae7

File tree

11 files changed

+108
-65
lines changed

11 files changed

+108
-65
lines changed

larq_zoo/core/model_factory.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,17 @@
1-
from typing import Callable, Optional, Tuple, Union
1+
from typing import Optional, Tuple
22

33
import tensorflow as tf
44
from zookeeper import ComponentField, Field
55
from zookeeper.tf import Dataset
66

77
from larq_zoo.core import utils
88

9-
QuantizerType = Union[
10-
tf.keras.layers.Layer, Callable[[tf.Tensor], tf.Tensor], str, None
11-
]
12-
ConstraintType = Union[tf.keras.constraints.Constraint, str, None]
139
DimType = Optional[int]
1410

1511

1612
class ModelFactory:
1713
"""A base class for Larq Zoo models. Defines some common fields."""
1814

19-
# Don't set any defaults here.
20-
input_quantizer: QuantizerType = Field()
21-
kernel_quantizer: QuantizerType = Field()
22-
kernel_constraint: ConstraintType = Field()
23-
2415
# This field is included for automatic inference of `num_clases`, if no
2516
# value is otherwise provided. We set `allow_missing` because we don't want
2617
# to throw an error if a dataset is not provided, as long as `num_classes`

larq_zoo/literature/binary_alex_net.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ class BinaryAlexNetFactory(ModelFactory):
1717

1818
inflation_ratio: int = Field(1)
1919

20-
input_quantizer = Field("ste_sign")
21-
kernel_quantizer = Field("ste_sign")
22-
kernel_constraint = Field("weight_clip")
20+
input_quantizer = "ste_sign"
21+
kernel_quantizer = "ste_sign"
22+
kernel_constraint = "weight_clip"
2323

2424
def conv_block(
2525
self,

larq_zoo/literature/birealnet.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ class BiRealNetFactory(ModelFactory):
1414

1515
filters: int = Field(64)
1616

17-
input_quantizer = Field("approx_sign")
18-
kernel_quantizer = Field("magnitude_aware_sign")
19-
kernel_constraint = Field("weight_clip")
17+
input_quantizer = "approx_sign"
18+
kernel_quantizer = "magnitude_aware_sign"
19+
kernel_constraint = "weight_clip"
2020

2121
kernel_initializer: Union[tf.keras.initializers.Initializer, str] = Field(
2222
"glorot_normal"

larq_zoo/literature/densenet.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,17 @@ class BinaryDenseNet(tf.keras.models.Model):
2121
class BinaryDenseNetFactory(ModelFactory):
2222
"""Implementation of [BinaryDenseNet](https://arxiv.org/abs/1906.08637)"""
2323

24-
input_quantizer = Field(lambda: lq.quantizers.SteSign(clip_value=1.3))
25-
kernel_quantizer = Field(lambda: lq.quantizers.SteSign(clip_value=1.3))
26-
kernel_constraint = Field(lambda: lq.constraints.WeightClip(clip_value=1.3))
24+
@property
25+
def input_quantizer(self):
26+
return lq.quantizers.SteSign(clip_value=1.3)
27+
28+
@property
29+
def kernel_quantizer(self):
30+
return lq.quantizers.SteSign(clip_value=1.3)
31+
32+
@property
33+
def kernel_constraint(self):
34+
return lq.constraints.WeightClip(clip_value=1.3)
2735

2836
initial_filters: int = Field(64)
2937
growth_rate: int = Field(64)

larq_zoo/literature/dorefanet.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,15 @@ class DoReFaNetFactory(ModelFactory):
3838

3939
activations_k_bit: int = Field(2)
4040

41-
input_quantizer = Field(
42-
lambda self: lq.quantizers.DoReFaQuantizer(k_bit=self.activations_k_bit)
43-
)
44-
kernel_quantizer = Field(lambda: magnitude_aware_sign_unclipped)
45-
kernel_constraint = Field(None)
41+
@property
42+
def input_quantizer(self):
43+
return lq.quantizers.DoReFaQuantizer(k_bit=self.activations_k_bit)
44+
45+
@property
46+
def kernel_quantizer(self):
47+
return magnitude_aware_sign_unclipped
48+
49+
kernel_constraint = None
4650

4751
def conv_block(
4852
self, x, filters, kernel_size, strides=1, pool=False, pool_padding="same"

larq_zoo/literature/meliusnet.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,18 @@ class MeliusNetFactory(ModelFactory):
2626
kernel_initializer: Optional[Union[str, tf.keras.initializers.Initializer]] = Field(
2727
"glorot_normal"
2828
)
29-
input_quantizer = Field(lambda: lq.quantizers.SteSign(1.3))
30-
kernel_quantizer = Field(lambda: lq.quantizers.SteSign(1.3))
31-
kernel_constraint = Field(lambda: lq.constraints.WeightClip(1.3))
29+
30+
@property
31+
def input_quantizer(self):
32+
return lq.quantizers.SteSign(1.3)
33+
34+
@property
35+
def kernel_quantizer(self):
36+
return lq.quantizers.SteSign(1.3)
37+
38+
@property
39+
def kernel_constraint(self):
40+
return lq.constraints.WeightClip(1.3)
3241

3342
def pool(self, x: tf.Tensor, name: str = None) -> tf.Tensor:
3443
return tf.keras.layers.MaxPool2D(2, strides=2, padding="same", name=name)(x)

larq_zoo/literature/real_to_bin_nets.py

+40-25
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from zookeeper import Field, factory
1212

1313
from larq_zoo.core import utils
14-
from larq_zoo.core.model_factory import ModelFactory, QuantizerType
14+
from larq_zoo.core.model_factory import ModelFactory
1515

1616

1717
class _SharedBaseFactory(ModelFactory, metaclass=ABCMeta):
@@ -126,8 +126,8 @@ class StrongBaselineNetFactory(_SharedBaseFactory):
126126

127127
scaling_r: int = 8
128128

129-
input_quantizer: QuantizerType = Field(None)
130-
kernel_quantizer: QuantizerType = Field(None)
129+
input_quantizer = None
130+
kernel_quantizer = None
131131

132132
class LearnedRescaleLayer(tf.keras.layers.Layer):
133133
"""Implements the learned activation rescaling XNOR-Net++ style.
@@ -359,53 +359,68 @@ def block(
359359
@factory
360360
class StrongBaselineNetBANFactory(StrongBaselineNetFactory):
361361
model_name = Field("baseline_ban")
362-
input_quantizer = Field("ste_sign")
363-
kernel_quantizer = Field(None)
364-
kernel_constraint = Field(None)
365-
kernel_regularizer = Field(lambda: tf.keras.regularizers.l2(1e-5))
362+
input_quantizer = "ste_sign"
363+
kernel_quantizer = None
364+
kernel_constraint = None
365+
366+
@property
367+
def kernel_regularizer(self):
368+
return tf.keras.regularizers.l2(1e-5)
366369

367370

368371
@factory
369372
class StrongBaselineNetBNNFactory(StrongBaselineNetFactory):
370373
model_name = Field("baseline_bnn")
371-
input_quantizer = Field("ste_sign")
372-
kernel_quantizer = Field("ste_sign")
373-
kernel_constraint = Field("weight_clip")
374+
input_quantizer = "ste_sign"
375+
kernel_quantizer = "ste_sign"
376+
kernel_constraint = "weight_clip"
374377

375378

376379
@factory
377380
class RealToBinNetFPFactory(RealToBinNetFactory):
378381
model_name = Field("r2b_fp")
379-
input_quantizer = Field(lambda: tf.keras.layers.Activation("tanh"))
380-
kernel_quantizer = Field(None)
381-
kernel_constraint = Field(None)
382-
kernel_regularizer = Field(lambda: tf.keras.regularizers.l2(1e-5))
382+
kernel_quantizer = None
383+
kernel_constraint = None
384+
385+
@property
386+
def input_quantizer(self):
387+
return tf.keras.layers.Activation("tanh")
388+
389+
@property
390+
def kernel_regularizer(self):
391+
return tf.keras.regularizers.l2(1e-5)
383392

384393

385394
@factory
386395
class RealToBinNetBANFactory(RealToBinNetFactory):
387396
model_name = Field("r2b_ban")
388-
input_quantizer = Field("ste_sign")
389-
kernel_quantizer = Field(None)
390-
kernel_constraint = Field(None)
391-
kernel_regularizer = Field(lambda: tf.keras.regularizers.l2(1e-5))
397+
input_quantizer = "ste_sign"
398+
kernel_quantizer = None
399+
kernel_constraint = None
400+
401+
@property
402+
def kernel_regularizer(self):
403+
return tf.keras.regularizers.l2(1e-5)
392404

393405

394406
@factory
395407
class RealToBinNetBNNFactory(RealToBinNetFactory):
396408
model_name = Field("r2b_bnn")
397-
input_quantizer = Field("ste_sign")
398-
kernel_quantizer = Field("ste_sign")
399-
kernel_constraint = Field("weight_clip")
409+
input_quantizer = "ste_sign"
410+
kernel_quantizer = "ste_sign"
411+
kernel_constraint = "weight_clip"
400412

401413

402414
@factory
403415
class ResNet18FPFactory(ResNet18Factory):
404416
model_name = Field("resnet_fp")
405-
input_quantizer = Field(None)
406-
kernel_quantizer = Field(None)
407-
kernel_constraint = Field(None)
408-
kernel_regularizer = Field(lambda: tf.keras.regularizers.l2(1e-5))
417+
input_quantizer = None
418+
kernel_quantizer = None
419+
kernel_constraint = None
420+
421+
@property
422+
def kernel_regularizer(self):
423+
return tf.keras.regularizers.l2(1e-5)
409424

410425

411426
def RealToBinaryNet(

larq_zoo/literature/resnet_e.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,17 @@ class BinaryResNetE18Factory(ModelFactory):
1515
num_layers: int = Field(18)
1616
initial_filters: int = Field(64)
1717

18-
input_quantizer = Field(lambda: lq.quantizers.SteSign(clip_value=1.25))
19-
kernel_quantizer = Field(lambda: lq.quantizers.SteSign(clip_value=1.25))
20-
kernel_constraint = Field(lambda: lq.constraints.WeightClip(clip_value=1.25))
18+
@property
19+
def input_quantizer(self):
20+
return lq.quantizers.SteSign(clip_value=1.25)
21+
22+
@property
23+
def kernel_quantizer(self):
24+
return lq.quantizers.SteSign(clip_value=1.25)
25+
26+
@property
27+
def kernel_constraint(self):
28+
return lq.constraints.WeightClip(clip_value=1.25)
2129

2230
@property
2331
def spec(self):

larq_zoo/literature/xnornet.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ def xnor_weight_scale(x):
2424
class XNORNetFactory(ModelFactory):
2525
"""Implementation of [XNOR-Net](https://arxiv.org/abs/1603.05279)"""
2626

27-
input_quantizer = Field("ste_sign")
28-
kernel_quantizer = Field("xnor_weight_scale")
29-
kernel_constraint = Field("weight_clip")
27+
input_quantizer = "ste_sign"
28+
kernel_quantizer = "xnor_weight_scale"
29+
kernel_constraint = "weight_clip"
3030

3131
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = Field(
3232
lambda: tf.keras.regularizers.l2(5e-7)

larq_zoo/sota/quicknet.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,17 @@ class QuickNetBaseFactory(ModelFactory, abc.ABC):
6565
transition_block: Callable[..., tf.Tensor] = Field()
6666
stem_filters: int = Field(64)
6767

68-
input_quantizer = Field(lambda: lq.quantizers.SteSign(clip_value=1.25))
69-
kernel_quantizer = Field(lambda: lq.quantizers.SteSign(clip_value=1.25))
70-
kernel_constraint = Field(lambda: lq.constraints.WeightClip(clip_value=1.25))
68+
@property
69+
def input_quantizer(self):
70+
return lq.quantizers.SteSign(clip_value=1.25)
71+
72+
@property
73+
def kernel_quantizer(self):
74+
return lq.quantizers.SteSign(clip_value=1.25)
75+
76+
@property
77+
def kernel_constraint(self):
78+
return lq.constraints.WeightClip(clip_value=1.25)
7179

7280
def __post_configure__(self):
7381
assert (

larq_zoo/training/knowledge_distillation/knowledge_distillation.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import tensorflow as tf
55
from zookeeper import ComponentField, Field, factory
66

7-
from larq_zoo.core.model_factory import ConstraintType, ModelFactory, QuantizerType
7+
from larq_zoo.core.model_factory import ModelFactory
88

99

1010
class AttentionMatchingLossLayer(tf.keras.layers.Layer):
@@ -296,9 +296,9 @@ class TeacherStudentModelFactory(ModelFactory):
296296
teacher_model: tf.keras.models.Model = ComponentField(allow_missing=True)
297297
student_model: tf.keras.models.Model = ComponentField()
298298

299-
input_quantizer: QuantizerType = Field(allow_missing=True)
300-
kernel_quantizer: QuantizerType = Field(allow_missing=True)
301-
kernel_constraint: ConstraintType = Field(allow_missing=True)
299+
input_quantizer = None
300+
kernel_quantizer = None
301+
kernel_constraint = None
302302

303303
# Must be set if there is a teacher and allow_missing teacher weights is not True.
304304
# Either a full path or the name of a network (in which case it will be sought in the current `model_dir`).

0 commit comments

Comments
 (0)