Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 303780351
  • Loading branch information
saberkun authored and tensorflower-gardener committed Mar 30, 2020
1 parent f302469 commit fddab2e
Show file tree
Hide file tree
Showing 15 changed files with 18 additions and 27 deletions.
13 changes: 5 additions & 8 deletions official/vision/image_classification/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from __future__ import print_function

import math
import tensorflow.compat.v2 as tf
import tensorflow as tf
from typing import Any, Dict, Iterable, List, Optional, Text, Tuple, Union

from tensorflow.python.keras.layers.preprocessing import image_preprocessing as image_ops
Expand Down Expand Up @@ -75,8 +75,7 @@ def from_4d(image: tf.Tensor, ndims: int) -> tf.Tensor:
return tf.reshape(image, new_shape)


def _convert_translation_to_transform(
translations: Iterable[int]) -> tf.Tensor:
def _convert_translation_to_transform(translations) -> tf.Tensor:
"""Converts translations to a projective transform.
The translation matrix looks like this:
Expand Down Expand Up @@ -166,8 +165,7 @@ def _convert_angles_to_transform(
)


def transform(image: tf.Tensor,
transforms: Iterable[float]) -> tf.Tensor:
def transform(image: tf.Tensor, transforms) -> tf.Tensor:
"""Prepares input data for `image_ops.transform`."""
original_ndims = tf.rank(image)
transforms = tf.convert_to_tensor(transforms, dtype=tf.float32)
Expand All @@ -181,8 +179,7 @@ def transform(image: tf.Tensor,
return from_4d(image, original_ndims)


def translate(image: tf.Tensor,
translations: Iterable[int]) -> tf.Tensor:
def translate(image: tf.Tensor, translations) -> tf.Tensor:
"""Translates image(s) by provided vectors.
Args:
Expand Down Expand Up @@ -577,7 +574,7 @@ def unwrap(image: tf.Tensor, replace: int) -> tf.Tensor:
return image


def _randomly_negate_tensor(tensor: tf.Tensor) -> tf.Tensor:
def _randomly_negate_tensor(tensor):
"""With 50% prob turn the tensor negative."""
should_flip = tf.cast(tf.floor(tf.random.uniform([]) + 0.5), tf.bool)
final_tensor = tf.cond(should_flip, lambda: tensor, lambda: -tensor)
Expand Down
3 changes: 1 addition & 2 deletions official/vision/image_classification/augment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from absl.testing import parameterized

import tensorflow.compat.v2 as tf
import tensorflow as tf

from official.vision.image_classification import augment

Expand Down Expand Up @@ -133,5 +133,4 @@ def test_all_policy_ops(self):
self.assertEqual((224, 224, 3), image.shape)

if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.')
tf.test.main()
3 changes: 1 addition & 2 deletions official/vision/image_classification/classifier_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from absl import app
from absl import flags
from absl import logging
import tensorflow.compat.v2 as tf
import tensorflow as tf

from official.modeling import performance
from official.modeling.hyperparams import params_dict
Expand Down Expand Up @@ -423,5 +423,4 @@ def main(_):
flags.mark_flag_as_required('model_type')
flags.mark_flag_as_required('dataset')

assert tf.version.VERSION.startswith('2.')
app.run(main)
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from absl import flags
from absl.testing import parameterized
import tensorflow.compat.v2 as tf
import tensorflow as tf

from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
Expand Down Expand Up @@ -313,5 +313,4 @@ def test_serialize_config(self):
tf.io.gfile.rmtree(model_dir)

if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.')
tf.test.main()
2 changes: 1 addition & 1 deletion official/vision/image_classification/dataset_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from typing import Any, List, Optional, Tuple, Mapping, Union
from absl import logging
from dataclasses import dataclass
import tensorflow.compat.v2 as tf
import tensorflow as tf
import tensorflow_datasets as tfds

from official.modeling.hyperparams import base_config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import numpy as np

import tensorflow.compat.v1 as tf1
import tensorflow.compat.v2 as tf
import tensorflow as tf
from typing import Text, Optional

from tensorflow.python.tpu import tpu_function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from absl import logging
from dataclasses import dataclass
import tensorflow.compat.v2 as tf
import tensorflow as tf

from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
Expand Down
2 changes: 1 addition & 1 deletion official/vision/image_classification/learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from typing import Any, List, Mapping

import tensorflow.compat.v2 as tf
import tensorflow as tf

BASE_LEARNING_RATE = 0.1

Expand Down
3 changes: 1 addition & 2 deletions official/vision/image_classification/learning_rate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import division
from __future__ import print_function

import tensorflow.compat.v2 as tf
import tensorflow as tf

from official.vision.image_classification import learning_rate

Expand Down Expand Up @@ -86,5 +86,4 @@ def test_piecewise_constant_decay_invalid_boundaries(self):


if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.')
tf.test.main()
1 change: 0 additions & 1 deletion official/vision/image_classification/mnist_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,5 +88,4 @@ def test_end_to_end(self, distribution):


if __name__ == "__main__":
tf.compat.v1.enable_v2_behavior()
tf.test.main()
2 changes: 1 addition & 1 deletion official/vision/image_classification/optimizer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from __future__ import print_function

from absl import logging
import tensorflow.compat.v2 as tf
import tensorflow as tf
import tensorflow_addons as tfa

from typing import Any, Dict, Text
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# from __future__ import google_type_annotations
from __future__ import print_function

import tensorflow.compat.v2 as tf
import tensorflow as tf

from absl.testing import parameterized
from official.vision.image_classification import optimizer_factory
Expand Down Expand Up @@ -111,5 +111,4 @@ def test_learning_rate_with_decay_and_warmup(self, lr_decay_type):


if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.')
tf.test.main()
2 changes: 1 addition & 1 deletion official/vision/image_classification/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# from __future__ import google_type_annotations
from __future__ import print_function

import tensorflow.compat.v2 as tf
import tensorflow as tf
from typing import List, Optional, Text, Tuple

from official.vision.image_classification import augment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import division
from __future__ import print_function

import tensorflow.compat.v2 as tf
import tensorflow as tf

from official.modeling import performance
from official.staging.training import grad_utils
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from absl import app
from absl import flags

import tensorflow.compat.v2 as tf
import tensorflow as tf

from official.vision.image_classification.resnet import imagenet_preprocessing
from official.vision.image_classification.resnet import resnet_model
Expand Down

0 comments on commit fddab2e

Please sign in to comment.