From a9fce8c12fec5cac6c597a9c778cdf40c26e7c36 Mon Sep 17 00:00:00 2001 From: polin Date: Fri, 29 May 2020 15:57:36 +0800 Subject: [PATCH] fix issue on geforce GPU --- bit_tf2/train.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/bit_tf2/train.py b/bit_tf2/train.py index 5292887..f95d7f7 100644 --- a/bit_tf2/train.py +++ b/bit_tf2/train.py @@ -27,6 +27,10 @@ import bit_tf2.models as models import input_pipeline_tf2_or_jax as input_pipeline +# set_memory_growth=True to prevent OOM error +gpu = tf.config.experimental.list_physical_devices('GPU') +tf.config.experimental.set_memory_growth(gpu[0], True) + def reshape_for_keras(features, batch_size, crop_size): features["image"] = tf.reshape(features["image"], (batch_size, crop_size, crop_size, 3)) @@ -140,7 +144,7 @@ def main(args): for epoch, accu in enumerate(history.history['val_accuracy']): logger.info( - f'Step: {epoch * args.eval_every}, ' + f'Step: {epoch * steps_per_epoch}, ' f'Test accuracy: {accu:0.3f}')