|
11 | 11 | from zookeeper import Field, factory
|
12 | 12 |
|
13 | 13 | from larq_zoo.core import utils
|
14 |
| -from larq_zoo.core.model_factory import ModelFactory, QuantizerType |
| 14 | +from larq_zoo.core.model_factory import ModelFactory |
15 | 15 |
|
16 | 16 |
|
17 | 17 | class _SharedBaseFactory(ModelFactory, metaclass=ABCMeta):
|
@@ -126,8 +126,8 @@ class StrongBaselineNetFactory(_SharedBaseFactory):
|
126 | 126 |
|
127 | 127 | scaling_r: int = 8
|
128 | 128 |
|
129 |
| - input_quantizer: QuantizerType = Field(None) |
130 |
| - kernel_quantizer: QuantizerType = Field(None) |
| 129 | + input_quantizer = None |
| 130 | + kernel_quantizer = None |
131 | 131 |
|
132 | 132 | class LearnedRescaleLayer(tf.keras.layers.Layer):
|
133 | 133 | """Implements the learned activation rescaling XNOR-Net++ style.
|
@@ -359,53 +359,68 @@ def block(
|
359 | 359 | @factory
|
360 | 360 | class StrongBaselineNetBANFactory(StrongBaselineNetFactory):
|
361 | 361 | model_name = Field("baseline_ban")
|
362 |
| - input_quantizer = Field("ste_sign") |
363 |
| - kernel_quantizer = Field(None) |
364 |
| - kernel_constraint = Field(None) |
365 |
| - kernel_regularizer = Field(lambda: tf.keras.regularizers.l2(1e-5)) |
| 362 | + input_quantizer = "ste_sign" |
| 363 | + kernel_quantizer = None |
| 364 | + kernel_constraint = None |
| 365 | + |
| 366 | + @property |
| 367 | + def kernel_regularizer(self): |
| 368 | + return tf.keras.regularizers.l2(1e-5) |
366 | 369 |
|
367 | 370 |
|
368 | 371 | @factory
|
369 | 372 | class StrongBaselineNetBNNFactory(StrongBaselineNetFactory):
|
370 | 373 | model_name = Field("baseline_bnn")
|
371 |
| - input_quantizer = Field("ste_sign") |
372 |
| - kernel_quantizer = Field("ste_sign") |
373 |
| - kernel_constraint = Field("weight_clip") |
| 374 | + input_quantizer = "ste_sign" |
| 375 | + kernel_quantizer = "ste_sign" |
| 376 | + kernel_constraint = "weight_clip" |
374 | 377 |
|
375 | 378 |
|
376 | 379 | @factory
|
377 | 380 | class RealToBinNetFPFactory(RealToBinNetFactory):
|
378 | 381 | model_name = Field("r2b_fp")
|
379 |
| - input_quantizer = Field(lambda: tf.keras.layers.Activation("tanh")) |
380 |
| - kernel_quantizer = Field(None) |
381 |
| - kernel_constraint = Field(None) |
382 |
| - kernel_regularizer = Field(lambda: tf.keras.regularizers.l2(1e-5)) |
| 382 | + kernel_quantizer = None |
| 383 | + kernel_constraint = None |
| 384 | + |
| 385 | + @property |
| 386 | + def input_quantizer(self): |
| 387 | + return tf.keras.layers.Activation("tanh") |
| 388 | + |
| 389 | + @property |
| 390 | + def kernel_regularizer(self): |
| 391 | + return tf.keras.regularizers.l2(1e-5) |
383 | 392 |
|
384 | 393 |
|
385 | 394 | @factory
|
386 | 395 | class RealToBinNetBANFactory(RealToBinNetFactory):
|
387 | 396 | model_name = Field("r2b_ban")
|
388 |
| - input_quantizer = Field("ste_sign") |
389 |
| - kernel_quantizer = Field(None) |
390 |
| - kernel_constraint = Field(None) |
391 |
| - kernel_regularizer = Field(lambda: tf.keras.regularizers.l2(1e-5)) |
| 397 | + input_quantizer = "ste_sign" |
| 398 | + kernel_quantizer = None |
| 399 | + kernel_constraint = None |
| 400 | + |
| 401 | + @property |
| 402 | + def kernel_regularizer(self): |
| 403 | + return tf.keras.regularizers.l2(1e-5) |
392 | 404 |
|
393 | 405 |
|
394 | 406 | @factory
|
395 | 407 | class RealToBinNetBNNFactory(RealToBinNetFactory):
|
396 | 408 | model_name = Field("r2b_bnn")
|
397 |
| - input_quantizer = Field("ste_sign") |
398 |
| - kernel_quantizer = Field("ste_sign") |
399 |
| - kernel_constraint = Field("weight_clip") |
| 409 | + input_quantizer = "ste_sign" |
| 410 | + kernel_quantizer = "ste_sign" |
| 411 | + kernel_constraint = "weight_clip" |
400 | 412 |
|
401 | 413 |
|
402 | 414 | @factory
|
403 | 415 | class ResNet18FPFactory(ResNet18Factory):
|
404 | 416 | model_name = Field("resnet_fp")
|
405 |
| - input_quantizer = Field(None) |
406 |
| - kernel_quantizer = Field(None) |
407 |
| - kernel_constraint = Field(None) |
408 |
| - kernel_regularizer = Field(lambda: tf.keras.regularizers.l2(1e-5)) |
| 417 | + input_quantizer = None |
| 418 | + kernel_quantizer = None |
| 419 | + kernel_constraint = None |
| 420 | + |
| 421 | + @property |
| 422 | + def kernel_regularizer(self): |
| 423 | + return tf.keras.regularizers.l2(1e-5) |
409 | 424 |
|
410 | 425 |
|
411 | 426 | def RealToBinaryNet(
|
|
0 commit comments