Skip to content

Commit

Permalink
Readability: Avoid global variables in resnet construction.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 302932162
  • Loading branch information
saberkun authored and tensorflower-gardener committed Mar 25, 2020
1 parent ad09cf4 commit 7a25758
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 126 deletions.
180 changes: 56 additions & 124 deletions official/vision/image_classification/resnet/resnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,21 @@
from tensorflow.python.keras import regularizers
from official.vision.image_classification.resnet import imagenet_preprocessing

L2_WEIGHT_DECAY = 1e-4
BATCH_NORM_DECAY = 0.9
BATCH_NORM_EPSILON = 1e-5

layers = tf.keras.layers


def _gen_l2_regularizer(use_l2_regularizer=True):
return regularizers.l2(L2_WEIGHT_DECAY) if use_l2_regularizer else None
def _gen_l2_regularizer(use_l2_regularizer=True, l2_weight_decay=1e-4):
return regularizers.l2(l2_weight_decay) if use_l2_regularizer else None


def identity_block(input_tensor,
kernel_size,
filters,
stage,
block,
use_l2_regularizer=True):
use_l2_regularizer=True,
batch_norm_decay=0.9,
batch_norm_epsilon=1e-5):
"""The identity block is the block that has no conv layer at shortcut.
Args:
Expand All @@ -61,6 +59,8 @@ def identity_block(input_tensor,
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
use_l2_regularizer: whether to use L2 regularizer on Conv layer.
batch_norm_decay: Moment of batch norm layers.
batch_norm_epsilon: Epsilon of batch borm layers.
Returns:
Output tensor for the block.
Expand All @@ -82,8 +82,8 @@ def identity_block(input_tensor,
input_tensor)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name=bn_name_base + '2a')(
x)
x = layers.Activation('relu')(x)
Expand All @@ -99,8 +99,8 @@ def identity_block(input_tensor,
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name=bn_name_base + '2b')(
x)
x = layers.Activation('relu')(x)
Expand All @@ -114,8 +114,8 @@ def identity_block(input_tensor,
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name=bn_name_base + '2c')(
x)

Expand All @@ -130,7 +130,9 @@ def conv_block(input_tensor,
stage,
block,
strides=(2, 2),
use_l2_regularizer=True):
use_l2_regularizer=True,
batch_norm_decay=0.9,
batch_norm_epsilon=1e-5):
"""A block that has a conv layer at shortcut.
Note that from stage 3,
Expand All @@ -145,6 +147,8 @@ def conv_block(input_tensor,
block: 'a','b'..., current block label, used for generating layer names
strides: Strides for the second conv layer in the block.
use_l2_regularizer: whether to use L2 regularizer on Conv layer.
batch_norm_decay: Moment of batch norm layers.
batch_norm_epsilon: Epsilon of batch borm layers.
Returns:
Output tensor for the block.
Expand All @@ -166,8 +170,8 @@ def conv_block(input_tensor,
input_tensor)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name=bn_name_base + '2a')(
x)
x = layers.Activation('relu')(x)
Expand All @@ -184,8 +188,8 @@ def conv_block(input_tensor,
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name=bn_name_base + '2b')(
x)
x = layers.Activation('relu')(x)
Expand All @@ -199,8 +203,8 @@ def conv_block(input_tensor,
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name=bn_name_base + '2c')(
x)

Expand All @@ -214,8 +218,8 @@ def conv_block(input_tensor,
input_tensor)
shortcut = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name=bn_name_base + '1')(
shortcut)

Expand All @@ -227,14 +231,18 @@ def conv_block(input_tensor,
def resnet50(num_classes,
batch_size=None,
use_l2_regularizer=True,
rescale_inputs=False):
rescale_inputs=False,
batch_norm_decay=0.9,
batch_norm_epsilon=1e-5):
"""Instantiates the ResNet50 architecture.
Args:
num_classes: `int` number of classes for image classification.
batch_size: Size of the batches for each step.
use_l2_regularizer: whether to use L2 regularizer on Conv/Dense layer.
rescale_inputs: whether to rescale inputs from 0 to 1.
batch_norm_decay: Moment of batch norm layers.
batch_norm_epsilon: Epsilon of batch borm layers.
Returns:
A Keras model instance.
Expand All @@ -260,6 +268,10 @@ def resnet50(num_classes,
else: # channels_last
bn_axis = 3

block_config = dict(
use_l2_regularizer=use_l2_regularizer,
batch_norm_decay=batch_norm_decay,
batch_norm_epsilon=batch_norm_epsilon)
x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x)
x = layers.Conv2D(
64, (7, 7),
Expand All @@ -272,113 +284,33 @@ def resnet50(num_classes,
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name='bn_conv1')(
x)
x = layers.Activation('relu')(x)
x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)

x = conv_block(
x,
3, [64, 64, 256],
stage=2,
block='a',
strides=(1, 1),
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [64, 64, 256],
stage=2,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [64, 64, 256],
stage=2,
block='c',
use_l2_regularizer=use_l2_regularizer)

x = conv_block(
x,
3, [128, 128, 512],
stage=3,
block='a',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [128, 128, 512],
stage=3,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [128, 128, 512],
stage=3,
block='c',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [128, 128, 512],
stage=3,
block='d',
use_l2_regularizer=use_l2_regularizer)

x = conv_block(
x,
3, [256, 256, 1024],
stage=4,
block='a',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='c',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='d',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='e',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [256, 256, 1024],
stage=4,
block='f',
use_l2_regularizer=use_l2_regularizer)

x = conv_block(
x,
3, [512, 512, 2048],
stage=5,
block='a',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [512, 512, 2048],
stage=5,
block='b',
use_l2_regularizer=use_l2_regularizer)
x = identity_block(
x,
3, [512, 512, 2048],
stage=5,
block='c',
use_l2_regularizer=use_l2_regularizer)
x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1), **block_config)
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b', **block_config)
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c', **block_config)

x = conv_block(x, 3, [128, 128, 512], stage=3, block='a', **block_config)
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b', **block_config)
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c', **block_config)
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d', **block_config)

x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a', **block_config)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b', **block_config)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c', **block_config)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d', **block_config)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e', **block_config)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f', **block_config)

x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a', **block_config)
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b', **block_config)
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c', **block_config)

x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ def step_fn(inputs):
loss = tf.reduce_sum(prediction_loss) * (1.0 /
self.flags_obj.batch_size)
num_replicas = self.strategy.num_replicas_in_sync

l2_weight_decay = 1e-4
if self.flags_obj.single_l2_loss_op:
l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.add_n([
l2_loss = l2_weight_decay * 2 * tf.add_n([
tf.nn.l2_loss(v)
for v in self.model.trainable_variables
if 'bn' not in v.name
Expand Down

0 comments on commit 7a25758

Please sign in to comment.