diff --git a/efficientnet/keras.py b/efficientnet/keras.py index 66e9b7f..d7be213 100644 --- a/efficientnet/keras.py +++ b/efficientnet/keras.py @@ -3,6 +3,7 @@ from .preprocessing import center_crop_and_resize +EfficientNetBn = inject_keras_modules(model.EfficientNetBn) EfficientNetB0 = inject_keras_modules(model.EfficientNetB0) EfficientNetB1 = inject_keras_modules(model.EfficientNetB1) EfficientNetB2 = inject_keras_modules(model.EfficientNetB2) diff --git a/efficientnet/model.py b/efficientnet/model.py index 7e09715..e8d2637 100644 --- a/efficientnet/model.py +++ b/efficientnet/model.py @@ -473,6 +473,19 @@ def EfficientNet(width_coefficient, return model +def EfficientNetBn(n, *args, **kwargs): + return [ + EfficientNetB0, + EfficientNetB1, + EfficientNetB2, + EfficientNetB3, + EfficientNetB4, + EfficientNetB5, + EfficientNetB6, + EfficientNetB7, + ][n](*args, **kwargs) + + def EfficientNetB0(include_top=True, weights='imagenet', input_tensor=None, diff --git a/efficientnet/tfkeras.py b/efficientnet/tfkeras.py index 6132984..e79ee7f 100644 --- a/efficientnet/tfkeras.py +++ b/efficientnet/tfkeras.py @@ -3,6 +3,7 @@ from .preprocessing import center_crop_and_resize +EfficientNetBn = inject_tfkeras_modules(model.EfficientNetBn) EfficientNetB0 = inject_tfkeras_modules(model.EfficientNetB0) EfficientNetB1 = inject_tfkeras_modules(model.EfficientNetB1) EfficientNetB2 = inject_tfkeras_modules(model.EfficientNetB2)