From 65c3c5135fa2819f72dd71182b23b89154afc1a6 Mon Sep 17 00:00:00 2001 From: Morten Dahl Date: Fri, 17 Apr 2020 18:24:23 +0200 Subject: [PATCH] copy test util from TFE to avoid dependency --- requirements-dev.txt | 1 + tf_big/BUILD | 3 +- tf_big/__init__.py | 2 + tf_big/python/ops/big_ops_test.py | 2 +- tf_big/python/test/BUILD | 21 ++++++++++ tf_big/python/test/__init__.py | 5 +++ tf_big/python/test/execution_context.py | 42 ++++++++++++++++++++ tf_big/python/test/execution_context_test.py | 28 +++++++++++++ 8 files changed, 102 insertions(+), 2 deletions(-) create mode 100644 tf_big/python/test/BUILD create mode 100644 tf_big/python/test/__init__.py create mode 100644 tf_big/python/test/execution_context.py create mode 100644 tf_big/python/test/execution_context_test.py diff --git a/requirements-dev.txt b/requirements-dev.txt index bf5dd1b..1226dee 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,4 @@ +absl-py==0.9.0 cpplint==1.4.4 numpy==1.16.4 pip==19.2.3 diff --git a/tf_big/BUILD b/tf_big/BUILD index 83c7b47..ded0c53 100644 --- a/tf_big/BUILD +++ b/tf_big/BUILD @@ -63,7 +63,8 @@ py_library( "python/ops/__init__.py", ]), deps = [ - ":big_ops_py" + ":big_ops_py", + "//tf_big/python/test:test_py", ], srcs_version = "PY2AND3", ) diff --git a/tf_big/__init__.py b/tf_big/__init__.py index 67204de..b9e43a8 100644 --- a/tf_big/__init__.py +++ b/tf_big/__init__.py @@ -1,6 +1,7 @@ from tf_big.python.tensor import set_secure_default from tf_big.python.tensor import get_secure_default +from tf_big.python import test from tf_big.python.tensor import Tensor from tf_big.python.tensor import constant @@ -21,6 +22,7 @@ 'set_secure_default', 'get_secure_default', + 'test', 'Tensor', 'constant', diff --git a/tf_big/python/ops/big_ops_test.py b/tf_big/python/ops/big_ops_test.py index 557ee92..0e413a3 100644 --- a/tf_big/python/ops/big_ops_test.py +++ b/tf_big/python/ops/big_ops_test.py @@ -147,4 +147,4 @@ def test_mod(self, run_eagerly): if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tf_big/python/test/BUILD b/tf_big/python/test/BUILD new file mode 100644 index 0000000..1562023 --- /dev/null +++ b/tf_big/python/test/BUILD @@ -0,0 +1,21 @@ +package(default_visibility = ["//visibility:public"]) + + +py_library( + name = "test_py", + srcs = ([ + "__init__.py", + "execution_context.py", + ]), +) + +py_test( + name = "execution_context_py_test", + srcs = [ + "execution_context_test.py", + ], + main = "execution_context_test.py", + deps = [ + ":test_py", + ], +) diff --git a/tf_big/python/test/__init__.py b/tf_big/python/test/__init__.py new file mode 100644 index 0000000..4ac4e90 --- /dev/null +++ b/tf_big/python/test/__init__.py @@ -0,0 +1,5 @@ +from .execution_context import tf_execution_context + +__all__ = [ + "tf_execution_context", +] diff --git a/tf_big/python/test/execution_context.py b/tf_big/python/test/execution_context.py new file mode 100644 index 0000000..18582ee --- /dev/null +++ b/tf_big/python/test/execution_context.py @@ -0,0 +1,42 @@ +import contextlib + +import tensorflow as tf + + +class EagerExecutionContext: + def scope(self): + return contextlib.suppress() + + def evaluate(self, value): + return value.numpy() + + +class GraphExecutionContext: + def __init__(self): + self._graph = None + self._session = None + + @property + def graph(self): + if self._graph is None: + self._graph = tf.Graph() + return self._graph + + @property + def session(self): + if self._session is None: + with self._graph.as_default(): + self._session = tf.compat.v1.Session() + return self._session + + def scope(self): + return self.graph.as_default() + + def evaluate(self, value): + return self.session.run(value) + + +def tf_execution_context(run_eagerly): + if run_eagerly: + return EagerExecutionContext() + return GraphExecutionContext() diff --git a/tf_big/python/test/execution_context_test.py b/tf_big/python/test/execution_context_test.py new file mode 100644 index 0000000..45423e4 --- /dev/null +++ b/tf_big/python/test/execution_context_test.py @@ -0,0 +1,28 @@ +# pylint: disable=missing-docstring +import unittest + +import numpy as np +import tensorflow as tf +from absl.testing import parameterized + +from tf_big.python.test import tf_execution_context + + +class TestExecutionContext(parameterized.TestCase): + @parameterized.parameters({"run_eagerly": True}, {"run_eagerly": False}) + def test_tf_execution_mode(self, run_eagerly): + context = tf_execution_context(run_eagerly) + with context.scope(): + x = tf.fill(dims=(2, 2), value=5.0) + assert tf.executing_eagerly() == run_eagerly + + assert isinstance(x, tf.Tensor) + actual_result = context.evaluate(x) + assert isinstance(actual_result, np.ndarray) + + expected_result = np.array([[5.0, 5.0], [5.0, 5.0]]) + np.testing.assert_equal(actual_result, expected_result) + + +if __name__ == "__main__": + unittest.main()