From 54602a66490de1f5a48190bf905d9d3aff2a858f Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Wed, 25 Mar 2020 11:52:47 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 302937425 --- .../vision/image_classification/README.md | 214 ++-- .../vision/image_classification/augment.py | 1002 +++++++++++++++++ .../image_classification/augment_test.py | 137 +++ .../vision/image_classification/callbacks.py | 136 +++ .../image_classification/callbacks_test.py | 86 ++ .../classifier_trainer.py | 427 +++++++ .../classifier_trainer_test.py | 317 ++++++ .../image_classification/configs/__init__.py | 14 + .../configs/base_configs.py | 223 ++++ .../image_classification/configs/configs.py | 121 ++ .../imagenet/efficientnet-b0-gpu.yaml | 51 + .../imagenet/efficientnet-b0-tpu.yaml | 52 + .../imagenet/efficientnet-b1-gpu.yaml | 44 + .../imagenet/efficientnet-b1-tpu.yaml | 49 + .../configs/examples/resnet/imagenet/gpu.yaml | 53 + .../configs/examples/resnet/imagenet/tpu.yaml | 58 + .../image_classification/dataset_factory.py | 476 ++++++++ .../efficientnet/__init__.py | 0 .../efficientnet/common_modules.py | 100 ++ .../efficientnet/efficientnet_config.py | 75 ++ .../efficientnet/efficientnet_model.py | 503 +++++++++ .../image_classification/learning_rate.py | 120 ++ .../learning_rate_test.py | 90 ++ .../image_classification/optimizer_factory.py | 161 +++ .../optimizer_factory_test.py | 115 ++ .../image_classification/preprocessing.py | 391 +++++++ .../image_classification/resnet/README.md | 129 +++ 27 files changed, 5010 insertions(+), 134 deletions(-) create mode 100644 official/vision/image_classification/augment.py create mode 100644 official/vision/image_classification/augment_test.py create mode 100644 official/vision/image_classification/callbacks.py create mode 100644 official/vision/image_classification/callbacks_test.py create mode 100644 official/vision/image_classification/classifier_trainer.py create mode 100644 official/vision/image_classification/classifier_trainer_test.py create mode 100644 official/vision/image_classification/configs/__init__.py create mode 100644 official/vision/image_classification/configs/base_configs.py create mode 100644 official/vision/image_classification/configs/configs.py create mode 100644 official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-gpu.yaml create mode 100644 official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-tpu.yaml create mode 100644 official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-gpu.yaml create mode 100644 official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-tpu.yaml create mode 100644 official/vision/image_classification/configs/examples/resnet/imagenet/gpu.yaml create mode 100644 official/vision/image_classification/configs/examples/resnet/imagenet/tpu.yaml create mode 100644 official/vision/image_classification/dataset_factory.py create mode 100644 official/vision/image_classification/efficientnet/__init__.py create mode 100644 official/vision/image_classification/efficientnet/common_modules.py create mode 100644 official/vision/image_classification/efficientnet/efficientnet_config.py create mode 100644 official/vision/image_classification/efficientnet/efficientnet_model.py create mode 100644 official/vision/image_classification/learning_rate.py create mode 100644 official/vision/image_classification/learning_rate_test.py create mode 100644 official/vision/image_classification/optimizer_factory.py create mode 100644 official/vision/image_classification/optimizer_factory_test.py create mode 100644 official/vision/image_classification/preprocessing.py create mode 100644 official/vision/image_classification/resnet/README.md diff --git a/official/vision/image_classification/README.md b/official/vision/image_classification/README.md index 8b4e2d13422..b5958cb0418 100644 --- a/official/vision/image_classification/README.md +++ b/official/vision/image_classification/README.md @@ -1,190 +1,136 @@ # Image Classification -This folder contains the TF 2.0 model examples for image classification: +This folder contains TF 2.0 model examples for image classification: -* [ResNet](#resnet) * [MNIST](#mnist) +* [Classifier Trainer](#classifier-trainer), a framework that uses the Keras +compile/fit methods for image classification models, including: + * ResNet + * EfficientNet[^1] +[^1]: Currently a work in progress. We cannot match "AutoAugment (AA)" in [the original version](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet). For more information about other types of models, please refer to this [README file](../../README.md). -## ResNet - -Similar to the [estimator implementation](../../r1/resnet), the Keras -implementation has code for the ImageNet dataset. The ImageNet -version uses a ResNet50 model implemented in -[`resnet_model.py`](./resnet/resnet_model.py). - +## Before you begin Please make sure that you have the latest version of TensorFlow installed and [add the models folder to your Python path](/official/#running-the-models). -### Pretrained Models - -* [ResNet50 Checkpoints](https://storage.googleapis.com/cloud-tpu-checkpoints/resnet/resnet50.tar.gz) - -* ResNet50 TFHub: [feature vector](https://tfhub.dev/tensorflow/resnet_50/feature_vector/1) -and [classification](https://tfhub.dev/tensorflow/resnet_50/classification/1) - -### ImageNet Training +### ImageNet preparation Download the ImageNet dataset and convert it to TFRecord format. The following [script](https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py) and [README](https://github.com/tensorflow/tpu/tree/master/tools/datasets#imagenet_to_gcspy) provide a few options. -Once your dataset is ready, you can begin training the model as follows: - -```bash -python resnet/resnet_imagenet_main.py -``` - -Again, if you did not download the data to the default directory, specify the -location with the `--data_dir` flag: - -```bash -python resnet/resnet_imagenet_main.py --data_dir=/path/to/imagenet -``` - -There are more flag options you can specify. Here are some examples: - -- `--use_synthetic_data`: when set to true, synthetic data, rather than real -data, are used; -- `--batch_size`: the batch size used for the model; -- `--model_dir`: the directory to save the model checkpoint; -- `--train_epochs`: number of epoches to run for training the model; -- `--train_steps`: number of steps to run for training the model. We now only -support a number that is smaller than the number of batches in an epoch. -- `--skip_eval`: when set to true, evaluation as well as validation during -training is skipped - -For example, this is a typical command line to run with ImageNet data with -batch size 128 per GPU: - -```bash -python -m resnet/resnet_imagenet_main.py \ - --model_dir=/tmp/model_dir/something \ - --num_gpus=2 \ - --batch_size=128 \ - --train_epochs=90 \ - --train_steps=10 \ - --use_synthetic_data=false -``` - -See [`common.py`](common.py) for full list of options. - -### Using multiple GPUs - -You can train these models on multiple GPUs using `tf.distribute.Strategy` API. -You can read more about them in this -[guide](https://www.tensorflow.org/guide/distribute_strategy). - -In this example, we have made it easier to use is with just a command line flag -`--num_gpus`. By default this flag is 1 if TensorFlow is compiled with CUDA, -and 0 otherwise. - -- --num_gpus=0: Uses tf.distribute.OneDeviceStrategy with CPU as the device. -- --num_gpus=1: Uses tf.distribute.OneDeviceStrategy with GPU as the device. -- --num_gpus=2+: Uses tf.distribute.MirroredStrategy to run synchronous -distributed training across the GPUs. - -If you wish to run without `tf.distribute.Strategy`, you can do so by setting -`--distribution_strategy=off`. - -### Running on multiple GPU hosts - -You can also train these models on multiple hosts, each with GPUs, using -`tf.distribute.Strategy`. - -The easiest way to run multi-host benchmarks is to set the -[`TF_CONFIG`](https://www.tensorflow.org/guide/distributed_training#TF_CONFIG) -appropriately at each host. e.g., to run using `MultiWorkerMirroredStrategy` on -2 hosts, the `cluster` in `TF_CONFIG` should have 2 `host:port` entries, and -host `i` should have the `task` in `TF_CONFIG` set to `{"type": "worker", -"index": i}`. `MultiWorkerMirroredStrategy` will automatically use all the -available GPUs at each host. - ### Running on Cloud TPUs -Note: This model will **not** work with TPUs on Colab. +Note: These models will **not** work with TPUs on Colab. -You can train the ResNet CTL model on Cloud TPUs using +You can train image classification models on Cloud TPUs using `tf.distribute.TPUStrategy`. If you are not familiar with Cloud TPUs, it is strongly recommended that you go through the [quickstart](https://cloud.google.com/tpu/docs/quickstart) to learn how to create a TPU and GCE VM. -To run ResNet model on a TPU, you must set `--distribution_strategy=tpu` and -`--tpu=$TPU_NAME`, where `$TPU_NAME` the name of your TPU in the Cloud Console. -From a GCE VM, you can run the following command to train ResNet for one epoch -on a v2-8 or v3-8 TPU: +## MNIST + +To download the data and run the MNIST sample model locally for the first time, +run one of the following command: ```bash -python resnet/resnet_ctl_imagenet_main.py \ - --tpu=$TPU_NAME \ +python3 mnist_main.py \ --model_dir=$MODEL_DIR \ --data_dir=$DATA_DIR \ - --batch_size=1024 \ - --steps_per_loop=500 \ - --train_epochs=1 \ - --use_synthetic_data=false \ - --dtype=fp32 \ - --enable_eager=true \ - --enable_tensorboard=true \ - --distribution_strategy=tpu \ - --log_steps=50 \ - --single_l2_loss_op=true \ - --use_tf_function=true + --train_epochs=10 \ + --distribution_strategy=one_device \ + --num_gpus=$NUM_GPUS \ + --download ``` -To train the ResNet to convergence, run it for 90 epochs: +To train the model on a Cloud TPU, run the following command: ```bash -python resnet/resnet_ctl_imagenet_main.py \ +python3 mnist_main.py \ --tpu=$TPU_NAME \ --model_dir=$MODEL_DIR \ --data_dir=$DATA_DIR \ - --batch_size=1024 \ - --steps_per_loop=500 \ - --train_epochs=90 \ - --use_synthetic_data=false \ - --dtype=fp32 \ - --enable_eager=true \ - --enable_tensorboard=true \ + --train_epochs=10 \ --distribution_strategy=tpu \ - --log_steps=50 \ - --single_l2_loss_op=true \ - --use_tf_function=true + --download ``` -Note: `$MODEL_DIR` and `$DATA_DIR` must be GCS paths. +Note: the `--download` flag is only required the first time you run the model. -## MNIST +## Classifier Trainer +The classifier trainer is a unified framework for running image classification +models using Keras's compile/fit methods. Experiments should be provided in the +form of YAML files, some examples are included within the configs/examples +folder. Please see [configs/examples](./configs/examples) for more example +configurations. -To download the data and run the MNIST sample model locally for the first time, -run one of the following command: +The provided configuration files use a per replica batch size and is scaled +by the number of devices. For instance, if `batch size` = 64, then for 1 GPU +the global batch size would be 64 * 1 = 64. For 8 GPUs, the global batch size +would be 64 * 8 = 512. Similarly, for a v3-8 TPU, the global batch size would +be 64 * 8 = 512, and for a v3-32, the global batch size is 64 * 32 = 2048. +### ResNet50 + +#### On GPU: ```bash -python mnist_main.py \ +python3 classifier_trainer.py \ + --mode=train_and_eval \ + --model_type=resnet \ + --dataset=imagenet \ --model_dir=$MODEL_DIR \ --data_dir=$DATA_DIR \ - --train_epochs=10 \ - --distribution_strategy=one_device \ - --num_gpus=$NUM_GPUS \ - --download + --config_file=configs/examples/resnet/imagenet/gpu.yaml \ + --params_override='runtime.num_gpus=$NUM_GPUS' ``` -To train the model on a Cloud TPU, run the following command: +#### On TPU: +```bash +python3 classifier_trainer.py \ + --mode=train_and_eval \ + --model_type=resnet \ + --dataset=imagenet \ + --tpu=$TPU_NAME \ + --model_dir=$MODEL_DIR \ + --data_dir=$DATA_DIR \ + --config_file=config/examples/resnet/imagenet/tpu.yaml +``` + +### EfficientNet +**Note: EfficientNet development is a work in progress.** +#### On GPU: +```bash +python3 classifier_trainer.py \ + --mode=train_and_eval \ + --model_type=efficientnet \ + --dataset=imagenet \ + --model_dir=$MODEL_DIR \ + --data_dir=$DATA_DIR \ + --config_file=configs/examples/efficientnet/imagenet/efficientnet-b0-gpu.yaml \ + --params_override='runtime.num_gpus=$NUM_GPUS' +``` + +#### On TPU: ```bash -python mnist_main.py \ +python3 classifier_trainer.py \ + --mode=train_and_eval \ + --model_type=efficientnet \ + --dataset=imagenet \ --tpu=$TPU_NAME \ --model_dir=$MODEL_DIR \ --data_dir=$DATA_DIR \ - --train_epochs=10 \ - --distribution_strategy=tpu \ - --download + --config_file=config/examples/efficientnet/imagenet/efficientnet-b0-tpu.yaml ``` -Note: the `--download` flag is only required the first time you run the model. +Note that the number of GPU devices can be overridden in the command line using +`--params_overrides`. The TPU does not need this override as the device is fixed +by providing the TPU address or name with the `--tpu` flag. + diff --git a/official/vision/image_classification/augment.py b/official/vision/image_classification/augment.py new file mode 100644 index 00000000000..a71c8d00832 --- /dev/null +++ b/official/vision/image_classification/augment.py @@ -0,0 +1,1002 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""AutoAugment and RandAugment policies for enhanced image preprocessing. + +AutoAugment Reference: https://arxiv.org/abs/1805.09501 +RandAugment Reference: https://arxiv.org/abs/1909.13719 +""" + +from __future__ import absolute_import +from __future__ import division +# from __future__ import google_type_annotations +from __future__ import print_function + +import math +import tensorflow.compat.v2 as tf +from typing import Any, Dict, Iterable, List, Optional, Text, Tuple, Union + +from tensorflow.python.keras.layers.preprocessing import image_preprocessing as image_ops + +# This signifies the max integer that the controller RNN could predict for the +# augmentation scheme. +_MAX_LEVEL = 10. + + +def to_4d(image: tf.Tensor) -> tf.Tensor: + """Converts an input Tensor to 4 dimensions. + + 4D image => [N, H, W, C] or [N, C, H, W] + 3D image => [1, H, W, C] or [1, C, H, W] + 2D image => [1, H, W, 1] + + Args: + image: The 2/3/4D input tensor. + + Returns: + A 4D image tensor. + + Raises: + `TypeError` if `image` is not a 2/3/4D tensor. + + """ + shape = tf.shape(image) + original_rank = tf.rank(image) + left_pad = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) + right_pad = tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) + new_shape = tf.concat( + [ + tf.ones(shape=left_pad, dtype=tf.int32), + shape, + tf.ones(shape=right_pad, dtype=tf.int32), + ], + axis=0, + ) + return tf.reshape(image, new_shape) + + +def from_4d(image: tf.Tensor, ndims: int) -> tf.Tensor: + """Converts a 4D image back to `ndims` rank.""" + shape = tf.shape(image) + begin = tf.cast(tf.less_equal(ndims, 3), dtype=tf.int32) + end = 4 - tf.cast(tf.equal(ndims, 2), dtype=tf.int32) + new_shape = shape[begin:end] + return tf.reshape(image, new_shape) + + +def _convert_translation_to_transform( + translations: Iterable[int]) -> tf.Tensor: + """Converts translations to a projective transform. + + The translation matrix looks like this: + [[1 0 -dx] + [0 1 -dy] + [0 0 1]] + + Args: + translations: The 2-element list representing [dx, dy], or a matrix of + 2-element lists representing [dx dy] to translate for each image. The + shape must be static. + + Returns: + The transformation matrix of shape (num_images, 8). + + Raises: + `TypeError` if + - the shape of `translations` is not known or + - the shape of `translations` is not rank 1 or 2. + + """ + translations = tf.convert_to_tensor(translations, dtype=tf.float32) + if translations.get_shape().ndims is None: + raise TypeError('translations rank must be statically known') + elif len(translations.get_shape()) == 1: + translations = translations[None] + elif len(translations.get_shape()) != 2: + raise TypeError('translations should have rank 1 or 2.') + num_translations = tf.shape(translations)[0] + + return tf.concat( + values=[ + tf.ones((num_translations, 1), tf.dtypes.float32), + tf.zeros((num_translations, 1), tf.dtypes.float32), + -translations[:, 0, None], + tf.zeros((num_translations, 1), tf.dtypes.float32), + tf.ones((num_translations, 1), tf.dtypes.float32), + -translations[:, 1, None], + tf.zeros((num_translations, 2), tf.dtypes.float32), + ], + axis=1, + ) + + +def _convert_angles_to_transform( + angles: Union[Iterable[float], float], + image_width: int, + image_height: int) -> tf.Tensor: + """Converts an angle or angles to a projective transform. + + Args: + angles: A scalar to rotate all images, or a vector to rotate a batch of + images. This must be a scalar. + image_width: The width of the image(s) to be transformed. + image_height: The height of the image(s) to be transformed. + + Returns: + A tensor of shape (num_images, 8). + + Raises: + `TypeError` if `angles` is not rank 0 or 1. + + """ + angles = tf.convert_to_tensor(angles, dtype=tf.float32) + if len(angles.get_shape()) == 0: # pylint:disable=g-explicit-length-test + angles = angles[None] + elif len(angles.get_shape()) != 1: + raise TypeError('Angles should have a rank 0 or 1.') + x_offset = ((image_width - 1) - + (tf.math.cos(angles) * (image_width - 1) - tf.math.sin(angles) * + (image_height - 1))) / 2.0 + y_offset = ((image_height - 1) - + (tf.math.sin(angles) * (image_width - 1) + tf.math.cos(angles) * + (image_height - 1))) / 2.0 + num_angles = tf.shape(angles)[0] + return tf.concat( + values=[ + tf.math.cos(angles)[:, None], + -tf.math.sin(angles)[:, None], + x_offset[:, None], + tf.math.sin(angles)[:, None], + tf.math.cos(angles)[:, None], + y_offset[:, None], + tf.zeros((num_angles, 2), tf.dtypes.float32), + ], + axis=1, + ) + + +def transform(image: tf.Tensor, + transforms: Iterable[float]) -> tf.Tensor: + """Prepares input data for `image_ops.transform`.""" + original_ndims = tf.rank(image) + transforms = tf.convert_to_tensor(transforms, dtype=tf.float32) + if len(tf.shape(transforms)) == 1: + transforms = transforms[None] + image = to_4d(image) + image = image_ops.transform( + images=image, + transforms=transforms, + interpolation='nearest') + return from_4d(image, original_ndims) + + +def translate(image: tf.Tensor, + translations: Iterable[int]) -> tf.Tensor: + """Translates image(s) by provided vectors. + + Args: + image: An image Tensor of type uint8. + translations: A vector or matrix representing [dx dy]. + + Returns: + The translated version of the image. + + """ + transforms = _convert_translation_to_transform(translations) + return transform(image, transforms=transforms) + + +def rotate(image: tf.Tensor, degrees: float) -> tf.Tensor: + """Rotates the image by degrees either clockwise or counterclockwise. + + Args: + image: An image Tensor of type uint8. + degrees: Float, a scalar angle in degrees to rotate all images by. If + degrees is positive the image will be rotated clockwise otherwise it will + be rotated counterclockwise. + + Returns: + The rotated version of image. + + """ + # Convert from degrees to radians. + degrees_to_radians = math.pi / 180.0 + radians = degrees * degrees_to_radians + + original_ndims = tf.rank(image) + image = to_4d(image) + + image_height = tf.cast(tf.shape(image)[1], tf.float32) + image_width = tf.cast(tf.shape(image)[2], tf.float32) + transforms = _convert_angles_to_transform(angles=radians, + image_width=image_width, + image_height=image_height) + # In practice, we should randomize the rotation degrees by flipping + # it negatively half the time, but that's done on 'degrees' outside + # of the function. + image = transform(image, transforms=transforms) + return from_4d(image, original_ndims) + + +def blend(image1: tf.Tensor, image2: tf.Tensor, factor: float) -> tf.Tensor: + """Blend image1 and image2 using 'factor'. + + Factor can be above 0.0. A value of 0.0 means only image1 is used. + A value of 1.0 means only image2 is used. A value between 0.0 and + 1.0 means we linearly interpolate the pixel values between the two + images. A value greater than 1.0 "extrapolates" the difference + between the two pixel values, and we clip the results to values + between 0 and 255. + + Args: + image1: An image Tensor of type uint8. + image2: An image Tensor of type uint8. + factor: A floating point value above 0.0. + + Returns: + A blended image Tensor of type uint8. + """ + if factor == 0.0: + return tf.convert_to_tensor(image1) + if factor == 1.0: + return tf.convert_to_tensor(image2) + + image1 = tf.cast(image1, tf.float32) + image2 = tf.cast(image2, tf.float32) + + difference = image2 - image1 + scaled = factor * difference + + # Do addition in float. + temp = tf.cast(image1, tf.float32) + scaled + + # Interpolate + if factor > 0.0 and factor < 1.0: + # Interpolation means we always stay within 0 and 255. + return tf.cast(temp, tf.uint8) + + # Extrapolate: + # + # We need to clip and then cast. + return tf.cast(tf.clip_by_value(temp, 0.0, 255.0), tf.uint8) + + +def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor: + """Apply cutout (https://arxiv.org/abs/1708.04552) to image. + + This operation applies a (2*pad_size x 2*pad_size) mask of zeros to + a random location within `img`. The pixel values filled in will be of the + value `replace`. The located where the mask will be applied is randomly + chosen uniformly over the whole image. + + Args: + image: An image Tensor of type uint8. + pad_size: Specifies how big the zero mask that will be generated is that + is applied to the image. The mask will be of size + (2*pad_size x 2*pad_size). + replace: What pixel value to fill in the image in the area that has + the cutout mask applied to it. + + Returns: + An image Tensor that is of type uint8. + """ + image_height = tf.shape(image)[0] + image_width = tf.shape(image)[1] + + # Sample the center location in the image where the zero mask will be applied. + cutout_center_height = tf.random.uniform( + shape=[], minval=0, maxval=image_height, + dtype=tf.int32) + + cutout_center_width = tf.random.uniform( + shape=[], minval=0, maxval=image_width, + dtype=tf.int32) + + lower_pad = tf.maximum(0, cutout_center_height - pad_size) + upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size) + left_pad = tf.maximum(0, cutout_center_width - pad_size) + right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size) + + cutout_shape = [image_height - (lower_pad + upper_pad), + image_width - (left_pad + right_pad)] + padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]] + mask = tf.pad( + tf.zeros(cutout_shape, dtype=image.dtype), + padding_dims, constant_values=1) + mask = tf.expand_dims(mask, -1) + mask = tf.tile(mask, [1, 1, 3]) + image = tf.where( + tf.equal(mask, 0), + tf.ones_like(image, dtype=image.dtype) * replace, + image) + return image + + +def solarize(image: tf.Tensor, threshold: int = 128) -> tf.Tensor: + # For each pixel in the image, select the pixel + # if the value is less than the threshold. + # Otherwise, subtract 255 from the pixel. + return tf.where(image < threshold, image, 255 - image) + + +def solarize_add(image: tf.Tensor, + addition: int = 0, + threshold: int = 128) -> tf.Tensor: + # For each pixel in the image less than threshold + # we add 'addition' amount to it and then clip the + # pixel value to be between 0 and 255. The value + # of 'addition' is between -128 and 128. + added_image = tf.cast(image, tf.int64) + addition + added_image = tf.cast(tf.clip_by_value(added_image, 0, 255), tf.uint8) + return tf.where(image < threshold, added_image, image) + + +def color(image: tf.Tensor, factor: float) -> tf.Tensor: + """Equivalent of PIL Color.""" + degenerate = tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image)) + return blend(degenerate, image, factor) + + +def contrast(image: tf.Tensor, factor: float) -> tf.Tensor: + """Equivalent of PIL Contrast.""" + degenerate = tf.image.rgb_to_grayscale(image) + # Cast before calling tf.histogram. + degenerate = tf.cast(degenerate, tf.int32) + + # Compute the grayscale histogram, then compute the mean pixel value, + # and create a constant image size of that value. Use that as the + # blending degenerate target of the original image. + hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256) + mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0 + degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean + degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) + degenerate = tf.image.grayscale_to_rgb(tf.cast(degenerate, tf.uint8)) + return blend(degenerate, image, factor) + + +def brightness(image: tf.Tensor, factor: float) -> tf.Tensor: + """Equivalent of PIL Brightness.""" + degenerate = tf.zeros_like(image) + return blend(degenerate, image, factor) + + +def posterize(image: tf.Tensor, bits: int) -> tf.Tensor: + """Equivalent of PIL Posterize.""" + shift = 8 - bits + return tf.bitwise.left_shift(tf.bitwise.right_shift(image, shift), shift) + + +def wrapped_rotate(image: tf.Tensor, degrees: float, replace: int) -> tf.Tensor: + """Applies rotation with wrap/unwrap.""" + image = rotate(wrap(image), degrees=degrees) + return unwrap(image, replace) + + +def translate_x(image: tf.Tensor, pixels: int, replace: int) -> tf.Tensor: + """Equivalent of PIL Translate in X dimension.""" + image = translate(wrap(image), [-pixels, 0]) + return unwrap(image, replace) + + +def translate_y(image: tf.Tensor, pixels: int, replace: int) -> tf.Tensor: + """Equivalent of PIL Translate in Y dimension.""" + image = translate(wrap(image), [0, -pixels]) + return unwrap(image, replace) + + +def shear_x(image: tf.Tensor, level: float, replace: int) -> tf.Tensor: + """Equivalent of PIL Shearing in X dimension.""" + # Shear parallel to x axis is a projective transform + # with a matrix form of: + # [1 level + # 0 1]. + image = transform(image=wrap(image), + transforms=[1., level, 0., 0., 1., 0., 0., 0.]) + return unwrap(image, replace) + + +def shear_y(image: tf.Tensor, level: float, replace: int) -> tf.Tensor: + """Equivalent of PIL Shearing in Y dimension.""" + # Shear parallel to y axis is a projective transform + # with a matrix form of: + # [1 0 + # level 1]. + image = transform(image=wrap(image), + transforms=[1., 0., 0., level, 1., 0., 0., 0.]) + return unwrap(image, replace) + + +def autocontrast(image: tf.Tensor) -> tf.Tensor: + """Implements Autocontrast function from PIL using TF ops. + + Args: + image: A 3D uint8 tensor. + + Returns: + The image after it has had autocontrast applied to it and will be of type + uint8. + """ + + def scale_channel(image: tf.Tensor) -> tf.Tensor: + """Scale the 2D image using the autocontrast rule.""" + # A possibly cheaper version can be done using cumsum/unique_with_counts + # over the histogram values, rather than iterating over the entire image. + # to compute mins and maxes. + lo = tf.cast(tf.reduce_min(image), tf.float32) + hi = tf.cast(tf.reduce_max(image), tf.float32) + + # Scale the image, making the lowest value 0 and the highest value 255. + def scale_values(im): + scale = 255.0 / (hi - lo) + offset = -lo * scale + im = tf.cast(im, tf.float32) * scale + offset + im = tf.clip_by_value(im, 0.0, 255.0) + return tf.cast(im, tf.uint8) + + result = tf.cond(hi > lo, lambda: scale_values(image), lambda: image) + return result + + # Assumes RGB for now. Scales each channel independently + # and then stacks the result. + s1 = scale_channel(image[:, :, 0]) + s2 = scale_channel(image[:, :, 1]) + s3 = scale_channel(image[:, :, 2]) + image = tf.stack([s1, s2, s3], 2) + return image + + +def sharpness(image: tf.Tensor, factor: float) -> tf.Tensor: + """Implements Sharpness function from PIL using TF ops.""" + orig_image = image + image = tf.cast(image, tf.float32) + # Make image 4D for conv operation. + image = tf.expand_dims(image, 0) + # SMOOTH PIL Kernel. + kernel = tf.constant( + [[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32, + shape=[3, 3, 1, 1]) / 13. + # Tile across channel dimension. + kernel = tf.tile(kernel, [1, 1, 3, 1]) + strides = [1, 1, 1, 1] + degenerate = tf.nn.depthwise_conv2d( + image, kernel, strides, padding='VALID', dilations=[1, 1]) + degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) + degenerate = tf.squeeze(tf.cast(degenerate, tf.uint8), [0]) + + # For the borders of the resulting image, fill in the values of the + # original image. + mask = tf.ones_like(degenerate) + padded_mask = tf.pad(mask, [[1, 1], [1, 1], [0, 0]]) + padded_degenerate = tf.pad(degenerate, [[1, 1], [1, 1], [0, 0]]) + result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image) + + # Blend the final result. + return blend(result, orig_image, factor) + + +def equalize(image: tf.Tensor) -> tf.Tensor: + """Implements Equalize function from PIL using TF ops.""" + def scale_channel(im, c): + """Scale the data in the channel to implement equalize.""" + im = tf.cast(im[:, :, c], tf.int32) + # Compute the histogram of the image channel. + histo = tf.histogram_fixed_width(im, [0, 255], nbins=256) + + # For the purposes of computing the step, filter out the nonzeros. + nonzero = tf.where(tf.not_equal(histo, 0)) + nonzero_histo = tf.reshape(tf.gather(histo, nonzero), [-1]) + step = (tf.reduce_sum(nonzero_histo) - nonzero_histo[-1]) // 255 + + def build_lut(histo, step): + # Compute the cumulative sum, shifting by step // 2 + # and then normalization by step. + lut = (tf.cumsum(histo) + (step // 2)) // step + # Shift lut, prepending with 0. + lut = tf.concat([[0], lut[:-1]], 0) + # Clip the counts to be in range. This is done + # in the C code for image.point. + return tf.clip_by_value(lut, 0, 255) + + # If step is zero, return the original image. Otherwise, build + # lut from the full histogram and step and then index from it. + result = tf.cond(tf.equal(step, 0), + lambda: im, + lambda: tf.gather(build_lut(histo, step), im)) + + return tf.cast(result, tf.uint8) + + # Assumes RGB for now. Scales each channel independently + # and then stacks the result. + s1 = scale_channel(image, 0) + s2 = scale_channel(image, 1) + s3 = scale_channel(image, 2) + image = tf.stack([s1, s2, s3], 2) + return image + + +def invert(image: tf.Tensor) -> tf.Tensor: + """Inverts the image pixels.""" + image = tf.convert_to_tensor(image) + return 255 - image + + +def wrap(image: tf.Tensor) -> tf.Tensor: + """Returns 'image' with an extra channel set to all 1s.""" + shape = tf.shape(image) + extended_channel = tf.ones([shape[0], shape[1], 1], image.dtype) + extended = tf.concat([image, extended_channel], axis=2) + return extended + + +def unwrap(image: tf.Tensor, replace: int) -> tf.Tensor: + """Unwraps an image produced by wrap. + + Where there is a 0 in the last channel for every spatial position, + the rest of the three channels in that spatial dimension are grayed + (set to 128). Operations like translate and shear on a wrapped + Tensor will leave 0s in empty locations. Some transformations look + at the intensity of values to do preprocessing, and we want these + empty pixels to assume the 'average' value, rather than pure black. + + + Args: + image: A 3D Image Tensor with 4 channels. + replace: A one or three value 1D tensor to fill empty pixels. + + Returns: + image: A 3D image Tensor with 3 channels. + """ + image_shape = tf.shape(image) + # Flatten the spatial dimensions. + flattened_image = tf.reshape(image, [-1, image_shape[2]]) + + # Find all pixels where the last channel is zero. + alpha_channel = tf.expand_dims(flattened_image[:, 3], axis=-1) + + replace = tf.concat([replace, tf.ones([1], image.dtype)], 0) + + # Where they are zero, fill them in with 'replace'. + flattened_image = tf.where( + tf.equal(alpha_channel, 0), + tf.ones_like(flattened_image, dtype=image.dtype) * replace, + flattened_image) + + image = tf.reshape(flattened_image, image_shape) + image = tf.slice(image, [0, 0, 0], [image_shape[0], image_shape[1], 3]) + return image + + +def _randomly_negate_tensor(tensor: tf.Tensor) -> tf.Tensor: + """With 50% prob turn the tensor negative.""" + should_flip = tf.cast(tf.floor(tf.random.uniform([]) + 0.5), tf.bool) + final_tensor = tf.cond(should_flip, lambda: tensor, lambda: -tensor) + return final_tensor + + +def _rotate_level_to_arg(level: float): + level = (level/_MAX_LEVEL) * 30. + level = _randomly_negate_tensor(level) + return (level,) + + +def _shrink_level_to_arg(level: float): + """Converts level to ratio by which we shrink the image content.""" + if level == 0: + return (1.0,) # if level is zero, do not shrink the image + # Maximum shrinking ratio is 2.9. + level = 2. / (_MAX_LEVEL / level) + 0.9 + return (level,) + + +def _enhance_level_to_arg(level: float): + return ((level/_MAX_LEVEL) * 1.8 + 0.1,) + + +def _shear_level_to_arg(level: float): + level = (level/_MAX_LEVEL) * 0.3 + # Flip level to negative with 50% chance. + level = _randomly_negate_tensor(level) + return (level,) + + +def _translate_level_to_arg(level: float, translate_const: float): + level = (level/_MAX_LEVEL) * float(translate_const) + # Flip level to negative with 50% chance. + level = _randomly_negate_tensor(level) + return (level,) + + +def _mult_to_arg(level: float, multiplier: float = 1.): + return (int((level / _MAX_LEVEL) * multiplier),) + + +def _apply_func_with_prob(func: Any, + image: tf.Tensor, + args: Any, + prob: float): + """Apply `func` to image w/ `args` as input with probability `prob`.""" + assert isinstance(args, tuple) + + # Apply the function with probability `prob`. + should_apply_op = tf.cast( + tf.floor(tf.random.uniform([], dtype=tf.float32) + prob), tf.bool) + augmented_image = tf.cond( + should_apply_op, + lambda: func(image, *args), + lambda: image) + return augmented_image + + +def select_and_apply_random_policy(policies: Any, image: tf.Tensor): + """Select a random policy from `policies` and apply it to `image`.""" + policy_to_select = tf.random.uniform([], maxval=len(policies), dtype=tf.int32) + # Note that using tf.case instead of tf.conds would result in significantly + # larger graphs and would even break export for some larger policies. + for (i, policy) in enumerate(policies): + image = tf.cond( + tf.equal(i, policy_to_select), + lambda selected_policy=policy: selected_policy(image), + lambda: image) + return image + + +NAME_TO_FUNC = { + 'AutoContrast': autocontrast, + 'Equalize': equalize, + 'Invert': invert, + 'Rotate': wrapped_rotate, + 'Posterize': posterize, + 'Solarize': solarize, + 'SolarizeAdd': solarize_add, + 'Color': color, + 'Contrast': contrast, + 'Brightness': brightness, + 'Sharpness': sharpness, + 'ShearX': shear_x, + 'ShearY': shear_y, + 'TranslateX': translate_x, + 'TranslateY': translate_y, + 'Cutout': cutout, +} + +# Functions that have a 'replace' parameter +REPLACE_FUNCS = frozenset({ + 'Rotate', + 'TranslateX', + 'ShearX', + 'ShearY', + 'TranslateY', + 'Cutout', +}) + + +def level_to_arg(cutout_const: float, translate_const: float): + """Creates a dict mapping image operation names to their arguments.""" + + no_arg = lambda level: () + posterize_arg = lambda level: _mult_to_arg(level, 4) + solarize_arg = lambda level: _mult_to_arg(level, 256) + solarize_add_arg = lambda level: _mult_to_arg(level, 110) + cutout_arg = lambda level: _mult_to_arg(level, cutout_const) + translate_arg = lambda level: _translate_level_to_arg(level, translate_const) + + args = { + 'AutoContrast': no_arg, + 'Equalize': no_arg, + 'Invert': no_arg, + 'Rotate': _rotate_level_to_arg, + 'Posterize': posterize_arg, + 'Solarize': solarize_arg, + 'SolarizeAdd': solarize_add_arg, + 'Color': _enhance_level_to_arg, + 'Contrast': _enhance_level_to_arg, + 'Brightness': _enhance_level_to_arg, + 'Sharpness': _enhance_level_to_arg, + 'ShearX': _shear_level_to_arg, + 'ShearY': _shear_level_to_arg, + 'Cutout': cutout_arg, + 'TranslateX': translate_arg, + 'TranslateY': translate_arg, + } + return args + + +def _parse_policy_info(name: Text, + prob: float, + level: float, + replace_value: List[int], + cutout_const: float, + translate_const: float) -> Tuple[Any, float, Any]: + """Return the function that corresponds to `name` and update `level` param.""" + func = NAME_TO_FUNC[name] + args = level_to_arg(cutout_const, translate_const)[name](level) + + if name in REPLACE_FUNCS: + # Add in replace arg if it is required for the function that is called. + args = tuple(list(args) + [replace_value]) + + return func, prob, args + + +class ImageAugment(object): + """Image augmentation class for applying image distortions.""" + + def distort(self, image: tf.Tensor) -> tf.Tensor: + """Given an image tensor, returns a distorted image with the same shape. + + Args: + image: `Tensor` of shape [height, width, 3] representing an image. + + Returns: + The augmented version of `image`. + """ + raise NotImplementedError() + + +class AutoAugment(ImageAugment): + """Applies the AutoAugment policy to images. + + AutoAugment is from the paper: https://arxiv.org/abs/1805.09501. + """ + + def __init__(self, + augmentation_name: Text = 'v0', + policies: Optional[Dict[Text, Any]] = None, + cutout_const: float = 100, + translate_const: float = 250): + """Applies the AutoAugment policy to images. + + Args: + augmentation_name: The name of the AutoAugment policy to use. The + available options are `v0` and `test`. `v0` is the policy used for all + of the results in the paper and was found to achieve the best results on + the COCO dataset. `v1`, `v2` and `v3` are additional good policies found + on the COCO dataset that have slight variation in what operations were + used during the search procedure along with how many operations are + applied in parallel to a single image (2 vs 3). + policies: list of lists of tuples in the form `(func, prob, level)`, + `func` is a string name of the augmentation function, `prob` is the + probability of applying the `func` operation, `level` is the input + argument for `func`. + cutout_const: multiplier for applying cutout. + translate_const: multiplier for applying translation. + """ + super(AutoAugment, self).__init__() + + if policies is None: + self.available_policies = { + 'v0': self.policy_v0(), + 'test': self.policy_test(), + 'simple': self.policy_simple(), + } + + if augmentation_name not in self.available_policies: + raise ValueError( + 'Invalid augmentation_name: {}'.format(augmentation_name)) + + self.augmentation_name = augmentation_name + self.policies = self.available_policies[augmentation_name] + self.cutout_const = float(cutout_const) + self.translate_const = float(translate_const) + + def distort(self, image: tf.Tensor) -> tf.Tensor: + """Applies the AutoAugment policy to `image`. + + AutoAugment is from the paper: https://arxiv.org/abs/1805.09501. + + Args: + image: `Tensor` of shape [height, width, 3] representing an image. + + Returns: + A version of image that now has data augmentation applied to it based on + the `policies` pass into the function. + """ + input_image_type = image.dtype + + if input_image_type != tf.uint8: + image = tf.clip_by_value(image, 0.0, 255.0) + image = tf.cast(image, dtype=tf.uint8) + + replace_value = [128] * 3 + + # func is the string name of the augmentation function, prob is the + # probability of applying the operation and level is the parameter + # associated with the tf op. + + # tf_policies are functions that take in an image and return an augmented + # image. + tf_policies = [] + for policy in self.policies: + tf_policy = [] + # Link string name to the correct python function and make sure the + # correct argument is passed into that function. + for policy_info in policy: + policy_info = list(policy_info) + [ + replace_value, self.cutout_const, self.translate_const + ] + tf_policy.append(_parse_policy_info(*policy_info)) + # Now build the tf policy that will apply the augmentation procedue + # on image. + def make_final_policy(tf_policy_): + + def final_policy(image_): + for func, prob, args in tf_policy_: + image_ = _apply_func_with_prob(func, image_, args, prob) + return image_ + + return final_policy + + tf_policies.append(make_final_policy(tf_policy)) + + image = select_and_apply_random_policy(tf_policies, image) + image = tf.cast(image, dtype=input_image_type) + return image + + @staticmethod + def policy_v0(): + """Autoaugment policy that was used in AutoAugment Paper. + + Each tuple is an augmentation operation of the form + (operation, probability, magnitude). Each element in policy is a + sub-policy that will be applied sequentially on the image. + + Returns: + the policy. + """ + + # TODO(dankondratyuk): tensorflow_addons defines custom ops, which + # for some reason are not included when building/linking + # This results in the error, "Op type not registered + # 'Addons>ImageProjectiveTransformV2' in binary" when running on borg TPUs + policy = [ + [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], + [('Color', 0.4, 9), ('Equalize', 0.6, 3)], + [('Color', 0.4, 1), ('Rotate', 0.6, 8)], + [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)], + [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)], + [('Color', 0.2, 0), ('Equalize', 0.8, 8)], + [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)], + [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)], + [('Color', 0.6, 1), ('Equalize', 1.0, 2)], + [('Invert', 0.4, 9), ('Rotate', 0.6, 0)], + [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], + [('Color', 0.4, 7), ('Equalize', 0.6, 0)], + [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)], + [('Solarize', 0.6, 8), ('Color', 0.6, 9)], + [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], + [('Rotate', 1.0, 7), ('TranslateY', 0.8, 9)], + [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)], + [('ShearY', 0.8, 0), ('Color', 0.6, 4)], + [('Color', 1.0, 0), ('Rotate', 0.6, 2)], + [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], + [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], + [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], + [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], + [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], + [('Color', 0.8, 6), ('Rotate', 0.4, 5)], + ] + return policy + + @staticmethod + def policy_simple(): + """Same as `policy_v0`, except with custom ops removed.""" + + policy = [ + [('Color', 0.4, 9), ('Equalize', 0.6, 3)], + [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)], + [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)], + [('Color', 0.2, 0), ('Equalize', 0.8, 8)], + [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)], + [('Color', 0.6, 1), ('Equalize', 1.0, 2)], + [('Color', 0.4, 7), ('Equalize', 0.6, 0)], + [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)], + [('Solarize', 0.6, 8), ('Color', 0.6, 9)], + [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], + [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], + [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], + [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], + ] + return policy + + @staticmethod + def policy_test(): + """Autoaugment test policy for debugging.""" + policy = [ + [('TranslateX', 1.0, 4), ('Equalize', 1.0, 10)], + ] + return policy + + +class RandAugment(ImageAugment): + """Applies the RandAugment policy to images. + + RandAugment is from the paper https://arxiv.org/abs/1909.13719, + """ + + def __init__(self, + num_layers: int = 2, + magnitude: float = 10., + cutout_const: float = 40., + translate_const: float = 100.): + """Applies the RandAugment policy to images. + + Args: + num_layers: Integer, the number of augmentation transformations to apply + sequentially to an image. Represented as (N) in the paper. Usually best + values will be in the range [1, 3]. + magnitude: Integer, shared magnitude across all augmentation operations. + Represented as (M) in the paper. Usually best values are in the range + [5, 10]. + cutout_const: multiplier for applying cutout. + translate_const: multiplier for applying translation. + """ + super(RandAugment, self).__init__() + + self.num_layers = num_layers + self.magnitude = float(magnitude) + self.cutout_const = float(cutout_const) + self.translate_const = float(translate_const) + self.available_ops = [ + 'AutoContrast', 'Equalize', 'Invert', 'Rotate', 'Posterize', 'Solarize', + 'Color', 'Contrast', 'Brightness', 'Sharpness', 'ShearX', 'ShearY', + 'TranslateX', 'TranslateY', 'Cutout', 'SolarizeAdd' + ] + + def distort(self, image: tf.Tensor) -> tf.Tensor: + """Applies the RandAugment policy to `image`. + + Args: + image: `Tensor` of shape [height, width, 3] representing an image. + + Returns: + The augmented version of `image`. + """ + input_image_type = image.dtype + + if input_image_type != tf.uint8: + image = tf.clip_by_value(image, 0.0, 255.0) + image = tf.cast(image, dtype=tf.uint8) + + replace_value = [128] * 3 + min_prob, max_prob = 0.2, 0.8 + + for _ in range(self.num_layers): + op_to_select = tf.random.uniform( + [], maxval=len(self.available_ops) + 1, dtype=tf.int32) + + branch_fns = [] + for (i, op_name) in enumerate(self.available_ops): + prob = tf.random.uniform([], + minval=min_prob, + maxval=max_prob, + dtype=tf.float32) + func, _, args = _parse_policy_info(op_name, + prob, + self.magnitude, + replace_value, + self.cutout_const, + self.translate_const) + branch_fns.append(( + i, + # pylint:disable=g-long-lambda + lambda selected_func=func, selected_args=args: selected_func( + image, *selected_args))) + # pylint:enable=g-long-lambda + + image = tf.switch_case(branch_index=op_to_select, + branch_fns=branch_fns, + default=lambda: tf.identity(image)) + + image = tf.cast(image, dtype=input_image_type) + return image diff --git a/official/vision/image_classification/augment_test.py b/official/vision/image_classification/augment_test.py new file mode 100644 index 00000000000..364aeaec4e7 --- /dev/null +++ b/official/vision/image_classification/augment_test.py @@ -0,0 +1,137 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for autoaugment.""" + +from __future__ import absolute_import +from __future__ import division +# from __future__ import google_type_annotations +from __future__ import print_function + +from absl.testing import parameterized + +import tensorflow.compat.v2 as tf + +from official.vision.image_classification import augment + + +def get_dtype_test_cases(): + return [ + ('uint8', tf.uint8), + ('int32', tf.int32), + ('float16', tf.float16), + ('float32', tf.float32), + ] + + +@parameterized.named_parameters(get_dtype_test_cases()) +class TransformsTest(parameterized.TestCase, tf.test.TestCase): + """Basic tests for fundamental transformations.""" + + def test_to_from_4d(self, dtype): + for shape in [(10, 10), (10, 10, 10), (10, 10, 10, 10)]: + original_ndims = len(shape) + image = tf.zeros(shape, dtype=dtype) + image_4d = augment.to_4d(image) + self.assertEqual(4, tf.rank(image_4d)) + self.assertAllEqual(image, augment.from_4d(image_4d, original_ndims)) + + def test_transform(self, dtype): + image = tf.constant([[1, 2], [3, 4]], dtype=dtype) + self.assertAllEqual(augment.transform(image, transforms=[1]*8), + [[4, 4], [4, 4]]) + + def test_translate(self, dtype): + image = tf.constant( + [[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]], + dtype=dtype) + translations = [-1, -1] + translated = augment.translate(image=image, + translations=translations) + expected = [[1, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0], [0, 0, 0, 0]] + self.assertAllEqual(translated, expected) + + def test_translate_shapes(self, dtype): + translation = [0, 0] + for shape in [(3, 3), (5, 5), (224, 224, 3)]: + image = tf.zeros(shape, dtype=dtype) + self.assertAllEqual(image, augment.translate(image, translation)) + + def test_translate_invalid_translation(self, dtype): + image = tf.zeros((1, 1), dtype=dtype) + invalid_translation = [[[1, 1]]] + with self.assertRaisesRegex(TypeError, 'rank 1 or 2'): + _ = augment.translate(image, invalid_translation) + + def test_rotate(self, dtype): + image = tf.reshape(tf.cast(tf.range(9), dtype), (3, 3)) + rotation = 90. + transformed = augment.rotate(image=image, degrees=rotation) + expected = [[2, 5, 8], + [1, 4, 7], + [0, 3, 6]] + self.assertAllEqual(transformed, expected) + + def test_rotate_shapes(self, dtype): + degrees = 0. + for shape in [(3, 3), (5, 5), (224, 224, 3)]: + image = tf.zeros(shape, dtype=dtype) + self.assertAllEqual(image, augment.rotate(image, degrees)) + + +class AutoaugmentTest(tf.test.TestCase): + + def test_autoaugment(self): + """Smoke test to be sure there are no syntax errors.""" + image = tf.zeros((224, 224, 3), dtype=tf.uint8) + + augmenter = augment.AutoAugment() + aug_image = augmenter.distort(image) + + self.assertEqual((224, 224, 3), aug_image.shape) + + def test_randaug(self): + """Smoke test to be sure there are no syntax errors.""" + image = tf.zeros((224, 224, 3), dtype=tf.uint8) + + augmenter = augment.RandAugment() + aug_image = augmenter.distort(image) + + self.assertEqual((224, 224, 3), aug_image.shape) + + def test_all_policy_ops(self): + """Smoke test to be sure all augmentation functions can execute.""" + + prob = 1 + magnitude = 10 + replace_value = [128] * 3 + cutout_const = 100 + translate_const = 250 + + image = tf.ones((224, 224, 3), dtype=tf.uint8) + + for op_name in augment.NAME_TO_FUNC: + func, _, args = augment._parse_policy_info(op_name, + prob, + magnitude, + replace_value, + cutout_const, + translate_const) + image = func(image, *args) + + self.assertEqual((224, 224, 3), image.shape) + +if __name__ == '__main__': + assert tf.version.VERSION.startswith('2.') + tf.test.main() diff --git a/official/vision/image_classification/callbacks.py b/official/vision/image_classification/callbacks.py new file mode 100644 index 00000000000..b35736aed8b --- /dev/null +++ b/official/vision/image_classification/callbacks.py @@ -0,0 +1,136 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Common modules for callbacks.""" +from __future__ import absolute_import +from __future__ import division +# from __future__ import google_type_annotations +from __future__ import print_function + +import os +from absl import logging + +import tensorflow as tf +from typing import Any, List, MutableMapping, Text + + +def get_callbacks(model_checkpoint: bool = True, + include_tensorboard: bool = True, + track_lr: bool = True, + write_model_weights: bool = True, + initial_step: int = 0, + model_dir: Text = None) -> List[tf.keras.callbacks.Callback]: + """Get all callbacks.""" + model_dir = model_dir or '' + callbacks = [] + if model_checkpoint: + ckpt_full_path = os.path.join(model_dir, 'model.ckpt-{epoch:04d}') + callbacks.append(tf.keras.callbacks.ModelCheckpoint( + ckpt_full_path, save_weights_only=True, verbose=1)) + if include_tensorboard: + callbacks.append(CustomTensorBoard( + log_dir=model_dir, + track_lr=track_lr, + initial_step=initial_step, + write_images=write_model_weights)) + return callbacks + + +def get_scalar_from_tensor(t: tf.Tensor) -> int: + """Utility function to convert a Tensor to a scalar.""" + t = tf.keras.backend.get_value(t) + if callable(t): + return t() + else: + return t + + +class CustomTensorBoard(tf.keras.callbacks.TensorBoard): + """A customized TensorBoard callback that tracks additional datapoints. + + Metrics tracked: + - Global learning rate + + Attributes: + log_dir: the path of the directory where to save the log files to be + parsed by TensorBoard. + track_lr: `bool`, whether or not to track the global learning rate. + initial_step: the initial step, used for preemption recovery. + **kwargs: Additional arguments for backwards compatibility. Possible key + is `period`. + """ + # TODO(b/146499062): track params, flops, log lr, l2 loss, + # classification loss + + def __init__(self, + log_dir: Text, + track_lr: bool = False, + initial_step: int = 0, + **kwargs): + super(CustomTensorBoard, self).__init__(log_dir=log_dir, **kwargs) + self.step = initial_step + self._track_lr = track_lr + + def on_batch_begin(self, + epoch: int, + logs: MutableMapping[Text, Any] = None) -> None: + self.step += 1 + if logs is None: + logs = {} + logs.update(self._calculate_metrics()) + super(CustomTensorBoard, self).on_batch_begin(epoch, logs) + + def on_epoch_begin(self, + epoch: int, + logs: MutableMapping[Text, Any] = None) -> None: + if logs is None: + logs = {} + metrics = self._calculate_metrics() + logs.update(metrics) + for k, v in metrics.items(): + logging.info('Current %s: %f', k, v) + super(CustomTensorBoard, self).on_epoch_begin(epoch, logs) + + def on_epoch_end(self, + epoch: int, + logs: MutableMapping[Text, Any] = None) -> None: + if logs is None: + logs = {} + metrics = self._calculate_metrics() + logs.update(metrics) + super(CustomTensorBoard, self).on_epoch_end(epoch, logs) + + def _calculate_metrics(self) -> MutableMapping[Text, Any]: + logs = {} + if self._track_lr: + logs['learning_rate'] = self._calculate_lr() + return logs + + def _calculate_lr(self) -> int: + """Calculates the learning rate given the current step.""" + lr = self._get_base_optimizer().lr + if callable(lr): + lr = lr(self.step) + return get_scalar_from_tensor(lr) + + def _get_base_optimizer(self) -> tf.keras.optimizers.Optimizer: + """Get the base optimizer used by the current model.""" + + optimizer = self.model.optimizer + + # The optimizer might be wrapped by another class, so unwrap it + while hasattr(optimizer, '_optimizer'): + optimizer = optimizer._optimizer # pylint:disable=protected-access + + return optimizer diff --git a/official/vision/image_classification/callbacks_test.py b/official/vision/image_classification/callbacks_test.py new file mode 100644 index 00000000000..970fcff847c --- /dev/null +++ b/official/vision/image_classification/callbacks_test.py @@ -0,0 +1,86 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for callbacks.""" + +from __future__ import absolute_import +from __future__ import division +# from __future__ import google_type_annotations +from __future__ import print_function + +import collections +import functools +import os + +from absl.testing import parameterized + +import numpy as np +import tensorflow as tf + +from tensorflow.python.keras import callbacks_test +from tensorflow.python.keras import keras_parameterized +from official.vision.image_classification import callbacks + +_ObservedSummary = collections.namedtuple('_ObservedSummary', ('logdir', 'tag')) + + +def _trivial_function(a): + return a + + +class UtilFunctionTests(tf.test.TestCase, parameterized.TestCase): + """Tests to check utility functions provided in callbacks.py.""" + + @parameterized.named_parameters( + ('integer', 1), + ('float', 1.), + ('lambda', lambda: 1), + ('partial', functools.partial(_trivial_function, 1))) + def test_scalar_from_tensors(self, t): + t = tf.Variable(t) + value = callbacks.get_scalar_from_tensor(t) + print (value) + self.assertTrue(np.isscalar(value)) + + +@keras_parameterized.run_with_all_model_types +@keras_parameterized.run_all_keras_modes(always_skip_v1=True) +class CustomTensorBoardTest(callbacks_test.TestTensorBoardV2): + + def test_custom_tb_learning_rate(self): + os.chdir(self.get_temp_dir()) + model = self._get_model() + x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) + tb_cbk = callbacks.CustomTensorBoard(log_dir=self.logdir, + track_lr=True) + + model.fit( + x, + y, + batch_size=2, + epochs=2, + validation_data=(x, y), + callbacks=[tb_cbk]) + + summary_file = callbacks_test.list_summaries(logdir=self.logdir) + self.assertEqual( + summary_file.scalars, { + _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'), + _ObservedSummary(logdir=self.train_dir, tag='epoch_learning_rate'), + _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'), + }) + + +if __name__ == '__main__': + tf.test.main() diff --git a/official/vision/image_classification/classifier_trainer.py b/official/vision/image_classification/classifier_trainer.py new file mode 100644 index 00000000000..0c37f42378e --- /dev/null +++ b/official/vision/image_classification/classifier_trainer.py @@ -0,0 +1,427 @@ +# Lint as: python3 +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Runs an Image Classification model.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import pprint +from typing import Any, Tuple, Text, Optional, Mapping + +from absl import app +from absl import flags +from absl import logging +import tensorflow.compat.v2 as tf + +from official.modeling import performance +from official.modeling.hyperparams import params_dict +from official.utils import hyperparams_flags +from official.utils.logs import logger +from official.utils.misc import distribution_utils +from official.utils.misc import keras_utils +from official.vision.image_classification import callbacks as custom_callbacks +from official.vision.image_classification import dataset_factory +from official.vision.image_classification import optimizer_factory +from official.vision.image_classification.configs import base_configs +from official.vision.image_classification.configs import configs +from official.vision.image_classification.efficientnet import efficientnet_model +from official.vision.image_classification.resnet import common +from official.vision.image_classification.resnet import resnet_model + +MODELS = { + 'efficientnet': efficientnet_model.EfficientNet.from_name, + 'resnet': resnet_model.resnet50, +} + + +def _get_metrics(one_hot: bool) -> Mapping[Text, Any]: + """Get a dict of available metrics to track.""" + if one_hot: + return { + # (name, metric_fn) + 'acc': tf.keras.metrics.CategoricalAccuracy(name='accuracy'), + 'accuracy': tf.keras.metrics.CategoricalAccuracy(name='accuracy'), + 'top_1': tf.keras.metrics.CategoricalAccuracy(name='accuracy'), + 'top_5': tf.keras.metrics.TopKCategoricalAccuracy( + k=5, + name='top_5_accuracy'), + } + else: + return { + # (name, metric_fn) + 'acc': tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'), + 'accuracy': tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'), + 'top_1': tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'), + 'top_5': tf.keras.metrics.SparseTopKCategoricalAccuracy( + k=5, + name='top_5_accuracy'), + } + + +def get_image_size_from_model( + params: base_configs.ExperimentConfig) -> Optional[int]: + """If the given model has a preferred image size, return it.""" + if params.model_name == 'efficientnet': + efficientnet_name = params.model.model_params.model_name + if efficientnet_name in efficientnet_model.MODEL_CONFIGS: + return efficientnet_model.MODEL_CONFIGS[efficientnet_name].resolution + return None + + +def _get_dataset_builders(params: base_configs.ExperimentConfig, + strategy: tf.distribute.Strategy, + one_hot: bool + ) -> Tuple[Any, Any, Any]: + """Create and return train, validation, and test dataset builders.""" + if one_hot: + logging.warning('label_smoothing > 0, so datasets will be one hot encoded.') + else: + logging.warning('label_smoothing not applied, so datasets will not be one ' + 'hot encoded.') + + num_devices = strategy.num_replicas_in_sync + image_size = get_image_size_from_model(params) + + dataset_configs = [ + params.train_dataset, params.validation_dataset, params.test_dataset + ] + builders = [] + + for config in dataset_configs: + if config is not None and config.has_data: + builder = dataset_factory.DatasetBuilder( + config, + image_size=image_size or config.image_size, + num_devices=num_devices, + one_hot=one_hot) + else: + builder = None + builders.append(builder) + + return builders + + +def get_loss_scale(params: base_configs.ExperimentConfig, + fp16_default: float = 128.) -> float: + """Returns the loss scale for initializations.""" + loss_scale = params.model.loss.loss_scale + if loss_scale == 'dynamic': + return loss_scale + elif loss_scale is not None: + return float(loss_scale) + elif params.train_dataset.dtype == 'float32': + return 1. + else: + assert params.train_dataset.dtype == 'float16' + return fp16_default + + +def _get_params_from_flags(flags_obj: flags.FlagValues): + """Get ParamsDict from flags.""" + model = flags_obj.model_type.lower() + dataset = flags_obj.dataset.lower() + params = configs.get_config(model=model, dataset=dataset) + + flags_overrides = { + 'model_dir': flags_obj.model_dir, + 'mode': flags_obj.mode, + 'model': { + 'name': model, + }, + 'runtime': { + 'enable_eager': flags_obj.enable_eager, + 'tpu': flags_obj.tpu, + }, + 'train_dataset': { + 'data_dir': flags_obj.data_dir, + }, + 'validation_dataset': { + 'data_dir': flags_obj.data_dir, + }, + 'test_dataset': { + 'data_dir': flags_obj.data_dir, + }, + } + + overriding_configs = (flags_obj.config_file, + flags_obj.params_override, + flags_overrides) + + pp = pprint.PrettyPrinter() + + logging.info('Base params: %s', pp.pformat(params.as_dict())) + + for param in overriding_configs: + logging.info('Overriding params: %s', param) + # Set is_strict to false because we can have dynamic dict parameters. + params = params_dict.override_params_dict(params, param, is_strict=False) + + params.validate() + params.lock() + + logging.info('Final model parameters: %s', pp.pformat(params.as_dict())) + return params + + +def resume_from_checkpoint(model: tf.keras.Model, + model_dir: str, + train_steps: int) -> int: + """Resumes from the latest checkpoint, if possible. + + Loads the model weights and optimizer settings from a checkpoint. + This function should be used in case of preemption recovery. + + Args: + model: The model whose weights should be restored. + model_dir: The directory where model weights were saved. + train_steps: The number of steps to train. + + Returns: + The epoch of the latest checkpoint, or 0 if not restoring. + + """ + logging.info('Load from checkpoint is enabled.') + latest_checkpoint = tf.train.latest_checkpoint(model_dir) + logging.info('latest_checkpoint: %s', latest_checkpoint) + if not latest_checkpoint: + logging.info('No checkpoint detected.') + return 0 + + logging.info('Checkpoint file %s found and restoring from ' + 'checkpoint', latest_checkpoint) + model.load_weights(latest_checkpoint) + initial_epoch = model.optimizer.iterations // train_steps + logging.info('Completed loading from checkpoint.') + logging.info('Resuming from epoch %d', initial_epoch) + return int(initial_epoch) + + +def initialize(params: base_configs.ExperimentConfig): + """Initializes backend related initializations.""" + keras_utils.set_session_config( + enable_eager=params.runtime.enable_eager, + enable_xla=params.runtime.enable_xla) + if params.runtime.gpu_threads_enabled: + keras_utils.set_gpu_thread_mode_and_count( + per_gpu_thread_count=params.runtime.per_gpu_thread_count, + gpu_thread_mode=params.runtime.gpu_thread_mode, + num_gpus=params.runtime.num_gpus, + datasets_num_private_threads=params.runtime.dataset_num_private_threads) + + dataset = params.train_dataset or params.validation_dataset + performance.set_mixed_precision_policy(dataset.dtype) + + if dataset.data_format: + data_format = dataset.data_format + elif tf.config.list_physical_devices('GPU'): + data_format = 'channels_first' + else: + data_format = 'channels_last' + tf.keras.backend.set_image_data_format(data_format) + distribution_utils.configure_cluster( + params.runtime.worker_hosts, + params.runtime.task_index) + if params.runtime.enable_eager: + # Enable eager execution to allow step-by-step debugging + tf.config.experimental_run_functions_eagerly(True) + + +def define_classifier_flags(): + """Defines common flags for image classification.""" + hyperparams_flags.initialize_common_flags() + flags.DEFINE_string( + 'data_dir', + default=None, + help='The location of the input data.') + flags.DEFINE_string( + 'mode', + default=None, + help='Mode to run: `train`, `eval`, `train_and_eval` or `export`.') + flags.DEFINE_bool( + 'enable_eager', + default=None, + help='Use eager execution and disable autograph for debugging.') + flags.DEFINE_string( + 'model_type', + default=None, + help='The type of the model, e.g. EfficientNet, etc.') + flags.DEFINE_string( + 'dataset', + default=None, + help='The name of the dataset, e.g. ImageNet, etc.') + + +def serialize_config(params: base_configs.ExperimentConfig, + model_dir: str): + """Serializes and saves the experiment config.""" + params_save_path = os.path.join(model_dir, 'params.yaml') + logging.info('Saving experiment configuration to %s', params_save_path) + tf.io.gfile.makedirs(model_dir) + params_dict.save_params_dict_to_yaml(params, params_save_path) + + +def train_and_eval( + params: base_configs.ExperimentConfig, + strategy_override: tf.distribute.Strategy) -> Mapping[str, Any]: + """Runs the train and eval path using compile/fit.""" + logging.info('Running train and eval.') + + # Note: for TPUs, strategy and scope should be created before the dataset + strategy = strategy_override or distribution_utils.get_distribution_strategy( + distribution_strategy=params.runtime.distribution_strategy, + all_reduce_alg=params.runtime.all_reduce_alg, + num_gpus=params.runtime.num_gpus, + tpu_address=params.runtime.tpu) + + strategy_scope = distribution_utils.get_strategy_scope(strategy) + + logging.info('Detected %d devices.', strategy.num_replicas_in_sync) + + label_smoothing = params.model.loss.label_smoothing + one_hot = label_smoothing and label_smoothing > 0 + + builders = _get_dataset_builders(params, strategy, one_hot) + datasets = [builder.build() if builder else None for builder in builders] + + # Unpack datasets and builders based on train/val/test splits + train_builder, validation_builder, test_builder = builders # pylint: disable=unbalanced-tuple-unpacking + train_dataset, validation_dataset, test_dataset = datasets + + train_epochs = params.train.epochs + train_steps = params.train.steps or train_builder.num_steps + validation_steps = params.evaluation.steps or validation_builder.num_steps + + logging.info('Global batch size: %d', train_builder.global_batch_size) + + with strategy_scope: + model_params = params.model.model_params.as_dict() + model = MODELS[params.model.name](**model_params) + learning_rate = optimizer_factory.build_learning_rate( + params=params.model.learning_rate, + batch_size=train_builder.global_batch_size, + train_steps=train_steps) + optimizer = optimizer_factory.build_optimizer( + optimizer_name=params.model.optimizer.name, + base_learning_rate=learning_rate, + params=params.model.optimizer.as_dict()) + + metrics_map = _get_metrics(one_hot) + metrics = [metrics_map[metric] for metric in params.train.metrics] + + if one_hot: + loss_obj = tf.keras.losses.CategoricalCrossentropy( + label_smoothing=params.model.loss.label_smoothing) + else: + loss_obj = tf.keras.losses.SparseCategoricalCrossentropy() + model.compile(optimizer=optimizer, + loss=loss_obj, + metrics=metrics, + run_eagerly=params.runtime.enable_eager) + + initial_epoch = 0 + if params.train.resume_checkpoint: + initial_epoch = resume_from_checkpoint(model=model, + model_dir=params.model_dir, + train_steps=train_steps) + + serialize_config(params=params, model_dir=params.model_dir) + # TODO(dankondratyuk): callbacks significantly slow down training + callbacks = custom_callbacks.get_callbacks( + model_checkpoint=params.train.callbacks.enable_checkpoint_and_export, + include_tensorboard=params.train.callbacks.enable_tensorboard, + track_lr=params.train.tensorboard.track_lr, + write_model_weights=params.train.tensorboard.write_model_weights, + initial_step=initial_epoch * train_steps, + model_dir=params.model_dir) + + history = model.fit( + train_dataset, + epochs=train_epochs, + steps_per_epoch=train_steps, + initial_epoch=initial_epoch, + callbacks=callbacks, + validation_data=validation_dataset, + validation_steps=validation_steps, + validation_freq=params.evaluation.epochs_between_evals) + + validation_output = model.evaluate( + validation_dataset, steps=validation_steps, verbose=2) + + # TODO(dankondratyuk): eval and save final test accuracy + + stats = common.build_stats(history, + validation_output, + callbacks) + return stats + + +def export(params: base_configs.ExperimentConfig): + """Runs the model export functionality.""" + logging.info('Exporting model.') + model_params = params.model.model_params.as_dict() + model = MODELS[params.model.name](**model_params) + checkpoint = params.export.checkpoint + if checkpoint is None: + logging.info('No export checkpoint was provided. Using the latest ' + 'checkpoint from model_dir.') + checkpoint = tf.train.latest_checkpoint(params.model_dir) + + model.load_weights(checkpoint) + model.save(params.export.destination) + + +def run(flags_obj: flags.FlagValues, + strategy_override: tf.distribute.Strategy = None) -> Mapping[str, Any]: + """Runs Image Classification model using native Keras APIs. + + Args: + flags_obj: An object containing parsed flag values. + strategy_override: A `tf.distribute.Strategy` object to use for model. + + Returns: + Dictionary of training/eval stats + """ + params = _get_params_from_flags(flags_obj) + initialize(params) + + if params.mode == 'train_and_eval': + return train_and_eval(params, strategy_override) + elif params.mode == 'export_only': + export(params) + else: + raise ValueError('{} is not a valid mode.'.format(params.mode)) + + +def main(_): + with logger.benchmark_context(flags.FLAGS): + stats = run(flags.FLAGS) + if stats: + logging.info('Run stats:\n%s', stats) + + +if __name__ == '__main__': + logging.set_verbosity(logging.INFO) + define_classifier_flags() + flags.mark_flag_as_required('data_dir') + flags.mark_flag_as_required('mode') + flags.mark_flag_as_required('model_type') + flags.mark_flag_as_required('dataset') + + assert tf.version.VERSION.startswith('2.') + app.run(main) diff --git a/official/vision/image_classification/classifier_trainer_test.py b/official/vision/image_classification/classifier_trainer_test.py new file mode 100644 index 00000000000..84c66c36daf --- /dev/null +++ b/official/vision/image_classification/classifier_trainer_test.py @@ -0,0 +1,317 @@ +# Lint as: python3 +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for the classifier trainer models.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import functools +import json + +import os +import sys + +from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional, Tuple + +from absl import flags +from absl.testing import parameterized +import tensorflow.compat.v2 as tf + +from tensorflow.python.distribute import combinations +from tensorflow.python.distribute import strategy_combinations +from official.utils.flags import core as flags_core +from official.vision.image_classification import classifier_trainer +from official.vision.image_classification import dataset_factory +from official.vision.image_classification import test_utils +from official.vision.image_classification.configs import base_configs + +classifier_trainer.define_classifier_flags() + + +def distribution_strategy_combinations() -> Iterable[Tuple[Any, ...]]: + """Returns the combinations of end-to-end tests to run.""" + return combinations.combine( + distribution=[ + strategy_combinations.default_strategy, + strategy_combinations.tpu_strategy, + strategy_combinations.one_device_strategy_gpu, + ], + model=[ + 'efficientnet', + 'resnet', + ], + mode='eager', + dataset=[ + 'imagenet', + ], + ) + + +def get_params_override(params_override: Mapping[str, Any]) -> str: + """Converts params_override dict to string command.""" + return '--params_override=' + json.dumps(params_override) + + +def basic_params_override() -> MutableMapping[str, Any]: + """Returns a basic parameter configuration for testing.""" + return { + 'train_dataset': { + 'builder': 'synthetic', + 'use_per_replica_batch_size': True, + 'batch_size': 1, + 'image_size': 224, + }, + 'validation_dataset': { + 'builder': 'synthetic', + 'batch_size': 1, + 'use_per_replica_batch_size': True, + 'image_size': 224, + }, + 'test_dataset': { + 'builder': 'synthetic', + 'batch_size': 1, + 'use_per_replica_batch_size': True, + 'image_size': 224, + }, + 'train': { + 'steps': 1, + 'epochs': 1, + 'callbacks': { + 'enable_checkpoint_and_export': True, + 'enable_tensorboard': False, + }, + }, + 'evaluation': { + 'steps': 1, + }, + } + + +def get_trivial_model(num_classes: int) -> tf.keras.Model: + """Creates and compiles trivial model for ImageNet dataset.""" + model = test_utils.trivial_model(num_classes=num_classes) + lr = 0.01 + optimizer = tf.keras.optimizers.SGD(learning_rate=lr) + loss_obj = tf.keras.losses.SparseCategoricalCrossentropy() + model.compile(optimizer=optimizer, + loss=loss_obj, + run_eagerly=True) + return model + + +def get_trivial_data() -> tf.data.Dataset: + """Gets trivial data in the ImageNet size.""" + def generate_data(_) -> tf.data.Dataset: + image = tf.zeros(shape=(224, 224, 3), dtype=tf.float32) + label = tf.zeros([1], dtype=tf.int32) + return image, label + + dataset = tf.data.Dataset.range(1) + dataset = dataset.repeat() + dataset = dataset.map(generate_data, + num_parallel_calls=tf.data.experimental.AUTOTUNE) + dataset = dataset.prefetch(buffer_size=1).batch(1) + return dataset + + +def run_end_to_end(main: Callable[[Any], None], + extra_flags: Optional[Iterable[str]] = None, + model_dir: Optional[str] = None): + """Runs the classifier trainer end-to-end.""" + extra_flags = [] if extra_flags is None else extra_flags + args = [sys.argv[0], '--model_dir', model_dir] + extra_flags + flags_core.parse_flags(argv=args) + main(flags.FLAGS) + + +class ClassifierTest(tf.test.TestCase, parameterized.TestCase): + """Unit tests for Keras models.""" + _tempdir = None + + @classmethod + def setUpClass(cls): # pylint: disable=invalid-name + super(ClassifierTest, cls).setUpClass() + + def tearDown(self): + super(ClassifierTest, self).tearDown() + tf.io.gfile.rmtree(self.get_temp_dir()) + + @combinations.generate(distribution_strategy_combinations()) + def test_end_to_end_train_and_eval_export(self, distribution, model, dataset): + """Test train_and_eval and export for Keras classifier models.""" + # Some parameters are not defined as flags (e.g. cannot run + # classifier_train.py --batch_size=...) by design, so use + # "--params_override=..." instead + model_dir = self.get_temp_dir() + base_flags = [ + '--data_dir=not_used', + '--model_type=' + model, + '--dataset=' + dataset, + ] + train_and_eval_flags = base_flags + [ + get_params_override(basic_params_override()), + '--mode=train_and_eval', + ] + + export_params = basic_params_override() + export_path = os.path.join(model_dir, 'export') + export_params['export'] = {} + export_params['export']['destination'] = export_path + export_flags = base_flags + [ + '--mode=export_only', + get_params_override(export_params) + ] + + run = functools.partial(classifier_trainer.run, + strategy_override=distribution) + run_end_to_end(main=run, + extra_flags=train_and_eval_flags, + model_dir=model_dir) + run_end_to_end(main=run, + extra_flags=export_flags, + model_dir=model_dir) + self.assertTrue(os.path.exists(export_path)) + + @combinations.generate(distribution_strategy_combinations()) + def test_end_to_end_invalid_mode(self, distribution, model, dataset): + """Test the Keras EfficientNet model with `strategy`.""" + model_dir = self.get_temp_dir() + extra_flags = [ + '--data_dir=not_used', + '--mode=invalid_mode', + '--model_type=' + model, + '--dataset=' + dataset, + get_params_override(basic_params_override()), + ] + + run = functools.partial(classifier_trainer.run, + strategy_override=distribution) + with self.assertRaises(ValueError): + run_end_to_end(main=run, extra_flags=extra_flags, model_dir=model_dir) + + +class UtilTests(parameterized.TestCase, tf.test.TestCase): + """Tests for individual utility functions within classifier_trainer.py.""" + + @parameterized.named_parameters( + ('efficientnet-b0', 'efficientnet', 'efficientnet-b0', 224), + ('efficientnet-b1', 'efficientnet', 'efficientnet-b1', 240), + ('efficientnet-b2', 'efficientnet', 'efficientnet-b2', 260), + ('efficientnet-b3', 'efficientnet', 'efficientnet-b3', 300), + ('efficientnet-b4', 'efficientnet', 'efficientnet-b4', 380), + ('efficientnet-b5', 'efficientnet', 'efficientnet-b5', 456), + ('efficientnet-b6', 'efficientnet', 'efficientnet-b6', 528), + ('efficientnet-b7', 'efficientnet', 'efficientnet-b7', 600), + ('resnet', 'resnet', '', None), + ) + def test_get_model_size(self, model, model_name, expected): + config = base_configs.ExperimentConfig( + model_name=model, + model=base_configs.ModelConfig( + model_params={ + 'model_name': model_name, + }, + ) + ) + size = classifier_trainer.get_image_size_from_model(config) + self.assertEqual(size, expected) + + @parameterized.named_parameters( + ('dynamic', 'dynamic', None, 'dynamic'), + ('scalar', 128., None, 128.), + ('float32', None, 'float32', 1), + ('float16', None, 'float16', 128), + ) + def test_get_loss_scale(self, loss_scale, dtype, expected): + config = base_configs.ExperimentConfig( + model=base_configs.ModelConfig( + loss=base_configs.LossConfig(loss_scale=loss_scale)), + train_dataset=dataset_factory.DatasetConfig(dtype=dtype)) + ls = classifier_trainer.get_loss_scale(config, fp16_default=128) + self.assertEqual(ls, expected) + + @parameterized.named_parameters( + ('float16', 'float16'), + ('bfloat16', 'bfloat16') + ) + def test_initialize(self, dtype): + config = base_configs.ExperimentConfig( + runtime=base_configs.RuntimeConfig( + enable_eager=False, + enable_xla=False, + gpu_threads_enabled=True, + per_gpu_thread_count=1, + gpu_thread_mode='gpu_private', + num_gpus=1, + dataset_num_private_threads=1, + ), + train_dataset=dataset_factory.DatasetConfig(dtype=dtype), + model=base_configs.ModelConfig( + loss=base_configs.LossConfig(loss_scale='dynamic')), + ) + classifier_trainer.initialize(config) + + def test_resume_from_checkpoint(self): + """Tests functionality for resuming from checkpoint.""" + # Set the keras policy + policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') + tf.keras.mixed_precision.experimental.set_policy(policy) + + # Get the model, datasets, and compile it. + model = get_trivial_model(10) + + # Create the checkpoint + model_dir = self.get_temp_dir() + train_epochs = 1 + train_steps = 10 + ds = get_trivial_data() + callbacks = [ + tf.keras.callbacks.ModelCheckpoint( + os.path.join(model_dir, 'model.ckpt-{epoch:04d}'), + save_weights_only=True) + ] + model.fit( + ds, + callbacks=callbacks, + epochs=train_epochs, + steps_per_epoch=train_steps) + + # Test load from checkpoint + clean_model = get_trivial_model(10) + weights_before_load = copy.deepcopy(clean_model.get_weights()) + initial_epoch = classifier_trainer.resume_from_checkpoint( + model=clean_model, + model_dir=model_dir, + train_steps=train_steps) + self.assertEqual(initial_epoch, 1) + self.assertNotAllClose(weights_before_load, clean_model.get_weights()) + + tf.io.gfile.rmtree(model_dir) + + def test_serialize_config(self): + """Tests functionality for serializing data.""" + config = base_configs.ExperimentConfig() + model_dir = self.get_temp_dir() + classifier_trainer.serialize_config(params=config, model_dir=model_dir) + saved_params_path = os.path.join(model_dir, 'params.yaml') + self.assertTrue(os.path.exists(saved_params_path)) + tf.io.gfile.rmtree(model_dir) + +if __name__ == '__main__': + assert tf.version.VERSION.startswith('2.') + tf.test.main() diff --git a/official/vision/image_classification/configs/__init__.py b/official/vision/image_classification/configs/__init__.py new file mode 100644 index 00000000000..931c2ef11db --- /dev/null +++ b/official/vision/image_classification/configs/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/official/vision/image_classification/configs/base_configs.py b/official/vision/image_classification/configs/base_configs.py new file mode 100644 index 00000000000..850eb7d8656 --- /dev/null +++ b/official/vision/image_classification/configs/base_configs.py @@ -0,0 +1,223 @@ +# Lint as: python3 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Definitions for high level configuration groups..""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from typing import Any, List, Mapping, Optional + +import dataclasses + +from official.modeling.hyperparams import base_config + + +CallbacksConfig = base_config.CallbacksConfig +TensorboardConfig = base_config.TensorboardConfig +RuntimeConfig = base_config.RuntimeConfig + + +@dataclasses.dataclass +class ExportConfig(base_config.Config): + """Configuration for exports. + + Attributes: + checkpoint: the path to the checkpoint to export. + destination: the path to where the checkpoint should be exported. + + """ + checkpoint: str = None + destination: str = None + + +@dataclasses.dataclass +class MetricsConfig(base_config.Config): + """Configuration for Metrics. + + Attributes: + accuracy: Whether or not to track accuracy as a Callback. Defaults to None. + top_5: Whether or not to track top_5_accuracy as a Callback. Defaults to + None. + + """ + accuracy: bool = None + top_5: bool = None + + +@dataclasses.dataclass +class TrainConfig(base_config.Config): + """Configuration for training. + + Attributes: + resume_checkpoint: Whether or not to enable load checkpoint loading. + Defaults to None. + epochs: The number of training epochs to run. Defaults to None. + steps: The number of steps to run per epoch. If None, then this will be + inferred based on the number of images and batch size. Defaults to None. + callbacks: An instance of CallbacksConfig. + metrics: An instance of MetricsConfig. + tensorboard: An instance of TensorboardConfig. + + """ + resume_checkpoint: bool = None + epochs: int = None + steps: int = None + callbacks: CallbacksConfig = CallbacksConfig() + metrics: List[str] = None + tensorboard: TensorboardConfig = TensorboardConfig() + + +@dataclasses.dataclass +class EvalConfig(base_config.Config): + """Configuration for evaluation. + + Attributes: + epochs_between_evals: The number of train epochs to run between evaluations. + Defaults to None. + steps: The number of eval steps to run during evaluation. If None, this will + be inferred based on the number of images and batch size. Defaults to + None. + + """ + epochs_between_evals: int = None + steps: int = None + + +@dataclasses.dataclass +class LossConfig(base_config.Config): + """Configuration for Loss. + + Attributes: + name: The name of the loss. Defaults to None. + loss_scale: The type of loss scale + label_smoothing: Whether or not to apply label smoothing to the loss. This + only applies to 'categorical_cross_entropy'. + + """ + name: str = None + loss_scale: str = None + label_smoothing: float = None + + +@dataclasses.dataclass +class OptimizerConfig(base_config.Config): + """Configuration for Optimizers. + + Attributes: + name: The name of the optimizer. Defaults to None. + decay: Decay or rho, discounting factor for gradient. Defaults to None. + epsilon: Small value used to avoid 0 denominator. Defaults to None. + momentum: Plain momentum constant. Defaults to None. + nesterov: Whether or not to apply Nesterov momentum. Defaults to None. + moving_average_decay: The amount of decay to apply. If 0 or None, then + exponential moving average is not used. Defaults to None. + lookahead: Whether or not to apply the lookahead optimizer. Defaults to + None. + beta_1: The exponential decay rate for the 1st moment estimates. Used in + the Adam optimizers. Defaults to None. + beta_2: The exponential decay rate for the 2nd moment estimates. Used in + the Adam optimizers. Defaults to None. + epsilon: Small value used to avoid 0 denominator. Defaults to 1e-7. + + """ + name: str = None + decay: float = None + epsilon: float = None + momentum: float = None + nesterov: bool = None + moving_average_decay: Optional[float] = None + lookahead: Optional[bool] = None + beta_1: float = None + beta_2: float = None + epsilon: float = None + + +@dataclasses.dataclass +class LearningRateConfig(base_config.Config): + """Configuration for learning rates. + + Attributes: + name: The name of the learning rate. Defaults to None. + initial_lr: The initial learning rate. Defaults to None. + decay_epochs: The number of decay epochs. Defaults to None. + decay_rate: The rate of decay. Defaults to None. + warmup_epochs: The number of warmup epochs. Defaults to None. + batch_lr_multiplier: The multiplier to apply to the base learning rate, + if necessary. Defaults to None. + examples_per_epoch: the number of examples in a single epoch. + Defaults to None. + boundaries: boundaries used in piecewise constant decay with warmup. + multipliers: multipliers used in piecewise constant decay with warmup. + scale_by_batch_size: Scale the learning rate by a fraction of the batch + size. Set to 0 for no scaling (default). + + """ + name: str = None + initial_lr: float = None + decay_epochs: float = None + decay_rate: float = None + warmup_epochs: int = None + examples_per_epoch: int = None + boundaries: List[int] = None + multipliers: List[float] = None + scale_by_batch_size: float = 0. + + +@dataclasses.dataclass +class ModelConfig(base_config.Config): + """Configuration for Models. + + Attributes: + name: The name of the model. Defaults to None. + model_params: The parameters used to create the model. Defaults to None. + num_classes: The number of classes in the model. Defaults to None. + loss: A `LossConfig` instance. Defaults to None. + optimizer: An `OptimizerConfig` instance. Defaults to None. + + """ + name: str = None + model_params: Mapping[str, Any] = None + num_classes: int = None + loss: LossConfig = None + optimizer: OptimizerConfig = None + + +@dataclasses.dataclass +class ExperimentConfig(base_config.Config): + """Base configuration for an image classification experiment. + + Attributes: + model_dir: The directory to use when running an experiment. + mode: e.g. 'train_and_eval', 'export' + runtime: A `RuntimeConfig` instance. + train: A `TrainConfig` instance. + evaluation: An `EvalConfig` instance. + model: A `ModelConfig` instance. + export: An `ExportConfig` instance. + + """ + model_dir: str = None + model_name: str = None + mode: str = None + runtime: RuntimeConfig = None + train_dataset: Any = None + validation_dataset: Any = None + test_dataset: Any = None + train: TrainConfig = None + evaluation: EvalConfig = None + model: ModelConfig = None + export: ExportConfig = None diff --git a/official/vision/image_classification/configs/configs.py b/official/vision/image_classification/configs/configs.py new file mode 100644 index 00000000000..f4ae1bdb0a3 --- /dev/null +++ b/official/vision/image_classification/configs/configs.py @@ -0,0 +1,121 @@ +# Lint as: python3 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Configuration utils for image classification experiments.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import dataclasses + +from official.vision.image_classification import dataset_factory +from official.vision.image_classification.configs import base_configs +from official.vision.image_classification.efficientnet import efficientnet_config +from official.vision.image_classification.resnet import resnet_config + + +@dataclasses.dataclass +class EfficientNetImageNetConfig(base_configs.ExperimentConfig): + """Base configuration to train efficientnet-b0 on ImageNet. + + Attributes: + export: An `ExportConfig` instance + runtime: A `RuntimeConfig` instance. + dataset: A `DatasetConfig` instance. + train: A `TrainConfig` instance. + evaluation: An `EvalConfig` instance. + model: A `ModelConfig` instance. + + """ + export: base_configs.ExportConfig = base_configs.ExportConfig() + runtime: base_configs.RuntimeConfig = base_configs.RuntimeConfig() + train_dataset: dataset_factory.DatasetConfig = \ + dataset_factory.ImageNetConfig(split='train') + validation_dataset: dataset_factory.DatasetConfig = \ + dataset_factory.ImageNetConfig(split='validation') + test_dataset: dataset_factory.DatasetConfig = \ + dataset_factory.ImageNetConfig(split='validation') + train: base_configs.TrainConfig = base_configs.TrainConfig( + resume_checkpoint=True, + epochs=500, + steps=None, + callbacks=base_configs.CallbacksConfig(enable_checkpoint_and_export=True, + enable_tensorboard=True), + metrics=['accuracy', 'top_5'], + tensorboard=base_configs.TensorboardConfig(track_lr=True, + write_model_weights=False)) + evaluation: base_configs.EvalConfig = base_configs.EvalConfig( + epochs_between_evals=1, + steps=None) + model: base_configs.ModelConfig = \ + efficientnet_config.EfficientNetModelConfig() + + +@dataclasses.dataclass +class ResNetImagenetConfig(base_configs.ExperimentConfig): + """Base configuration to train resnet-50 on ImageNet.""" + export: base_configs.ExportConfig = base_configs.ExportConfig() + runtime: base_configs.RuntimeConfig = base_configs.RuntimeConfig() + train_dataset: dataset_factory.DatasetConfig = \ + dataset_factory.ImageNetConfig(split='train', + one_hot=False, + mean_subtract=True, + standardize=True) + validation_dataset: dataset_factory.DatasetConfig = \ + dataset_factory.ImageNetConfig(split='validation', + one_hot=False, + mean_subtract=True, + standardize=True) + test_dataset: dataset_factory.DatasetConfig = \ + dataset_factory.ImageNetConfig(split='validation', + one_hot=False, + mean_subtract=True, + standardize=True) + train: base_configs.TrainConfig = base_configs.TrainConfig( + resume_checkpoint=True, + epochs=90, + steps=None, + callbacks=base_configs.CallbacksConfig(enable_checkpoint_and_export=True, + enable_tensorboard=True), + metrics=['accuracy', 'top_5'], + tensorboard=base_configs.TensorboardConfig(track_lr=True, + write_model_weights=False)) + evaluation: base_configs.EvalConfig = base_configs.EvalConfig( + epochs_between_evals=1, + steps=None) + model: base_configs.ModelConfig = resnet_config.ResNetModelConfig() + + +def get_config(model: str, dataset: str) -> base_configs.ExperimentConfig: + """Given model and dataset names, return the ExperimentConfig.""" + dataset_model_config_map = { + 'imagenet': { + 'efficientnet': EfficientNetImageNetConfig(), + 'resnet': ResNetImagenetConfig(), + } + } + try: + return dataset_model_config_map[dataset][model] + except KeyError: + if dataset not in dataset_model_config_map: + raise KeyError('Invalid dataset received. Received: {}. Supported ' + 'datasets include: {}'.format( + dataset, + ', '.join(dataset_model_config_map.keys()))) + raise KeyError('Invalid model received. Received: {}. Supported models for' + '{} include: {}'.format( + model, + dataset, + ', '.join(dataset_model_config_map[dataset].keys()))) diff --git a/official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-gpu.yaml b/official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-gpu.yaml new file mode 100644 index 00000000000..47b3e8f5ba4 --- /dev/null +++ b/official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-gpu.yaml @@ -0,0 +1,51 @@ +# Training configuration for EfficientNet-b0 trained on ImageNet on GPUs. +# Takes ~32 minutes per epoch for 8 V100s. +# Reaches ~76.1% within 350 epochs. +# Note: This configuration uses a scaled per-replica batch size based on the number of devices. +runtime: + model_dir: null + mode: 'train_and_eval' + distribution_strategy: 'mirrored' + num_gpus: 1 +train_dataset: + name: 'imagenet2012' + data_dir: null + builder: 'records' + split: 'train' + num_classes: 1000 + num_examples: 1281167 + batch_size: 32 + use_per_replica_batch_size: True + dtype: 'float32' + augmenter: + name: 'autoaugment' +validation_dataset: + name: 'imagenet2012' + data_dir: null + builder: 'records' + split: 'validation' + num_classes: 1000 + num_examples: 50000 + batch_size: 32 + use_per_replica_batch_size: True + dtype: 'float32' +model: + model_params: + model_name: 'efficientnet-b0' + overrides: + num_classes: 1000 + batch_norm: 'default' + dtype: 'float32' + optimizer: + name: 'rmsprop' + momentum: 0.9 + decay: 0.9 + learning_rate: + name: 'exponential' + loss: + label_smoothing: 0.1 +train: + resume_checkpoint: True + epochs: 500 +evaluation: + epochs_between_evals: 1 diff --git a/official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-tpu.yaml b/official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-tpu.yaml new file mode 100644 index 00000000000..d6f86728798 --- /dev/null +++ b/official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-tpu.yaml @@ -0,0 +1,52 @@ +# Training configuration for EfficientNet-b0 trained on ImageNet on TPUs. +# Takes ~2 minutes, 50 seconds per epoch for v3-32. +# Reaches ~76.1% within 350 epochs. +# Note: This configuration uses a scaled per-replica batch size based on the number of devices. +runtime: + model_dir: null + mode: 'train_and_eval' + distribution_strategy: 'tpu' +train_dataset: + name: 'imagenet2012' + data_dir: null + builder: 'records' + split: 'train' + num_classes: 1000 + num_examples: 1281167 + batch_size: 128 + use_per_replica_batch_size: True + dtype: 'bfloat16' + augmenter: + name: 'autoaugment' +validation_dataset: + name: 'imagenet2012' + data_dir: null + builder: 'records' + split: 'validation' + num_classes: 1000 + num_examples: 50000 + batch_size: 128 + use_per_replica_batch_size: True + dtype: 'bfloat16' +model: + model_params: + model_name: 'efficientnet-b0' + overrides: + num_classes: 1000 + batch_norm: 'tpu' + dtype: 'bfloat16' + optimizer: + name: 'rmsprop' + momentum: 0.9 + decay: 0.9 + moving_average_decay: 0. + lookahead: false + learning_rate: + name: 'exponential' + loss: + label_smoothing: 0.1 +train: + resume_checkpoint: True + epochs: 500 +evaluation: + epochs_between_evals: 1 diff --git a/official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-gpu.yaml b/official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-gpu.yaml new file mode 100644 index 00000000000..0807672e5ce --- /dev/null +++ b/official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-gpu.yaml @@ -0,0 +1,44 @@ +# Note: This configuration uses a scaled per-replica batch size based on the number of devices. +runtime: + model_dir: null + mode: 'train_and_eval' + distribution_strategy: 'mirrored' + num_gpus: 1 +train_dataset: + name: 'imagenet2012' + data_dir: null + builder: 'records' + split: 'train' + num_classes: 1000 + num_examples: 1281167 + batch_size: 32 + dtype: 'float32' +validation_dataset: + name: 'imagenet2012' + data_dir: null + builder: 'records' + split: 'validation' + num_classes: 1000 + num_examples: 50000 + batch_size: 32 + dtype: 'float32' +model: + model_params: + model_name: 'efficientnet-b1' + overrides: + num_classes: 1000 + batch_norm: 'default' + dtype: 'float32' + optimizer: + name: 'rmsprop' + momentum: 0.9 + decay: 0.9 + learning_rate: + name: 'exponential' + loss: + label_smoothing: 0.1 +train: + resume_checkpoint: True + epochs: 500 +evaluation: + epochs_between_evals: 1 diff --git a/official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-tpu.yaml b/official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-tpu.yaml new file mode 100644 index 00000000000..382482b9791 --- /dev/null +++ b/official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-tpu.yaml @@ -0,0 +1,49 @@ +# Training configuration for EfficientNet-b1 trained on ImageNet on TPUs. +# Takes ~3 minutes, 15 seconds per epoch for v3-32. +# Note: This configuration uses a scaled per-replica batch size based on the number of devices. +runtime: + model_dir: null + mode: 'train_and_eval' + distribution_strategy: 'tpu' +train_dataset: + name: 'imagenet2012' + data_dir: null + builder: 'records' + split: 'train' + num_classes: 1000 + num_examples: 1281167 + batch_size: 128 + use_per_replica_batch_size: True + dtype: 'bfloat16' + augmenter: + name: 'autoaugment' +validation_dataset: + name: 'imagenet2012' + data_dir: null + builder: 'records' + split: 'validation' + num_classes: 1000 + num_examples: 50000 + batch_size: 128 + use_per_replica_batch_size: True + dtype: 'bfloat16' +model: + model_params: + model_name: 'efficientnet-b1' + overrides: + num_classes: 1000 + batch_norm: 'tpu' + dtype: 'bfloat16' + optimizer: + name: 'rmsprop' + momentum: 0.9 + decay: 0.9 + learning_rate: + name: 'exponential' + loss: + label_smoothing: 0.1 +train: + resume_checkpoint: True + epochs: 500 +evaluation: + epochs_between_evals: 1 diff --git a/official/vision/image_classification/configs/examples/resnet/imagenet/gpu.yaml b/official/vision/image_classification/configs/examples/resnet/imagenet/gpu.yaml new file mode 100644 index 00000000000..9bb684ffc44 --- /dev/null +++ b/official/vision/image_classification/configs/examples/resnet/imagenet/gpu.yaml @@ -0,0 +1,53 @@ +# Training configuration for ResNet trained on ImageNet on GPUs. +# Takes ~3 minutes, 15 seconds per epoch for 8 V100s. +# Reaches ~76.1% within 90 epochs. +# Note: This configuration uses a scaled per-replica batch size based on the number of devices. +runtime: + model_dir: null + mode: 'train_and_eval' + distribution_strategy: 'mirrored' + num_gpus: 1 +train_dataset: + name: 'imagenet2012' + data_dir: null + builder: 'records' + split: 'train' + image_size: 224 + num_classes: 1000 + num_examples: 1281167 + batch_size: 128 + use_per_replica_batch_size: True + dtype: 'float32' + mean_subtract: True + standardize: True +validation_dataset: + name: 'imagenet2012' + data_dir: null + builder: 'records' + split: 'validation' + image_size: 224 + num_classes: 1000 + num_examples: 50000 + batch_size: 128 + use_per_replica_batch_size: True + dtype: 'float32' + mean_subtract: True + standardize: True +model: + model_name: 'resnet' + model_params: + rescale_inputs: False + optimizer: + name: 'momentum' + momentum: 0.9 + decay: 0.9 + epsilon: 0.001 + learning_rate: + name: 'piecewise_constant_with_warmup' + loss: + label_smoothing: 0.1 +train: + resume_checkpoint: True + epochs: 90 +evaluation: + epochs_between_evals: 1 diff --git a/official/vision/image_classification/configs/examples/resnet/imagenet/tpu.yaml b/official/vision/image_classification/configs/examples/resnet/imagenet/tpu.yaml new file mode 100644 index 00000000000..3e7055dff33 --- /dev/null +++ b/official/vision/image_classification/configs/examples/resnet/imagenet/tpu.yaml @@ -0,0 +1,58 @@ +# Training configuration for ResNet trained on ImageNet on TPUs. +# Takes ~2 minutes, 43 seconds per epoch for a v3-32. +# Reaches ~76.1% within 90 epochs. +# Note: This configuration uses a scaled per-replica batch size based on the number of devices. +runtime: + model_dir: null + mode: 'train_and_eval' + distribution_strategy: 'tpu' +train_dataset: + name: 'imagenet2012' + data_dir: null + builder: 'records' + split: 'train' + one_hot: False + image_size: 224 + num_classes: 1000 + num_examples: 1281167 + batch_size: 128 + use_per_replica_batch_size: True + mean_subtract: True + standardize: True + dtype: 'bfloat16' +validation_dataset: + name: 'imagenet2012' + data_dir: null + builder: 'records' + split: 'validation' + one_hot: False + image_size: 224 + num_classes: 1000 + num_examples: 50000 + batch_size: 128 + use_per_replica_batch_size: True + mean_subtract: True + standardize: True + dtype: 'bfloat16' +model: + model_name: 'resnet' + model_params: + rescale_inputs: False + optimizer: + name: 'momentum' + momentum: 0.9 + decay: 0.9 + epsilon: 0.001 + moving_average_decay: 0. + lookahead: false + learning_rate: + name: 'piecewise_constant_with_warmup' + loss: + label_smoothing: 0.1 +train: + callbacks: + enable_checkpoint_and_export: True + resume_checkpoint: True + epochs: 90 +evaluation: + epochs_between_evals: 1 diff --git a/official/vision/image_classification/dataset_factory.py b/official/vision/image_classification/dataset_factory.py new file mode 100644 index 00000000000..27e770b40b4 --- /dev/null +++ b/official/vision/image_classification/dataset_factory.py @@ -0,0 +1,476 @@ +# Lint as: python3 +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Dataset utilities for vision tasks using TFDS and tf.data.Dataset.""" +from __future__ import absolute_import +from __future__ import division +# from __future__ import google_type_annotations +from __future__ import print_function + +import os +from typing import Any, List, Optional, Tuple, Mapping, Union +from absl import logging +from dataclasses import dataclass +import tensorflow.compat.v2 as tf +import tensorflow_datasets as tfds + +from official.modeling.hyperparams import base_config +from official.vision.image_classification import augment +from official.vision.image_classification import preprocessing + + +AUGMENTERS = { + 'autoaugment': augment.AutoAugment, + 'randaugment': augment.RandAugment, +} + + +@dataclass +class AugmentConfig(base_config.Config): + """Configuration for image augmenters. + + Attributes: + name: The name of the image augmentation to use. Possible options are + None (default), 'autoaugment', or 'randaugment'. + params: Any paramaters used to initialize the augmenter. + """ + name: Optional[str] = None + params: Optional[Mapping[str, Any]] = None + + def build(self) -> augment.ImageAugment: + """Build the augmenter using this config.""" + params = self.params or {} + augmenter = AUGMENTERS.get(self.name, None) + return augmenter(**params) if augmenter is not None else None + + +@dataclass +class DatasetConfig(base_config.Config): + """The base configuration for building datasets. + + Attributes: + name: The name of the Dataset. Usually should correspond to a TFDS dataset. + data_dir: The path where the dataset files are stored, if available. + filenames: Optional list of strings representing the TFRecord names. + builder: The builder type used to load the dataset. Value should be one of + 'tfds' (load using TFDS), 'records' (load from TFRecords), or 'synthetic' + (generate dummy synthetic data without reading from files). + split: The split of the dataset. Usually 'train', 'validation', or 'test'. + image_size: The size of the image in the dataset. This assumes that + `width` == `height`. Set to 'infer' to infer the image size from TFDS + info. This requires `name` to be a registered dataset in TFDS. + num_classes: The number of classes given by the dataset. Set to 'infer' + to infer the image size from TFDS info. This requires `name` to be a + registered dataset in TFDS. + num_channels: The number of channels given by the dataset. Set to 'infer' + to infer the image size from TFDS info. This requires `name` to be a + registered dataset in TFDS. + num_examples: The number of examples given by the dataset. Set to 'infer' + to infer the image size from TFDS info. This requires `name` to be a + registered dataset in TFDS. + batch_size: The base batch size for the dataset. + use_per_replica_batch_size: Whether to scale the batch size based on + available resources. If set to `True`, the dataset builder will return + batch_size multiplied by `num_devices`, the number of device replicas + (e.g., the number of GPUs or TPU cores). + num_devices: The number of replica devices to use. This should be set by + `strategy.num_replicas_in_sync` when using a distribution strategy. + data_format: The data format of the images. Should be 'channels_last' or + 'channels_first'. + dtype: The desired dtype of the dataset. This will be set during + preprocessing. + one_hot: Whether to apply one hot encoding. Set to `True` to be able to use + label smoothing. + augmenter: The augmenter config to use. No augmentation is used by default. + download: Whether to download data using TFDS. + shuffle_buffer_size: The buffer size used for shuffling training data. + file_shuffle_buffer_size: The buffer size used for shuffling raw training + files. + skip_decoding: Whether to skip image decoding when loading from TFDS. + deterministic_train: Whether the examples in the training set should output + in a deterministic order. + use_slack: whether to introduce slack in the last prefetch. This may reduce + CPU contention at the start of a training step. + cache: whether to cache to dataset examples. Can be used to avoid re-reading + from disk on the second epoch. Requires significant memory overhead. + mean_subtract: whether or not to apply mean subtraction to the dataset. + standardize: whether or not to apply standardization to the dataset. + """ + name: Optional[str] = None + data_dir: Optional[str] = None + filenames: Optional[List[str]] = None + builder: str = 'tfds' + split: str = 'train' + image_size: Union[int, str] = 'infer' + num_classes: Union[int, str] = 'infer' + num_channels: Union[int, str] = 'infer' + num_examples: Union[int, str] = 'infer' + batch_size: int = 128 + use_per_replica_batch_size: bool = False + num_devices: int = 1 + data_format: str = 'channels_last' + dtype: str = 'float32' + one_hot: bool = True + augmenter: AugmentConfig = AugmentConfig() + download: bool = False + shuffle_buffer_size: int = 10000 + file_shuffle_buffer_size: int = 1024 + skip_decoding: bool = True + deterministic_train: bool = False + use_slack: bool = True + cache: bool = False + mean_subtract: bool = False + standardize: bool = False + + @property + def has_data(self): + """Whether this dataset is has any data associated with it.""" + return self.name or self.data_dir or self.filenames + + +@dataclass +class ImageNetConfig(DatasetConfig): + """The base ImageNet dataset config.""" + name: str = 'imagenet2012' + # Note: for large datasets like ImageNet, using records is faster than tfds + builder: str = 'records' + image_size: int = 224 + batch_size: int = 128 + + +@dataclass +class Cifar10Config(DatasetConfig): + """The base CIFAR-10 dataset config.""" + name: str = 'cifar10' + image_size: int = 224 + batch_size: int = 128 + download: bool = True + cache: bool = True + + +class DatasetBuilder: + """An object for building datasets. + + Allows building various pipelines fetching examples, preprocessing, etc. + Maintains additional state information calculated from the dataset, i.e., + training set split, batch size, and number of steps (batches). + """ + + def __init__(self, config: DatasetConfig, **overrides: Any): + """Initialize the builder from the config.""" + self.config = config.replace(**overrides) + self.builder_info = None + + if self.config.augmenter is not None: + logging.info('Using augmentation: %s', self.config.augmenter.name) + self.augmenter = self.config.augmenter.build() + else: + self.augmenter = None + + @property + def is_training(self) -> bool: + """Whether this is the training set.""" + return self.config.split == 'train' + + @property + def batch_size(self) -> int: + """The batch size, multiplied by the number of replicas (if configured).""" + if self.config.use_per_replica_batch_size: + return self.global_batch_size + else: + return self.config.batch_size + + @property + def global_batch_size(self): + """The global batch size across all replicas.""" + return self.config.batch_size * self.config.num_devices + + @property + def num_steps(self) -> int: + """The number of steps (batches) to exhaust this dataset.""" + # Always divide by the global batch size to get the correct # of steps + return self.num_examples // self.global_batch_size + + @property + def image_size(self) -> int: + """The size of each image (can be inferred from the dataset).""" + + if self.config.image_size == 'infer': + return self.info.features['image'].shape[0] + else: + return int(self.config.image_size) + + @property + def num_channels(self) -> int: + """The number of image channels (can be inferred from the dataset).""" + if self.config.num_channels == 'infer': + return self.info.features['image'].shape[-1] + else: + return int(self.config.num_channels) + + @property + def num_examples(self) -> int: + """The number of examples (can be inferred from the dataset).""" + if self.config.num_examples == 'infer': + return self.info.splits[self.config.split].num_examples + else: + return int(self.config.num_examples) + + @property + def num_classes(self) -> int: + """The number of classes (can be inferred from the dataset).""" + if self.config.num_classes == 'infer': + return self.info.features['label'].num_classes + else: + return int(self.config.num_classes) + + @property + def info(self) -> tfds.core.DatasetInfo: + """The TFDS dataset info, if available.""" + if self.builder_info is None: + self.builder_info = tfds.builder(self.config.name).info + return self.builder_info + + def build(self, input_context: tf.distribute.InputContext = None + ) -> tf.data.Dataset: + """Construct a dataset end-to-end and return it. + + Args: + input_context: An optional context provided by `tf.distribute` for + cross-replica training. This isn't necessary if using Keras + compile/fit. + + Returns: + A TensorFlow dataset outputting batched images and labels. + """ + + builders = { + 'tfds': self.load_tfds, + 'records': self.load_records, + 'synthetic': self.load_synthetic, + } + + builder = builders.get(self.config.builder, None) + + if builder is None: + raise ValueError('Unknown builder type {}'.format(self.config.builder)) + + dataset = builder() + dataset = self.pipeline(dataset, input_context) + + return dataset + + def load_tfds(self) -> tf.data.Dataset: + """Return a dataset loading files from TFDS.""" + + logging.info('Using TFDS to load data.') + + builder = tfds.builder(self.config.name, + data_dir=self.config.data_dir) + + if self.config.download: + builder.download_and_prepare() + + decoders = {} + + if self.config.skip_decoding: + decoders['image'] = tfds.decode.SkipDecoding() + + read_config = tfds.ReadConfig( + interleave_parallel_reads=64, + interleave_block_length=1) + + dataset = builder.as_dataset( + split=self.config.split, + as_supervised=True, + shuffle_files=True, + decoders=decoders, + read_config=read_config) + + return dataset + + def load_records(self) -> tf.data.Dataset: + """Return a dataset loading files with TFRecords.""" + logging.info('Using TFRecords to load data.') + + if self.config.filenames is None: + if self.config.data_dir is None: + raise ValueError('Dataset must specify a path for the data files.') + + file_pattern = os.path.join(self.config.data_dir, + '{}*'.format(self.config.split)) + dataset = tf.data.Dataset.list_files(file_pattern, shuffle=True) + else: + dataset = tf.data.Dataset.from_tensor_slices(self.config.filenames) + if self.is_training: + # Shuffle the input files. + dataset.shuffle(buffer_size=self.config.file_shuffle_buffer_size) + + return dataset + + def load_synthetic(self) -> tf.data.Dataset: + """Return a dataset generating dummy synthetic data.""" + logging.info('Generating a synthetic dataset.') + + def generate_data(_): + image = tf.zeros([self.image_size, self.image_size, self.num_channels], + dtype=self.config.dtype) + label = tf.zeros([1], dtype=tf.int32) + return image, label + + dataset = tf.data.Dataset.range(1) + dataset = dataset.repeat() + dataset = dataset.map(generate_data, + num_parallel_calls=tf.data.experimental.AUTOTUNE) + return dataset + + def pipeline(self, + dataset: tf.data.Dataset, + input_context: tf.distribute.InputContext = None + ) -> tf.data.Dataset: + """Build a pipeline fetching, shuffling, and preprocessing the dataset. + + Args: + dataset: A `tf.data.Dataset` that loads raw files. + input_context: An optional context provided by `tf.distribute` for + cross-replica training. This isn't necessary if using Keras + compile/fit. + + Returns: + A TensorFlow dataset outputting batched images and labels. + """ + if input_context and input_context.num_input_pipelines > 1: + dataset = dataset.shard(input_context.num_input_pipelines, + input_context.input_pipeline_id) + + if self.is_training and not self.config.cache: + dataset = dataset.repeat() + + if self.config.builder == 'records': + # Read the data from disk in parallel + buffer_size = 8 * 1024 * 1024 # Use 8 MiB per file + dataset = dataset.interleave( + lambda name: tf.data.TFRecordDataset(name, buffer_size=buffer_size), + cycle_length=16, + num_parallel_calls=tf.data.experimental.AUTOTUNE) + + dataset = dataset.prefetch(self.global_batch_size) + + if self.config.cache: + dataset = dataset.cache() + + if self.is_training: + dataset = dataset.shuffle(self.config.shuffle_buffer_size) + dataset = dataset.repeat() + + # Parse, pre-process, and batch the data in parallel + if self.config.builder == 'records': + preprocess = self.parse_record + else: + preprocess = self.preprocess + dataset = dataset.map(preprocess, + num_parallel_calls=tf.data.experimental.AUTOTUNE) + + dataset = dataset.batch(self.batch_size, drop_remainder=self.is_training) + + # Note: we could do image normalization here, but we defer it to the model + # which can perform it much faster on a GPU/TPU + # TODO(dankondratyuk): if we fix prefetching, we can do it here + + if self.is_training and self.config.deterministic_train is not None: + options = tf.data.Options() + options.experimental_deterministic = self.config.deterministic_train + options.experimental_slack = self.config.use_slack + options.experimental_optimization.parallel_batch = True + options.experimental_optimization.map_fusion = True + options.experimental_optimization.map_vectorization.enabled = True + options.experimental_optimization.map_parallelization = True + dataset = dataset.with_options(options) + + # Prefetch overlaps in-feed with training + # Note: autotune here is not recommended, as this can lead to memory leaks. + # Instead, use a constant prefetch size like the the number of devices. + dataset = dataset.prefetch(self.config.num_devices) + + return dataset + + def parse_record(self, record: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: + """Parse an ImageNet record from a serialized string Tensor.""" + keys_to_features = { + 'image/encoded': + tf.io.FixedLenFeature((), tf.string, ''), + 'image/format': + tf.io.FixedLenFeature((), tf.string, 'jpeg'), + 'image/class/label': + tf.io.FixedLenFeature([], tf.int64, -1), + 'image/class/text': + tf.io.FixedLenFeature([], tf.string, ''), + 'image/object/bbox/xmin': + tf.io.VarLenFeature(dtype=tf.float32), + 'image/object/bbox/ymin': + tf.io.VarLenFeature(dtype=tf.float32), + 'image/object/bbox/xmax': + tf.io.VarLenFeature(dtype=tf.float32), + 'image/object/bbox/ymax': + tf.io.VarLenFeature(dtype=tf.float32), + 'image/object/class/label': + tf.io.VarLenFeature(dtype=tf.int64), + } + + parsed = tf.io.parse_single_example(record, keys_to_features) + + label = tf.reshape(parsed['image/class/label'], shape=[1]) + label = tf.cast(label, dtype=tf.int32) + + # Subtract one so that labels are in [0, 1000) + label -= 1 + + image_bytes = tf.reshape(parsed['image/encoded'], shape=[]) + image, label = self.preprocess(image_bytes, label) + + return image, label + + def preprocess(self, image: tf.Tensor, label: tf.Tensor + ) -> Tuple[tf.Tensor, tf.Tensor]: + """Apply image preprocessing and augmentation to the image and label.""" + if self.is_training: + image = preprocessing.preprocess_for_train( + image, + image_size=self.image_size, + mean_subtract=self.config.mean_subtract, + standardize=self.config.standardize, + dtype=self.config.dtype, + augmenter=self.augmenter) + else: + image = preprocessing.preprocess_for_eval( + image, + image_size=self.image_size, + num_channels=self.num_channels, + mean_subtract=self.config.mean_subtract, + standardize=self.config.standardize, + dtype=self.config.dtype) + + label = tf.cast(label, tf.int32) + if self.config.one_hot: + label = tf.one_hot(label, self.num_classes) + label = tf.reshape(label, [self.num_classes]) + + return image, label + + @classmethod + def from_params(cls, *args, **kwargs): + """Construct a dataset builder from a default config and any overrides.""" + config = DatasetConfig.from_args(*args, **kwargs) + return cls(config) diff --git a/official/vision/image_classification/efficientnet/__init__.py b/official/vision/image_classification/efficientnet/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/official/vision/image_classification/efficientnet/common_modules.py b/official/vision/image_classification/efficientnet/common_modules.py new file mode 100644 index 00000000000..b25dcb1c0e8 --- /dev/null +++ b/official/vision/image_classification/efficientnet/common_modules.py @@ -0,0 +1,100 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Common modeling utilities.""" +from __future__ import absolute_import +from __future__ import division +# from __future__ import google_type_annotations +from __future__ import print_function + +import numpy as np + +import tensorflow.compat.v1 as tf1 +import tensorflow.compat.v2 as tf +from typing import Text, Optional + +from tensorflow.python.tpu import tpu_function + + +@tf.keras.utils.register_keras_serializable(package='Text') +class TpuBatchNormalization(tf.keras.layers.BatchNormalization): + """Cross replica batch normalization.""" + + def __init__(self, fused: Optional[bool] = False, **kwargs): + if fused in (True, None): + raise ValueError('TpuBatchNormalization does not support fused=True.') + super(TpuBatchNormalization, self).__init__(fused=fused, **kwargs) + + def _cross_replica_average(self, t: tf.Tensor, num_shards_per_group: int): + """Calculates the average value of input tensor across TPU replicas.""" + num_shards = tpu_function.get_tpu_context().number_of_shards + group_assignment = None + if num_shards_per_group > 1: + if num_shards % num_shards_per_group != 0: + raise ValueError( + 'num_shards: %d mod shards_per_group: %d, should be 0' % + (num_shards, num_shards_per_group)) + num_groups = num_shards // num_shards_per_group + group_assignment = [[ + x for x in range(num_shards) if x // num_shards_per_group == y + ] for y in range(num_groups)] + return tf1.tpu.cross_replica_sum(t, group_assignment) / tf.cast( + num_shards_per_group, t.dtype) + + def _moments(self, inputs: tf.Tensor, reduction_axes: int, keep_dims: int): + """Compute the mean and variance: it overrides the original _moments.""" + shard_mean, shard_variance = super(TpuBatchNormalization, self)._moments( + inputs, reduction_axes, keep_dims=keep_dims) + + num_shards = tpu_function.get_tpu_context().number_of_shards or 1 + if num_shards <= 8: # Skip cross_replica for 2x2 or smaller slices. + num_shards_per_group = 1 + else: + num_shards_per_group = max(8, num_shards // 8) + if num_shards_per_group > 1: + # Compute variance using: Var[X]= E[X^2] - E[X]^2. + shard_square_of_mean = tf.math.square(shard_mean) + shard_mean_of_square = shard_variance + shard_square_of_mean + group_mean = self._cross_replica_average(shard_mean, num_shards_per_group) + group_mean_of_square = self._cross_replica_average( + shard_mean_of_square, num_shards_per_group) + group_variance = group_mean_of_square - tf.math.square(group_mean) + return (group_mean, group_variance) + else: + return (shard_mean, shard_variance) + + +def get_batch_norm(batch_norm_type: Text) -> tf.keras.layers.BatchNormalization: + """A helper to create a batch normalization getter. + + Args: + batch_norm_type: The type of batch normalization layer implementation. `tpu` + will use `TpuBatchNormalization`. + + Returns: + An instance of `tf.keras.layers.BatchNormalization`. + """ + if batch_norm_type == 'tpu': + return TpuBatchNormalization + + return tf.keras.layers.BatchNormalization + + +def count_params(model, trainable_only=True): + """Returns the count of all model parameters, or just trainable ones.""" + if not trainable_only: + return model.count_params() + else: + return int(np.sum([tf.keras.backend.count_params(p) + for p in model.trainable_weights])) diff --git a/official/vision/image_classification/efficientnet/efficientnet_config.py b/official/vision/image_classification/efficientnet/efficientnet_config.py new file mode 100644 index 00000000000..b939145fb8d --- /dev/null +++ b/official/vision/image_classification/efficientnet/efficientnet_config.py @@ -0,0 +1,75 @@ +# Lint as: python3 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Configuration definitions for EfficientNet losses, learning rates, and optimizers.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from typing import Any, Mapping + +import dataclasses + +from official.vision.image_classification.configs import base_configs + + +@dataclasses.dataclass +class EfficientNetModelConfig(base_configs.ModelConfig): + """Configuration for the EfficientNet model. + + This configuration will default to settings used for training efficientnet-b0 + on a v3-8 TPU on ImageNet. + + Attributes: + name: The name of the model. Defaults to 'EfficientNet'. + num_classes: The number of classes in the model. + model_params: A dictionary that represents the parameters of the + EfficientNet model. These will be passed in to the "from_name" function. + loss: The configuration for loss. Defaults to a categorical cross entropy + implementation. + optimizer: The configuration for optimizations. Defaults to an RMSProp + configuration. + learning_rate: The configuration for learning rate. Defaults to an + exponential configuration. + + """ + name: str = 'EfficientNet' + num_classes: int = 1000 + model_params: Mapping[str, Any] = dataclasses.field(default_factory=lambda: { + 'model_name': 'efficientnet-b0', + 'model_weights_path': '', + 'copy_to_local': False, + 'overrides': { + 'batch_norm': 'default', + 'rescale_input': True, + 'num_classes': 1000, + } + }) + loss: base_configs.LossConfig = base_configs.LossConfig( + name='categorical_crossentropy', + label_smoothing=0.1) + optimizer: base_configs.OptimizerConfig = base_configs.OptimizerConfig( + name='rmsprop', + decay=0.9, + epsilon=0.001, + momentum=0.9, + moving_average_decay=None) + learning_rate: base_configs.LearningRateConfig = base_configs.LearningRateConfig( # pylint: disable=line-too-long + name='exponential', + initial_lr=0.008, + decay_epochs=2.4, + decay_rate=0.97, + warmup_epochs=5, + scale_by_batch_size=1. / 128.) diff --git a/official/vision/image_classification/efficientnet/efficientnet_model.py b/official/vision/image_classification/efficientnet/efficientnet_model.py new file mode 100644 index 00000000000..657661b2c97 --- /dev/null +++ b/official/vision/image_classification/efficientnet/efficientnet_model.py @@ -0,0 +1,503 @@ +# Lint as: python3 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Contains definitions for EfficientNet model. + +[1] Mingxing Tan, Quoc V. Le + EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks. + ICML'19, https://arxiv.org/abs/1905.11946 +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import os +from typing import Any, Dict, Optional, Text, Tuple + +from absl import logging +from dataclasses import dataclass +import tensorflow.compat.v2 as tf + +from official.modeling import tf_utils +from official.modeling.hyperparams import base_config +from official.vision.image_classification import preprocessing +from official.vision.image_classification.efficientnet import common_modules + + +@dataclass +class BlockConfig(base_config.Config): + """Config for a single MB Conv Block.""" + input_filters: int = 0 + output_filters: int = 0 + kernel_size: int = 3 + num_repeat: int = 1 + expand_ratio: int = 1 + strides: Tuple[int, int] = (1, 1) + se_ratio: Optional[float] = None + id_skip: bool = True + fused_conv: bool = False + conv_type: str = 'depthwise' + + +@dataclass +class ModelConfig(base_config.Config): + """Default Config for Efficientnet-B0.""" + width_coefficient: float = 1.0 + depth_coefficient: float = 1.0 + resolution: int = 224 + dropout_rate: float = 0.2 + blocks: Tuple[BlockConfig, ...] = ( + # (input_filters, output_filters, kernel_size, num_repeat, + # expand_ratio, strides, se_ratio) + # pylint: disable=bad-whitespace + BlockConfig.from_args(32, 16, 3, 1, 1, (1, 1), 0.25), + BlockConfig.from_args(16, 24, 3, 2, 6, (2, 2), 0.25), + BlockConfig.from_args(24, 40, 5, 2, 6, (2, 2), 0.25), + BlockConfig.from_args(40, 80, 3, 3, 6, (2, 2), 0.25), + BlockConfig.from_args(80, 112, 5, 3, 6, (1, 1), 0.25), + BlockConfig.from_args(112, 192, 5, 4, 6, (2, 2), 0.25), + BlockConfig.from_args(192, 320, 3, 1, 6, (1, 1), 0.25), + # pylint: enable=bad-whitespace + ) + stem_base_filters: int = 32 + top_base_filters: int = 1280 + activation: str = 'simple_swish' + batch_norm: str = 'default' + bn_momentum: float = 0.99 + bn_epsilon: float = 1e-3 + # While the original implementation used a weight decay of 1e-5, + # tf.nn.l2_loss divides it by 2, so we halve this to compensate in Keras + weight_decay: float = 5e-6 + drop_connect_rate: float = 0.2 + depth_divisor: int = 8 + min_depth: Optional[int] = None + use_se: bool = True + input_channels: int = 3 + num_classes: int = 1000 + model_name: str = 'efficientnet' + rescale_input: bool = True + data_format: str = 'channels_last' + dtype: str = 'float32' + + +MODEL_CONFIGS = { + # (width, depth, resolution, dropout) + 'efficientnet-b0': ModelConfig.from_args(1.0, 1.0, 224, 0.2), + 'efficientnet-b1': ModelConfig.from_args(1.0, 1.1, 240, 0.2), + 'efficientnet-b2': ModelConfig.from_args(1.1, 1.2, 260, 0.3), + 'efficientnet-b3': ModelConfig.from_args(1.2, 1.4, 300, 0.3), + 'efficientnet-b4': ModelConfig.from_args(1.4, 1.8, 380, 0.4), + 'efficientnet-b5': ModelConfig.from_args(1.6, 2.2, 456, 0.4), + 'efficientnet-b6': ModelConfig.from_args(1.8, 2.6, 528, 0.5), + 'efficientnet-b7': ModelConfig.from_args(2.0, 3.1, 600, 0.5), +} + +CONV_KERNEL_INITIALIZER = { + 'class_name': 'VarianceScaling', + 'config': { + 'scale': 2.0, + 'mode': 'fan_out', + # Note: this is a truncated normal distribution + 'distribution': 'normal' + } +} + +DENSE_KERNEL_INITIALIZER = { + 'class_name': 'VarianceScaling', + 'config': { + 'scale': 1 / 3.0, + 'mode': 'fan_out', + 'distribution': 'uniform' + } +} + + +def round_filters(filters: int, + config: ModelConfig) -> int: + """Round number of filters based on width coefficient.""" + width_coefficient = config.width_coefficient + min_depth = config.min_depth + divisor = config.depth_divisor + orig_filters = filters + + if not width_coefficient: + return filters + + filters *= width_coefficient + min_depth = min_depth or divisor + new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_filters < 0.9 * filters: + new_filters += divisor + logging.info('round_filter input=%s output=%s', orig_filters, new_filters) + return int(new_filters) + + +def round_repeats(repeats: int, depth_coefficient: float) -> int: + """Round number of repeats based on depth coefficient.""" + return int(math.ceil(depth_coefficient * repeats)) + + +def conv2d_block(inputs: tf.Tensor, + conv_filters: Optional[int], + config: ModelConfig, + kernel_size: Any = (1, 1), + strides: Any = (1, 1), + use_batch_norm: bool = True, + use_bias: bool = False, + activation: Any = None, + depthwise: bool = False, + name: Text = None): + """A conv2d followed by batch norm and an activation.""" + batch_norm = common_modules.get_batch_norm(config.batch_norm) + bn_momentum = config.bn_momentum + bn_epsilon = config.bn_epsilon + data_format = config.data_format + weight_decay = config.weight_decay + + name = name or '' + + # Collect args based on what kind of conv2d block is desired + init_kwargs = { + 'kernel_size': kernel_size, + 'strides': strides, + 'use_bias': use_bias, + 'padding': 'same', + 'name': name + '_conv2d', + 'kernel_regularizer': tf.keras.regularizers.l2(weight_decay), + 'bias_regularizer': tf.keras.regularizers.l2(weight_decay), + } + + if depthwise: + conv2d = tf.keras.layers.DepthwiseConv2D + init_kwargs.update({'depthwise_initializer': CONV_KERNEL_INITIALIZER}) + else: + conv2d = tf.keras.layers.Conv2D + init_kwargs.update({'filters': conv_filters, + 'kernel_initializer': CONV_KERNEL_INITIALIZER}) + + x = conv2d(**init_kwargs)(inputs) + + if use_batch_norm: + bn_axis = 1 if data_format == 'channels_first' else -1 + x = batch_norm(axis=bn_axis, + momentum=bn_momentum, + epsilon=bn_epsilon, + name=name + '_bn')(x) + + if activation is not None: + x = tf.keras.layers.Activation(activation, + name=name + '_activation')(x) + return x + + +def mb_conv_block(inputs: tf.Tensor, + block: BlockConfig, + config: ModelConfig, + prefix: Text = None): + """Mobile Inverted Residual Bottleneck. + + Args: + inputs: the Keras input to the block + block: BlockConfig, arguments to create a Block + config: ModelConfig, a set of model parameters + prefix: prefix for naming all layers + + Returns: + the output of the block + """ + use_se = config.use_se + activation = tf_utils.get_activation(config.activation) + drop_connect_rate = config.drop_connect_rate + data_format = config.data_format + use_depthwise = block.conv_type != 'no_depthwise' + prefix = prefix or '' + + filters = block.input_filters * block.expand_ratio + + x = inputs + + if block.fused_conv: + # If we use fused mbconv, skip expansion and use regular conv. + x = conv2d_block(x, + filters, + config, + kernel_size=block.kernel_size, + strides=block.strides, + activation=activation, + name=prefix + 'fused') + else: + if block.expand_ratio != 1: + # Expansion phase + kernel_size = (1, 1) if use_depthwise else (3, 3) + x = conv2d_block(x, + filters, + config, + kernel_size=kernel_size, + activation=activation, + name=prefix + 'expand') + + # Depthwise Convolution + if use_depthwise: + x = conv2d_block(x, + conv_filters=None, + config=config, + kernel_size=block.kernel_size, + strides=block.strides, + activation=activation, + depthwise=True, + name=prefix + 'depthwise') + + # Squeeze and Excitation phase + if use_se: + assert block.se_ratio is not None + assert 0 < block.se_ratio <= 1 + num_reduced_filters = max(1, int( + block.input_filters * block.se_ratio + )) + + if data_format == 'channels_first': + se_shape = (filters, 1, 1) + else: + se_shape = (1, 1, filters) + + se = tf.keras.layers.GlobalAveragePooling2D(name=prefix + 'se_squeeze')(x) + se = tf.keras.layers.Reshape(se_shape, name=prefix + 'se_reshape')(se) + + se = conv2d_block(se, + num_reduced_filters, + config, + use_bias=True, + use_batch_norm=False, + activation=activation, + name=prefix + 'se_reduce') + se = conv2d_block(se, + filters, + config, + use_bias=True, + use_batch_norm=False, + activation='sigmoid', + name=prefix + 'se_expand') + x = tf.keras.layers.multiply([x, se], name=prefix + 'se_excite') + + # Output phase + x = conv2d_block(x, + block.output_filters, + config, + activation=None, + name=prefix + 'project') + + # Add identity so that quantization-aware training can insert quantization + # ops correctly. + x = tf.keras.layers.Activation(tf_utils.get_activation('identity'), + name=prefix + 'id')(x) + + if (block.id_skip + and all(s == 1 for s in block.strides) + and block.input_filters == block.output_filters): + if drop_connect_rate and drop_connect_rate > 0: + # Apply dropconnect + # The only difference between dropout and dropconnect in TF is scaling by + # drop_connect_rate during training. See: + # https://github.com/keras-team/keras/pull/9898#issuecomment-380577612 + x = tf.keras.layers.Dropout(drop_connect_rate, + noise_shape=(None, 1, 1, 1), + name=prefix + 'drop')(x) + + x = tf.keras.layers.add([x, inputs], name=prefix + 'add') + + return x + + +def efficientnet(image_input: tf.keras.layers.Input, + config: ModelConfig): + """Creates an EfficientNet graph given the model parameters. + + This function is wrapped by the `EfficientNet` class to make a tf.keras.Model. + + Args: + image_input: the input batch of images + config: the model config + + Returns: + the output of efficientnet + """ + depth_coefficient = config.depth_coefficient + blocks = config.blocks + stem_base_filters = config.stem_base_filters + top_base_filters = config.top_base_filters + activation = tf_utils.get_activation(config.activation) + dropout_rate = config.dropout_rate + drop_connect_rate = config.drop_connect_rate + num_classes = config.num_classes + input_channels = config.input_channels + rescale_input = config.rescale_input + data_format = config.data_format + dtype = config.dtype + weight_decay = config.weight_decay + + x = image_input + + if rescale_input: + x = preprocessing.normalize_images(x, + num_channels=input_channels, + dtype=dtype, + data_format=data_format) + + # Build stem + x = conv2d_block(x, + round_filters(stem_base_filters, config), + config, + kernel_size=[3, 3], + strides=[2, 2], + activation=activation, + name='stem') + + # Build blocks + num_blocks_total = sum(block.num_repeat for block in blocks) + block_num = 0 + + for stack_idx, block in enumerate(blocks): + assert block.num_repeat > 0 + # Update block input and output filters based on depth multiplier + block = block.replace( + input_filters=round_filters(block.input_filters, config), + output_filters=round_filters(block.output_filters, config), + num_repeat=round_repeats(block.num_repeat, depth_coefficient)) + + # The first block needs to take care of stride and filter size increase + drop_rate = drop_connect_rate * float(block_num) / num_blocks_total + config = config.replace(drop_connect_rate=drop_rate) + block_prefix = 'stack_{}/block_0/'.format(stack_idx) + x = mb_conv_block(x, block, config, block_prefix) + block_num += 1 + if block.num_repeat > 1: + block = block.replace( + input_filters=block.output_filters, + strides=[1, 1] + ) + + for block_idx in range(block.num_repeat - 1): + drop_rate = drop_connect_rate * float(block_num) / num_blocks_total + config = config.replace(drop_connect_rate=drop_rate) + block_prefix = 'stack_{}/block_{}/'.format(stack_idx, block_idx + 1) + x = mb_conv_block(x, block, config, prefix=block_prefix) + block_num += 1 + + # Build top + x = conv2d_block(x, + round_filters(top_base_filters, config), + config, + activation=activation, + name='top') + + # Build classifier + x = tf.keras.layers.GlobalAveragePooling2D(name='top_pool')(x) + if dropout_rate and dropout_rate > 0: + x = tf.keras.layers.Dropout(dropout_rate, name='top_dropout')(x) + x = tf.keras.layers.Dense( + num_classes, + kernel_initializer=DENSE_KERNEL_INITIALIZER, + kernel_regularizer=tf.keras.regularizers.l2(weight_decay), + bias_regularizer=tf.keras.regularizers.l2(weight_decay), + name='logits')(x) + x = tf.keras.layers.Activation('softmax', name='probs')(x) + + return x + + +@tf.keras.utils.register_keras_serializable(package='Vision') +class EfficientNet(tf.keras.Model): + """Wrapper class for an EfficientNet Keras model. + + Contains helper methods to build, manage, and save metadata about the model. + """ + + def __init__(self, + config: ModelConfig = None, + overrides: Dict[Text, Any] = None): + """Create an EfficientNet model. + + Args: + config: (optional) the main model parameters to create the model + overrides: (optional) a dict containing keys that can override + config + """ + overrides = overrides or {} + config = config or ModelConfig() + + self.config = config.replace(**overrides) + + input_channels = self.config.input_channels + model_name = self.config.model_name + input_shape = (None, None, input_channels) # Should handle any size image + image_input = tf.keras.layers.Input(shape=input_shape) + + output = efficientnet(image_input, self.config) + + # Cast to float32 in case we have a different model dtype + output = tf.cast(output, tf.float32) + + logging.info('Building model %s with params %s', + model_name, + self.config) + + super(EfficientNet, self).__init__( + inputs=image_input, outputs=output, name=model_name) + + @classmethod + def from_name(cls, + model_name: Text, + model_weights_path: Text = None, + copy_to_local: bool = False, + overrides: Dict[Text, Any] = None): + """Construct an EfficientNet model from a predefined model name. + + E.g., `EfficientNet.from_name('efficientnet-b0')`. + + Args: + model_name: the predefined model name + model_weights_path: the path to the weights (h5 file or saved model dir) + copy_to_local: copy the weights to a local tmp dir + overrides: (optional) a dict containing keys that can override config + + Returns: + A constructed EfficientNet instance. + """ + model_configs = dict(MODEL_CONFIGS) + overrides = dict(overrides) if overrides else {} + + # One can define their own custom models if necessary + model_configs.update(overrides.pop('model_config', {})) + + if model_name not in model_configs: + raise ValueError('Unknown model name {}'.format(model_name)) + + config = model_configs[model_name] + + model = cls(config=config, overrides=overrides) + + if model_weights_path: + if copy_to_local: + tmp_file = os.path.join('/tmp', model_name + '.h5') + model_weights_file = os.path.join(model_weights_path, 'model.h5') + tf.io.gfile.copy(model_weights_file, tmp_file, overwrite=True) + model_weights_path = tmp_file + + model.load_weights(model_weights_path) + + return model diff --git a/official/vision/image_classification/learning_rate.py b/official/vision/image_classification/learning_rate.py new file mode 100644 index 00000000000..ae2b279866e --- /dev/null +++ b/official/vision/image_classification/learning_rate.py @@ -0,0 +1,120 @@ +# Lint as: python3 +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Learning rate utilities for vision tasks.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from typing import Any, List, Mapping + +import tensorflow.compat.v2 as tf + +BASE_LEARNING_RATE = 0.1 + + +class WarmupDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule): + """A wrapper for LearningRateSchedule that includes warmup steps.""" + + def __init__( + self, + lr_schedule: tf.keras.optimizers.schedules.LearningRateSchedule, + warmup_steps: int): + """Add warmup decay to a learning rate schedule. + + Args: + lr_schedule: base learning rate scheduler + warmup_steps: number of warmup steps + + """ + super(WarmupDecaySchedule, self).__init__() + self._lr_schedule = lr_schedule + self._warmup_steps = warmup_steps + + def __call__(self, step: int): + lr = self._lr_schedule(step) + if self._warmup_steps: + initial_learning_rate = tf.convert_to_tensor( + self._lr_schedule.initial_learning_rate, name="initial_learning_rate") + dtype = initial_learning_rate.dtype + global_step_recomp = tf.cast(step, dtype) + warmup_steps = tf.cast(self._warmup_steps, dtype) + warmup_lr = initial_learning_rate * global_step_recomp / warmup_steps + lr = tf.cond(global_step_recomp < warmup_steps, + lambda: warmup_lr, + lambda: lr) + return lr + + def get_config(self) -> Mapping[str, Any]: + config = self._lr_schedule.get_config() + config.update({ + "warmup_steps": self._warmup_steps, + }) + return config + + +# TODO(b/149030439) - refactor this with +# tf.keras.optimizers.schedules.PiecewiseConstantDecay + WarmupDecaySchedule. +class PiecewiseConstantDecayWithWarmup( + tf.keras.optimizers.schedules.LearningRateSchedule): + """Piecewise constant decay with warmup schedule.""" + + def __init__(self, + batch_size: int, + epoch_size: int, + warmup_epochs: int, + boundaries: List[int], + multipliers: List[float]): + """Piecewise constant decay with warmup. + + Args: + batch_size: The training batch size used in the experiment. + epoch_size: The size of an epoch, or the number of examples in an epoch. + warmup_epochs: The number of warmup epochs to apply. + boundaries: The list of floats with strictly increasing entries. + multipliers: The list of multipliers/learning rates to use for the + piecewise portion. The length must be 1 less than that of boundaries. + + """ + super(PiecewiseConstantDecayWithWarmup, self).__init__() + if len(boundaries) != len(multipliers) - 1: + raise ValueError("The length of boundaries must be 1 less than the " + "length of multipliers") + + base_lr_batch_size = 256 + steps_per_epoch = epoch_size // batch_size + + self._rescaled_lr = BASE_LEARNING_RATE * batch_size / base_lr_batch_size + self._step_boundaries = [float(steps_per_epoch) * x for x in boundaries] + self._lr_values = [self._rescaled_lr * m for m in multipliers] + self._warmup_steps = warmup_epochs * steps_per_epoch + + def __call__(self, step: int): + """Compute learning rate at given step.""" + def warmup_lr(): + return self._rescaled_lr * ( + step / tf.cast(self._warmup_steps, tf.float32)) + def piecewise_lr(): + return tf.compat.v1.train.piecewise_constant( + tf.cast(step, tf.float32), self._step_boundaries, self._lr_values) + return tf.cond(step < self._warmup_steps, warmup_lr, piecewise_lr) + + def get_config(self) -> Mapping[str, Any]: + return { + "rescaled_lr": self._rescaled_lr, + "step_boundaries": self._step_boundaries, + "lr_values": self._lr_values, + "warmup_steps": self._warmup_steps, + } diff --git a/official/vision/image_classification/learning_rate_test.py b/official/vision/image_classification/learning_rate_test.py new file mode 100644 index 00000000000..39b2bf11a3e --- /dev/null +++ b/official/vision/image_classification/learning_rate_test.py @@ -0,0 +1,90 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for learning_rate.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.compat.v2 as tf + +from official.vision.image_classification import learning_rate + + +class LearningRateTests(tf.test.TestCase): + + def test_warmup_decay(self): + """Basic computational test for warmup decay.""" + initial_lr = 0.01 + decay_steps = 100 + decay_rate = 0.01 + warmup_steps = 10 + + base_lr = tf.keras.optimizers.schedules.ExponentialDecay( + initial_learning_rate=initial_lr, + decay_steps=decay_steps, + decay_rate=decay_rate) + lr = learning_rate.WarmupDecaySchedule( + lr_schedule=base_lr, + warmup_steps=warmup_steps) + + for step in range(warmup_steps - 1): + config = lr.get_config() + self.assertEqual(config['warmup_steps'], warmup_steps) + self.assertAllClose(self.evaluate(lr(step)), + step / warmup_steps * initial_lr) + + def test_piecewise_constant_decay_with_warmup(self): + """Basic computational test for piecewise constant decay with warmup.""" + boundaries = [1, 2, 3] + warmup_epochs = boundaries[0] + learning_rate_multipliers = [1.0, 0.1, 0.001] + expected_keys = [ + 'rescaled_lr', 'step_boundaries', 'lr_values', 'warmup_steps', + ] + + expected_lrs = [0.0, 0.1, 0.1] + + lr = learning_rate.PiecewiseConstantDecayWithWarmup( + batch_size=256, + epoch_size=256, + warmup_epochs=warmup_epochs, + boundaries=boundaries[1:], + multipliers=learning_rate_multipliers) + + step = 0 + + config = lr.get_config() + self.assertAllInSet(list(config.keys()), expected_keys) + + for boundary, expected_lr in zip(boundaries, expected_lrs): + for _ in range(step, boundary): + self.assertAllClose(self.evaluate(lr(step)), expected_lr) + step += 1 + + def test_piecewise_constant_decay_invalid_boundaries(self): + with self.assertRaisesRegex(ValueError, + 'The length of boundaries must be 1 less '): + learning_rate.PiecewiseConstantDecayWithWarmup( + batch_size=256, + epoch_size=256, + warmup_epochs=1, + boundaries=[1, 2], + multipliers=[1, 2]) + + +if __name__ == '__main__': + assert tf.version.VERSION.startswith('2.') + tf.test.main() diff --git a/official/vision/image_classification/optimizer_factory.py b/official/vision/image_classification/optimizer_factory.py new file mode 100644 index 00000000000..ae14b9f165d --- /dev/null +++ b/official/vision/image_classification/optimizer_factory.py @@ -0,0 +1,161 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Optimizer factory for vision tasks.""" +from __future__ import absolute_import +from __future__ import division +# from __future__ import google_type_annotations +from __future__ import print_function + +from absl import logging +import tensorflow.compat.v2 as tf +import tensorflow_addons as tfa + +from typing import Any, Dict, Text +from official.vision.image_classification import learning_rate +from official.vision.image_classification.configs import base_configs + + +def build_optimizer( + optimizer_name: Text, + base_learning_rate: tf.keras.optimizers.schedules.LearningRateSchedule, + params: Dict[Text, Any]): + """Build the optimizer based on name. + + Args: + optimizer_name: String representation of the optimizer name. Examples: + sgd, momentum, rmsprop. + base_learning_rate: `tf.keras.optimizers.schedules.LearningRateSchedule` + base learning rate. + params: String -> Any dictionary representing the optimizer params. + This should contain optimizer specific parameters such as + `base_learning_rate`, `decay`, etc. + + Returns: + A tf.keras.Optimizer. + + Raises: + ValueError if the provided optimizer_name is not supported. + + """ + optimizer_name = optimizer_name.lower() + logging.info('Building %s optimizer with params %s', optimizer_name, params) + + if optimizer_name == 'sgd': + logging.info('Using SGD optimizer') + nesterov = params.get('nesterov', False) + optimizer = tf.keras.optimizers.SGD(learning_rate=base_learning_rate, + nesterov=nesterov) + elif optimizer_name == 'momentum': + logging.info('Using momentum optimizer') + nesterov = params.get('nesterov', False) + optimizer = tf.keras.optimizers.SGD(learning_rate=base_learning_rate, + momentum=params['momentum'], + nesterov=nesterov) + elif optimizer_name == 'rmsprop': + logging.info('Using RMSProp') + rho = params.get('decay', None) or params.get('rho', 0.9) + momentum = params.get('momentum', 0.9) + epsilon = params.get('epsilon', 1e-07) + optimizer = tf.keras.optimizers.RMSprop(learning_rate=base_learning_rate, + rho=rho, + momentum=momentum, + epsilon=epsilon) + elif optimizer_name == 'adam': + logging.info('Using Adam') + beta_1 = params.get('beta_1', 0.9) + beta_2 = params.get('beta_2', 0.999) + epsilon = params.get('epsilon', 1e-07) + optimizer = tf.keras.optimizers.Adam(learning_rate=base_learning_rate, + beta_1=beta_1, + beta_2=beta_2, + epsilon=epsilon) + elif optimizer_name == 'adamw': + logging.info('Using AdamW') + weight_decay = params.get('weight_decay', 0.01) + beta_1 = params.get('beta_1', 0.9) + beta_2 = params.get('beta_2', 0.999) + epsilon = params.get('epsilon', 1e-07) + optimizer = tfa.optimizers.AdamW(weight_decay=weight_decay, + learning_rate=base_learning_rate, + beta_1=beta_1, + beta_2=beta_2, + epsilon=epsilon) + else: + raise ValueError('Unknown optimizer %s' % optimizer_name) + + moving_average_decay = params.get('moving_average_decay', 0.) + if moving_average_decay is not None and moving_average_decay > 0.: + logging.info('Including moving average decay.') + optimizer = tfa.optimizers.MovingAverage( + optimizer, + average_decay=params['moving_average_decay'], + num_updates=None) + if params.get('lookahead', None): + logging.info('Using lookahead optimizer.') + optimizer = tfa.optimizers.Lookahead(optimizer) + return optimizer + + +def build_learning_rate(params: base_configs.LearningRateConfig, + batch_size: int = None, + train_steps: int = None): + """Build the learning rate given the provided configuration.""" + decay_type = params.name + base_lr = params.initial_lr + decay_rate = params.decay_rate + if params.decay_epochs is not None: + decay_steps = params.decay_epochs * train_steps + else: + decay_steps = 0 + if params.warmup_epochs is not None: + warmup_steps = params.warmup_epochs * train_steps + else: + warmup_steps = 0 + + lr_multiplier = params.scale_by_batch_size + + if lr_multiplier and lr_multiplier > 0: + # Scale the learning rate based on the batch size and a multiplier + base_lr *= lr_multiplier * batch_size + logging.info('Scaling the learning rate based on the batch size ' + 'multiplier. New base_lr: %f', base_lr) + + if decay_type == 'exponential': + logging.info('Using exponential learning rate with: ' + 'initial_learning_rate: %f, decay_steps: %d, ' + 'decay_rate: %f', base_lr, decay_steps, decay_rate) + lr = tf.keras.optimizers.schedules.ExponentialDecay( + initial_learning_rate=base_lr, + decay_steps=decay_steps, + decay_rate=decay_rate) + elif decay_type == 'piecewise_constant_with_warmup': + logging.info('Using Piecewise constant decay with warmup. ' + 'Parameters: batch_size: %d, epoch_size: %d, ' + 'warmup_epochs: %d, boundaries: %s, multipliers: %s', + batch_size, params.examples_per_epoch, + params.warmup_epochs, params.boundaries, + params.multipliers) + lr = learning_rate.PiecewiseConstantDecayWithWarmup( + batch_size=batch_size, + epoch_size=params.examples_per_epoch, + warmup_epochs=params.warmup_epochs, + boundaries=params.boundaries, + multipliers=params.multipliers) + if warmup_steps > 0: + if decay_type != 'piecewise_constant_with_warmup': + logging.info('Applying %d warmup steps to the learning rate', + warmup_steps) + lr = learning_rate.WarmupDecaySchedule(lr, warmup_steps) + return lr diff --git a/official/vision/image_classification/optimizer_factory_test.py b/official/vision/image_classification/optimizer_factory_test.py new file mode 100644 index 00000000000..7d2a18ddb50 --- /dev/null +++ b/official/vision/image_classification/optimizer_factory_test.py @@ -0,0 +1,115 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for optimizer_factory.""" + +from __future__ import absolute_import +from __future__ import division +# from __future__ import google_type_annotations +from __future__ import print_function + +import tensorflow.compat.v2 as tf + +from absl.testing import parameterized +from official.vision.image_classification import optimizer_factory +from official.vision.image_classification.configs import base_configs + + +class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + ('sgd', 'sgd', 0., False), + ('momentum', 'momentum', 0., False), + ('rmsprop', 'rmsprop', 0., False), + ('adam', 'adam', 0., False), + ('adamw', 'adamw', 0., False), + ('momentum_lookahead', 'momentum', 0., True), + ('sgd_ema', 'sgd', 0.001, False), + ('momentum_ema', 'momentum', 0.001, False), + ('rmsprop_ema', 'rmsprop', 0.001, False)) + def test_optimizer(self, optimizer_name, moving_average_decay, lookahead): + """Smoke test to be sure no syntax errors.""" + params = { + 'learning_rate': 0.001, + 'rho': 0.09, + 'momentum': 0., + 'epsilon': 1e-07, + 'moving_average_decay': moving_average_decay, + 'lookahead': lookahead, + } + optimizer = optimizer_factory.build_optimizer( + optimizer_name=optimizer_name, + base_learning_rate=params['learning_rate'], + params=params) + self.assertTrue(issubclass(type(optimizer), tf.keras.optimizers.Optimizer)) + + def test_unknown_optimizer(self): + with self.assertRaises(ValueError): + optimizer_factory.build_optimizer( + optimizer_name='this_optimizer_does_not_exist', + base_learning_rate=None, + params=None) + + def test_learning_rate_without_decay_or_warmups(self): + params = base_configs.LearningRateConfig( + name='exponential', + initial_lr=0.01, + decay_rate=0.01, + decay_epochs=None, + warmup_epochs=None, + scale_by_batch_size=0.01, + examples_per_epoch=1, + boundaries=[0], + multipliers=[0, 1]) + batch_size = 1 + train_steps = 1 + + lr = optimizer_factory.build_learning_rate( + params=params, + batch_size=batch_size, + train_steps=train_steps) + self.assertTrue( + issubclass( + type(lr), tf.keras.optimizers.schedules.LearningRateSchedule)) + + @parameterized.named_parameters( + ('exponential', 'exponential'), + ('piecewise_constant_with_warmup', 'piecewise_constant_with_warmup')) + def test_learning_rate_with_decay_and_warmup(self, lr_decay_type): + """Basic smoke test for syntax.""" + params = base_configs.LearningRateConfig( + name=lr_decay_type, + initial_lr=0.01, + decay_rate=0.01, + decay_epochs=1, + warmup_epochs=1, + scale_by_batch_size=0.01, + examples_per_epoch=1, + boundaries=[0], + multipliers=[0, 1]) + batch_size = 1 + train_steps = 1 + + lr = optimizer_factory.build_learning_rate( + params=params, + batch_size=batch_size, + train_steps=train_steps) + self.assertTrue( + issubclass( + type(lr), tf.keras.optimizers.schedules.LearningRateSchedule)) + + +if __name__ == '__main__': + assert tf.version.VERSION.startswith('2.') + tf.test.main() diff --git a/official/vision/image_classification/preprocessing.py b/official/vision/image_classification/preprocessing.py new file mode 100644 index 00000000000..200b228bfbf --- /dev/null +++ b/official/vision/image_classification/preprocessing.py @@ -0,0 +1,391 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preprocessing functions for images.""" + +from __future__ import absolute_import +from __future__ import division +# from __future__ import google_type_annotations +from __future__ import print_function + +import tensorflow.compat.v2 as tf +from typing import List, Optional, Text, Tuple + +from official.vision.image_classification import augment + + +# Calculated from the ImageNet training set +MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255) +STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255) + +IMAGE_SIZE = 224 +CROP_PADDING = 32 + + +def mean_image_subtraction( + image_bytes: tf.Tensor, + means: Tuple[float, ...], + num_channels: int = 3, + dtype: tf.dtypes.DType = tf.float32, +) -> tf.Tensor: + """Subtracts the given means from each image channel. + + For example: + means = [123.68, 116.779, 103.939] + image_bytes = mean_image_subtraction(image_bytes, means) + + Note that the rank of `image` must be known. + + Args: + image_bytes: a tensor of size [height, width, C]. + means: a C-vector of values to subtract from each channel. + num_channels: number of color channels in the image that will be distorted. + dtype: the dtype to convert the images to. Set to `None` to skip conversion. + + Returns: + the centered image. + + Raises: + ValueError: If the rank of `image` is unknown, if `image` has a rank other + than three or if the number of channels in `image` doesn't match the + number of values in `means`. + """ + if image_bytes.get_shape().ndims != 3: + raise ValueError('Input must be of size [height, width, C>0]') + + if len(means) != num_channels: + raise ValueError('len(means) must match the number of channels') + + # We have a 1-D tensor of means; convert to 3-D. + # Note(b/130245863): we explicitly call `broadcast` instead of simply + # expanding dimensions for better performance. + means = tf.broadcast_to(means, tf.shape(image_bytes)) + if dtype is not None: + means = tf.cast(means, dtype=dtype) + + return image_bytes - means + + +def standardize_image( + image_bytes: tf.Tensor, + stddev: Tuple[float, ...], + num_channels: int = 3, + dtype: tf.dtypes.DType = tf.float32, +) -> tf.Tensor: + """Divides the given stddev from each image channel. + + For example: + stddev = [123.68, 116.779, 103.939] + image_bytes = standardize_image(image_bytes, stddev) + + Note that the rank of `image` must be known. + + Args: + image_bytes: a tensor of size [height, width, C]. + stddev: a C-vector of values to divide from each channel. + num_channels: number of color channels in the image that will be distorted. + dtype: the dtype to convert the images to. Set to `None` to skip conversion. + + Returns: + the centered image. + + Raises: + ValueError: If the rank of `image` is unknown, if `image` has a rank other + than three or if the number of channels in `image` doesn't match the + number of values in `stddev`. + """ + if image_bytes.get_shape().ndims != 3: + raise ValueError('Input must be of size [height, width, C>0]') + + if len(stddev) != num_channels: + raise ValueError('len(stddev) must match the number of channels') + + # We have a 1-D tensor of stddev; convert to 3-D. + # Note(b/130245863): we explicitly call `broadcast` instead of simply + # expanding dimensions for better performance. + stddev = tf.broadcast_to(stddev, tf.shape(image_bytes)) + if dtype is not None: + stddev = tf.cast(stddev, dtype=dtype) + + return image_bytes / stddev + + +def normalize_images(features: tf.Tensor, + mean_rgb: Tuple[float, ...] = MEAN_RGB, + stddev_rgb: Tuple[float, ...] = STDDEV_RGB, + num_channels: int = 3, + dtype: tf.dtypes.DType = tf.float32, + data_format: Text = 'channels_last') -> tf.Tensor: + """Normalizes the input image channels with the given mean and stddev. + + Args: + features: `Tensor` representing decoded images in float format. + mean_rgb: the mean of the channels to subtract. + stddev_rgb: the stddev of the channels to divide. + num_channels: the number of channels in the input image tensor. + dtype: the dtype to convert the images to. Set to `None` to skip conversion. + data_format: the format of the input image tensor + ['channels_first', 'channels_last']. + + Returns: + A normalized image `Tensor`. + """ + # TODO(allencwang) - figure out how to use mean_image_subtraction and + # standardize_image on batches of images and replace the following. + if data_format == 'channels_first': + stats_shape = [num_channels, 1, 1] + else: + stats_shape = [1, 1, num_channels] + + if dtype is not None: + features = tf.image.convert_image_dtype(features, dtype=dtype) + + if mean_rgb is not None: + mean_rgb = tf.constant(mean_rgb, + shape=stats_shape, + dtype=features.dtype) + mean_rgb = tf.broadcast_to(mean_rgb, tf.shape(features)) + features = features - mean_rgb + + if stddev_rgb is not None: + stddev_rgb = tf.constant(stddev_rgb, + shape=stats_shape, + dtype=features.dtype) + stddev_rgb = tf.broadcast_to(stddev_rgb, tf.shape(features)) + features = features / stddev_rgb + + return features + + +def decode_and_center_crop(image_bytes: tf.Tensor, + image_size: int = IMAGE_SIZE, + crop_padding: int = CROP_PADDING) -> tf.Tensor: + """Crops to center of image with padding then scales image_size. + + Args: + image_bytes: `Tensor` representing an image binary of arbitrary size. + image_size: image height/width dimension. + crop_padding: the padding size to use when centering the crop. + + Returns: + A decoded and cropped image `Tensor`. + """ + decoded = image_bytes.dtype != tf.string + shape = (tf.shape(image_bytes) if decoded + else tf.image.extract_jpeg_shape(image_bytes)) + image_height = shape[0] + image_width = shape[1] + + padded_center_crop_size = tf.cast( + ((image_size / (image_size + crop_padding)) * + tf.cast(tf.minimum(image_height, image_width), tf.float32)), + tf.int32) + + offset_height = ((image_height - padded_center_crop_size) + 1) // 2 + offset_width = ((image_width - padded_center_crop_size) + 1) // 2 + crop_window = tf.stack([offset_height, offset_width, + padded_center_crop_size, padded_center_crop_size]) + if decoded: + image = tf.image.crop_to_bounding_box( + image_bytes, + offset_height=offset_height, + offset_width=offset_width, + target_height=padded_center_crop_size, + target_width=padded_center_crop_size) + else: + image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) + + image = resize_image(image_bytes=image, + height=image_size, + width=image_size) + + return image + + +def decode_crop_and_flip(image_bytes: tf.Tensor) -> tf.Tensor: + """Crops an image to a random part of the image, then randomly flips. + + Args: + image_bytes: `Tensor` representing an image binary of arbitrary size. + + Returns: + A decoded and cropped image `Tensor`. + + """ + decoded = image_bytes.dtype != tf.string + bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) + shape = (tf.shape(image_bytes) if decoded + else tf.image.extract_jpeg_shape(image_bytes)) + sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( + shape, + bounding_boxes=bbox, + min_object_covered=0.1, + aspect_ratio_range=[0.75, 1.33], + area_range=[0.05, 1.0], + max_attempts=100, + use_image_if_no_bounding_boxes=True) + bbox_begin, bbox_size, _ = sample_distorted_bounding_box + + # Reassemble the bounding box in the format the crop op requires. + offset_height, offset_width, _ = tf.unstack(bbox_begin) + target_height, target_width, _ = tf.unstack(bbox_size) + crop_window = tf.stack([offset_height, offset_width, + target_height, target_width]) + if decoded: + cropped = tf.image.crop_to_bounding_box( + image_bytes, + offset_height=offset_height, + offset_width=offset_width, + target_height=target_height, + target_width=target_width) + else: + cropped = tf.image.decode_and_crop_jpeg(image_bytes, + crop_window, + channels=3) + + # Flip to add a little more random distortion in. + cropped = tf.image.random_flip_left_right(cropped) + return cropped + + +def resize_image(image_bytes: tf.Tensor, + height: int = IMAGE_SIZE, + width: int = IMAGE_SIZE) -> tf.Tensor: + """Resizes an image to a given height and width. + + Args: + image_bytes: `Tensor` representing an image binary of arbitrary size. + height: image height dimension. + width: image width dimension. + + Returns: + A tensor containing the resized image. + + """ + return tf.compat.v1.image.resize( + image_bytes, [height, width], method=tf.image.ResizeMethod.BILINEAR, + align_corners=False) + + +def preprocess_for_eval( + image_bytes: tf.Tensor, + image_size: int = IMAGE_SIZE, + num_channels: int = 3, + mean_subtract: bool = False, + standardize: bool = False, + dtype: tf.dtypes.DType = tf.float32 +) -> tf.Tensor: + """Preprocesses the given image for evaluation. + + Args: + image_bytes: `Tensor` representing an image binary of arbitrary size. + image_size: image height/width dimension. + num_channels: number of image input channels. + mean_subtract: whether or not to apply mean subtraction. + standardize: whether or not to apply standardization. + dtype: the dtype to convert the images to. Set to `None` to skip conversion. + + Returns: + A preprocessed and normalized image `Tensor`. + """ + images = decode_and_center_crop(image_bytes, image_size) + images = tf.reshape(images, [image_size, image_size, num_channels]) + + if mean_subtract: + images = mean_image_subtraction(image_bytes=images, means=MEAN_RGB) + if standardize: + images = standardize_image(image_bytes=images, stddev=STDDEV_RGB) + if dtype is not None: + images = tf.image.convert_image_dtype(images, dtype=dtype) + + return images + + +def load_eval_image(filename: Text, image_size: int = IMAGE_SIZE) -> tf.Tensor: + """Reads an image from the filesystem and applies image preprocessing. + + Args: + filename: a filename path of an image. + image_size: image height/width dimension. + + Returns: + A preprocessed and normalized image `Tensor`. + """ + image_bytes = tf.io.read_file(filename) + image = preprocess_for_eval(image_bytes, image_size) + + return image + + +def build_eval_dataset(filenames: List[Text], + labels: List[int] = None, + image_size: int = IMAGE_SIZE, + batch_size: int = 1) -> tf.Tensor: + """Builds a tf.data.Dataset from a list of filenames and labels. + + Args: + filenames: a list of filename paths of images. + labels: a list of labels corresponding to each image. + image_size: image height/width dimension. + batch_size: the batch size used by the dataset + + Returns: + A preprocessed and normalized image `Tensor`. + """ + if labels is None: + labels = [0] * len(filenames) + + filenames = tf.constant(filenames) + labels = tf.constant(labels) + dataset = tf.data.Dataset.from_tensor_slices((filenames, labels)) + + dataset = dataset.map( + lambda filename, label: (load_eval_image(filename, image_size), label)) + dataset = dataset.batch(batch_size) + + return dataset + + +def preprocess_for_train(image_bytes: tf.Tensor, + image_size: int = IMAGE_SIZE, + augmenter: Optional[augment.ImageAugment] = None, + mean_subtract: bool = False, + standardize: bool = False, + dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor: + """Preprocesses the given image for training. + + Args: + image_bytes: `Tensor` representing an image binary of + arbitrary size of dtype tf.uint8. + image_size: image height/width dimension. + augmenter: the image augmenter to apply. + mean_subtract: whether or not to apply mean subtraction. + standardize: whether or not to apply standardization. + dtype: the dtype to convert the images to. Set to `None` to skip conversion. + + Returns: + A preprocessed and normalized image `Tensor`. + """ + images = decode_crop_and_flip(image_bytes=image_bytes) + images = resize_image(images, height=image_size, width=image_size) + if mean_subtract: + images = mean_image_subtraction(image_bytes=images, means=MEAN_RGB) + if standardize: + images = standardize_image(image_bytes=images, stddev=STDDEV_RGB) + if augmenter is not None: + images = augmenter.distort(images) + if dtype is not None: + images = tf.image.convert_image_dtype(images, dtype) + + return images diff --git a/official/vision/image_classification/resnet/README.md b/official/vision/image_classification/resnet/README.md new file mode 100644 index 00000000000..d5923a83a5a --- /dev/null +++ b/official/vision/image_classification/resnet/README.md @@ -0,0 +1,129 @@ +This folder contains a compile/fit and +[custom training loop (CTL)](#resnet-custom-training-loop) implementation for +ResNet50. + +## Before you begin +Please refer to the [README](../README.md) in the parent directory for +information on setup and preparing the data. + +## ResNet (custom training loop) + +Similar to the [estimator implementation](../../../r1/resnet), the Keras +implementation has code for the ImageNet dataset. The ImageNet +version uses a ResNet50 model implemented in +[`resnet_model.py`](./resnet_model.py). + + +### Pretrained Models + +* [ResNet50 Checkpoints](https://storage.googleapis.com/cloud-tpu-checkpoints/resnet/resnet50.tar.gz) + +* ResNet50 TFHub: [feature vector](https://tfhub.dev/tensorflow/resnet_50/feature_vector/1) +and [classification](https://tfhub.dev/tensorflow/resnet_50/classification/1) + +```bash +python3 resnet_imagenet_main.py +``` + +Again, if you did not download the data to the default directory, specify the +location with the `--data_dir` flag: + +```bash +python3 resnet_imagenet_main.py --data_dir=/path/to/imagenet +``` + +There are more flag options you can specify. Here are some examples: + +- `--use_synthetic_data`: when set to true, synthetic data, rather than real +data, are used; +- `--batch_size`: the batch size used for the model; +- `--model_dir`: the directory to save the model checkpoint; +- `--train_epochs`: number of epoches to run for training the model; +- `--train_steps`: number of steps to run for training the model. We now only +support a number that is smaller than the number of batches in an epoch. +- `--skip_eval`: when set to true, evaluation as well as validation during +training is skipped + +For example, this is a typical command line to run with ImageNet data with +batch size 128 per GPU: + +```bash +python3 -m resnet_imagenet_main.py \ + --model_dir=/tmp/model_dir/something \ + --num_gpus=2 \ + --batch_size=128 \ + --train_epochs=90 \ + --train_steps=10 \ + --use_synthetic_data=false +``` + +See [`common.py`](common.py) for full list of options. + +### Using multiple GPUs + +You can train these models on multiple GPUs using `tf.distribute.Strategy` API. +You can read more about them in this +[guide](https://www.tensorflow.org/guide/distribute_strategy). + +In this example, we have made it easier to use is with just a command line flag +`--num_gpus`. By default this flag is 1 if TensorFlow is compiled with CUDA, +and 0 otherwise. + +- --num_gpus=0: Uses tf.distribute.OneDeviceStrategy with CPU as the device. +- --num_gpus=1: Uses tf.distribute.OneDeviceStrategy with GPU as the device. +- --num_gpus=2+: Uses tf.distribute.MirroredStrategy to run synchronous +distributed training across the GPUs. + +If you wish to run without `tf.distribute.Strategy`, you can do so by setting +`--distribution_strategy=off`. + +### Running on multiple GPU hosts + +You can also train these models on multiple hosts, each with GPUs, using +`tf.distribute.Strategy`. + +The easiest way to run multi-host benchmarks is to set the +[`TF_CONFIG`](https://www.tensorflow.org/guide/distributed_training#TF_CONFIG) +appropriately at each host. e.g., to run using `MultiWorkerMirroredStrategy` on +2 hosts, the `cluster` in `TF_CONFIG` should have 2 `host:port` entries, and +host `i` should have the `task` in `TF_CONFIG` set to `{"type": "worker", +"index": i}`. `MultiWorkerMirroredStrategy` will automatically use all the +available GPUs at each host. + +### Running on Cloud TPUs + +Note: This model will **not** work with TPUs on Colab. + +You can train the ResNet CTL model on Cloud TPUs using +`tf.distribute.TPUStrategy`. If you are not familiar with Cloud TPUs, it is +strongly recommended that you go through the +[quickstart](https://cloud.google.com/tpu/docs/quickstart) to learn how to +create a TPU and GCE VM. + +To run ResNet model on a TPU, you must set `--distribution_strategy=tpu` and +`--tpu=$TPU_NAME`, where `$TPU_NAME` the name of your TPU in the Cloud Console. +From a GCE VM, you can run the following command to train ResNet for one epoch +on a v2-8 or v3-8 TPU by setting `TRAIN_EPOCHS` to 1: + +```bash +python3 resnet_ctl_imagenet_main.py \ + --tpu=$TPU_NAME \ + --model_dir=$MODEL_DIR \ + --data_dir=$DATA_DIR \ + --batch_size=1024 \ + --steps_per_loop=500 \ + --train_epochs=$TRAIN_EPOCHS \ + --use_synthetic_data=false \ + --dtype=fp32 \ + --enable_eager=true \ + --enable_tensorboard=true \ + --distribution_strategy=tpu \ + --log_steps=50 \ + --single_l2_loss_op=true \ + --use_tf_function=true +``` + +To train the ResNet to convergence, run it for 90 epochs by setting +`TRAIN_EPOCHS` to 90. + +Note: `$MODEL_DIR` and `$DATA_DIR` must be GCS paths.