From 567db23f314063c0953f9117b6edcf40fd1d6291 Mon Sep 17 00:00:00 2001 From: Kristofer Krus Date: Sat, 20 Jul 2019 16:49:02 +0200 Subject: [PATCH] Add generic constructor Add a parameterized constructor for creating any of the EfficientNet variants, i.e. from EfficientNetB0 to EfficientNetB7. --- efficientnet/keras.py | 1 + efficientnet/model.py | 13 +++++++++++++ efficientnet/tfkeras.py | 1 + 3 files changed, 15 insertions(+) diff --git a/efficientnet/keras.py b/efficientnet/keras.py index 66e9b7f..d7be213 100644 --- a/efficientnet/keras.py +++ b/efficientnet/keras.py @@ -3,6 +3,7 @@ from .preprocessing import center_crop_and_resize +EfficientNetBn = inject_keras_modules(model.EfficientNetBn) EfficientNetB0 = inject_keras_modules(model.EfficientNetB0) EfficientNetB1 = inject_keras_modules(model.EfficientNetB1) EfficientNetB2 = inject_keras_modules(model.EfficientNetB2) diff --git a/efficientnet/model.py b/efficientnet/model.py index 7e09715..e8d2637 100644 --- a/efficientnet/model.py +++ b/efficientnet/model.py @@ -473,6 +473,19 @@ def EfficientNet(width_coefficient, return model +def EfficientNetBn(n, *args, **kwargs): + return [ + EfficientNetB0, + EfficientNetB1, + EfficientNetB2, + EfficientNetB3, + EfficientNetB4, + EfficientNetB5, + EfficientNetB6, + EfficientNetB7, + ][n](*args, **kwargs) + + def EfficientNetB0(include_top=True, weights='imagenet', input_tensor=None, diff --git a/efficientnet/tfkeras.py b/efficientnet/tfkeras.py index 6132984..e79ee7f 100644 --- a/efficientnet/tfkeras.py +++ b/efficientnet/tfkeras.py @@ -3,6 +3,7 @@ from .preprocessing import center_crop_and_resize +EfficientNetBn = inject_tfkeras_modules(model.EfficientNetBn) EfficientNetB0 = inject_tfkeras_modules(model.EfficientNetB0) EfficientNetB1 = inject_tfkeras_modules(model.EfficientNetB1) EfficientNetB2 = inject_tfkeras_modules(model.EfficientNetB2)