Skip to content

Commit

Permalink
Move a R1 specific util function from common utils to R1 models.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 303767122
  • Loading branch information
saberkun authored and tensorflower-gardener committed Mar 30, 2020
1 parent 01d1931 commit fc02382
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 51 deletions.
35 changes: 33 additions & 2 deletions official/r1/resnet/resnet_run_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,37 @@ def poly_rate_fn(global_step):
return learning_rate_fn


def per_replica_batch_size(batch_size, num_gpus):
"""For multi-gpu, batch-size must be a multiple of the number of GPUs.
Note that distribution strategy handles this automatically when used with
Keras. For using with Estimator, we need to get per GPU batch.
Args:
batch_size: Global batch size to be divided among devices. This should be
equal to num_gpus times the single-GPU batch_size for multi-gpu training.
num_gpus: How many GPUs are used with DistributionStrategies.
Returns:
Batch size per device.
Raises:
ValueError: if batch_size is not divisible by number of devices
"""
if num_gpus <= 1:
return batch_size

remainder = batch_size % num_gpus
if remainder:
err = ('When running with multiple GPUs, batch size '
'must be a multiple of the number of available GPUs. Found {} '
'GPUs with a batch size of {}; try --batch_size={} instead.'
).format(num_gpus, batch_size, batch_size - remainder)
raise ValueError(err)
return int(batch_size / num_gpus)


def resnet_model_fn(features, labels, mode, model_class,
resnet_size, weight_decay, learning_rate_fn, momentum,
data_format, resnet_version, loss_scale,
Expand Down Expand Up @@ -620,7 +651,7 @@ def input_fn_train(num_epochs, input_context=None):
return input_function(
is_training=True,
data_dir=flags_obj.data_dir,
batch_size=distribution_utils.per_replica_batch_size(
batch_size=per_replica_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
num_epochs=num_epochs,
dtype=flags_core.get_tf_dtype(flags_obj),
Expand All @@ -631,7 +662,7 @@ def input_fn_eval():
return input_function(
is_training=False,
data_dir=flags_obj.data_dir,
batch_size=distribution_utils.per_replica_batch_size(
batch_size=per_replica_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
num_epochs=1,
dtype=flags_core.get_tf_dtype(flags_obj))
Expand Down
34 changes: 32 additions & 2 deletions official/r1/transformer/transformer_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,36 @@ def construct_estimator(flags_obj, params, schedule_manager):
},
config=run_config)

def per_replica_batch_size(batch_size, num_gpus):
"""For multi-gpu, batch-size must be a multiple of the number of GPUs.
Note that distribution strategy handles this automatically when used with
Keras. For using with Estimator, we need to get per GPU batch.
Args:
batch_size: Global batch size to be divided among devices. This should be
equal to num_gpus times the single-GPU batch_size for multi-gpu training.
num_gpus: How many GPUs are used with DistributionStrategies.
Returns:
Batch size per device.
Raises:
ValueError: if batch_size is not divisible by number of devices
"""
if num_gpus <= 1:
return batch_size

remainder = batch_size % num_gpus
if remainder:
err = ('When running with multiple GPUs, batch size '
'must be a multiple of the number of available GPUs. Found {} '
'GPUs with a batch size of {}; try --batch_size={} instead.'
).format(num_gpus, batch_size, batch_size - remainder)
raise ValueError(err)
return int(batch_size / num_gpus)


def run_transformer(flags_obj):
"""Create tf.Estimator to train and evaluate transformer model.
Expand Down Expand Up @@ -605,8 +635,8 @@ def run_transformer(flags_obj):

total_batch_size = params["batch_size"]
if not params["use_tpu"]:
params["batch_size"] = distribution_utils.per_replica_batch_size(
params["batch_size"], num_gpus)
params["batch_size"] = per_replica_batch_size(params["batch_size"],
num_gpus)

schedule_manager = schedule.Manager(
train_steps=flags_obj.train_steps,
Expand Down
31 changes: 0 additions & 31 deletions official/utils/misc/distribution_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,37 +157,6 @@ def get_distribution_strategy(distribution_strategy="mirrored",
"Unrecognized Distribution Strategy: %r" % distribution_strategy)


def per_replica_batch_size(batch_size, num_gpus):
"""For multi-gpu, batch-size must be a multiple of the number of GPUs.
Note that distribution strategy handles this automatically when used with
Keras. For using with Estimator, we need to get per GPU batch.
Args:
batch_size: Global batch size to be divided among devices. This should be
equal to num_gpus times the single-GPU batch_size for multi-gpu training.
num_gpus: How many GPUs are used with DistributionStrategies.
Returns:
Batch size per device.
Raises:
ValueError: if batch_size is not divisible by number of devices
"""
if num_gpus <= 1:
return batch_size

remainder = batch_size % num_gpus
if remainder:
err = ('When running with multiple GPUs, batch size '
'must be a multiple of the number of available GPUs. Found {} '
'GPUs with a batch size of {}; try --batch_size={} instead.'
).format(num_gpus, batch_size, batch_size - remainder)
raise ValueError(err)
return int(batch_size / num_gpus)


# The `SyntheticDataset` is a temporary solution for generating synthetic data
# directly on devices. It is only useful for Keras with Distribution
# Strategies. We will have better support in `tf.data` or Distribution Strategy
Expand Down
16 changes: 0 additions & 16 deletions official/utils/misc/distribution_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,5 @@ def test_mirrored_strategy(self):
self.assertIn('GPU', device)


class PerReplicaBatchSizeTest(tf.test.TestCase):
"""Tests for per_replica_batch_size."""

def test_batch_size(self):
self.assertEquals(
distribution_utils.per_replica_batch_size(147, num_gpus=0), 147)
self.assertEquals(
distribution_utils.per_replica_batch_size(147, num_gpus=1), 147)
self.assertEquals(
distribution_utils.per_replica_batch_size(147, num_gpus=7), 21)

def test_batch_size_with_remainder(self):
with self.assertRaises(ValueError):
distribution_utils.per_replica_batch_size(147, num_gpus=5)


if __name__ == "__main__":
tf.test.main()

0 comments on commit fc02382

Please sign in to comment.