Skip to content

Commit

Permalink
copy test util from TFE to avoid dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
mortendahl committed Apr 17, 2020
1 parent 45786b5 commit 65c3c51
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 2 deletions.
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
absl-py==0.9.0
cpplint==1.4.4
numpy==1.16.4
pip==19.2.3
Expand Down
3 changes: 2 additions & 1 deletion tf_big/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ py_library(
"python/ops/__init__.py",
]),
deps = [
":big_ops_py"
":big_ops_py",
"//tf_big/python/test:test_py",
],
srcs_version = "PY2AND3",
)
2 changes: 2 additions & 0 deletions tf_big/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from tf_big.python.tensor import set_secure_default
from tf_big.python.tensor import get_secure_default

from tf_big.python import test
from tf_big.python.tensor import Tensor

from tf_big.python.tensor import constant
Expand All @@ -21,6 +22,7 @@
'set_secure_default',
'get_secure_default',

'test',
'Tensor',

'constant',
Expand Down
2 changes: 1 addition & 1 deletion tf_big/python/ops/big_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,4 @@ def test_mod(self, run_eagerly):


if __name__ == '__main__':
unittest.main()
unittest.main()
21 changes: 21 additions & 0 deletions tf_big/python/test/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package(default_visibility = ["//visibility:public"])


py_library(
name = "test_py",
srcs = ([
"__init__.py",
"execution_context.py",
]),
)

py_test(
name = "execution_context_py_test",
srcs = [
"execution_context_test.py",
],
main = "execution_context_test.py",
deps = [
":test_py",
],
)
5 changes: 5 additions & 0 deletions tf_big/python/test/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .execution_context import tf_execution_context

__all__ = [
"tf_execution_context",
]
42 changes: 42 additions & 0 deletions tf_big/python/test/execution_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import contextlib

import tensorflow as tf


class EagerExecutionContext:
def scope(self):
return contextlib.suppress()

def evaluate(self, value):
return value.numpy()


class GraphExecutionContext:
def __init__(self):
self._graph = None
self._session = None

@property
def graph(self):
if self._graph is None:
self._graph = tf.Graph()
return self._graph

@property
def session(self):
if self._session is None:
with self._graph.as_default():
self._session = tf.compat.v1.Session()
return self._session

def scope(self):
return self.graph.as_default()

def evaluate(self, value):
return self.session.run(value)


def tf_execution_context(run_eagerly):
if run_eagerly:
return EagerExecutionContext()
return GraphExecutionContext()
28 changes: 28 additions & 0 deletions tf_big/python/test/execution_context_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# pylint: disable=missing-docstring
import unittest

import numpy as np
import tensorflow as tf
from absl.testing import parameterized

from tf_big.python.test import tf_execution_context


class TestExecutionContext(parameterized.TestCase):
@parameterized.parameters({"run_eagerly": True}, {"run_eagerly": False})
def test_tf_execution_mode(self, run_eagerly):
context = tf_execution_context(run_eagerly)
with context.scope():
x = tf.fill(dims=(2, 2), value=5.0)
assert tf.executing_eagerly() == run_eagerly

assert isinstance(x, tf.Tensor)
actual_result = context.evaluate(x)
assert isinstance(actual_result, np.ndarray)

expected_result = np.array([[5.0, 5.0], [5.0, 5.0]])
np.testing.assert_equal(actual_result, expected_result)


if __name__ == "__main__":
unittest.main()

0 comments on commit 65c3c51

Please sign in to comment.