Skip to content

Commit

Permalink
Check if there are GPUs instead of if TF is built with CUDA support.
Browse files Browse the repository at this point in the history
The TF pip packages are always built with CUDA support, so tf.test.is_built_with_cuda() would return True even if the user had no GPU.

PiperOrigin-RevId: 302928378
  • Loading branch information
reedwm authored and tensorflower-gardener committed Mar 25, 2020
1 parent 7c83a9d commit ad09cf4
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 11 deletions.
4 changes: 2 additions & 2 deletions official/benchmark/models/resnet_cifar_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ def run(flags_obj):

data_format = flags_obj.data_format
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')
data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
else 'channels_last')
tf.keras.backend.set_image_data_format(data_format)

strategy = distribution_utils.get_distribution_strategy(
Expand Down
2 changes: 1 addition & 1 deletion official/modeling/model_training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def _run_callbacks_on_batch_end(batch, logs):
# Runs several steps in the host while loop.
steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop)

if tf.test.is_built_with_cuda():
if tf.config.list_physical_devices('GPU'):
# TODO(zongweiz): merge with train_steps once tf.while_loop
# GPU performance bugs are fixed.
for _ in range(steps):
Expand Down
4 changes: 2 additions & 2 deletions official/r1/mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ def run_mnist(flags_obj):

data_format = flags_obj.data_format
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')
data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
else 'channels_last')
mnist_classifier = tf.estimator.Estimator(
model_fn=model_function,
model_dir=flags_obj.model_dir,
Expand Down
4 changes: 2 additions & 2 deletions official/r1/resnet/resnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,8 +391,8 @@ def __init__(self, resnet_size, bottleneck, num_classes, num_filters,
self.resnet_size = resnet_size

if not data_format:
data_format = (
'channels_first' if tf.test.is_built_with_cuda() else 'channels_last')
data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
else 'channels_last')

self.resnet_version = resnet_version
if resnet_version not in (1, 2):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ def run(flags_obj):
# TODO(anj-s): Set data_format without using Keras.
data_format = flags_obj.data_format
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')
data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
else 'channels_last')
tf.keras.backend.set_image_data_format(data_format)

strategy = distribution_utils.get_distribution_strategy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def run(flags_obj):

data_format = flags_obj.data_format
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')
data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
else 'channels_last')
tf.keras.backend.set_image_data_format(data_format)

# Configures cluster spec for distribution strategy.
Expand Down

0 comments on commit ad09cf4

Please sign in to comment.