diff --git a/official/vision/image_classification/resnet/resnet_model.py b/official/vision/image_classification/resnet/resnet_model.py index 643ca5ad18c..10f1233356e 100644 --- a/official/vision/image_classification/resnet/resnet_model.py +++ b/official/vision/image_classification/resnet/resnet_model.py @@ -35,15 +35,11 @@ from tensorflow.python.keras import regularizers from official.vision.image_classification.resnet import imagenet_preprocessing -L2_WEIGHT_DECAY = 1e-4 -BATCH_NORM_DECAY = 0.9 -BATCH_NORM_EPSILON = 1e-5 - layers = tf.keras.layers -def _gen_l2_regularizer(use_l2_regularizer=True): - return regularizers.l2(L2_WEIGHT_DECAY) if use_l2_regularizer else None +def _gen_l2_regularizer(use_l2_regularizer=True, l2_weight_decay=1e-4): + return regularizers.l2(l2_weight_decay) if use_l2_regularizer else None def identity_block(input_tensor, @@ -51,7 +47,9 @@ def identity_block(input_tensor, filters, stage, block, - use_l2_regularizer=True): + use_l2_regularizer=True, + batch_norm_decay=0.9, + batch_norm_epsilon=1e-5): """The identity block is the block that has no conv layer at shortcut. Args: @@ -61,6 +59,8 @@ def identity_block(input_tensor, stage: integer, current stage label, used for generating layer names block: 'a','b'..., current block label, used for generating layer names use_l2_regularizer: whether to use L2 regularizer on Conv layer. + batch_norm_decay: Moment of batch norm layers. + batch_norm_epsilon: Epsilon of batch borm layers. Returns: Output tensor for the block. @@ -82,8 +82,8 @@ def identity_block(input_tensor, input_tensor) x = layers.BatchNormalization( axis=bn_axis, - momentum=BATCH_NORM_DECAY, - epsilon=BATCH_NORM_EPSILON, + momentum=batch_norm_decay, + epsilon=batch_norm_epsilon, name=bn_name_base + '2a')( x) x = layers.Activation('relu')(x) @@ -99,8 +99,8 @@ def identity_block(input_tensor, x) x = layers.BatchNormalization( axis=bn_axis, - momentum=BATCH_NORM_DECAY, - epsilon=BATCH_NORM_EPSILON, + momentum=batch_norm_decay, + epsilon=batch_norm_epsilon, name=bn_name_base + '2b')( x) x = layers.Activation('relu')(x) @@ -114,8 +114,8 @@ def identity_block(input_tensor, x) x = layers.BatchNormalization( axis=bn_axis, - momentum=BATCH_NORM_DECAY, - epsilon=BATCH_NORM_EPSILON, + momentum=batch_norm_decay, + epsilon=batch_norm_epsilon, name=bn_name_base + '2c')( x) @@ -130,7 +130,9 @@ def conv_block(input_tensor, stage, block, strides=(2, 2), - use_l2_regularizer=True): + use_l2_regularizer=True, + batch_norm_decay=0.9, + batch_norm_epsilon=1e-5): """A block that has a conv layer at shortcut. Note that from stage 3, @@ -145,6 +147,8 @@ def conv_block(input_tensor, block: 'a','b'..., current block label, used for generating layer names strides: Strides for the second conv layer in the block. use_l2_regularizer: whether to use L2 regularizer on Conv layer. + batch_norm_decay: Moment of batch norm layers. + batch_norm_epsilon: Epsilon of batch borm layers. Returns: Output tensor for the block. @@ -166,8 +170,8 @@ def conv_block(input_tensor, input_tensor) x = layers.BatchNormalization( axis=bn_axis, - momentum=BATCH_NORM_DECAY, - epsilon=BATCH_NORM_EPSILON, + momentum=batch_norm_decay, + epsilon=batch_norm_epsilon, name=bn_name_base + '2a')( x) x = layers.Activation('relu')(x) @@ -184,8 +188,8 @@ def conv_block(input_tensor, x) x = layers.BatchNormalization( axis=bn_axis, - momentum=BATCH_NORM_DECAY, - epsilon=BATCH_NORM_EPSILON, + momentum=batch_norm_decay, + epsilon=batch_norm_epsilon, name=bn_name_base + '2b')( x) x = layers.Activation('relu')(x) @@ -199,8 +203,8 @@ def conv_block(input_tensor, x) x = layers.BatchNormalization( axis=bn_axis, - momentum=BATCH_NORM_DECAY, - epsilon=BATCH_NORM_EPSILON, + momentum=batch_norm_decay, + epsilon=batch_norm_epsilon, name=bn_name_base + '2c')( x) @@ -214,8 +218,8 @@ def conv_block(input_tensor, input_tensor) shortcut = layers.BatchNormalization( axis=bn_axis, - momentum=BATCH_NORM_DECAY, - epsilon=BATCH_NORM_EPSILON, + momentum=batch_norm_decay, + epsilon=batch_norm_epsilon, name=bn_name_base + '1')( shortcut) @@ -227,7 +231,9 @@ def conv_block(input_tensor, def resnet50(num_classes, batch_size=None, use_l2_regularizer=True, - rescale_inputs=False): + rescale_inputs=False, + batch_norm_decay=0.9, + batch_norm_epsilon=1e-5): """Instantiates the ResNet50 architecture. Args: @@ -235,6 +241,8 @@ def resnet50(num_classes, batch_size: Size of the batches for each step. use_l2_regularizer: whether to use L2 regularizer on Conv/Dense layer. rescale_inputs: whether to rescale inputs from 0 to 1. + batch_norm_decay: Moment of batch norm layers. + batch_norm_epsilon: Epsilon of batch borm layers. Returns: A Keras model instance. @@ -260,6 +268,10 @@ def resnet50(num_classes, else: # channels_last bn_axis = 3 + block_config = dict( + use_l2_regularizer=use_l2_regularizer, + batch_norm_decay=batch_norm_decay, + batch_norm_epsilon=batch_norm_epsilon) x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x) x = layers.Conv2D( 64, (7, 7), @@ -272,113 +284,33 @@ def resnet50(num_classes, x) x = layers.BatchNormalization( axis=bn_axis, - momentum=BATCH_NORM_DECAY, - epsilon=BATCH_NORM_EPSILON, + momentum=batch_norm_decay, + epsilon=batch_norm_epsilon, name='bn_conv1')( x) x = layers.Activation('relu')(x) x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x) x = conv_block( - x, - 3, [64, 64, 256], - stage=2, - block='a', - strides=(1, 1), - use_l2_regularizer=use_l2_regularizer) - x = identity_block( - x, - 3, [64, 64, 256], - stage=2, - block='b', - use_l2_regularizer=use_l2_regularizer) - x = identity_block( - x, - 3, [64, 64, 256], - stage=2, - block='c', - use_l2_regularizer=use_l2_regularizer) - - x = conv_block( - x, - 3, [128, 128, 512], - stage=3, - block='a', - use_l2_regularizer=use_l2_regularizer) - x = identity_block( - x, - 3, [128, 128, 512], - stage=3, - block='b', - use_l2_regularizer=use_l2_regularizer) - x = identity_block( - x, - 3, [128, 128, 512], - stage=3, - block='c', - use_l2_regularizer=use_l2_regularizer) - x = identity_block( - x, - 3, [128, 128, 512], - stage=3, - block='d', - use_l2_regularizer=use_l2_regularizer) - - x = conv_block( - x, - 3, [256, 256, 1024], - stage=4, - block='a', - use_l2_regularizer=use_l2_regularizer) - x = identity_block( - x, - 3, [256, 256, 1024], - stage=4, - block='b', - use_l2_regularizer=use_l2_regularizer) - x = identity_block( - x, - 3, [256, 256, 1024], - stage=4, - block='c', - use_l2_regularizer=use_l2_regularizer) - x = identity_block( - x, - 3, [256, 256, 1024], - stage=4, - block='d', - use_l2_regularizer=use_l2_regularizer) - x = identity_block( - x, - 3, [256, 256, 1024], - stage=4, - block='e', - use_l2_regularizer=use_l2_regularizer) - x = identity_block( - x, - 3, [256, 256, 1024], - stage=4, - block='f', - use_l2_regularizer=use_l2_regularizer) - - x = conv_block( - x, - 3, [512, 512, 2048], - stage=5, - block='a', - use_l2_regularizer=use_l2_regularizer) - x = identity_block( - x, - 3, [512, 512, 2048], - stage=5, - block='b', - use_l2_regularizer=use_l2_regularizer) - x = identity_block( - x, - 3, [512, 512, 2048], - stage=5, - block='c', - use_l2_regularizer=use_l2_regularizer) + x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1), **block_config) + x = identity_block(x, 3, [64, 64, 256], stage=2, block='b', **block_config) + x = identity_block(x, 3, [64, 64, 256], stage=2, block='c', **block_config) + + x = conv_block(x, 3, [128, 128, 512], stage=3, block='a', **block_config) + x = identity_block(x, 3, [128, 128, 512], stage=3, block='b', **block_config) + x = identity_block(x, 3, [128, 128, 512], stage=3, block='c', **block_config) + x = identity_block(x, 3, [128, 128, 512], stage=3, block='d', **block_config) + + x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a', **block_config) + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b', **block_config) + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c', **block_config) + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d', **block_config) + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e', **block_config) + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f', **block_config) + + x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a', **block_config) + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b', **block_config) + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c', **block_config) x = layers.GlobalAveragePooling2D()(x) x = layers.Dense( diff --git a/official/vision/image_classification/resnet/resnet_runnable.py b/official/vision/image_classification/resnet/resnet_runnable.py index b6992ae631d..2bd1052e840 100644 --- a/official/vision/image_classification/resnet/resnet_runnable.py +++ b/official/vision/image_classification/resnet/resnet_runnable.py @@ -158,9 +158,9 @@ def step_fn(inputs): loss = tf.reduce_sum(prediction_loss) * (1.0 / self.flags_obj.batch_size) num_replicas = self.strategy.num_replicas_in_sync - + l2_weight_decay = 1e-4 if self.flags_obj.single_l2_loss_op: - l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.add_n([ + l2_loss = l2_weight_decay * 2 * tf.add_n([ tf.nn.l2_loss(v) for v in self.model.trainable_variables if 'bn' not in v.name