Skip to content

Commit d01a683

Browse files
committed
Don't set quantizers as fields
1 parent 959ae94 commit d01a683

File tree

8 files changed

+52
-33
lines changed

8 files changed

+52
-33
lines changed

larq_zoo/core/model_factory.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,17 @@
1-
from typing import Callable, Optional, Tuple, Union
1+
from typing import Optional, Tuple
22

33
import tensorflow as tf
44
from zookeeper import ComponentField, Field
55
from zookeeper.tf import Dataset
66

77
from larq_zoo.core import utils
88

9-
QuantizerType = Union[
10-
tf.keras.layers.Layer, Callable[[tf.Tensor], tf.Tensor], str, None
11-
]
12-
ConstraintType = Union[tf.keras.constraints.Constraint, str, None]
139
DimType = Optional[int]
1410

1511

1612
class ModelFactory:
1713
"""A base class for Larq Zoo models. Defines some common fields."""
1814

19-
# Don't set any defaults here.
20-
input_quantizer: QuantizerType = Field()
21-
kernel_quantizer: QuantizerType = Field()
22-
kernel_constraint: ConstraintType = Field()
23-
2415
# This field is included for automatic inference of `num_clases`, if no
2516
# value is otherwise provided. We set `allow_missing` because we don't want
2617
# to throw an error if a dataset is not provided, as long as `num_classes`

larq_zoo/literature/binary_alex_net.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ class BinaryAlexNetFactory(ModelFactory):
1717

1818
inflation_ratio: int = Field(1)
1919

20-
input_quantizer = Field("ste_sign")
21-
kernel_quantizer = Field("ste_sign")
22-
kernel_constraint = Field("weight_clip")
20+
input_quantizer = "ste_sign"
21+
kernel_quantizer = "ste_sign"
22+
kernel_constraint = "weight_clip"
2323

2424
def conv_block(
2525
self,

larq_zoo/literature/birealnet.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ class BiRealNetFactory(ModelFactory):
1414

1515
filters: int = Field(64)
1616

17-
input_quantizer = Field("approx_sign")
18-
kernel_quantizer = Field("magnitude_aware_sign")
19-
kernel_constraint = Field("weight_clip")
17+
input_quantizer = "approx_sign"
18+
kernel_quantizer = "magnitude_aware_sign"
19+
kernel_constraint = "weight_clip"
2020

2121
kernel_initializer: Union[tf.keras.initializers.Initializer, str] = Field(
2222
"glorot_normal"

larq_zoo/literature/densenet.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,17 @@ class BinaryDenseNet(tf.keras.models.Model):
2121
class BinaryDenseNetFactory(ModelFactory):
2222
"""Implementation of [BinaryDenseNet](https://arxiv.org/abs/1906.08637)"""
2323

24-
input_quantizer = Field(lambda: lq.quantizers.SteSign(clip_value=1.3))
25-
kernel_quantizer = Field(lambda: lq.quantizers.SteSign(clip_value=1.3))
26-
kernel_constraint = Field(lambda: lq.constraints.WeightClip(clip_value=1.3))
24+
@property
25+
def input_quantizer(self):
26+
return lq.quantizers.SteSign(clip_value=1.3)
27+
28+
@property
29+
def kernel_quantizer(self):
30+
return lq.quantizers.SteSign(clip_value=1.3)
31+
32+
@property
33+
def kernel_constraint(self):
34+
return lq.constraints.WeightClip(clip_value=1.3)
2735

2836
initial_filters: int = Field(64)
2937
growth_rate: int = Field(64)

larq_zoo/literature/dorefanet.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,15 @@ class DoReFaNetFactory(ModelFactory):
3838

3939
activations_k_bit: int = Field(2)
4040

41-
input_quantizer = Field(
42-
lambda self: lq.quantizers.DoReFaQuantizer(k_bit=self.activations_k_bit)
43-
)
44-
kernel_quantizer = Field(lambda: magnitude_aware_sign_unclipped)
45-
kernel_constraint = Field(None)
41+
@property
42+
def input_quantizer(self):
43+
return lq.quantizers.DoReFaQuantizer(k_bit=self.activations_k_bit)
44+
45+
@property
46+
def kernel_quantizer(self):
47+
return magnitude_aware_sign_unclipped
48+
49+
kernel_constraint = None
4650

4751
def conv_block(
4852
self, x, filters, kernel_size, strides=1, pool=False, pool_padding="same"

larq_zoo/literature/resnet_e.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,17 @@ class BinaryResNetE18Factory(ModelFactory):
1515
num_layers: int = Field(18)
1616
initial_filters: int = Field(64)
1717

18-
input_quantizer = Field(lambda: lq.quantizers.SteSign(clip_value=1.25))
19-
kernel_quantizer = Field(lambda: lq.quantizers.SteSign(clip_value=1.25))
20-
kernel_constraint = Field(lambda: lq.constraints.WeightClip(clip_value=1.25))
18+
@property
19+
def input_quantizer(self):
20+
return lq.quantizers.SteSign(clip_value=1.25)
21+
22+
@property
23+
def kernel_quantizer(self):
24+
return lq.quantizers.SteSign(clip_value=1.25)
25+
26+
@property
27+
def kernel_constraint(self):
28+
return lq.constraints.WeightClip(clip_value=1.25)
2129

2230
@property
2331
def spec(self):

larq_zoo/literature/xnornet.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ def xnor_weight_scale(x):
2424
class XNORNetFactory(ModelFactory):
2525
"""Implementation of [XNOR-Net](https://arxiv.org/abs/1603.05279)"""
2626

27-
input_quantizer = Field("ste_sign")
28-
kernel_quantizer = Field("xnor_weight_scale")
29-
kernel_constraint = Field("weight_clip")
27+
input_quantizer = "ste_sign"
28+
kernel_quantizer = "xnor_weight_scale"
29+
kernel_constraint = "weight_clip"
3030

3131
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = Field(
3232
lambda: tf.keras.regularizers.l2(5e-7)

larq_zoo/sota/quicknet.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,17 @@ class QuickNetBaseFactory(ModelFactory):
6868
transition_block: MethodType = Field(None)
6969
stem_filters: int = Field(64)
7070

71-
input_quantizer = Field(lambda: lq.quantizers.SteSign(clip_value=1.25))
72-
kernel_quantizer = Field(lambda: lq.quantizers.SteSign(clip_value=1.25))
73-
kernel_constraint = Field(lambda: lq.constraints.WeightClip(clip_value=1.25))
71+
@property
72+
def input_quantizer(self):
73+
return lq.quantizers.SteSign(clip_value=1.25)
74+
75+
@property
76+
def kernel_quantizer(self):
77+
return lq.quantizers.SteSign(clip_value=1.25)
78+
79+
@property
80+
def kernel_constraint(self):
81+
return lq.constraints.WeightClip(clip_value=1.25)
7482

7583
def conv_block(
7684
self, x: tf.Tensor, filters: int, use_squeeze_and_excite: bool, strides: int = 1

0 commit comments

Comments
 (0)