diff --git a/.bazelversion b/.bazelversion index 4be2c727..f22d756d 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1 @@ -6.5.0 \ No newline at end of file +6.5.0 diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 35efe54c..69b851c4 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -22,7 +22,7 @@ on: branches: - master release: - types: [published] + types: [published] jobs: build: diff --git a/.github/workflows/ci-lint.yml b/.github/workflows/ci-lint.yml new file mode 100644 index 00000000..dede434d --- /dev/null +++ b/.github/workflows/ci-lint.yml @@ -0,0 +1,21 @@ +name: pre-commit + +on: + pull_request: + push: + branches: [master] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4.1.7 + with: + # Ensure the full history is fetched + # This is required to run pre-commit on a specific set of commits + # TODO: Remove this when all the pre-commit issues are fixed + fetch-depth: 0 + - uses: actions/setup-python@v5.1.1 + with: + python-version: 3.13 + - uses: pre-commit/action@v3.0.1 diff --git a/.gitignore b/.gitignore index fdf94603..3ecf3ba3 100644 --- a/.gitignore +++ b/.gitignore @@ -126,4 +126,4 @@ dmypy.json .pyre/ # pb2.py files -*_pb2.py \ No newline at end of file +*_pb2.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..d74e3dbe --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,39 @@ +# pre-commit is a tool to perform a predefined set of tasks manually and/or +# automatically before git commits are made. +# +# Config reference: https://pre-commit.com/#pre-commit-configyaml---top-level +# +# Common tasks +# +# - Register git hooks: pre-commit install --install-hooks +# - Run on all files: pre-commit run --all-files +# +# These pre-commit hooks are run as CI. +# +# NOTE: if it can be avoided, add configs/args in pyproject.toml or below instead of creating a new `.config.file`. +# https://pre-commit.ci/#configuration +ci: + autoupdate_schedule: monthly + autofix_commit_msg: | + [pre-commit.ci] Apply automatic pre-commit fixes + +repos: + # general + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: end-of-file-fixer + exclude: '\.svg$|\.patch$' + - id: trailing-whitespace + exclude: '\.svg$|\.patch$' + - id: check-json + - id: check-yaml + args: [--allow-multiple-documents, --unsafe] + - id: check-toml + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.5.6 + hooks: + - id: ruff + args: ["--fix"] + - id: ruff-format diff --git a/LICENSE b/LICENSE index c1d8805b..f0e600d3 100644 --- a/LICENSE +++ b/LICENSE @@ -226,4 +226,4 @@ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/pyproject.toml b/pyproject.toml index 6a345a2e..27839ccc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,3 +21,112 @@ requires = [ # (b/206845101) "numpy~=1.22.0", ] + +[tool.ruff] +line-length = 88 + +[tool.ruff.lint] +select = [ + # pycodestyle + "E", + "W", + # Pyflakes + "F", + # pyupgrade + "UP", + # flake8-bugbear + "B", + # flake8-simplify + "SIM", + # isort + "I", + # pep8 naming + "N", + # pydocstyle + "D", + # annotations + "ANN", + # debugger + "T10", + # flake8-pytest + "PT", + # flake8-return + "RET", + # flake8-unused-arguments + "ARG", + # flake8-fixme + "FIX", + # flake8-eradicate + "ERA", + # pandas-vet + "PD", + # numpy-specific rules + "NPY", +] + +ignore = [ + "D104", # Missing docstring in public package + "D100", # Missing docstring in public module + "D211", # No blank line before class + "PD901", # Avoid using 'df' for pandas dataframes. Perfectly fine in functions with limited scope + "ANN201", # Missing return type annotation for public function (makes no sense for NoneType return types...) + "ANN101", # Missing type annotation for `self` + "ANN204", # Missing return type annotation for special method + "ANN002", # Missing type annotation for `*args` + "ANN003", # Missing type annotation for `**kwargs` + "D105", # Missing docstring in magic method + "D203", # 1 blank line before after class docstring + "D204", # 1 blank line required after class docstring + "D413", # 1 blank line after parameters + "SIM108", # Simplify if/else to one line; not always clearer + "D206", # Docstrings should be indented with spaces; unnecessary when running ruff-format + "E501", # Line length too long; unnecessary when running ruff-format + "W191", # Indentation contains tabs; unnecessary when running ruff-format + + # FIX AND REMOVE BELOW CODES: + "ANN001", # Missing type annotation for function argument + "ANN102", # Missing type annotation for `cls` in classmethod + "ANN202", # Missing return type annotation for private function + "ANN205", # Missing return type annotation for staticmethod + "ANN206", # Missing return type annotation for classmethod `setUpClass` + "ANN401", # Dynamically typed expressions (typing.Any) are disallowed + "ARG001", # Unused function argument + "ARG002", # Unused method argument + "ARG005", # Unused lambda argument + "B007", # Loop control variable `...` not used within loop body + "B008", # Do not perform function call in argument defaults + "B904", # Within an `except` clause, raise exceptions with `raise ... from err` or `raise ... from None` + "D101", # Missing docstring in public class + "D102", # Missing docstring in public method + "D103", # Missing docstring in public function + "D107", # Missing docstring in `__init__` + "D401", # First line of docstring should be in imperative mood + "D404", # First word of the docstring should not be "This" + "D417", # Missing argument description in the docstring + "E731", # Do not assign a `lambda` expression, use a `def` + "E741", # Ambiguous variable name + "ERA001", # Found commented-out code + "F401", # `...` imported but unused + "F403", # `from ... import *` used; unable to detect undefined names + "FIX002", # Line contains TODO, consider resolving the issue + "FIX004", # Line contains HACK, consider resolving the issue + "N802", # Function name should be lowercase + "NPY002", # Replace legacy `np.random.rand` call with `np.random.Generator` + "PD011", # Use `.to_numpy()` instead of `.values` + "PT009", # Use a regular `assert` instead of unittest-style asserts + "PT018", # Assertion should be broken down into multiple parts + "PT027", # Use `pytest.raises` instead of unittest-style `assertRaisesRegex` + "RET504", # Unnecessary assignment to `...` before `return` statement + "RET505", # Unnecessary `elif` or `else` after `return` statement + "SIM103", # Return the negated condition directly + "SIM105", # Use `contextlib.suppress(...)` instead of `try`-`except`-`pass` + "SIM117", # Use a single `with` statement with multiple contexts instead of nested `with` statements + "SIM118", # Use `key in dict` instead of `key in dict.keys()` + "SIM212", # Use `... if ... else ...` instead of `... if not ... else ...` + "UP008", # Use `super()` instead of `super(__class__, self)` + "UP028", # Replace `yield` over `for` loop with `yield from` + "UP031", # Use format specifiers instead of percent format +] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401"] diff --git a/setup.py b/setup.py index c8427cf7..4992b714 100644 --- a/setup.py +++ b/setup.py @@ -18,199 +18,203 @@ import shutil import subprocess import sys - +from distutils.command import build # pylint:disable=g-bad-import-order # setuptools must be imported prior to distutils. import setuptools -from distutils.command import build -# pylint:enable=g-bad-import-order -from setuptools import find_packages -from setuptools import setup +# pylint:enable=g-bad-import-order +from setuptools import find_packages, setup from setuptools.command.install import install from setuptools.dist import Distribution class _BuildCommand(build.build): - """Build everything that is needed to install. + """Build everything that is needed to install. + + This overrides the original distutils "build" command to to run gen_proto + command before any sub_commands. - This overrides the original distutils "build" command to to run gen_proto - command before any sub_commands. + build command is also invoked from bdist_wheel and install command, therefore + this implementation covers the following commands: + - pip install . (which invokes bdist_wheel) + - python setup.py install (which invokes install command) + - python setup.py bdist_wheel (which invokes bdist_wheel command) + """ - build command is also invoked from bdist_wheel and install command, therefore - this implementation covers the following commands: - - pip install . (which invokes bdist_wheel) - - python setup.py install (which invokes install command) - - python setup.py bdist_wheel (which invokes bdist_wheel command) - """ + def _build_cc_extensions(self): + return True - def _build_cc_extensions(self): - return True - # Add "bazel_build" command as the first sub_command of "build". Each - # sub_command of "build" (e.g. "build_py", "build_ext", etc.) is executed - # sequentially when running a "build" command, if the second item in the tuple - # (predicate method) is evaluated to true. - sub_commands = [ - ('bazel_build', _build_cc_extensions), - ] + build.build.sub_commands + # Add "bazel_build" command as the first sub_command of "build". Each + # sub_command of "build" (e.g. "build_py", "build_ext", etc.) is executed + # sequentially when running a "build" command, if the second item in the tuple + # (predicate method) is evaluated to true. + sub_commands = [ + ("bazel_build", _build_cc_extensions), + ] + build.build.sub_commands # TFX BSL is not a purelib. However because of the extension module is not # built by setuptools, it will be incorrectly treated as a purelib. The # following works around that bug. class _InstallPlatlibCommand(install): - - def finalize_options(self): - install.finalize_options(self) - self.install_lib = self.install_platlib + def finalize_options(self): + install.finalize_options(self) + self.install_lib = self.install_platlib class _BazelBuildCommand(setuptools.Command): - """Generate proto stub files in python. - - Running this command will populate foo_pb2.py file next to your foo.proto - file. - """ - - def initialize_options(self): - pass - - def finalize_options(self): - self._bazel_cmd = shutil.which('bazel') - if not self._bazel_cmd: - raise RuntimeError( - 'Could not find "bazel" binary. Please visit ' - 'https://docs.bazel.build/versions/master/install.html for ' - 'installation instruction.') - self._additional_build_options = ['--verbose_failures', '--sandbox_debug'] - if platform.system() == 'Darwin': - # This flag determines the platform qualifier of the macos wheel. - if platform.machine() == 'arm64': - self._additional_build_options = ['--macos_minimum_os=11.0', - '--config=macos_arm64'] - else: - self._additional_build_options = ['--macos_minimum_os=10.14'] - - def run(self): - subprocess.check_call( - [self._bazel_cmd, 'run', '-c', 'opt'] - + self._additional_build_options - + ['//tfx_bsl:move_generated_files'], - # Bazel should be invoked in a directory containing bazel WORKSPACE - # file, which is the root directory. - cwd=os.path.dirname(os.path.realpath(__file__)), - env=dict(os.environ, PYTHON_BIN_PATH=sys.executable), - ) + """Generate proto stub files in python. + + Running this command will populate foo_pb2.py file next to your foo.proto + file. + """ + + def initialize_options(self): + pass + + def finalize_options(self): + self._bazel_cmd = shutil.which("bazel") + if not self._bazel_cmd: + raise RuntimeError( + 'Could not find "bazel" binary. Please visit ' + "https://docs.bazel.build/versions/master/install.html for " + "installation instruction." + ) + self._additional_build_options = ["--verbose_failures", "--sandbox_debug"] + if platform.system() == "Darwin": + # This flag determines the platform qualifier of the macos wheel. + if platform.machine() == "arm64": + self._additional_build_options = [ + "--macos_minimum_os=11.0", + "--config=macos_arm64", + ] + else: + self._additional_build_options = ["--macos_minimum_os=10.14"] + + def run(self): + subprocess.check_call( + [self._bazel_cmd, "run", "-c", "opt"] + + self._additional_build_options + + ["//tfx_bsl:move_generated_files"], + # Bazel should be invoked in a directory containing bazel WORKSPACE + # file, which is the root directory. + cwd=os.path.dirname(os.path.realpath(__file__)), + env=dict(os.environ, PYTHON_BIN_PATH=sys.executable), + ) class _BinaryDistribution(Distribution): - """This class is needed in order to create OS specific wheels.""" + """This class is needed in order to create OS specific wheels.""" - def is_pure(self): - return False + def is_pure(self): + return False - def has_ext_modules(self): - return True + def has_ext_modules(self): + return True def select_constraint(default, nightly=None, git_master=None): - """Select dependency constraint based on TFX_DEPENDENCY_SELECTOR env var.""" - selector = os.environ.get('TFX_DEPENDENCY_SELECTOR') - if selector == 'UNCONSTRAINED': - return '' - elif selector == 'NIGHTLY' and nightly is not None: - return nightly - elif selector == 'GIT_MASTER' and git_master is not None: - return git_master - else: - return default + """Select dependency constraint based on TFX_DEPENDENCY_SELECTOR env var.""" + selector = os.environ.get("TFX_DEPENDENCY_SELECTOR") + if selector == "UNCONSTRAINED": + return "" + elif selector == "NIGHTLY" and nightly is not None: + return nightly + elif selector == "GIT_MASTER" and git_master is not None: + return git_master + else: + return default # Get version from version module. -with open('tfx_bsl/version.py') as fp: - globals_dict = {} - exec(fp.read(), globals_dict) # pylint: disable=exec-used -__version__ = globals_dict['__version__'] +with open("tfx_bsl/version.py") as fp: + globals_dict = {} + exec(fp.read(), globals_dict) # pylint: disable=exec-used +__version__ = globals_dict["__version__"] # Get the long description from the README file. -with open('README.md') as fp: - _LONG_DESCRIPTION = fp.read() +with open("README.md") as fp: + _LONG_DESCRIPTION = fp.read() setup( - name='tfx-bsl', + name="tfx-bsl", version=__version__, - author='Google LLC', - author_email='tensorflow-extended-dev@googlegroups.com', - license='Apache 2.0', + author="Google LLC", + author_email="tensorflow-extended-dev@googlegroups.com", + license="Apache 2.0", classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Intended Audience :: Developers', - 'Intended Audience :: Education', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: MacOS :: MacOS X', - 'Operating System :: POSIX :: Linux', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'Programming Language :: Python :: 3 :: Only', - 'Topic :: Scientific/Engineering', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'Topic :: Scientific/Engineering :: Mathematics', - 'Topic :: Software Development', - 'Topic :: Software Development :: Libraries', - 'Topic :: Software Development :: Libraries :: Python Modules', + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: MacOS :: MacOS X", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3 :: Only", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", ], namespace_packages=[], # Make sure to sync the versions of common dependencies (absl-py, numpy, # and protobuf) with TF. install_requires=[ - 'absl-py>=0.9,<2.0.0', + "absl-py>=0.9,<2.0.0", 'apache-beam[gcp]>=2.53,<3;python_version>="3.11"', 'apache-beam[gcp]>=2.50,<2.51;python_version<"3.11"', - 'google-api-python-client>=1.7.11,<2', - 'numpy>=1.22.0', - 'pandas>=1.0,<2', + "google-api-python-client>=1.7.11,<2", + "numpy>=1.22.0", + "pandas>=1.0,<2", 'protobuf>=4.25.2,<6.0.0;python_version>="3.11"', 'protobuf>=4.21.6,<6.0.0;python_version<"3.11"', - 'pyarrow>=10,<11', - 'tensorflow>=2.17,<2.18', - 'tensorflow-metadata' + "pyarrow>=10,<11", + "tensorflow>=2.17,<2.18", + "tensorflow-metadata" + select_constraint( - default='>=1.17.1,<1.18.0', - nightly='>=1.18.0.dev', - git_master='@git+https://github.com/tensorflow/metadata@master', + default=">=1.17.1,<1.18.0", + nightly=">=1.18.0.dev", + git_master="@git+https://github.com/tensorflow/metadata@master", ), - 'tensorflow-serving-api' + "tensorflow-serving-api" + select_constraint( - default='>=2.13.0,<3', - nightly='>=2.13.0.dev', - git_master='@git+https://github.com/tensorflow/serving@master', + default=">=2.13.0,<3", + nightly=">=2.13.0.dev", + git_master="@git+https://github.com/tensorflow/serving@master", ), ], - python_requires='>=3.9,<4', + extras_require={ + "dev": ["pre-commit"], + }, + python_requires=">=3.9,<4", packages=find_packages(), include_package_data=True, - package_data={'': ['*.lib', '*.pyd', '*.so']}, + package_data={"": ["*.lib", "*.pyd", "*.so"]}, zip_safe=False, distclass=_BinaryDistribution, description=( - 'tfx_bsl (TFX Basic Shared Libraries) contains libraries ' - 'shared by many TFX (TensorFlow eXtended) libraries and ' - 'components.' + "tfx_bsl (TFX Basic Shared Libraries) contains libraries " + "shared by many TFX (TensorFlow eXtended) libraries and " + "components." ), long_description=_LONG_DESCRIPTION, - long_description_content_type='text/markdown', - keywords='tfx bsl', - url='https://www.tensorflow.org/tfx', - download_url='https://github.com/tensorflow/tfx-bsl/tags', + long_description_content_type="text/markdown", + keywords="tfx bsl", + url="https://www.tensorflow.org/tfx", + download_url="https://github.com/tensorflow/tfx-bsl/tags", requires=[], cmdclass={ - 'install': _InstallPlatlibCommand, - 'build': _BuildCommand, - 'bazel_build': _BazelBuildCommand, + "install": _InstallPlatlibCommand, + "build": _BuildCommand, + "bazel_build": _BazelBuildCommand, }, ) diff --git a/tfx_bsl/arrow/array_util.py b/tfx_bsl/arrow/array_util.py index 9526a8b0..f7e6772b 100644 --- a/tfx_bsl/arrow/array_util.py +++ b/tfx_bsl/arrow/array_util.py @@ -13,180 +13,202 @@ # limitations under the License. """Arrow Array utilities.""" -from typing import Tuple, Optional, Union +from typing import Optional, Tuple, Union import numpy as np import pyarrow as pa + # pytype: disable=import-error # pylint: disable=g-import-not-at-top # pylint: disable=unused-import # See b/148667210 for why the ImportError is ignored. try: - from tfx_bsl.cc.tfx_bsl_extension.arrow.array_util import ListLengthsFromListArray - from tfx_bsl.cc.tfx_bsl_extension.arrow.array_util import GetElementLengths - from tfx_bsl.cc.tfx_bsl_extension.arrow.array_util import GetFlattenedArrayParentIndices - from tfx_bsl.cc.tfx_bsl_extension.arrow.array_util import GetArrayNullBitmapAsByteArray - from tfx_bsl.cc.tfx_bsl_extension.arrow.array_util import GetBinaryArrayTotalByteSize - from tfx_bsl.cc.tfx_bsl_extension.arrow.array_util import IndexIn - from tfx_bsl.cc.tfx_bsl_extension.arrow.array_util import MakeListArrayFromParentIndicesAndValues as _MakeListArrayFromParentIndicesAndValues - from tfx_bsl.cc.tfx_bsl_extension.arrow.array_util import CooFromListArray - from tfx_bsl.cc.tfx_bsl_extension.arrow.array_util import FillNullLists - from tfx_bsl.cc.tfx_bsl_extension.arrow.array_util import GetByteSize - from tfx_bsl.cc.tfx_bsl_extension.arrow.array_util import CountInvalidUTF8 + from tfx_bsl.cc.tfx_bsl_extension.arrow.array_util import ( + CooFromListArray, + CountInvalidUTF8, + FillNullLists, + GetArrayNullBitmapAsByteArray, + GetBinaryArrayTotalByteSize, + GetByteSize, + GetElementLengths, + GetFlattenedArrayParentIndices, + IndexIn, + ListLengthsFromListArray, + ) + from tfx_bsl.cc.tfx_bsl_extension.arrow.array_util import ( + MakeListArrayFromParentIndicesAndValues as _MakeListArrayFromParentIndicesAndValues, + ) except ImportError: - import sys - sys.stderr.write("Error importing tfx_bsl_extension.arrow.array_util. " - "Some tfx_bsl functionalities are not available") + import sys + + sys.stderr.write( + "Error importing tfx_bsl_extension.arrow.array_util. " + "Some tfx_bsl functionalities are not available" + ) # pytype: enable=import-error # pylint: enable=g-import-not-at-top # pylint: enable=unused-import def ToSingletonListArray(array: pa.Array) -> pa.Array: # pylint: disable=invalid-name - """Converts an array of `type` to a `LargeListArray`. - - Where result[i] is null if array[i] is null; [array[i]] otherwise. - - Args: - array: an arrow Array. - - Returns: - a LargeListArray. - """ - array_size = len(array) - # fast path: values are not copied. - if array.null_count == 0: + """Converts an array of `type` to a `LargeListArray`. + + Where result[i] is null if array[i] is null; [array[i]] otherwise. + + Args: + ---- + array: an arrow Array. + + Returns: + ------- + a LargeListArray. + """ + array_size = len(array) + # fast path: values are not copied. + if array.null_count == 0: + return pa.LargeListArray.from_arrays( + pa.array(np.arange(0, array_size + 1, dtype=np.int32)), array + ) + + # null_mask[i] = 1 iff array[i] is null. + null_mask = np.asarray(GetArrayNullBitmapAsByteArray(array)) + # presence_mask[i] = 0 iff array[i] is null + presence_mask = np.subtract(1, null_mask, dtype=np.uint8) + offsets_np = np.zeros((array_size + 1,), np.int32) + np.cumsum(presence_mask, out=offsets_np[1:]) + + # This is the null mask over offsets (but ListArray.from_arrays() uses it as + # the null mask for the ListArray), so its length is array_size +1, but the + # last element is always False. + list_array_null_mask = np.zeros((array_size + 1,), bool) + list_array_null_mask[:array_size] = null_mask.view(bool) + values_non_null = array.take(pa.array(np.flatnonzero(presence_mask))) return pa.LargeListArray.from_arrays( - pa.array(np.arange(0, array_size + 1, dtype=np.int32)), array) - - # null_mask[i] = 1 iff array[i] is null. - null_mask = np.asarray(GetArrayNullBitmapAsByteArray(array)) - # presence_mask[i] = 0 iff array[i] is null - presence_mask = np.subtract(1, null_mask, dtype=np.uint8) - offsets_np = np.zeros((array_size + 1,), np.int32) - np.cumsum(presence_mask, out=offsets_np[1:]) - - # This is the null mask over offsets (but ListArray.from_arrays() uses it as - # the null mask for the ListArray), so its length is array_size +1, but the - # last element is always False. - list_array_null_mask = np.zeros((array_size + 1,), bool) - list_array_null_mask[:array_size] = null_mask.view(bool) - values_non_null = array.take(pa.array(np.flatnonzero(presence_mask))) - return pa.LargeListArray.from_arrays( - pa.array(offsets_np, mask=list_array_null_mask), values_non_null) - - -def MakeListArrayFromParentIndicesAndValues(num_parents: int, # pylint: disable=invalid-name - parent_indices: pa.Array, - values: pa.Array, - empty_list_as_null: bool = True): - """Makes an Arrow LargeListArray from parent indices and values. - - For example, if `num_parents = 6`, `parent_indices = [0, 1, 1, 3, 3]` and - `values` is (an arrow Array of) `[0, 1, 2, 3, 4]`, then the result will - be a `pa.LargeListArray` of integers: - `[[0], [1, 2], , [3, 4], ]` - where `` is `null` if `empty_list_as_null` is True, or `[]` if - False. - - Args: - num_parents: integer, number of sub-list. Must be greater than or equal to - `max(parent_indices) + 1`. - parent_indices: an int64 pa.Array. Must be sorted in increasing order. - values: a pa.Array. Its length must equal to the length of `parent_indices`. - empty_list_as_null: if True, empty sub-lists will become null elements - in the result ListArray. Otherwise they become empty sub-lists. - - Returns: - A LargeListArray. - """ - return _MakeListArrayFromParentIndicesAndValues(num_parents, parent_indices, - values, empty_list_as_null) + pa.array(offsets_np, mask=list_array_null_mask), values_non_null + ) + + +def MakeListArrayFromParentIndicesAndValues( + num_parents: int, # pylint: disable=invalid-name + parent_indices: pa.Array, + values: pa.Array, + empty_list_as_null: bool = True, +): + """Makes an Arrow LargeListArray from parent indices and values. + + For example, if `num_parents = 6`, `parent_indices = [0, 1, 1, 3, 3]` and + `values` is (an arrow Array of) `[0, 1, 2, 3, 4]`, then the result will + be a `pa.LargeListArray` of integers: + `[[0], [1, 2], , [3, 4], ]` + where `` is `null` if `empty_list_as_null` is True, or `[]` if + False. + + Args: + ---- + num_parents: integer, number of sub-list. Must be greater than or equal to + `max(parent_indices) + 1`. + parent_indices: an int64 pa.Array. Must be sorted in increasing order. + values: a pa.Array. Its length must equal to the length of `parent_indices`. + empty_list_as_null: if True, empty sub-lists will become null elements + in the result ListArray. Otherwise they become empty sub-lists. + + Returns: + ------- + A LargeListArray. + """ + return _MakeListArrayFromParentIndicesAndValues( + num_parents, parent_indices, values, empty_list_as_null + ) def is_list_like(data_type: pa.DataType) -> bool: - """Returns true if an Arrow type is list-like.""" - return pa.types.is_list(data_type) or pa.types.is_large_list(data_type) + """Returns true if an Arrow type is list-like.""" + return pa.types.is_list(data_type) or pa.types.is_large_list(data_type) def get_innermost_nested_type(arrow_type: pa.DataType) -> pa.DataType: - """Returns the innermost type of a nested list type.""" - while is_list_like(arrow_type): - arrow_type = arrow_type.value_type - return arrow_type + """Returns the innermost type of a nested list type.""" + while is_list_like(arrow_type): + arrow_type = arrow_type.value_type + return arrow_type def flatten_nested( array: pa.Array, return_parent_indices: bool = False - ) -> Tuple[pa.Array, Optional[np.ndarray]]: - """Flattens all the list arrays nesting an array. - - If `array` is not list-like, itself will be returned. - - Args: - array: pa.Array to flatten. - return_parent_indices: If True, also returns the parent indices array. - - Returns: - A tuple. The first term is the flattened array. The second term is None - if `return_parent_indices` is False; otherwise it's a parent indices array - parallel to the flattened array: if parent_indices[i] = j, then - flattened_array[i] belongs to the j-th element of the input array. - """ - parent_indices = None - - while is_list_like(array.type): - if return_parent_indices: - cur_parent_indices = GetFlattenedArrayParentIndices( - array).to_numpy() - if parent_indices is None: - parent_indices = cur_parent_indices - else: - parent_indices = parent_indices[cur_parent_indices] - array = array.flatten() - - # the array is not nested at the first place. - if return_parent_indices and parent_indices is None: - parent_indices = np.arange(len(array)) - return array, parent_indices +) -> Tuple[pa.Array, Optional[np.ndarray]]: + """Flattens all the list arrays nesting an array. + + If `array` is not list-like, itself will be returned. + + Args: + ---- + array: pa.Array to flatten. + return_parent_indices: If True, also returns the parent indices array. + + Returns: + ------- + A tuple. The first term is the flattened array. The second term is None + if `return_parent_indices` is False; otherwise it's a parent indices array + parallel to the flattened array: if parent_indices[i] = j, then + flattened_array[i] belongs to the j-th element of the input array. + """ + parent_indices = None + + while is_list_like(array.type): + if return_parent_indices: + cur_parent_indices = GetFlattenedArrayParentIndices(array).to_numpy() + if parent_indices is None: + parent_indices = cur_parent_indices + else: + parent_indices = parent_indices[cur_parent_indices] + array = array.flatten() + + # the array is not nested at the first place. + if return_parent_indices and parent_indices is None: + parent_indices = np.arange(len(array)) + return array, parent_indices def get_field(struct_array: pa.StructArray, field: Union[str, int]) -> pa.Array: - """Returns struct_array.field(field) with null propagation. - - This function is equivalent to struct_array.field() but correctly handles - null propagation (the parent struct's null values are propagated to children). - - Args: - struct_array: A struct array which should be queried. - field: The request field to retrieve. - - Returns: - A pa.Array containing the requested field. - - Raises: - KeyError: If field is not a child field in struct_array. - """ - child_array = struct_array.field(field) - - # In case all values are present then there's no need for special handling. - # We can return child_array as is to avoid a performance penalty caused by - # constructing and flattening the returned array. - if struct_array.null_count == 0: - return child_array - # is_valid returns a BooleanArray with two buffers the buffer at offset - # 0 is always None and buffer 1 contains the data on which fields are - # valid/not valid. - # (https://arrow.apache.org/docs/format/Columnar.html#buffer-listing-for-each-layout) - validity_bitmap_buffer = struct_array.is_valid().buffers()[1] - - # Construct a new struct array with a single field. Calling flatten() on the - # new array guarantees validity bitmaps are merged correctly. - new_type = pa.struct([pa.field(field, child_array.type)]) - filtered_struct = pa.StructArray.from_buffers( - new_type, - len(struct_array), [validity_bitmap_buffer], - null_count=struct_array.null_count, - children=[child_array]) - return filtered_struct.flatten()[0] - + """Returns struct_array.field(field) with null propagation. + + This function is equivalent to struct_array.field() but correctly handles + null propagation (the parent struct's null values are propagated to children). + + Args: + ---- + struct_array: A struct array which should be queried. + field: The request field to retrieve. + + Returns: + ------- + A pa.Array containing the requested field. + + Raises: + ------ + KeyError: If field is not a child field in struct_array. + """ + child_array = struct_array.field(field) + + # In case all values are present then there's no need for special handling. + # We can return child_array as is to avoid a performance penalty caused by + # constructing and flattening the returned array. + if struct_array.null_count == 0: + return child_array + # is_valid returns a BooleanArray with two buffers the buffer at offset + # 0 is always None and buffer 1 contains the data on which fields are + # valid/not valid. + # (https://arrow.apache.org/docs/format/Columnar.html#buffer-listing-for-each-layout) + validity_bitmap_buffer = struct_array.is_valid().buffers()[1] + + # Construct a new struct array with a single field. Calling flatten() on the + # new array guarantees validity bitmaps are merged correctly. + new_type = pa.struct([pa.field(field, child_array.type)]) + filtered_struct = pa.StructArray.from_buffers( + new_type, + len(struct_array), + [validity_bitmap_buffer], + null_count=struct_array.null_count, + children=[child_array], + ) + return filtered_struct.flatten()[0] diff --git a/tfx_bsl/arrow/array_util_test.py b/tfx_bsl/arrow/array_util_test.py index 9c146817..9e7036d1 100644 --- a/tfx_bsl/arrow/array_util_test.py +++ b/tfx_bsl/arrow/array_util_test.py @@ -17,13 +17,10 @@ import numpy as np import pyarrow as pa +from absl.testing import absltest, parameterized from tfx_bsl.arrow import array_util -from absl.testing import absltest -from absl.testing import parameterized - - _LIST_TYPE_PARAMETERS = [ dict(testcase_name="list", list_type_factory=pa.list_), dict(testcase_name="large_list", list_type_factory=pa.large_list), @@ -31,219 +28,245 @@ class ArrayUtilTest(parameterized.TestCase): + def test_invalid_input_type(self): + functions_expecting_list_array = [ + array_util.GetFlattenedArrayParentIndices, + ] + functions_expecting_array = [array_util.GetArrayNullBitmapAsByteArray] + functions_expecting_binary_array = [array_util.GetBinaryArrayTotalByteSize] + for f in itertools.chain( + functions_expecting_list_array, + functions_expecting_array, + functions_expecting_binary_array, + ): + with self.assertRaises((TypeError, RuntimeError)): + f(1) + + for f in functions_expecting_list_array: + with self.assertRaisesRegex(RuntimeError, "UNIMPLEMENTED"): + f(pa.array([1, 2, 3])) + + for f in functions_expecting_binary_array: + with self.assertRaisesRegex(RuntimeError, "UNIMPLEMENTED"): + f(pa.array([[1, 2, 3]])) + + @parameterized.named_parameters(*_LIST_TYPE_PARAMETERS) + def test_list_lengths(self, list_type_factory): + list_lengths = array_util.ListLengthsFromListArray( + pa.array([], type=list_type_factory(pa.int64())) + ) + self.assertTrue(list_lengths.equals(pa.array([], type=pa.int64()))) + list_lengths = array_util.ListLengthsFromListArray( + pa.array([[1.0, 2.0], [], [3.0]], type=list_type_factory(pa.float32())) + ) + self.assertTrue(list_lengths.equals(pa.array([2, 0, 1], type=pa.int64()))) + list_lengths = array_util.ListLengthsFromListArray( + pa.array([[1.0, 2.0], None, [3.0]], type=list_type_factory(pa.float64())) + ) + self.assertTrue(list_lengths.equals(pa.array([2, 0, 1], type=pa.int64()))) + + @parameterized.named_parameters(*_LIST_TYPE_PARAMETERS) + def test_element_lengths_list_array(self, list_type_factory): + list_lengths = array_util.GetElementLengths( + pa.array([], type=list_type_factory(pa.int64())) + ) + self.assertTrue(list_lengths.equals(pa.array([], type=pa.int64()))) + list_lengths = array_util.GetElementLengths( + pa.array([[1.0, 2.0], [], [3.0]], list_type_factory(pa.float32())) + ) + self.assertTrue(list_lengths.equals(pa.array([2, 0, 1], type=pa.int64()))) + list_lengths = array_util.GetElementLengths( + pa.array([[1.0, 2.0], None, [3.0]], list_type_factory(pa.float64())) + ) + self.assertTrue(list_lengths.equals(pa.array([2, 0, 1], type=pa.int64()))) + + @parameterized.named_parameters( + *[ + dict(testcase_name="binary", binary_like_type=pa.binary()), + dict(testcase_name="string", binary_like_type=pa.string()), + dict(testcase_name="large_binary", binary_like_type=pa.large_binary()), + dict(testcase_name="large_string", binary_like_type=pa.large_string()), + ] + ) + def test_element_lengths_binary_like(self, binary_like_type): + list_lengths = array_util.GetElementLengths( + pa.array([b"a", b"bb", None, b"", b"ccc"], type=binary_like_type) + ) + self.assertTrue(list_lengths.equals(pa.array([1, 2, 0, 0, 3], type=pa.int64()))) + + def test_element_lengths_unsupported_type(self): + with self.assertRaisesRegex(RuntimeError, "UNIMPLEMENTED"): + array_util.GetElementLengths(pa.array([1, 2, 3], type=pa.int32())) + + def test_get_array_null_bitmap_as_byte_array(self): + array = pa.array([], type=pa.int32()) + null_masks = array_util.GetArrayNullBitmapAsByteArray(array) + self.assertTrue(null_masks.equals(pa.array([], type=pa.uint8()))) + + array = pa.array([1, 2, None, 3, None], type=pa.int32()) + null_masks = array_util.GetArrayNullBitmapAsByteArray(array) + self.assertTrue(null_masks.equals(pa.array([0, 0, 1, 0, 1], type=pa.uint8()))) + + array = pa.array([1, 2, 3]) + null_masks = array_util.GetArrayNullBitmapAsByteArray(array) + self.assertTrue(null_masks.equals(pa.array([0, 0, 0], type=pa.uint8()))) + + array = pa.array([None, None, None], type=pa.int32()) + null_masks = array_util.GetArrayNullBitmapAsByteArray(array) + self.assertTrue(null_masks.equals(pa.array([1, 1, 1], type=pa.uint8()))) + # Demonstrate that the returned array can be converted to a numpy boolean + # array w/o copying + np.testing.assert_equal( + np.array([True, True, True]), null_masks.to_numpy().view(bool) + ) - def test_invalid_input_type(self): + @parameterized.named_parameters( + *[ + dict( + testcase_name="list", + list_type_factory=pa.list_, + parent_indices_type=pa.int32(), + ), + dict( + testcase_name="large_list", + list_type_factory=pa.large_list, + parent_indices_type=pa.int64(), + ), + ] + ) + def test_get_flattened_array_parent_indices( + self, list_type_factory, parent_indices_type + ): + indices = array_util.GetFlattenedArrayParentIndices( + pa.array([], type=list_type_factory(pa.int32())) + ) + self.assertTrue(indices.equals(pa.array([], type=parent_indices_type))) - functions_expecting_list_array = [ - array_util.GetFlattenedArrayParentIndices, - ] - functions_expecting_array = [array_util.GetArrayNullBitmapAsByteArray] - functions_expecting_binary_array = [array_util.GetBinaryArrayTotalByteSize] - for f in itertools.chain(functions_expecting_list_array, - functions_expecting_array, - functions_expecting_binary_array): - with self.assertRaises((TypeError, RuntimeError)): - f(1) - - for f in functions_expecting_list_array: - with self.assertRaisesRegex(RuntimeError, "UNIMPLEMENTED"): - f(pa.array([1, 2, 3])) - - for f in functions_expecting_binary_array: - with self.assertRaisesRegex(RuntimeError, "UNIMPLEMENTED"): - f(pa.array([[1, 2, 3]])) - - @parameterized.named_parameters(*_LIST_TYPE_PARAMETERS) - def test_list_lengths(self, list_type_factory): - list_lengths = array_util.ListLengthsFromListArray( - pa.array([], type=list_type_factory(pa.int64()))) - self.assertTrue(list_lengths.equals(pa.array([], type=pa.int64()))) - list_lengths = array_util.ListLengthsFromListArray( - pa.array([[1., 2.], [], [3.]], type=list_type_factory(pa.float32()))) - self.assertTrue(list_lengths.equals(pa.array([2, 0, 1], type=pa.int64()))) - list_lengths = array_util.ListLengthsFromListArray( - pa.array([[1., 2.], None, [3.]], type=list_type_factory(pa.float64()))) - self.assertTrue(list_lengths.equals(pa.array([2, 0, 1], type=pa.int64()))) - - @parameterized.named_parameters(*_LIST_TYPE_PARAMETERS) - def test_element_lengths_list_array(self, list_type_factory): - list_lengths = array_util.GetElementLengths( - pa.array([], type=list_type_factory(pa.int64()))) - self.assertTrue(list_lengths.equals(pa.array([], type=pa.int64()))) - list_lengths = array_util.GetElementLengths( - pa.array([[1., 2.], [], [3.]], list_type_factory(pa.float32()))) - self.assertTrue(list_lengths.equals(pa.array([2, 0, 1], type=pa.int64()))) - list_lengths = array_util.GetElementLengths( - pa.array([[1., 2.], None, [3.]], list_type_factory(pa.float64()))) - self.assertTrue(list_lengths.equals(pa.array([2, 0, 1], type=pa.int64()))) - - @parameterized.named_parameters(*[ - dict(testcase_name="binary", binary_like_type=pa.binary()), - dict(testcase_name="string", binary_like_type=pa.string()), - dict(testcase_name="large_binary", binary_like_type=pa.large_binary()), - dict(testcase_name="large_string", binary_like_type=pa.large_string()), - ]) - def test_element_lengths_binary_like(self, binary_like_type): - - list_lengths = array_util.GetElementLengths( - pa.array([b"a", b"bb", None, b"", b"ccc"], type=binary_like_type)) - self.assertTrue(list_lengths.equals(pa.array([1, 2, 0, 0, 3], - type=pa.int64()))) - - def test_element_lengths_unsupported_type(self): - with self.assertRaisesRegex(RuntimeError, "UNIMPLEMENTED"): - array_util.GetElementLengths(pa.array([1, 2, 3], type=pa.int32())) - - def test_get_array_null_bitmap_as_byte_array(self): - array = pa.array([], type=pa.int32()) - null_masks = array_util.GetArrayNullBitmapAsByteArray(array) - self.assertTrue(null_masks.equals(pa.array([], type=pa.uint8()))) - - array = pa.array([1, 2, None, 3, None], type=pa.int32()) - null_masks = array_util.GetArrayNullBitmapAsByteArray(array) - self.assertTrue( - null_masks.equals(pa.array([0, 0, 1, 0, 1], type=pa.uint8()))) - - array = pa.array([1, 2, 3]) - null_masks = array_util.GetArrayNullBitmapAsByteArray(array) - self.assertTrue(null_masks.equals(pa.array([0, 0, 0], type=pa.uint8()))) - - array = pa.array([None, None, None], type=pa.int32()) - null_masks = array_util.GetArrayNullBitmapAsByteArray(array) - self.assertTrue(null_masks.equals(pa.array([1, 1, 1], type=pa.uint8()))) - # Demonstrate that the returned array can be converted to a numpy boolean - # array w/o copying - np.testing.assert_equal( - np.array([True, True, True]), null_masks.to_numpy().view(bool)) - - @parameterized.named_parameters(*[ - dict( - testcase_name="list", - list_type_factory=pa.list_, - parent_indices_type=pa.int32()), - dict( - testcase_name="large_list", - list_type_factory=pa.large_list, - parent_indices_type=pa.int64()), - ]) - def test_get_flattened_array_parent_indices(self, list_type_factory, - parent_indices_type): - indices = array_util.GetFlattenedArrayParentIndices( - pa.array([], type=list_type_factory(pa.int32()))) - self.assertTrue(indices.equals(pa.array([], type=parent_indices_type))) - - indices = array_util.GetFlattenedArrayParentIndices( - pa.array([[1.], [2.], [], [3., 4.]], - type=list_type_factory(pa.float32()))) - self.assertTrue( - indices.equals(pa.array([0, 1, 3, 3], type=parent_indices_type))) - - indices = array_util.GetFlattenedArrayParentIndices( - pa.array([[1.], [2.], [], [3., 4.]], - type=list_type_factory(pa.float32())).slice(1)) - self.assertTrue( - indices.equals(pa.array([0, 2, 2], type=parent_indices_type))) - - indices = array_util.GetFlattenedArrayParentIndices( - pa.array([list(range(1024))], - type=list_type_factory(pa.int64()))) - self.assertTrue( - indices.equals(pa.array([0] * 1024, type=parent_indices_type))) - - @parameterized.named_parameters(*[ - dict(testcase_name="binary", binary_like_type=pa.binary()), - dict(testcase_name="string", binary_like_type=pa.string()), - dict(testcase_name="large_binary", binary_like_type=pa.large_binary()), - dict(testcase_name="large_string", binary_like_type=pa.large_string()), - ]) - def test_get_binary_array_total_byte_size(self, binary_like_type): - array = pa.array([b"abc", None, b"def", b"", b"ghi"], type=binary_like_type) - self.assertEqual(9, array_util.GetBinaryArrayTotalByteSize(array)) - sliced_1_2 = array.slice(1, 2) - self.assertEqual(3, array_util.GetBinaryArrayTotalByteSize(sliced_1_2)) - sliced_2 = array.slice(2) - self.assertEqual(6, array_util.GetBinaryArrayTotalByteSize(sliced_2)) - - empty_array = pa.array([], type=binary_like_type) - self.assertEqual(0, array_util.GetBinaryArrayTotalByteSize(empty_array)) - - def test_indexin_integer(self): - values = pa.array([99, 42, 3, None]) - # TODO(b/203116559): Change this back to [3, 3, 99] once arrow >= 5.0 - # is required by TFDV. - value_set = pa.array([3, 4, 99]) - actual = array_util.IndexIn(values, value_set) - actual.validate() - self.assertTrue( - actual.equals(pa.array([2, None, 0, None], type=pa.int32()))) - - @parameterized.parameters( - *(list( - itertools.product([pa.binary(), pa.large_binary()], - [pa.binary(), pa.large_binary()])) + - list( - itertools.product([pa.string(), pa.large_string()], - [pa.string(), pa.large_string()])))) - def test_indexin_binary_alike(self, values_type, value_set_type): - # Case #1: value_set does not contain null. - values = pa.array(["aa", "bb", "cc", None], values_type) - value_set = pa.array(["cc", "cc", "aa"], value_set_type) - actual = array_util.IndexIn(values, value_set) - actual.validate() - self.assertTrue( - actual.equals(pa.array([1, None, 0, None], type=pa.int32())), - "actual: {}".format(actual)) - - # Case #2: value_set contains nulls. - values = pa.array(["aa", "bb", "cc", None], values_type) - value_set = pa.array(["cc", None, None, "bb"], value_set_type) - actual = array_util.IndexIn(values, value_set) - actual.validate() - self.assertTrue( - actual.equals(pa.array([None, 2, 0, 1], type=pa.int32())), - "actual: {}".format(actual)) - - def test_is_list_like(self): - for t in (pa.list_(pa.int64()), pa.large_list(pa.int64())): - self.assertTrue(array_util.is_list_like(t)) - - for t in (pa.binary(), pa.int64(), pa.large_string()): - self.assertFalse(array_util.is_list_like(t)) - - def test_get_innermost_nested_type_nested_input(self): - for inner_type in pa.int64(), pa.float32(), pa.binary(): - for t in (pa.list_(inner_type), pa.large_list(inner_type)): + indices = array_util.GetFlattenedArrayParentIndices( + pa.array( + [[1.0], [2.0], [], [3.0, 4.0]], type=list_type_factory(pa.float32()) + ) + ) self.assertTrue( - array_util.get_innermost_nested_type(t).equals(inner_type) + indices.equals(pa.array([0, 1, 3, 3], type=parent_indices_type)) ) - def test_get_innermost_nested_type_non_nested_input(self): - for t in pa.int64(), pa.float32(), pa.binary(): - self.assertTrue(array_util.get_innermost_nested_type(t).equals(t)) + indices = array_util.GetFlattenedArrayParentIndices( + pa.array( + [[1.0], [2.0], [], [3.0, 4.0]], type=list_type_factory(pa.float32()) + ).slice(1) + ) + self.assertTrue(indices.equals(pa.array([0, 2, 2], type=parent_indices_type))) - def test_flatten_nested(self): - input_array = pa.array([[[1, 2]], None, [None, [3]]]) - flattened, parent_indices = array_util.flatten_nested( - input_array, return_parent_indices=False + indices = array_util.GetFlattenedArrayParentIndices( + pa.array([list(range(1024))], type=list_type_factory(pa.int64())) + ) + self.assertTrue(indices.equals(pa.array([0] * 1024, type=parent_indices_type))) + + @parameterized.named_parameters( + *[ + dict(testcase_name="binary", binary_like_type=pa.binary()), + dict(testcase_name="string", binary_like_type=pa.string()), + dict(testcase_name="large_binary", binary_like_type=pa.large_binary()), + dict(testcase_name="large_string", binary_like_type=pa.large_string()), + ] ) - expected = pa.array([1, 2, 3]) - expected_parent_indices = [0, 0, 2] - self.assertIs(parent_indices, None) - self.assertTrue(flattened.equals(expected)) - - flattened, parent_indices = array_util.flatten_nested( - input_array, return_parent_indices=True + def test_get_binary_array_total_byte_size(self, binary_like_type): + array = pa.array([b"abc", None, b"def", b"", b"ghi"], type=binary_like_type) + self.assertEqual(9, array_util.GetBinaryArrayTotalByteSize(array)) + sliced_1_2 = array.slice(1, 2) + self.assertEqual(3, array_util.GetBinaryArrayTotalByteSize(sliced_1_2)) + sliced_2 = array.slice(2) + self.assertEqual(6, array_util.GetBinaryArrayTotalByteSize(sliced_2)) + + empty_array = pa.array([], type=binary_like_type) + self.assertEqual(0, array_util.GetBinaryArrayTotalByteSize(empty_array)) + + def test_indexin_integer(self): + values = pa.array([99, 42, 3, None]) + # TODO(b/203116559): Change this back to [3, 3, 99] once arrow >= 5.0 + # is required by TFDV. + value_set = pa.array([3, 4, 99]) + actual = array_util.IndexIn(values, value_set) + actual.validate() + self.assertTrue(actual.equals(pa.array([2, None, 0, None], type=pa.int32()))) + + @parameterized.parameters( + *( + list( + itertools.product( + [pa.binary(), pa.large_binary()], [pa.binary(), pa.large_binary()] + ) + ) + + list( + itertools.product( + [pa.string(), pa.large_string()], [pa.string(), pa.large_string()] + ) + ) + ) ) - self.assertTrue(flattened.equals(expected)) - np.testing.assert_array_equal(parent_indices, expected_parent_indices) + def test_indexin_binary_alike(self, values_type, value_set_type): + # Case #1: value_set does not contain null. + values = pa.array(["aa", "bb", "cc", None], values_type) + value_set = pa.array(["cc", "cc", "aa"], value_set_type) + actual = array_util.IndexIn(values, value_set) + actual.validate() + self.assertTrue( + actual.equals(pa.array([1, None, 0, None], type=pa.int32())), + f"actual: {actual}", + ) - def test_flatten_nested_non_list(self): - input_array = pa.array([1, 2]) - flattened, parent_indices = array_util.flatten_nested( - input_array, return_parent_indices=True - ) - self.assertTrue(flattened.equals(pa.array([1, 2]))) - np.testing.assert_array_equal(parent_indices, [0, 1]) + # Case #2: value_set contains nulls. + values = pa.array(["aa", "bb", "cc", None], values_type) + value_set = pa.array(["cc", None, None, "bb"], value_set_type) + actual = array_util.IndexIn(values, value_set) + actual.validate() + self.assertTrue( + actual.equals(pa.array([None, 2, 0, 1], type=pa.int32())), + f"actual: {actual}", + ) + + def test_is_list_like(self): + for t in (pa.list_(pa.int64()), pa.large_list(pa.int64())): + self.assertTrue(array_util.is_list_like(t)) + + for t in (pa.binary(), pa.int64(), pa.large_string()): + self.assertFalse(array_util.is_list_like(t)) + + def test_get_innermost_nested_type_nested_input(self): + for inner_type in pa.int64(), pa.float32(), pa.binary(): + for t in (pa.list_(inner_type), pa.large_list(inner_type)): + self.assertTrue( + array_util.get_innermost_nested_type(t).equals(inner_type) + ) + + def test_get_innermost_nested_type_non_nested_input(self): + for t in pa.int64(), pa.float32(), pa.binary(): + self.assertTrue(array_util.get_innermost_nested_type(t).equals(t)) + + def test_flatten_nested(self): + input_array = pa.array([[[1, 2]], None, [None, [3]]]) + flattened, parent_indices = array_util.flatten_nested( + input_array, return_parent_indices=False + ) + expected = pa.array([1, 2, 3]) + expected_parent_indices = [0, 0, 2] + self.assertIs(parent_indices, None) + self.assertTrue(flattened.equals(expected)) + + flattened, parent_indices = array_util.flatten_nested( + input_array, return_parent_indices=True + ) + self.assertTrue(flattened.equals(expected)) + np.testing.assert_array_equal(parent_indices, expected_parent_indices) + + def test_flatten_nested_non_list(self): + input_array = pa.array([1, 2]) + flattened, parent_indices = array_util.flatten_nested( + input_array, return_parent_indices=True + ) + self.assertTrue(flattened.equals(pa.array([1, 2]))) + np.testing.assert_array_equal(parent_indices, [0, 1]) _MAKE_LIST_ARRAY_INVALID_INPUT_TEST_CASES = [ @@ -253,15 +276,15 @@ def test_flatten_nested_non_list(self): parent_indices=pa.array([0], type=pa.int32()), values=pa.array([1]), expected_error=RuntimeError, - expected_error_regexp="must be int64" - ), + expected_error_regexp="must be int64", + ), dict( testcase_name="parent_indices_length_not_equal_to_values_length", num_parents=1, parent_indices=pa.array([0], type=pa.int64()), values=pa.array([1, 2]), expected_error=RuntimeError, - expected_error_regexp="values array and parent indices array must be of the same length" + expected_error_regexp="values array and parent indices array must be of the same length", ), dict( testcase_name="num_parents_too_small", @@ -269,8 +292,8 @@ def test_flatten_nested_non_list(self): parent_indices=pa.array([1], type=pa.int64()), values=pa.array([1]), expected_error=RuntimeError, - expected_error_regexp="Found a parent index 1 while num_parents was 1" - ) + expected_error_regexp="Found a parent index 1 while num_parents was 1", + ), ] @@ -281,8 +304,10 @@ def test_flatten_nested_non_list(self): parent_indices=pa.array([], type=pa.int64()), values=pa.array([], type=pa.int64()), empty_list_as_null=True, - expected=pa.array([None, None, None, None, None], - type=pa.large_list(pa.int64()))), + expected=pa.array( + [None, None, None, None, None], type=pa.large_list(pa.int64()) + ), + ), dict( testcase_name="leading_nulls", num_parents=3, @@ -297,16 +322,18 @@ def test_flatten_nested_non_list(self): parent_indices=pa.array([0, 0, 0, 3, 3], type=pa.int64()), values=pa.array(["a", "b", "c", "d", "e"], type=pa.binary()), empty_list_as_null=True, - expected=pa.array([["a", "b", "c"], None, None, ["d", "e"]], - type=pa.large_list(pa.binary()))), + expected=pa.array( + [["a", "b", "c"], None, None, ["d", "e"]], type=pa.large_list(pa.binary()) + ), + ), dict( testcase_name="parents_are_all_empty", num_parents=5, parent_indices=pa.array([], type=pa.int64()), values=pa.array([], type=pa.int64()), empty_list_as_null=False, - expected=pa.array([[], [], [], [], []], - type=pa.large_list(pa.int64()))), + expected=pa.array([[], [], [], [], []], type=pa.large_list(pa.int64())), + ), dict( testcase_name="leading_empties", num_parents=3, @@ -321,32 +348,36 @@ def test_flatten_nested_non_list(self): parent_indices=pa.array([0, 0, 0, 3, 3], type=pa.int64()), values=pa.array(["a", "b", "c", "d", "e"], type=pa.binary()), empty_list_as_null=False, - expected=pa.array([["a", "b", "c"], [], [], ["d", "e"]], - type=pa.large_list(pa.binary())), + expected=pa.array( + [["a", "b", "c"], [], [], ["d", "e"]], type=pa.large_list(pa.binary()) ), + ), ] class MakeListArrayFromParentIndicesAndValuesTest(parameterized.TestCase): - - @parameterized.named_parameters(*_MAKE_LIST_ARRAY_INVALID_INPUT_TEST_CASES) - def testInvalidInput(self, num_parents, parent_indices, values, - expected_error, expected_error_regexp): - with self.assertRaisesRegex(expected_error, expected_error_regexp): - array_util.MakeListArrayFromParentIndicesAndValues( - num_parents, parent_indices, values) - - @parameterized.named_parameters(*_MAKE_LIST_ARRAY_TEST_CASES) - def testMakeListArray(self, num_parents, parent_indices, values, - empty_list_as_null, expected): - actual = array_util.MakeListArrayFromParentIndicesAndValues( - num_parents, parent_indices, values, empty_list_as_null) - actual.validate() - if not empty_list_as_null: - self.assertEqual(actual.null_count, 0) - self.assertTrue( - actual.equals(expected), - "actual: {}, expected: {}".format(actual, expected)) + @parameterized.named_parameters(*_MAKE_LIST_ARRAY_INVALID_INPUT_TEST_CASES) + def testInvalidInput( + self, num_parents, parent_indices, values, expected_error, expected_error_regexp + ): + with self.assertRaisesRegex(expected_error, expected_error_regexp): + array_util.MakeListArrayFromParentIndicesAndValues( + num_parents, parent_indices, values + ) + + @parameterized.named_parameters(*_MAKE_LIST_ARRAY_TEST_CASES) + def testMakeListArray( + self, num_parents, parent_indices, values, empty_list_as_null, expected + ): + actual = array_util.MakeListArrayFromParentIndicesAndValues( + num_parents, parent_indices, values, empty_list_as_null + ) + actual.validate() + if not empty_list_as_null: + self.assertEqual(actual.null_count, 0) + self.assertTrue( + actual.equals(expected), f"actual: {actual}, expected: {expected}" + ) _COO_FROM_LIST_ARRAY_TEST_CASES = [ @@ -369,15 +400,15 @@ def testMakeListArray(self, num_parents, parent_indices, values, list_array=[[]], expected_coo=[], expected_dense_shape=[1, 0], - array_types=[pa.list_(pa.int64()), - pa.large_list(pa.string())]), + array_types=[pa.list_(pa.int64()), pa.large_list(pa.string())], + ), dict( testcase_name="2d_ragged", list_array=[["a", "b"], ["c"], [], ["d", "e"]], expected_coo=[0, 0, 0, 1, 1, 0, 3, 0, 3, 1], expected_dense_shape=[4, 2], - array_types=[pa.list_(pa.string()), - pa.large_list(pa.large_binary())]), + array_types=[pa.list_(pa.string()), pa.large_list(pa.large_binary())], + ), dict( testcase_name="3d_ragged", list_array=[[["a", "b"], ["c"]], [[], ["d", "e"]]], @@ -394,26 +425,27 @@ def testMakeListArray(self, num_parents, parent_indices, values, class CooFromListArrayTest(parameterized.TestCase): - - @parameterized.named_parameters(*_COO_FROM_LIST_ARRAY_TEST_CASES) - def testCooFromListArray( - self, list_array, expected_coo, expected_dense_shape, array_types): - - for array_type in array_types: - for input_array in [ - pa.array(list_array, type=array_type), - # it should work for sliced arrays. - pa.array(list_array + list_array, - type=array_type).slice(0, len(list_array)), - pa.array(list_array + list_array, - type=array_type).slice(len(list_array)), - ]: - coo, dense_shape = array_util.CooFromListArray(input_array) - self.assertTrue(coo.type.equals(pa.int64())) - self.assertTrue(dense_shape.type.equals(pa.int64())) - - self.assertEqual(expected_coo, coo.to_pylist()) - self.assertEqual(expected_dense_shape, dense_shape.to_pylist()) + @parameterized.named_parameters(*_COO_FROM_LIST_ARRAY_TEST_CASES) + def testCooFromListArray( + self, list_array, expected_coo, expected_dense_shape, array_types + ): + for array_type in array_types: + for input_array in [ + pa.array(list_array, type=array_type), + # it should work for sliced arrays. + pa.array(list_array + list_array, type=array_type).slice( + 0, len(list_array) + ), + pa.array(list_array + list_array, type=array_type).slice( + len(list_array) + ), + ]: + coo, dense_shape = array_util.CooFromListArray(input_array) + self.assertTrue(coo.type.equals(pa.int64())) + self.assertTrue(dense_shape.type.equals(pa.int64())) + + self.assertEqual(expected_coo, coo.to_pylist()) + self.assertEqual(expected_dense_shape, dense_shape.to_pylist()) _FILL_NULL_LISTS_TEST_CASES = [ @@ -465,114 +497,133 @@ def testCooFromListArray( value_type=pa.large_binary(), fill_with=["x", "x"], expected=[["a"], ["b"], ["c"], ["x", "x"], ["d"], ["x", "x"], ["e"]], - ) + ), ] def _cross_named_parameters(*named_parameters_dicts): - result = [] - for product in itertools.product(*named_parameters_dicts): - crossed = dict(product[0]) - testcase_name = crossed["testcase_name"] - for d in product[1:]: - testcase_name += "_" + d["testcase_name"] - crossed.update(d) - crossed["testcase_name"] = testcase_name - result.append(crossed) - return result + result = [] + for product in itertools.product(*named_parameters_dicts): + crossed = dict(product[0]) + testcase_name = crossed["testcase_name"] + for d in product[1:]: + testcase_name += "_" + d["testcase_name"] + crossed.update(d) + crossed["testcase_name"] = testcase_name + result.append(crossed) + return result class FillNullListsTest(parameterized.TestCase): + @parameterized.named_parameters( + *_cross_named_parameters(_FILL_NULL_LISTS_TEST_CASES, _LIST_TYPE_PARAMETERS) + ) + def testFillNullLists( + self, list_array, value_type, fill_with, expected, list_type_factory + ): + actual = array_util.FillNullLists( + pa.array(list_array, type=list_type_factory(value_type)), + pa.array(fill_with, type=value_type), + ) + self.assertTrue( + actual.equals(pa.array(expected, type=list_type_factory(value_type))), + f"{actual} vs {expected}", + ) - @parameterized.named_parameters(*_cross_named_parameters( - _FILL_NULL_LISTS_TEST_CASES, _LIST_TYPE_PARAMETERS)) - def testFillNullLists( - self, list_array, value_type, fill_with, expected, list_type_factory): - actual = array_util.FillNullLists( - pa.array(list_array, type=list_type_factory(value_type)), - pa.array(fill_with, type=value_type)) - self.assertTrue( - actual.equals(pa.array(expected, type=list_type_factory(value_type))), - "{} vs {}".format(actual, expected)) - - def testNonListArray(self): - with self.assertRaisesRegex(RuntimeError, "UNIMPLEMENTED"): - array_util.FillNullLists(pa.array([1, 2, 3]), pa.array([4])) + def testNonListArray(self): + with self.assertRaisesRegex(RuntimeError, "UNIMPLEMENTED"): + array_util.FillNullLists(pa.array([1, 2, 3]), pa.array([4])) - def testValueTypeDoesNotEqualFillType(self): - with self.assertRaisesRegex(RuntimeError, "to be of the same type"): - array_util.FillNullLists(pa.array([[1]]), pa.array(["a"])) + def testValueTypeDoesNotEqualFillType(self): + with self.assertRaisesRegex(RuntimeError, "to be of the same type"): + array_util.FillNullLists(pa.array([[1]]), pa.array(["a"])) def _all_false_null_bitmap_size(size): - if pa.__version__ < "0.17": - return size - # starting from arrow 0.17, the array factory won't create a null bitmap if - # no element is null. - # TODO(zhuo): clean up this shim once tfx_bsl supports arrow 0.17+ - # exclusively. - return 0 + if pa.__version__ < "0.17": + return size + # starting from arrow 0.17, the array factory won't create a null bitmap if + # no element is null. + # TODO(zhuo): clean up this shim once tfx_bsl supports arrow 0.17+ + # exclusively. + return 0 def _get_numeric_byte_size_test_cases(): - result = [] - for array_type, sizeof in [ - (pa.int8(), 1), - (pa.uint8(), 1), - (pa.int16(), 2), - (pa.uint16(), 2), - (pa.int32(), 4), - (pa.uint32(), 4), - (pa.int64(), 8), - (pa.uint64(), 8), - (pa.float32(), 4), - (pa.float64(), 8), - ]: - result.append( - dict( - testcase_name=str(array_type), - array=pa.array(range(9), type=array_type), - slice_offset=2, - slice_length=3, - expected_size=(_all_false_null_bitmap_size(2) + sizeof * 9), - expected_sliced_size=(_all_false_null_bitmap_size(1) + sizeof * 3))) - return result + result = [] + for array_type, sizeof in [ + (pa.int8(), 1), + (pa.uint8(), 1), + (pa.int16(), 2), + (pa.uint16(), 2), + (pa.int32(), 4), + (pa.uint32(), 4), + (pa.int64(), 8), + (pa.uint64(), 8), + (pa.float32(), 4), + (pa.float64(), 8), + ]: + result.append( + dict( + testcase_name=str(array_type), + array=pa.array(range(9), type=array_type), + slice_offset=2, + slice_length=3, + expected_size=(_all_false_null_bitmap_size(2) + sizeof * 9), + expected_sliced_size=(_all_false_null_bitmap_size(1) + sizeof * 3), + ) + ) + return result def _get_binary_like_byte_size_test_cases(): - result = [] - for array_type, sizeof_offsets in [ - (pa.binary(), 4), - (pa.string(), 4), - (pa.large_binary(), 8), - (pa.large_string(), 8), - ]: - result.append( - dict( - testcase_name=str(array_type), - array=pa.array([ - "a", "bb", "ccc", "dddd", "eeeee", "ffffff", "ggggggg", - "hhhhhhhh", "iiiiiiiii" - ], - type=array_type), - slice_offset=1, - slice_length=3, - # contents: 45 - # offsets: 10 * sizeof_offsets - # null bitmap: 2 - expected_size=(45 + sizeof_offsets * 10 + - _all_false_null_bitmap_size(2)), - # contents: 9 - # offsets: 4 * sizeof_offsets - # null bitmap: 1 - expected_sliced_size=(9 + sizeof_offsets * 4 + - _all_false_null_bitmap_size(1)))) - return result + result = [] + for array_type, sizeof_offsets in [ + (pa.binary(), 4), + (pa.string(), 4), + (pa.large_binary(), 8), + (pa.large_string(), 8), + ]: + result.append( + dict( + testcase_name=str(array_type), + array=pa.array( + [ + "a", + "bb", + "ccc", + "dddd", + "eeeee", + "ffffff", + "ggggggg", + "hhhhhhhh", + "iiiiiiiii", + ], + type=array_type, + ), + slice_offset=1, + slice_length=3, + # contents: 45 + # offsets: 10 * sizeof_offsets + # null bitmap: 2 + expected_size=( + 45 + sizeof_offsets * 10 + _all_false_null_bitmap_size(2) + ), + # contents: 9 + # offsets: 4 * sizeof_offsets + # null bitmap: 1 + expected_sliced_size=( + 9 + sizeof_offsets * 4 + _all_false_null_bitmap_size(1) + ), + ) + ) + return result _GET_BYTE_SIZE_TEST_CASES = ( - _get_numeric_byte_size_test_cases() + - _get_binary_like_byte_size_test_cases() + [ + _get_numeric_byte_size_test_cases() + + _get_binary_like_byte_size_test_cases() + + [ dict( testcase_name="bool", array=pa.array([False] * 9, type=pa.bool_()), @@ -583,11 +634,13 @@ def _get_binary_like_byte_size_test_cases(): expected_size=(_all_false_null_bitmap_size(2) + 2), # contents: 1 # null bitmap: 1 - expected_sliced_size=(_all_false_null_bitmap_size(1) + 1)), + expected_sliced_size=(_all_false_null_bitmap_size(1) + 1), + ), dict( testcase_name="list", - array=pa.array([[1], [1, 1], [1, 1, 1], [1, 1, 1, 1]], - type=pa.list_(pa.int64())), + array=pa.array( + [[1], [1, 1], [1, 1, 1], [1, 1, 1, 1]], type=pa.list_(pa.int64()) + ), slice_offset=1, slice_length=2, # offsets: 5 * 4 @@ -601,12 +654,13 @@ def _get_binary_like_byte_size_test_cases(): # contents: # null bitmap: 1 # contents: 5 * 8 - expected_sliced_size=(3 * 4 + _all_false_null_bitmap_size(1 + 1) - + 5 * 8)), + expected_sliced_size=(3 * 4 + _all_false_null_bitmap_size(1 + 1) + 5 * 8), + ), dict( testcase_name="large_list", - array=pa.array([[1], [1, 1], [1, 1, 1], [1, 1, 1, 1]], - type=pa.large_list(pa.int64())), + array=pa.array( + [[1], [1, 1], [1, 1, 1], [1, 1, 1, 1]], type=pa.large_list(pa.int64()) + ), slice_offset=1, slice_length=2, # offsets: 5 * 8 @@ -620,72 +674,80 @@ def _get_binary_like_byte_size_test_cases(): # contents: # null bitmap: 1 # contents: 5 * 8 - expected_sliced_size=( - 3 * 8 + _all_false_null_bitmap_size(1 + 1) + 5 * 8)), + expected_sliced_size=(3 * 8 + _all_false_null_bitmap_size(1 + 1) + 5 * 8), + ), dict( testcase_name="deeply_nested_list", - array=pa.array([[["aaa"], ["bb", ""], None], - None, - [["c"], [], ["def", "g"]], - [["h"]]], - type=pa.list_(pa.list_(pa.binary()))), + array=pa.array( + [[["aaa"], ["bb", ""], None], None, [["c"], [], ["def", "g"]], [["h"]]], + type=pa.list_(pa.list_(pa.binary())), + ), slice_offset=1, slice_length=2, # innermost binary array: 1 + 11 + 8 * 4 # mid list array: 1 + 8 * 4 # outmost list array: 1 + 5 * 4 - expected_size=(97 + - # innermost binary array does not have null - _all_false_null_bitmap_size(1)), + expected_size=( + 97 + + + # innermost binary array does not have null + _all_false_null_bitmap_size(1) + ), # innermost binary array (["c", "def", "g"]): 1 + 5 + 4 * 4 # mid list array: ([["c"], [], ["def, "g]]): 1 + 4 * 4 # outmost list array: 1 + 3 * 4 expected_sliced_size=( - 51 + + 51 + + # innermost binary array does not have null - _all_false_null_bitmap_size(1))), + _all_false_null_bitmap_size(1) + ), + ), dict( testcase_name="null", array=pa.array([None] * 1000), slice_offset=4, slice_length=100, expected_size=0, - expected_sliced_size=0), + expected_sliced_size=0, + ), dict( testcase_name="struct", array=pa.array( - [{ - "a": 1, - "b": 2 - }] * 10, - type=pa.struct( - [pa.field("a", pa.int64()), - pa.field("b", pa.int64())])), + [{"a": 1, "b": 2}] * 10, + type=pa.struct([pa.field("a", pa.int64()), pa.field("b", pa.int64())]), + ), slice_offset=2, slice_length=1, - expected_size=(_all_false_null_bitmap_size(2) + - (_all_false_null_bitmap_size(2) + 10 * 8) * 2), - expected_sliced_size=(_all_false_null_bitmap_size(1) + - (_all_false_null_bitmap_size(1) + 8) * 2)) - ]) + expected_size=( + _all_false_null_bitmap_size(2) + + (_all_false_null_bitmap_size(2) + 10 * 8) * 2 + ), + expected_sliced_size=( + _all_false_null_bitmap_size(1) + + (_all_false_null_bitmap_size(1) + 8) * 2 + ), + ), + ] +) class GetByteSizeTest(parameterized.TestCase): + @parameterized.named_parameters(*_GET_BYTE_SIZE_TEST_CASES) + def testGetByteSize( + self, array, slice_offset, slice_length, expected_size, expected_sliced_size + ): + # make sure the empty array case does not crash. + array_util.GetByteSize(pa.array([], array.type)) - @parameterized.named_parameters(*_GET_BYTE_SIZE_TEST_CASES) - def testGetByteSize(self, array, slice_offset, slice_length, expected_size, - expected_sliced_size): - # make sure the empty array case does not crash. - array_util.GetByteSize(pa.array([], array.type)) - - self.assertEqual(array_util.GetByteSize(array), expected_size) + self.assertEqual(array_util.GetByteSize(array), expected_size) - sliced = array.slice(slice_offset, slice_length) - self.assertEqual(array_util.GetByteSize(sliced), expected_sliced_size) + sliced = array.slice(slice_offset, slice_length) + self.assertEqual(array_util.GetByteSize(sliced), expected_sliced_size) - def testUnsupported(self): - with self.assertRaisesRegex(RuntimeError, "UNIMPLEMENTED"): - array_util.GetByteSize(pa.array([], type=pa.timestamp("s"))) + def testUnsupported(self): + with self.assertRaisesRegex(RuntimeError, "UNIMPLEMENTED"): + array_util.GetByteSize(pa.array([], type=pa.timestamp("s"))) _TO_SINGLETON_LIST_ARRAY_TEST_CASES = [ @@ -697,33 +759,33 @@ def testUnsupported(self): dict( testcase_name="no_null", array=pa.array([1, 2, 3]), - expected_result=pa.array([[1], [2], [3]], - type=pa.large_list(pa.int64())), + expected_result=pa.array([[1], [2], [3]], type=pa.large_list(pa.int64())), ), dict( testcase_name="all_nulls", array=pa.array([None, None, None], type=pa.binary()), - expected_result=pa.array([None, None, None], - type=pa.large_list(pa.binary())), + expected_result=pa.array([None, None, None], type=pa.large_list(pa.binary())), ), dict( testcase_name="some_nulls", array=pa.array([None, None, 2, 3, None, 4, None, None]), - expected_result=pa.array([None, None, [2], [3], None, [4], None, None], - type=pa.large_list(pa.int64())), + expected_result=pa.array( + [None, None, [2], [3], None, [4], None, None], + type=pa.large_list(pa.int64()), + ), ), ] class ToSingletonListArrayTest(parameterized.TestCase): - - @parameterized.named_parameters(*_TO_SINGLETON_LIST_ARRAY_TEST_CASES) - def testToSingletonListArray(self, array, expected_result): - result = array_util.ToSingletonListArray(array) - result.validate() - self.assertTrue( - result.equals(expected_result), - "expected: {}; got: {}".format(expected_result, result)) + @parameterized.named_parameters(*_TO_SINGLETON_LIST_ARRAY_TEST_CASES) + def testToSingletonListArray(self, array, expected_result): + result = array_util.ToSingletonListArray(array) + result.validate() + self.assertTrue( + result.equals(expected_result), + f"expected: {expected_result}; got: {result}", + ) _COUNT_INVALID_UTF8_TEST_CASES = [ @@ -744,8 +806,7 @@ def testToSingletonListArray(self, array, expected_result): ), dict( testcase_name="some_valid_binary_array", - array=pa.array([b"a", b"b", b"\xfc\xa1\xa1\xa1\xa1\xa1"], - type="binary"), + array=pa.array([b"a", b"b", b"\xfc\xa1\xa1\xa1\xa1\xa1"], type="binary"), expected_count=1, ), dict( @@ -757,16 +818,15 @@ def testToSingletonListArray(self, array, expected_result): class CountInvalidUtf8(parameterized.TestCase): - - @parameterized.named_parameters(*_COUNT_INVALID_UTF8_TEST_CASES) - def test_count_utf8(self, array, expected_count=None, expected_error=None): - if expected_error: - with self.assertRaisesRegex(RuntimeError, expected_error): - array_util.CountInvalidUTF8(array) - else: - count = array_util.CountInvalidUTF8(array) - self.assertEqual(expected_count, count) + @parameterized.named_parameters(*_COUNT_INVALID_UTF8_TEST_CASES) + def test_count_utf8(self, array, expected_count=None, expected_error=None): + if expected_error: + with self.assertRaisesRegex(RuntimeError, expected_error): + array_util.CountInvalidUTF8(array) + else: + count = array_util.CountInvalidUTF8(array) + self.assertEqual(expected_count, count) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/tfx_bsl/arrow/path.py b/tfx_bsl/arrow/path.py index ef1052fd..2d6f5e05 100644 --- a/tfx_bsl/arrow/path.py +++ b/tfx_bsl/arrow/path.py @@ -18,120 +18,130 @@ from tensorflow_metadata.proto.v0 import path_pb2 -class ColumnPath(object): - """ColumnPath addresses a column potentially nested under a StructArray.""" +class ColumnPath: + """ColumnPath addresses a column potentially nested under a StructArray.""" + + __slot__ = ["_steps"] + + def __init__(self, steps: Union[Iterable[str], str]): + """If a single Step is specified, constructs a Path of that step.""" + if isinstance(steps, str): + steps = (steps,) + self._steps = tuple(steps) + + def to_proto(self) -> path_pb2.Path: + """Creates a tensorflow_metadata path proto this ColumnPath.""" + return path_pb2.Path(step=self._steps) - __slot__ = ["_steps"] + @staticmethod + def from_proto(path_proto: path_pb2.Path): + """Creates a ColumnPath from a tensorflow_metadata path proto. - def __init__(self, steps: Union[Iterable[Text], Text]): - """If a single Step is specified, constructs a Path of that step.""" - if isinstance(steps, Text): - steps = (steps,) - self._steps = tuple(steps) + Args: + ---- + path_proto: a tensorflow_metadata path proto. - def to_proto(self) -> path_pb2.Path: - """Creates a tensorflow_metadata path proto this ColumnPath.""" - return path_pb2.Path(step=self._steps) + Returns: + ------- + A ColumnPath representing the path proto's steps. + """ + return ColumnPath(path_proto.step) - @staticmethod - def from_proto(path_proto: path_pb2.Path): - """Creates a ColumnPath from a tensorflow_metadata path proto. + def steps(self) -> Tuple[str, ...]: + """Returns the tuple of steps that represents this ColumnPath.""" + return self._steps - Args: - path_proto: a tensorflow_metadata path proto. + def parent(self) -> "ColumnPath": + """Gets the parent path of the current ColumnPath. - Returns: - A ColumnPath representing the path proto's steps. - """ - return ColumnPath(path_proto.step) + example: ColumnPath(["this", "is", "my", "path"]).parent() will + return a ColumnPath representing "this.is.my". - def steps(self) -> Tuple[Text, ...]: - """Returns the tuple of steps that represents this ColumnPath.""" - return self._steps + Returns + ------- + A ColumnPath with the leaf step removed. + """ + if not self._steps: + raise ValueError("Root does not have parent.") + return ColumnPath(self._steps[:-1]) - def parent(self) -> "ColumnPath": - """Gets the parent path of the current ColumnPath. + def child(self, child_step: str) -> "ColumnPath": + """Creates a new ColumnPath with a new child. - example: ColumnPath(["this", "is", "my", "path"]).parent() will - return a ColumnPath representing "this.is.my". + example: ColumnPath(["this", "is", "my", "path"]).child("new_step") will + return a ColumnPath representing "this.is.my.path.new_step". - Returns: - A ColumnPath with the leaf step removed. - """ - if not self._steps: - raise ValueError("Root does not have parent.") - return ColumnPath(self._steps[:-1]) + Args: + ---- + child_step: name of the new child step to append. - def child(self, child_step: Text) -> "ColumnPath": - """Creates a new ColumnPath with a new child. + Returns: + ------- + A ColumnPath with the new child_step + """ + return ColumnPath(self._steps + (child_step,)) - example: ColumnPath(["this", "is", "my", "path"]).child("new_step") will - return a ColumnPath representing "this.is.my.path.new_step". + def prefix(self, ending_index: int) -> "ColumnPath": + """Creates a new ColumnPath, taking the prefix until the ending_index. - Args: - child_step: name of the new child step to append. + example: ColumnPath(["this", "is", "my", "path"]).prefix(1) will return a + ColumnPath representing "this.is.my". - Returns: - A ColumnPath with the new child_step - """ - return ColumnPath(self._steps + (child_step,)) + Args: + ---- + ending_index: where to end the prefix. - def prefix(self, ending_index: int) -> "ColumnPath": - """Creates a new ColumnPath, taking the prefix until the ending_index. + Returns: + ------- + A ColumnPath representing the prefix of this ColumnPath. + """ + return ColumnPath(self._steps[:ending_index]) - example: ColumnPath(["this", "is", "my", "path"]).prefix(1) will return a - ColumnPath representing "this.is.my". + def suffix(self, starting_index: int) -> "ColumnPath": + """Creates a new ColumnPath, taking the suffix from the starting_index. - Args: - ending_index: where to end the prefix. + example: ColumnPath(["this", "is", "my", "path"]).suffix(1) will return a + ColumnPath representing "is.my.path". - Returns: - A ColumnPath representing the prefix of this ColumnPath. - """ - return ColumnPath(self._steps[:ending_index]) + Args: + ---- + starting_index: where to start the suffix. - def suffix(self, starting_index: int) -> "ColumnPath": - """Creates a new ColumnPath, taking the suffix from the starting_index. + Returns: + ------- + A ColumnPath representing the suffix of this ColumnPath. + """ + return ColumnPath(self._steps[starting_index:]) - example: ColumnPath(["this", "is", "my", "path"]).suffix(1) will return a - ColumnPath representing "is.my.path". + def initial_step(self) -> str: + """Returns the first step of this path. - Args: - starting_index: where to start the suffix. + Raises + ------ + ValueError: if the path is empty. + """ + if not self._steps: + raise ValueError("This ColumnPath does not have any steps.") + return self._steps[0] + + def __str__(self) -> str: + return ".".join(self._steps) - Returns: - A ColumnPath representing the suffix of this ColumnPath. - """ - return ColumnPath(self._steps[starting_index:]) + def __repr__(self) -> str: + return self.__str__() + + def __eq__(self, other) -> bool: + return self._steps == other._steps # pylint: disable=protected-access - def initial_step(self) -> Text: - """Returns the first step of this path. + def __lt__(self, other) -> bool: + # lexicographic order. + return self._steps < other._steps # pylint: disable=protected-access - Raises: - ValueError: if the path is empty. - """ - if not self._steps: - raise ValueError("This ColumnPath does not have any steps.") - return self._steps[0] + def __hash__(self) -> int: + return hash(self._steps) - def __str__(self) -> Text: - return u".".join(self._steps) + def __len__(self) -> int: + return len(self._steps) - def __repr__(self) -> Text: - return self.__str__() - - def __eq__(self, other) -> bool: - return self._steps == other._steps # pylint: disable=protected-access - - def __lt__(self, other) -> bool: - # lexicographic order. - return self._steps < other._steps # pylint: disable=protected-access - - def __hash__(self) -> int: - return hash(self._steps) - - def __len__(self) -> int: - return len(self._steps) - - def __bool__(self) -> bool: - return bool(self._steps) + def __bool__(self) -> bool: + return bool(self._steps) diff --git a/tfx_bsl/arrow/table_util.py b/tfx_bsl/arrow/table_util.py index a11987ab..d803625a 100644 --- a/tfx_bsl/arrow/table_util.py +++ b/tfx_bsl/arrow/table_util.py @@ -18,21 +18,26 @@ import numpy as np import pyarrow as pa -from tfx_bsl.arrow import array_util -from tfx_bsl.arrow import path + +from tfx_bsl.arrow import array_util, path # pytype: disable=import-error # pylint: disable=unused-import # pylint: disable=g-import-not-at-top # See b/148667210 for why the ImportError is ignored. try: - from tfx_bsl.cc.tfx_bsl_extension.arrow.table_util import RecordBatchTake - from tfx_bsl.cc.tfx_bsl_extension.arrow.table_util import MergeRecordBatches as _MergeRecordBatches - from tfx_bsl.cc.tfx_bsl_extension.arrow.table_util import TotalByteSize as _TotalByteSize + from tfx_bsl.cc.tfx_bsl_extension.arrow.table_util import ( + MergeRecordBatches as _MergeRecordBatches, + ) + from tfx_bsl.cc.tfx_bsl_extension.arrow.table_util import RecordBatchTake + from tfx_bsl.cc.tfx_bsl_extension.arrow.table_util import ( + TotalByteSize as _TotalByteSize, + ) except ImportError as err: - sys.stderr.write("Error importing tfx_bsl_extension.arrow.table_util. " - "Some tfx_bsl functionalities are not available: {}" - .format(err)) + sys.stderr.write( + "Error importing tfx_bsl_extension.arrow.table_util. " + f"Some tfx_bsl functionalities are not available: {err}" + ) # pylint: enable=g-import-not-at-top # pytype: enable=import-error # pylint: enable=unused-import @@ -51,244 +56,264 @@ } -def TotalByteSize(table_or_batch: Union[pa.Table, pa.RecordBatch], - ignore_unsupported=False): - """Returns the in-memory size of a record batch or a table.""" - if isinstance(table_or_batch, pa.Table): - return sum([ - _TotalByteSize(b, ignore_unsupported) - for b in table_or_batch.to_batches(max_chunksize=None) - ]) - else: - return _TotalByteSize(table_or_batch, ignore_unsupported) +def TotalByteSize( + table_or_batch: Union[pa.Table, pa.RecordBatch], ignore_unsupported=False +): + """Returns the in-memory size of a record batch or a table.""" + if isinstance(table_or_batch, pa.Table): + return sum( + [ + _TotalByteSize(b, ignore_unsupported) + for b in table_or_batch.to_batches(max_chunksize=None) + ] + ) + else: + return _TotalByteSize(table_or_batch, ignore_unsupported) def NumpyKindToArrowType(kind: str) -> Optional[pa.DataType]: - return _NUMPY_KIND_TO_ARROW_TYPE.get(kind) + return _NUMPY_KIND_TO_ARROW_TYPE.get(kind) def MergeRecordBatches(record_batches: List[pa.RecordBatch]) -> pa.RecordBatch: - """Merges a list of arrow RecordBatches into one. Similar to MergeTables.""" - if not record_batches: - return _EMPTY_RECORD_BATCH - first_schema = record_batches[0].schema - assert any([r.num_rows > 0 for r in record_batches]), ( - "Unable to merge empty RecordBatches.") - if (all([r.schema.equals(first_schema) for r in record_batches[1:]]) - # combine_chunks() cannot correctly handle the case where there are - # 0 column. (ARROW-11232) - and first_schema): - one_chunk_table = pa.Table.from_batches(record_batches).combine_chunks() - batches = one_chunk_table.to_batches(max_chunksize=None) - assert len(batches) == 1 - return batches[0] - else: - # Our implementation of _MergeRecordBatches is different than - # pa.Table.concat_tables( - # [pa.Table.from_batches([rb]) for rb in record_batches], - # promote=True).combine_chunks().to_batches()[0] - # in its handling of struct-typed columns -- if two record batches have a - # column of the same name but of different struct types, _MergeRecordBatches - # will try merging (recursively) those struct types while concat_tables - # will not. We should consider upstreaming our implementation because it's a - # generalization - return _MergeRecordBatches(record_batches) + """Merges a list of arrow RecordBatches into one. Similar to MergeTables.""" + if not record_batches: + return _EMPTY_RECORD_BATCH + first_schema = record_batches[0].schema + assert any( + [r.num_rows > 0 for r in record_batches] + ), "Unable to merge empty RecordBatches." + if ( + all([r.schema.equals(first_schema) for r in record_batches[1:]]) + # combine_chunks() cannot correctly handle the case where there are + # 0 column. (ARROW-11232) + and first_schema + ): + one_chunk_table = pa.Table.from_batches(record_batches).combine_chunks() + batches = one_chunk_table.to_batches(max_chunksize=None) + assert len(batches) == 1 + return batches[0] + else: + # Our implementation of _MergeRecordBatches is different than + # pa.Table.concat_tables( + # [pa.Table.from_batches([rb]) for rb in record_batches], + # promote=True).combine_chunks().to_batches()[0] + # in its handling of struct-typed columns -- if two record batches have a + # column of the same name but of different struct types, _MergeRecordBatches + # will try merging (recursively) those struct types while concat_tables + # will not. We should consider upstreaming our implementation because it's a + # generalization + return _MergeRecordBatches(record_batches) def _CanonicalizeType(arrow_type: pa.DataType) -> pa.DataType: - """Returns canonical version of the given type.""" - if pa.types.is_list(arrow_type) or pa.types.is_large_list(arrow_type): - return pa.large_list(_CanonicalizeType(arrow_type.value_type)) - else: - result = NumpyKindToArrowType(np.dtype(arrow_type.to_pandas_dtype()).kind) - if result is None: - raise NotImplementedError(f"Type {arrow_type} is not supported.") - return result + """Returns canonical version of the given type.""" + if pa.types.is_list(arrow_type) or pa.types.is_large_list(arrow_type): + return pa.large_list(_CanonicalizeType(arrow_type.value_type)) + else: + result = NumpyKindToArrowType(np.dtype(arrow_type.to_pandas_dtype()).kind) + if result is None: + raise NotImplementedError(f"Type {arrow_type} is not supported.") + return result def CanonicalizeRecordBatch( - record_batch_with_primitive_arrays: pa.RecordBatch) -> pa.RecordBatch: - """Converts primitive arrays in a pyarrow.RecordBatch to LargeListArrays. - - The produced LargeListArrays' elements are lists that contain single element - of the array of the canonical pyarrow type. - - Args: - record_batch_with_primitive_arrays: A pyarrow.RecordBatch where values are - stored in primitive arrays or list arrays. - - Returns: - pyArrow.RecordBatch with LargeListArray columns. - """ - arrays = [] - for column_array in record_batch_with_primitive_arrays.columns: - canonical_type = _CanonicalizeType(column_array.type) - if canonical_type != column_array.type: - column_array = column_array.cast(canonical_type) - if pa.types.is_large_list(canonical_type): - arrays.append(column_array) - else: - arrays.append(array_util.ToSingletonListArray(column_array)) - return pa.RecordBatch.from_arrays( - arrays, record_batch_with_primitive_arrays.schema.names) + record_batch_with_primitive_arrays: pa.RecordBatch, +) -> pa.RecordBatch: + """Converts primitive arrays in a pyarrow.RecordBatch to LargeListArrays. + + The produced LargeListArrays' elements are lists that contain single element + of the array of the canonical pyarrow type. + + Args: + ---- + record_batch_with_primitive_arrays: A pyarrow.RecordBatch where values are + stored in primitive arrays or list arrays. + + Returns: + ------- + pyArrow.RecordBatch with LargeListArray columns. + """ + arrays = [] + for column_array in record_batch_with_primitive_arrays.columns: + canonical_type = _CanonicalizeType(column_array.type) + if canonical_type != column_array.type: + column_array = column_array.cast(canonical_type) + if pa.types.is_large_list(canonical_type): + arrays.append(column_array) + else: + arrays.append(array_util.ToSingletonListArray(column_array)) + return pa.RecordBatch.from_arrays( + arrays, record_batch_with_primitive_arrays.schema.names + ) def enumerate_arrays( # pylint: disable=invalid-name record_batch: pa.RecordBatch, enumerate_leaves_only: bool, - wrap_flat_struct_in_list: bool = True + wrap_flat_struct_in_list: bool = True, ) -> Iterable[Tuple[path.ColumnPath, pa.Array]]: - """Enumerates arrays in a RecordBatch. - - Define: - primitive: primitive arrow arrays (e.g. Int64Array). - nested_list := list | list | null - # note: a null array can be seen as a list, which contains only - # nulls and the type of the primitive is unknown. - # example: - # null, - # list, # like list> with only null values. - # list>, - struct := struct<{field: nested_list | struct}> | list - # example: - # struct<{"foo": list}, - # list}>>, - # struct<{"foo": struct<{"bar": list>}>}> - - This function assumes `record_batch` contains only nested_list and struct - columns. It enumerates each column in `record_batch`, and if that column is - a struct, it flattens the outer lists wrapping it (if any), and recursively - enumerates the array of each field in the struct (also see - `enumerate_leaves_only`). - - A ColumnPath is included in the result to address the enumerated array. - Note that the ColumnPath merely addresses in the `record_batch` and struct - arrays. It does not indicate whether / how a struct array is nested. - - Args: - record_batch: The RecordBatch whose arrays to be visited. - enumerate_leaves_only: If True, only enumerate leaf arrays. A leaf array - is an array whose type does not have any struct nested in. - Otherwise, also enumerate the struct arrays where the leaf arrays are - contained. - wrap_flat_struct_in_list: if True, and if a struct<[Ts]> array is - encountered, it will be wrapped in a list array, so it becomes a - list>, in which each sub-list contains one element. - A caller can make use of this option to assume all the arrays enumerated - here are list. - Yields: - A tuple. The first term is the path of the feature, and the second term is - the feature array. - """ - - def _recursion_helper( # pylint: disable=invalid-name - feature_path: path.ColumnPath, array: pa.Array, - ) -> Iterable[Tuple[path.ColumnPath, pa.Array]]: - """Recursion helper.""" - array_type = array.type - innermost_nested_type = array_util.get_innermost_nested_type(array_type) - if pa.types.is_struct(innermost_nested_type): - if not enumerate_leaves_only: - # special handing for a flat struct array -- wrap it in a ListArray - # whose elements are singleton lists. This way downstream can keep - # assuming the enumerated arrays are list<*>. - to_yield = array - if pa.types.is_struct(array_type) and wrap_flat_struct_in_list: - to_yield = array_util.ToSingletonListArray(array) - yield (feature_path, to_yield) - flat_struct_array, _ = array_util.flatten_nested(array) - for field in flat_struct_array.type: - field_name = field.name - yield from _recursion_helper( - feature_path.child(field_name), - array_util.get_field(flat_struct_array, field_name)) - else: - yield (feature_path, array) - - for column_name, column in zip(record_batch.schema.names, - record_batch.columns): - yield from _recursion_helper( - path.ColumnPath([column_name]), column) - - -def get_array( # pylint: disable=invalid-name + """Enumerates arrays in a RecordBatch. + + Define: + primitive: primitive arrow arrays (e.g. Int64Array). + nested_list := list | list | null + # note: a null array can be seen as a list, which contains only + # nulls and the type of the primitive is unknown. + # example: + # null, + # list, # like list> with only null values. + # list>, + struct := struct<{field: nested_list | struct}> | list + # example: + # struct<{"foo": list}, + # list}>>, + # struct<{"foo": struct<{"bar": list>}>}> + + This function assumes `record_batch` contains only nested_list and struct + columns. It enumerates each column in `record_batch`, and if that column is + a struct, it flattens the outer lists wrapping it (if any), and recursively + enumerates the array of each field in the struct (also see + `enumerate_leaves_only`). + + A ColumnPath is included in the result to address the enumerated array. + Note that the ColumnPath merely addresses in the `record_batch` and struct + arrays. It does not indicate whether / how a struct array is nested. + + Args: + ---- + record_batch: The RecordBatch whose arrays to be visited. + enumerate_leaves_only: If True, only enumerate leaf arrays. A leaf array + is an array whose type does not have any struct nested in. + Otherwise, also enumerate the struct arrays where the leaf arrays are + contained. + wrap_flat_struct_in_list: if True, and if a struct<[Ts]> array is + encountered, it will be wrapped in a list array, so it becomes a + list>, in which each sub-list contains one element. + A caller can make use of this option to assume all the arrays enumerated + here are list. + + Yields: + ------ + A tuple. The first term is the path of the feature, and the second term is + the feature array. + """ + + def _recursion_helper( # pylint: disable=invalid-name + feature_path: path.ColumnPath, + array: pa.Array, + ) -> Iterable[Tuple[path.ColumnPath, pa.Array]]: + """Recursion helper.""" + array_type = array.type + innermost_nested_type = array_util.get_innermost_nested_type(array_type) + if pa.types.is_struct(innermost_nested_type): + if not enumerate_leaves_only: + # special handing for a flat struct array -- wrap it in a ListArray + # whose elements are singleton lists. This way downstream can keep + # assuming the enumerated arrays are list<*>. + to_yield = array + if pa.types.is_struct(array_type) and wrap_flat_struct_in_list: + to_yield = array_util.ToSingletonListArray(array) + yield (feature_path, to_yield) + flat_struct_array, _ = array_util.flatten_nested(array) + for field in flat_struct_array.type: + field_name = field.name + yield from _recursion_helper( + feature_path.child(field_name), + array_util.get_field(flat_struct_array, field_name), + ) + else: + yield (feature_path, array) + + for column_name, column in zip(record_batch.schema.names, record_batch.columns): + yield from _recursion_helper(path.ColumnPath([column_name]), column) + + +def get_array( # pylint: disable=invalid-name record_batch: pa.RecordBatch, query_path: path.ColumnPath, return_example_indices: bool, wrap_flat_struct_in_list: bool = True, ) -> Tuple[pa.Array, Optional[np.ndarray]]: - """Retrieve a nested array (and optionally example indices) from RecordBatch. - - This function has the same assumption over `record_batch` as - `enumerate_arrays()` does. - - If the provided path refers to a leaf in the `record_batch`, then a - "nested_list" will be returned. If the provided path does not refer to a leaf, - a "struct" will be returned. - - See `enumerate_arrays()` for definition of "nested_list" and "struct". - - Args: - record_batch: The RecordBatch whose arrays to be visited. - query_path: The ColumnPath to lookup in the record_batch. - return_example_indices: Whether to return an additional array containing the - example indices of the elements in the array corresponding to the - query_path. - wrap_flat_struct_in_list: if True, and if the query_path leads to a - struct<[Ts]> array, it will be wrapped in a list array, where each - sub-list contains one element. Caller can make use of this option to - assume this function always returns a list. - - Returns: - A tuple. The first term is the feature array and the second term is the - example_indices array for the feature array (i.e. array[i] came from the - example at row example_indices[i] in the record_batch.). - - Raises: - KeyError: When the query_path is empty, or cannot be found in the - record_batch and its nested struct arrays. - """ - - def _recursion_helper( # pylint: disable=invalid-name - query_path: path.ColumnPath, array: pa.Array, - example_indices: Optional[np.ndarray] - ) -> Tuple[pa.Array, Optional[np.ndarray]]: - """Recursion helper.""" - array_type = array.type + """Retrieve a nested array (and optionally example indices) from RecordBatch. + + This function has the same assumption over `record_batch` as + `enumerate_arrays()` does. + + If the provided path refers to a leaf in the `record_batch`, then a + "nested_list" will be returned. If the provided path does not refer to a leaf, + a "struct" will be returned. + + See `enumerate_arrays()` for definition of "nested_list" and "struct". + + Args: + ---- + record_batch: The RecordBatch whose arrays to be visited. + query_path: The ColumnPath to lookup in the record_batch. + return_example_indices: Whether to return an additional array containing the + example indices of the elements in the array corresponding to the + query_path. + wrap_flat_struct_in_list: if True, and if the query_path leads to a + struct<[Ts]> array, it will be wrapped in a list array, where each + sub-list contains one element. Caller can make use of this option to + assume this function always returns a list. + + Returns: + ------- + A tuple. The first term is the feature array and the second term is the + example_indices array for the feature array (i.e. array[i] came from the + example at row example_indices[i] in the record_batch.). + + Raises: + ------ + KeyError: When the query_path is empty, or cannot be found in the + record_batch and its nested struct arrays. + """ + + def _recursion_helper( # pylint: disable=invalid-name + query_path: path.ColumnPath, + array: pa.Array, + example_indices: Optional[np.ndarray], + ) -> Tuple[pa.Array, Optional[np.ndarray]]: + """Recursion helper.""" + array_type = array.type + if not query_path: + if pa.types.is_struct(array_type) and wrap_flat_struct_in_list: + array = array_util.ToSingletonListArray(array) + return array, example_indices + if not pa.types.is_struct(array_util.get_innermost_nested_type(array_type)): + raise KeyError( + f"Cannot process query_path ({query_path}) inside an array of type " + f"{array_type}. Expecting a struct<...> or " + "(large_)list...>." + ) + flat_struct_array, parent_indices = array_util.flatten_nested( + array, example_indices is not None + ) + flat_indices = None + if example_indices is not None: + flat_indices = example_indices[parent_indices] + + step = query_path.steps()[0] + + try: + child_array = array_util.get_field(flat_struct_array, step) + except KeyError as exception: + raise KeyError(f"query_path step ({step}) not in struct.") from exception + + relative_path = path.ColumnPath(query_path.steps()[1:]) + return _recursion_helper(relative_path, child_array, flat_indices) + if not query_path: - if pa.types.is_struct(array_type) and wrap_flat_struct_in_list: - array = array_util.ToSingletonListArray(array) - return array, example_indices - if not pa.types.is_struct(array_util.get_innermost_nested_type(array_type)): - raise KeyError("Cannot process query_path ({}) inside an array of type " - "{}. Expecting a struct<...> or " - "(large_)list...>.".format( - query_path, array_type)) - flat_struct_array, parent_indices = array_util.flatten_nested( - array, example_indices is not None) - flat_indices = None - if example_indices is not None: - flat_indices = example_indices[parent_indices] - - step = query_path.steps()[0] - - try: - child_array = array_util.get_field(flat_struct_array, step) - except KeyError as exception: - raise KeyError(f"query_path step ({step}) not in struct.") from exception - - relative_path = path.ColumnPath(query_path.steps()[1:]) - return _recursion_helper(relative_path, child_array, flat_indices) - - if not query_path: - raise KeyError("query_path must be non-empty.") - column_name = query_path.steps()[0] - field_index = record_batch.schema.get_field_index(column_name) - if field_index < 0: - raise KeyError(f"query_path step 0 ({column_name}) not in record batch.") - array = record_batch.column(field_index) - array_path = path.ColumnPath(query_path.steps()[1:]) - - example_indices = np.arange( - record_batch.num_rows) if return_example_indices else None - return _recursion_helper(array_path, array, example_indices) + raise KeyError("query_path must be non-empty.") + column_name = query_path.steps()[0] + field_index = record_batch.schema.get_field_index(column_name) + if field_index < 0: + raise KeyError(f"query_path step 0 ({column_name}) not in record batch.") + array = record_batch.column(field_index) + array_path = path.ColumnPath(query_path.steps()[1:]) + + example_indices = ( + np.arange(record_batch.num_rows) if return_example_indices else None + ) + return _recursion_helper(array_path, array, example_indices) diff --git a/tfx_bsl/arrow/table_util_test.py b/tfx_bsl/arrow/table_util_test.py index 24aaf154..53ea9e80 100644 --- a/tfx_bsl/arrow/table_util_test.py +++ b/tfx_bsl/arrow/table_util_test.py @@ -15,19 +15,14 @@ import collections import itertools - from typing import Dict, Iterable, NamedTuple import numpy as np import pyarrow as pa import six -from tfx_bsl.arrow import array_util -from tfx_bsl.arrow import path -from tfx_bsl.arrow import table_util - -from absl.testing import absltest -from absl.testing import parameterized +from absl.testing import absltest, parameterized +from tfx_bsl.arrow import array_util, path, table_util _MERGE_TEST_CASES = [ dict( @@ -44,14 +39,12 @@ "uint64": pa.array([1, None, 3], type=pa.uint64()), "int32": pa.array([1, None, 3], type=pa.int32()), "uint32": pa.array([1, None, 3], type=pa.uint32()), - "float": pa.array([1., None, 3.], type=pa.float32()), - "double": pa.array([1., None, 3.], type=pa.float64()), + "float": pa.array([1.0, None, 3.0], type=pa.float32()), + "double": pa.array([1.0, None, 3.0], type=pa.float64()), "bytes": pa.array([b"abc", None, b"ghi"], type=pa.binary()), - "large_bytes": pa.array([b"abc", None, b"ghi"], - type=pa.large_binary()), - "unicode": pa.array([u"abc", None, u"ghi"], type=pa.utf8()), - "large_unicode": pa.array([u"abc", None, u"ghi"], - type=pa.large_utf8()), + "large_bytes": pa.array([b"abc", None, b"ghi"], type=pa.large_binary()), + "unicode": pa.array(["abc", None, "ghi"], type=pa.utf8()), + "large_unicode": pa.array(["abc", None, "ghi"], type=pa.large_utf8()), }, { "bool": pa.array([None, False], type=pa.bool_()), @@ -59,48 +52,39 @@ "uint64": pa.array([None, 4], type=pa.uint64()), "int32": pa.array([None, 4], type=pa.int32()), "uint32": pa.array([None, 4], type=pa.uint32()), - "float": pa.array([None, 4.], type=pa.float32()), - "double": pa.array([None, 4.], type=pa.float64()), + "float": pa.array([None, 4.0], type=pa.float32()), + "double": pa.array([None, 4.0], type=pa.float64()), "bytes": pa.array([None, b"jkl"], type=pa.binary()), "large_bytes": pa.array([None, b"jkl"], type=pa.large_binary()), - "unicode": pa.array([None, u"jkl"], type=pa.utf8()), - "large_unicode": pa.array([None, u"jkl"], type=pa.large_utf8()), + "unicode": pa.array([None, "jkl"], type=pa.utf8()), + "large_unicode": pa.array([None, "jkl"], type=pa.large_utf8()), }, ], expected_output={ - "bool": - pa.array([False, None, True, None, False], type=pa.bool_()), - "int64": - pa.array([1, None, 3, None, 4], type=pa.int64()), - "uint64": - pa.array([1, None, 3, None, 4], type=pa.uint64()), - "int32": - pa.array([1, None, 3, None, 4], type=pa.int32()), - "uint32": - pa.array([1, None, 3, None, 4], type=pa.uint32()), - "float": - pa.array([1., None, 3., None, 4.], type=pa.float32()), - "double": - pa.array([1., None, 3., None, 4.], type=pa.float64()), - "bytes": - pa.array([b"abc", None, b"ghi", None, b"jkl"], - type=pa.binary()), - "large_bytes": - pa.array([b"abc", None, b"ghi", None, b"jkl"], - type=pa.large_binary()), - "unicode": - pa.array([u"abc", None, u"ghi", None, u"jkl"], - type=pa.utf8()), - "large_unicode": - pa.array([u"abc", None, u"ghi", None, u"jkl"], - type=pa.large_utf8()), - }), + "bool": pa.array([False, None, True, None, False], type=pa.bool_()), + "int64": pa.array([1, None, 3, None, 4], type=pa.int64()), + "uint64": pa.array([1, None, 3, None, 4], type=pa.uint64()), + "int32": pa.array([1, None, 3, None, 4], type=pa.int32()), + "uint32": pa.array([1, None, 3, None, 4], type=pa.uint32()), + "float": pa.array([1.0, None, 3.0, None, 4.0], type=pa.float32()), + "double": pa.array([1.0, None, 3.0, None, 4.0], type=pa.float64()), + "bytes": pa.array([b"abc", None, b"ghi", None, b"jkl"], type=pa.binary()), + "large_bytes": pa.array( + [b"abc", None, b"ghi", None, b"jkl"], type=pa.large_binary() + ), + "unicode": pa.array(["abc", None, "ghi", None, "jkl"], type=pa.utf8()), + "large_unicode": pa.array( + ["abc", None, "ghi", None, "jkl"], type=pa.large_utf8() + ), + }, + ), dict( testcase_name="list", inputs=[ { - "list": - pa.array([[1, None, 3], None], type=pa.list_(pa.int32())), + "list": pa.array( + [[1, None, 3], None], type=pa.list_(pa.int32()) + ), }, { "list": pa.array([None], type=pa.list_(pa.int32())), @@ -113,300 +97,325 @@ }, ], expected_output={ - "list": - pa.array([[1, None, 3], None, None, []], - type=pa.list_(pa.int32())) - }), + "list": pa.array( + [[1, None, 3], None, None, []], type=pa.list_(pa.int32()) + ) + }, + ), dict( testcase_name="large_list", inputs=[ { - "large_list": - pa.array([[1, None, 3], None], - type=pa.large_list(pa.int32())), + "large_list": pa.array( + [[1, None, 3], None], type=pa.large_list(pa.int32()) + ), }, { - "large_list": - pa.array([None], type=pa.large_list(pa.int32())), + "large_list": pa.array([None], type=pa.large_list(pa.int32())), }, { - "large_list": - pa.array([], type=pa.large_list(pa.int32())), + "large_list": pa.array([], type=pa.large_list(pa.int32())), }, { - "large_list": - pa.array([[]], type=pa.large_list(pa.int32())), + "large_list": pa.array([[]], type=pa.large_list(pa.int32())), }, ], expected_output={ - "large_list": - pa.array([[1, None, 3], None, None, []], - type=pa.large_list(pa.int32())) - }), + "large_list": pa.array( + [[1, None, 3], None, None, []], type=pa.large_list(pa.int32()) + ) + }, + ), dict( testcase_name="struct", - inputs=[{ - "struct>": - pa.StructArray.from_arrays([ - pa.array([b"abc", None, b"def"]), - pa.array([[None], [1, 2], []], type=pa.list_(pa.int32())) - ], ["f1", "f2"]) - }, { - "struct>": - pa.StructArray.from_arrays([ - pa.array([b"ghi"]), - pa.array([[3]], type=pa.list_(pa.int32())) - ], ["f1", "f2"]) - }], - expected_output={ - "struct>": - pa.StructArray.from_arrays([ - pa.array([b"abc", None, b"def", b"ghi"]), - pa.array([[None], [1, 2], [], [3]], - type=pa.list_(pa.int32())) - ], ["f1", "f2"]) - }), - dict( - testcase_name="missing_or_null_column_fixed_width", inputs=[ { - "int32": pa.array([None, None], type=pa.null()) + "struct>": pa.StructArray.from_arrays( + [ + pa.array([b"abc", None, b"def"]), + pa.array([[None], [1, 2], []], type=pa.list_(pa.int32())), + ], + ["f1", "f2"], + ) }, { - "int64": pa.array([None, None], type=pa.null()) - }, - { - "int64": pa.array([123], type=pa.int64()) - }, - { - "int32": pa.array([456], type=pa.int32()) + "struct>": pa.StructArray.from_arrays( + [pa.array([b"ghi"]), pa.array([[3]], type=pa.list_(pa.int32()))], + ["f1", "f2"], + ) }, ], expected_output={ - "int32": - pa.array([None, None, None, None, None, 456], type=pa.int32()), - "int64": - pa.array([None, None, None, None, 123, None], type=pa.int64()), - }), + "struct>": pa.StructArray.from_arrays( + [ + pa.array([b"abc", None, b"def", b"ghi"]), + pa.array([[None], [1, 2], [], [3]], type=pa.list_(pa.int32())), + ], + ["f1", "f2"], + ) + }, + ), + dict( + testcase_name="missing_or_null_column_fixed_width", + inputs=[ + {"int32": pa.array([None, None], type=pa.null())}, + {"int64": pa.array([None, None], type=pa.null())}, + {"int64": pa.array([123], type=pa.int64())}, + {"int32": pa.array([456], type=pa.int32())}, + ], + expected_output={ + "int32": pa.array([None, None, None, None, None, 456], type=pa.int32()), + "int64": pa.array([None, None, None, None, 123, None], type=pa.int64()), + }, + ), dict( testcase_name="missing_or_null_column_list_alike", inputs=[ - { - "list": pa.array([None, None], type=pa.null()) - }, - { - "utf8": pa.array([None, None], type=pa.null()) - }, - { - "utf8": pa.array([u"abc"], type=pa.utf8()) - }, - { - "list": - pa.array([None, [123, 456]], type=pa.list_(pa.int32())) - }, + {"list": pa.array([None, None], type=pa.null())}, + {"utf8": pa.array([None, None], type=pa.null())}, + {"utf8": pa.array(["abc"], type=pa.utf8())}, + {"list": pa.array([None, [123, 456]], type=pa.list_(pa.int32()))}, ], expected_output={ - "list": - pa.array([None, None, None, None, None, None, [123, 456]], - type=pa.list_(pa.int32())), - "utf8": - pa.array([None, None, None, None, u"abc", None, None], - type=pa.utf8()), - }), + "list": pa.array( + [None, None, None, None, None, None, [123, 456]], + type=pa.list_(pa.int32()), + ), + "utf8": pa.array( + [None, None, None, None, "abc", None, None], type=pa.utf8() + ), + }, + ), dict( testcase_name="missing_or_null_column_struct", - inputs=[{ - "struct>": pa.array([None, None], type=pa.null()) - }, { - "list": pa.array([None, None], type=pa.null()) - }, { - "struct>": - pa.StructArray.from_arrays([ - pa.array([1, 2, None], type=pa.int32()), - pa.array([[1], None, [3, 4]], type=pa.list_(pa.int32())) - ], ["f1", "f2"]) - }, { - "list": pa.array([u"abc", None], type=pa.utf8()) - }], + inputs=[ + {"struct>": pa.array([None, None], type=pa.null())}, + {"list": pa.array([None, None], type=pa.null())}, + { + "struct>": pa.StructArray.from_arrays( + [ + pa.array([1, 2, None], type=pa.int32()), + pa.array([[1], None, [3, 4]], type=pa.list_(pa.int32())), + ], + ["f1", "f2"], + ) + }, + {"list": pa.array(["abc", None], type=pa.utf8())}, + ], expected_output={ - "list": - pa.array( - [None, None, None, None, None, None, None, u"abc", None], - type=pa.utf8()), - "struct>": - pa.array([ - None, None, None, None, (1, [1]), (2, None), - (None, [3, 4]), None, None + "list": pa.array( + [None, None, None, None, None, None, None, "abc", None], type=pa.utf8() + ), + "struct>": pa.array( + [ + None, + None, + None, + None, + (1, [1]), + (2, None), + (None, [3, 4]), + None, + None, ], - type=pa.struct([ - pa.field("f1", pa.int32()), - pa.field("f2", pa.list_(pa.int32())) - ])), - }), + type=pa.struct( + [pa.field("f1", pa.int32()), pa.field("f2", pa.list_(pa.int32()))] + ), + ), + }, + ), dict( testcase_name="merge_list_of_null_and_list_of_list", - inputs=[{ - "f": pa.array([[None, None], None], type=pa.list_(pa.null())) - }, { - "f": pa.array([[[123]], None], type=pa.list_(pa.list_(pa.int32()))) - }], + inputs=[ + {"f": pa.array([[None, None], None], type=pa.list_(pa.null()))}, + {"f": pa.array([[[123]], None], type=pa.list_(pa.list_(pa.int32())))}, + ], expected_output={ - "f": - pa.array([[None, None], None, [[123]], None], - type=pa.list_(pa.list_(pa.int32()))) - }), + "f": pa.array( + [[None, None], None, [[123]], None], type=pa.list_(pa.list_(pa.int32())) + ) + }, + ), dict( testcase_name="merge_large_list_of_null_and_list_of_list", - inputs=[{ - "f": pa.array([[None, None], None], type=pa.large_list(pa.null())) - }, { - "f": pa.array([[[123]], None], - type=pa.large_list(pa.large_list(pa.int32()))) - }], + inputs=[ + {"f": pa.array([[None, None], None], type=pa.large_list(pa.null()))}, + { + "f": pa.array( + [[[123]], None], type=pa.large_list(pa.large_list(pa.int32())) + ) + }, + ], expected_output={ - "f": - pa.array([[None, None], None, [[123]], None], - type=pa.large_list(pa.large_list(pa.int32()))) - }), + "f": pa.array( + [[None, None], None, [[123]], None], + type=pa.large_list(pa.large_list(pa.int32())), + ) + }, + ), dict( testcase_name="merge_sliced_list_of_null_and_list_of_list", - inputs=[{ - "f": pa.array( - [None, [None, None], None], type=pa.list_(pa.null())).slice(1) - }, { - "f": pa.array([[[123]], None], type=pa.list_(pa.list_(pa.int32()))) - }], + inputs=[ + { + "f": pa.array( + [None, [None, None], None], type=pa.list_(pa.null()) + ).slice(1) + }, + {"f": pa.array([[[123]], None], type=pa.list_(pa.list_(pa.int32())))}, + ], expected_output={ - "f": - pa.array([[None, None], None, [[123]], None], - type=pa.list_(pa.list_(pa.int32()))) - }), + "f": pa.array( + [[None, None], None, [[123]], None], type=pa.list_(pa.list_(pa.int32())) + ) + }, + ), dict( testcase_name="merge_list_of_list_and_list_of_null", - inputs=[{ - "f": pa.array([[[123]], None], type=pa.list_(pa.list_(pa.int32()))) - }, { - "f": pa.array([[None, None], None], type=pa.list_(pa.null())) - }], + inputs=[ + {"f": pa.array([[[123]], None], type=pa.list_(pa.list_(pa.int32())))}, + {"f": pa.array([[None, None], None], type=pa.list_(pa.null()))}, + ], expected_output={ - "f": - pa.array([[[123]], None, [None, None], None], - type=pa.list_(pa.list_(pa.int32()))) - }), + "f": pa.array( + [[[123]], None, [None, None], None], type=pa.list_(pa.list_(pa.int32())) + ) + }, + ), dict( testcase_name="merge_list_of_null_and_null", - inputs=[{ - "f": pa.array([None], type=pa.null()) - }, { - "f": pa.array([[None, None], None], type=pa.list_(pa.null())) - }], + inputs=[ + {"f": pa.array([None], type=pa.null())}, + {"f": pa.array([[None, None], None], type=pa.list_(pa.null()))}, + ], expected_output={ "f": pa.array([None, [None, None], None], type=pa.list_(pa.null())) - }), + }, + ), dict( testcase_name="merge_compatible_struct_missing_field", - inputs=[{ - "f": pa.array([{"a": [1]}, {"a": [2, 3]}]), - }, { - "f": pa.array([{"b": [1.0]}]), - }], + inputs=[ + { + "f": pa.array([{"a": [1]}, {"a": [2, 3]}]), + }, + { + "f": pa.array([{"b": [1.0]}]), + }, + ], expected_output={ - "f": pa.array([ - {"a": [1], "b": None}, - {"a": [2, 3], "b": None}, - {"a": None, "b": [1.0]}]) - }), + "f": pa.array( + [ + {"a": [1], "b": None}, + {"a": [2, 3], "b": None}, + {"a": None, "b": [1.0]}, + ] + ) + }, + ), dict( testcase_name="merge_compatible_struct_null_type", - inputs=[{ - "f": - pa.array([{"a": [[1]]}], - type=pa.struct([ - pa.field("a", - pa.large_list(pa.large_list(pa.int32()))) - ])), - }, { - "f": - pa.array([{"a": None}, {"a": None}], - type=pa.struct([pa.field("a", pa.null())])), - }], + inputs=[ + { + "f": pa.array( + [{"a": [[1]]}], + type=pa.struct( + [pa.field("a", pa.large_list(pa.large_list(pa.int32())))] + ), + ), + }, + { + "f": pa.array( + [{"a": None}, {"a": None}], + type=pa.struct([pa.field("a", pa.null())]), + ), + }, + ], expected_output={ - "f": - pa.array([{"a": [[1]]}, - {"a": None}, - {"a": None}], - type=pa.struct([ - pa.field("a", - pa.large_list(pa.large_list(pa.int32()))) - ])) - }), + "f": pa.array( + [{"a": [[1]]}, {"a": None}, {"a": None}], + type=pa.struct( + [pa.field("a", pa.large_list(pa.large_list(pa.int32())))] + ), + ) + }, + ), dict( testcase_name="merge_compatible_struct_in_struct", - inputs=[{ - "f": pa.array([{}, {}]), - }, { - "f": pa.array([ - {"a": [{"b": 1}]}, - {"a": [{"b": 2}]}, - ]) - }, { - "f": pa.array([ - {"a": [{"b": 3, "c": 1}]}, - ]) - }], + inputs=[ + { + "f": pa.array([{}, {}]), + }, + { + "f": pa.array( + [ + {"a": [{"b": 1}]}, + {"a": [{"b": 2}]}, + ] + ) + }, + { + "f": pa.array( + [ + {"a": [{"b": 3, "c": 1}]}, + ] + ) + }, + ], expected_output={ - "f": pa.array([ - {"a": None}, - {"a": None}, - {"a": [{"b": 1, "c": None}]}, - {"a": [{"b": 2, "c": None}]}, - {"a": [{"b": 3, "c": 1}]}]) - }) + "f": pa.array( + [ + {"a": None}, + {"a": None}, + {"a": [{"b": 1, "c": None}]}, + {"a": [{"b": 2, "c": None}]}, + {"a": [{"b": 3, "c": 1}]}, + ] + ) + }, + ), ] _MERGE_INVALID_INPUT_TEST_CASES = [ dict( testcase_name="column_type_differs", inputs=[ - pa.RecordBatch.from_arrays([pa.array([1, 2, 3], type=pa.int32())], - ["f1"]), - pa.RecordBatch.from_arrays([pa.array([4, 5, 6], type=pa.int64())], - ["f1"]) + pa.RecordBatch.from_arrays([pa.array([1, 2, 3], type=pa.int32())], ["f1"]), + pa.RecordBatch.from_arrays([pa.array([4, 5, 6], type=pa.int64())], ["f1"]), ], - expected_error_regexp="Unable to merge incompatible type"), + expected_error_regexp="Unable to merge incompatible type", + ), ] class MergeRecordBatchesTest(parameterized.TestCase): - - @parameterized.named_parameters(*_MERGE_INVALID_INPUT_TEST_CASES) - def test_invalid_inputs(self, inputs, expected_error_regexp): - with self.assertRaisesRegex(Exception, expected_error_regexp): - _ = table_util.MergeRecordBatches(inputs) - - @parameterized.named_parameters(*_MERGE_TEST_CASES) - def test_merge_record_batches(self, inputs, expected_output): - input_record_batches = [ - pa.RecordBatch.from_arrays(list(in_dict.values()), list(in_dict.keys())) - for in_dict in inputs - ] - merged = table_util.MergeRecordBatches(input_record_batches) - - self.assertLen(expected_output, merged.num_columns) - for column, column_name in zip(merged.columns, merged.schema.names): - self.assertTrue( - expected_output[column_name].equals(column), - "Column {}:\nexpected:{}\ngot: {}".format( - column_name, expected_output[column_name], column)) - - def test_merge_0_column_record_batches(self): - record_batches = ([ - pa.table([pa.array([1, 2, 3])], - ["ignore"]).remove_column(0).to_batches(max_chunksize=None)[0] - ] * 3) - merged = table_util.MergeRecordBatches(record_batches) - self.assertEqual(merged.num_rows, 9) - self.assertEqual(merged.num_columns, 0) + @parameterized.named_parameters(*_MERGE_INVALID_INPUT_TEST_CASES) + def test_invalid_inputs(self, inputs, expected_error_regexp): + with self.assertRaisesRegex(Exception, expected_error_regexp): + _ = table_util.MergeRecordBatches(inputs) + + @parameterized.named_parameters(*_MERGE_TEST_CASES) + def test_merge_record_batches(self, inputs, expected_output): + input_record_batches = [ + pa.RecordBatch.from_arrays(list(in_dict.values()), list(in_dict.keys())) + for in_dict in inputs + ] + merged = table_util.MergeRecordBatches(input_record_batches) + + self.assertLen(expected_output, merged.num_columns) + for column, column_name in zip(merged.columns, merged.schema.names): + self.assertTrue( + expected_output[column_name].equals(column), + f"Column {column_name}:\nexpected:{expected_output[column_name]}\ngot: {column}", + ) + + def test_merge_0_column_record_batches(self): + record_batches = [ + pa.table([pa.array([1, 2, 3])], ["ignore"]) + .remove_column(0) + .to_batches(max_chunksize=None)[0] + ] * 3 + merged = table_util.MergeRecordBatches(record_batches) + self.assertEqual(merged.num_rows, 9) + self.assertEqual(merged.num_columns, 0) _GET_TOTAL_BYTE_SIZE_TEST_NAMED_PARAMS = [ @@ -416,589 +425,601 @@ def test_merge_0_column_record_batches(self): class GetTotalByteSizeTest(parameterized.TestCase): + @parameterized.named_parameters(*_GET_TOTAL_BYTE_SIZE_TEST_NAMED_PARAMS) + def test_simple(self, factory): + # 3 int64 values + # 5 int32 offsets + # 1 null bitmap byte for outer ListArray + # 1 null bitmap byte for inner Int64Array + # 46 bytes in total. + list_array = pa.array([[1, 2], [None], None, None], type=pa.list_(pa.int64())) + + # 1 null bitmap byte for outer StructArray. + # 1 null bitmap byte for inner Int64Array. + # 4 int64 values. + # 34 bytes in total + struct_array = pa.array( + [{"a": 1}, {"a": 2}, {"a": None}, None], + type=pa.struct([pa.field("a", pa.int64())]), + ) + entity = factory([list_array, struct_array], ["a1", "a2"]) - @parameterized.named_parameters(*_GET_TOTAL_BYTE_SIZE_TEST_NAMED_PARAMS) - def test_simple(self, factory): - # 3 int64 values - # 5 int32 offsets - # 1 null bitmap byte for outer ListArray - # 1 null bitmap byte for inner Int64Array - # 46 bytes in total. - list_array = pa.array([[1, 2], [None], None, None], - type=pa.list_(pa.int64())) - - # 1 null bitmap byte for outer StructArray. - # 1 null bitmap byte for inner Int64Array. - # 4 int64 values. - # 34 bytes in total - struct_array = pa.array([{"a": 1}, {"a": 2}, {"a": None}, None], - type=pa.struct([pa.field("a", pa.int64())])) - entity = factory([list_array, struct_array], ["a1", "a2"]) - - self.assertEqual(46 + 34, table_util.TotalByteSize(entity)) + self.assertEqual(46 + 34, table_util.TotalByteSize(entity)) _TAKE_TEST_CASES = [ dict( testcase_name="no_index", row_indices=[], - expected_output=pa.RecordBatch.from_arrays([ - pa.array([], type=pa.list_(pa.int32())), - pa.array([], type=pa.list_(pa.binary())) - ], ["f1", "f2"])), + expected_output=pa.RecordBatch.from_arrays( + [ + pa.array([], type=pa.list_(pa.int32())), + pa.array([], type=pa.list_(pa.binary())), + ], + ["f1", "f2"], + ), + ), dict( testcase_name="one_index", row_indices=[1], - expected_output=pa.RecordBatch.from_arrays([ - pa.array([None], type=pa.list_(pa.int32())), - pa.array([["b", "c"]], type=pa.list_(pa.binary())) - ], ["f1", "f2"])), + expected_output=pa.RecordBatch.from_arrays( + [ + pa.array([None], type=pa.list_(pa.int32())), + pa.array([["b", "c"]], type=pa.list_(pa.binary())), + ], + ["f1", "f2"], + ), + ), dict( testcase_name="consecutive_first_row_included", row_indices=[0, 1, 2, 3], expected_output=pa.RecordBatch.from_arrays( [ pa.array([[1, 2, 3], None, [4], []], type=pa.list_(pa.int32())), - pa.array([["a"], ["b", "c"], None, []], - type=pa.list_(pa.binary())) + pa.array([["a"], ["b", "c"], None, []], type=pa.list_(pa.binary())), ], ["f1", "f2"], - )), + ), + ), dict( testcase_name="consecutive_last_row_included", row_indices=[5, 6, 7, 8], expected_output=pa.RecordBatch.from_arrays( [ pa.array([[7], [8, 9], [10], []], type=pa.list_(pa.int32())), - pa.array([["d", "e"], ["f"], None, ["g"]], - type=pa.list_(pa.binary())) + pa.array([["d", "e"], ["f"], None, ["g"]], type=pa.list_(pa.binary())), ], ["f1", "f2"], - )), + ), + ), dict( testcase_name="inconsecutive", row_indices=[1, 2, 3, 5], expected_output=pa.RecordBatch.from_arrays( [ pa.array([None, [4], [], [7]], type=pa.list_(pa.int32())), - pa.array([["b", "c"], None, [], ["d", "e"]], - type=pa.list_(pa.binary())) + pa.array( + [["b", "c"], None, [], ["d", "e"]], type=pa.list_(pa.binary()) + ), ], ["f1", "f2"], - )), + ), + ), dict( testcase_name="inconsecutive_last_row_included", row_indices=[2, 3, 4, 5, 7, 8], expected_output=pa.RecordBatch.from_arrays( [ - pa.array([[4], [], [5, 6], [7], [10], []], - type=pa.list_(pa.int32())), - pa.array([None, [], None, ["d", "e"], None, ["g"]], - type=pa.list_(pa.binary())) + pa.array([[4], [], [5, 6], [7], [10], []], type=pa.list_(pa.int32())), + pa.array( + [None, [], None, ["d", "e"], None, ["g"]], + type=pa.list_(pa.binary()), + ), ], ["f1", "f2"], - )), + ), + ), ] class RecordBatchTakeTest(parameterized.TestCase): + @parameterized.named_parameters(*_TAKE_TEST_CASES) + def test_success(self, row_indices, expected_output): + record_batch = pa.RecordBatch.from_arrays( + [ + pa.array( + [[1, 2, 3], None, [4], [], [5, 6], [7], [8, 9], [10], []], + type=pa.list_(pa.int32()), + ), + pa.array( + [["a"], ["b", "c"], None, [], None, ["d", "e"], ["f"], None, ["g"]], + type=pa.list_(pa.binary()), + ), + ], + ["f1", "f2"], + ) - @parameterized.named_parameters(*_TAKE_TEST_CASES) - def test_success(self, row_indices, expected_output): - record_batch = pa.RecordBatch.from_arrays([ - pa.array([[1, 2, 3], None, [4], [], [5, 6], [7], [8, 9], [10], []], - type=pa.list_(pa.int32())), - pa.array( - [["a"], ["b", "c"], None, [], None, ["d", "e"], ["f"], None, ["g"]], - type=pa.list_(pa.binary())), - ], ["f1", "f2"]) - - for row_indices_type in (pa.int32(), pa.int64()): - sliced = table_util.RecordBatchTake( - record_batch, pa.array(row_indices, type=row_indices_type)) - self.assertTrue( - sliced.equals(expected_output), - "Expected {}, got {}".format(expected_output, sliced)) + for row_indices_type in (pa.int32(), pa.int64()): + sliced = table_util.RecordBatchTake( + record_batch, pa.array(row_indices, type=row_indices_type) + ) + self.assertTrue( + sliced.equals(expected_output), + f"Expected {expected_output}, got {sliced}", + ) class CanonicalizeRecordBatchTest(parameterized.TestCase): + def test_canonicalize_record_batch(self): + rb_data = pa.RecordBatch.from_arrays( + [ + pa.array([17, 30], pa.int32()), + pa.array(["english", "spanish"]), + pa.array([False, True]), + pa.array([False, True]), + pa.array([["ne"], ["s", "ted"]]), + ], + ["age", "language", "prediction", "label", "nested"], + ) - def test_canonicalize_record_batch(self): - rb_data = pa.RecordBatch.from_arrays([ - pa.array([17, 30], pa.int32()), - pa.array(["english", "spanish"]), - pa.array([False, True]), - pa.array([False, True]), - pa.array([["ne"], ["s", "ted"]]) - ], ["age", "language", "prediction", "label", "nested"]) - - canonicalized_rb_data = table_util.CanonicalizeRecordBatch(rb_data) - self.assertEqual(canonicalized_rb_data.schema.names, rb_data.schema.names) - - expected_age_column = pa.array([[17], [30]], type=pa.large_list(pa.int64())) - expected_language_column = pa.array([["english"], ["spanish"]], - type=pa.large_list(pa.large_binary())) - expected_prediction_column = pa.array([[0], [1]], - type=pa.large_list(pa.int8())) - expected_label_column = pa.array([[0], [1]], type=pa.large_list(pa.int8())) - expected_nested_column = pa.array([["ne"], ["s", "ted"]], - type=pa.large_list(pa.large_binary())) - self.assertTrue( - canonicalized_rb_data.column( - canonicalized_rb_data.schema.get_field_index("age")).equals( - expected_age_column)) - self.assertTrue( - canonicalized_rb_data.column( - canonicalized_rb_data.schema.get_field_index("language")).equals( - expected_language_column)) - self.assertTrue( - canonicalized_rb_data.column( - canonicalized_rb_data.schema.get_field_index("prediction")).equals( - expected_prediction_column)) - self.assertTrue( - canonicalized_rb_data.column( - canonicalized_rb_data.schema.get_field_index("label")).equals( - expected_label_column)) - self.assertTrue( - canonicalized_rb_data.column( - canonicalized_rb_data.schema.get_field_index("nested")).equals( - expected_nested_column)) - - -_INPUT_RECORD_BATCH = pa.RecordBatch.from_arrays([ - pa.array([[1], [2, 3]]), - pa.array([[{ - "sf1": ["a", "b"] - }], [{ - "sf2": [{ - "ssf1": [3] - }, { - "ssf1": [4] - }] - }]]), - pa.array([ - { - "sf1": [[1, 2], [3]], - "sf2": [None], - }, - None, - ]), -], ["f1", "f2", "f3"]) + canonicalized_rb_data = table_util.CanonicalizeRecordBatch(rb_data) + self.assertEqual(canonicalized_rb_data.schema.names, rb_data.schema.names) + expected_age_column = pa.array([[17], [30]], type=pa.large_list(pa.int64())) + expected_language_column = pa.array( + [["english"], ["spanish"]], type=pa.large_list(pa.large_binary()) + ) + expected_prediction_column = pa.array([[0], [1]], type=pa.large_list(pa.int8())) + expected_label_column = pa.array([[0], [1]], type=pa.large_list(pa.int8())) + expected_nested_column = pa.array( + [["ne"], ["s", "ted"]], type=pa.large_list(pa.large_binary()) + ) + self.assertTrue( + canonicalized_rb_data.column( + canonicalized_rb_data.schema.get_field_index("age") + ).equals(expected_age_column) + ) + self.assertTrue( + canonicalized_rb_data.column( + canonicalized_rb_data.schema.get_field_index("language") + ).equals(expected_language_column) + ) + self.assertTrue( + canonicalized_rb_data.column( + canonicalized_rb_data.schema.get_field_index("prediction") + ).equals(expected_prediction_column) + ) + self.assertTrue( + canonicalized_rb_data.column( + canonicalized_rb_data.schema.get_field_index("label") + ).equals(expected_label_column) + ) + self.assertTrue( + canonicalized_rb_data.column( + canonicalized_rb_data.schema.get_field_index("nested") + ).equals(expected_nested_column) + ) + + +_INPUT_RECORD_BATCH = pa.RecordBatch.from_arrays( + [ + pa.array([[1], [2, 3]]), + pa.array([[{"sf1": ["a", "b"]}], [{"sf2": [{"ssf1": [3]}, {"ssf1": [4]}]}]]), + pa.array( + [ + { + "sf1": [[1, 2], [3]], + "sf2": [None], + }, + None, + ] + ), + ], + ["f1", "f2", "f3"], +) -ExpectedArray = collections.namedtuple( - "ExpectedArray", ["array", "parent_indices"]) + +ExpectedArray = collections.namedtuple("ExpectedArray", ["array", "parent_indices"]) _FEATURES_TO_ARRAYS = { - path.ColumnPath(["f1"]): ExpectedArray( - pa.array([[1], [2, 3]]), [0, 1]), - path.ColumnPath(["f2"]): ExpectedArray(pa.array([[{ - "sf1": ["a", "b"] - }], [{ - "sf2": [{ - "ssf1": [3] - }, { - "ssf1": [4] - }] - }]]), [0, 1]), - path.ColumnPath(["f3"]): ExpectedArray(pa.array([{ - "sf1": [[1, 2], [3]], - "sf2": [None], - }, None]), [0, 1]), - path.ColumnPath(["f2", "sf1"]): ExpectedArray( - pa.array([["a", "b"], None]), [0, 1]), + path.ColumnPath(["f1"]): ExpectedArray(pa.array([[1], [2, 3]]), [0, 1]), + path.ColumnPath(["f2"]): ExpectedArray( + pa.array([[{"sf1": ["a", "b"]}], [{"sf2": [{"ssf1": [3]}, {"ssf1": [4]}]}]]), + [0, 1], + ), + path.ColumnPath(["f3"]): ExpectedArray( + pa.array( + [ + { + "sf1": [[1, 2], [3]], + "sf2": [None], + }, + None, + ] + ), + [0, 1], + ), + path.ColumnPath(["f2", "sf1"]): ExpectedArray(pa.array([["a", "b"], None]), [0, 1]), path.ColumnPath(["f2", "sf2"]): ExpectedArray( - pa.array([None, [{ - "ssf1": [3] - }, { - "ssf1": [4] - }]]), [0, 1]), - path.ColumnPath(["f2", "sf2", "ssf1"]): ExpectedArray( - pa.array([[3], [4]]), [1, 1]), - path.ColumnPath(["f3", "sf1"]): ExpectedArray(pa.array( - [[[1, 2], [3]], None]), [0, 1]), - path.ColumnPath(["f3", "sf2"]): ExpectedArray( - pa.array([[None], None]), [0, 1]), + pa.array([None, [{"ssf1": [3]}, {"ssf1": [4]}]]), [0, 1] + ), + path.ColumnPath(["f2", "sf2", "ssf1"]): ExpectedArray(pa.array([[3], [4]]), [1, 1]), + path.ColumnPath(["f3", "sf1"]): ExpectedArray( + pa.array([[[1, 2], [3]], None]), [0, 1] + ), + path.ColumnPath(["f3", "sf2"]): ExpectedArray(pa.array([[None], None]), [0, 1]), } class EnumerateStructNullValueTestData(NamedTuple): - """Inputs and outputs for enumeration with pa.StructArrays with null values.""" - description: str - """Summary of test""" - batch: pa.RecordBatch - """Input Record Batch""" - expected_results: Dict[path.ColumnPath, pa.array] - """The expected output.""" - - -def _make_enumerate_data_with_missing_data_at_leaves( - ) -> Iterable[EnumerateStructNullValueTestData]: - """Test that having only nulls at leaf values gets translated correctly.""" - test_data_type = pa.list_(pa.struct([("f2", pa.list_(pa.float64()))])) - struct_column_as_list_dicts = [ - [], # first element of 'c'; note this is not counted as missing. - [ # second element of 'c' -- a list of length 2. - { - "f2": [2.0], - }, - None, # f2 is missing - ], - [ # third element of 'c' - None, # f2 is missing - ], - [], # fourth element of 'c'; note this is not counted as missing. - ] - - array = pa.array(struct_column_as_list_dicts, type=test_data_type) - - batch = pa.RecordBatch.from_arrays([array], ["c"]) - - full_expected_results = { - path.ColumnPath(["c"]): - pa.array([[], [{ - "f2": [2.0] - }, None], [None], []]), - path.ColumnPath(["c", "f2"]): - pa.array([[2.0], None, None]), - } - yield "Basic", batch, full_expected_results - - -def _make_enumerate_test_data_with_null_values_and_sliced_batches( - ) -> Iterable[EnumerateStructNullValueTestData]: - """Yields test data for sliced data where all slicing is consistent. - - Pyarrow slices with zero copy, sometimes subtle bugs can - arise when processing sliced data. - """ - test_data_type = pa.list_(pa.struct([("f2", pa.list_(pa.float64()))])) - struct_column_as_list_dicts = [ - [], # first element of 'c'; note this is not counted as missing. - [ # second element of 'c' -- a list of length 2. - { - "f2": [2.0], - }, - None, # f2 is missing - ], - [ # third element of 'c' - None, # f2 is missing - ], - [], # fourth element of 'c'; note this is not counted as missing. - ] - - array = pa.array(struct_column_as_list_dicts, type=test_data_type) - - batch = pa.RecordBatch.from_arrays([array], ["c"]) - slice_start, slice_end = 1, 3 - batch = pa.RecordBatch.from_arrays([array[slice_start:slice_end]], ["c"]) - - sliced_expected_results = { - path.ColumnPath(["c"]): pa.array([[{ - "f2": [2.0] - }, None], [None]]), - path.ColumnPath(["c", "f2"]): pa.array([[2.0], None, None]), - } - # Test case 1: slicing the array. - yield "SlicedArray", batch, sliced_expected_results - - batch = pa.RecordBatch.from_arrays([array], ["c"])[slice_start:slice_end] - # Test case 2: slicing the RecordBatch. - yield "SlicedRecordBatch", batch, sliced_expected_results - - -def _make_enumerate_test_data_with_null_top_level( - ) -> Iterable[EnumerateStructNullValueTestData]: - """Yields test data with a top level list element is missing.""" - test_data_type = pa.list_(pa.struct([("f2", pa.list_(pa.float64()))])) - struct_column_as_list_dicts = [ - [], # first element of 'c'; note this is not counted as missing. - None, # c is missing. - [ # third element of 'c' - None, # f2 is missing - ], - [], # fourth element of 'c'; note this is not counted as missing. - ] - array = pa.array( - struct_column_as_list_dicts, type=test_data_type) - validity_buffer_with_null = array.buffers()[0] - array_with_null_indicator = pa.Array.from_buffers( - array.type, - len(array) + array.offset, - [validity_buffer_with_null, array.buffers()[1]], - offset=0, - children=[array.values]) - batch_with_missing_entry = pa.RecordBatch.from_arrays( - [array_with_null_indicator], ["c"]) - missing_expected_results = { - path.ColumnPath(["c"]): - pa.array([[], None, [None], []], type=test_data_type), - path.ColumnPath(["c", "f2"]): - pa.array([None], type=pa.list_(pa.float64())), - } - yield ("ValuesPresentWithNullIndicator", batch_with_missing_entry, - missing_expected_results) - - -def _make_enumerate_test_data_with_slices_at_different_offsets( - ) -> Iterable[EnumerateStructNullValueTestData]: - """Yields a test cases constructed from array slices with different offsets. - - Slicing in pyarrow is zero copy, which can have subtle bugs, so ensure - the code works under more obscure situations. - """ - total_size = 10 - values_array = pa.array(range(total_size), type=pa.int64()) - # create 5 pyarrow.Array object each of size from the original array ([0,1], - # [2,3], etc - slices = [ - values_array[start:end] for (start, end) - in zip(range(0, total_size + 1, 2), range(2, total_size + 1, 2)) - ] # pyformat: disable - validity = pa.array([True, False], type=pa.bool_()) - # Label fields from "0" to "5" - new_type = pa.struct([pa.field(str(sl[0].as_py() // 2), sl.type) - for sl in slices]) - # Using the value buffer of validity as composed_struct's validity bitmap - # buffer. - composed_struct = pa.StructArray.from_buffers( - new_type, len(slices[0]), [validity.buffers()[1]], children=slices) - sliced_batch = pa.RecordBatch.from_arrays([composed_struct], ["c"]) - sliced_expected_results = { - path.ColumnPath(["c"]): - pa.array([ - [{"0": 0, "1": 2, "2": 4, "3": 6, "4": 8}], - None, - ]), - path.ColumnPath(["c", "0"]): pa.array([0, None], type=pa.int64()), - path.ColumnPath(["c", "1"]): pa.array([2, None], type=pa.int64()), - path.ColumnPath(["c", "2"]): pa.array([4, None], type=pa.int64()), - path.ColumnPath(["c", "3"]): pa.array([6, None], type=pa.int64()), - path.ColumnPath(["c", "4"]): pa.array([8, None], type=pa.int64()), - } # pyformat: disable - yield ("SlicedArrayWithOffests", sliced_batch, sliced_expected_results) - + """Inputs and outputs for enumeration with pa.StructArrays with null values.""" + + description: str + """Summary of test""" + batch: pa.RecordBatch + """Input Record Batch""" + expected_results: Dict[path.ColumnPath, pa.array] + """The expected output.""" + + +def _make_enumerate_data_with_missing_data_at_leaves() -> ( + Iterable[EnumerateStructNullValueTestData] +): + """Test that having only nulls at leaf values gets translated correctly.""" + test_data_type = pa.list_(pa.struct([("f2", pa.list_(pa.float64()))])) + struct_column_as_list_dicts = [ + [], # first element of 'c'; note this is not counted as missing. + [ # second element of 'c' -- a list of length 2. + { + "f2": [2.0], + }, + None, # f2 is missing + ], + [ # third element of 'c' + None, # f2 is missing + ], + [], # fourth element of 'c'; note this is not counted as missing. + ] -def _normalize(array: pa.Array) -> pa.Array: - """Round trips array through python objects. + array = pa.array(struct_column_as_list_dicts, type=test_data_type) - Comparing nested arrays with slices is buggy in Arrow 2.0 this method - is useful comparing two such arrays for logical equality. The bugs - appears to be fixed as of Arrow 5.0 this should be removable once that - becomes the minimum version. + batch = pa.RecordBatch.from_arrays([array], ["c"]) - Args: - array: The array to normalize. + full_expected_results = { + path.ColumnPath(["c"]): pa.array([[], [{"f2": [2.0]}, None], [None], []]), + path.ColumnPath(["c", "f2"]): pa.array([[2.0], None, None]), + } + yield "Basic", batch, full_expected_results - Returns: - An array that doesn't have any more zero copy slices in itself or - it's children. Note the schema might be slightly different for - all null arrays. - """ - return pa.array(array.to_pylist()) +def _make_enumerate_test_data_with_null_values_and_sliced_batches() -> ( + Iterable[EnumerateStructNullValueTestData] +): + """Yields test data for sliced data where all slicing is consistent. -class TableUtilTest(parameterized.TestCase): + Pyarrow slices with zero copy, sometimes subtle bugs can + arise when processing sliced data. + """ + test_data_type = pa.list_(pa.struct([("f2", pa.list_(pa.float64()))])) + struct_column_as_list_dicts = [ + [], # first element of 'c'; note this is not counted as missing. + [ # second element of 'c' -- a list of length 2. + { + "f2": [2.0], + }, + None, # f2 is missing + ], + [ # third element of 'c' + None, # f2 is missing + ], + [], # fourth element of 'c'; note this is not counted as missing. + ] - def test_get_array_empty_path(self): - with self.assertRaisesRegex(KeyError, r"query_path must be non-empty.*"): - table_util.get_array( - pa.RecordBatch.from_arrays([pa.array([[1], [2, 3]])], ["v"]), - query_path=path.ColumnPath([]), - return_example_indices=False, - ) - - def test_get_array_column_missing(self): - with self.assertRaisesRegex( - KeyError, r"query_path step 0 \(x\) not in record batch.*" - ): - table_util.get_array( - pa.RecordBatch.from_arrays([pa.array([[1], [2]])], ["y"]), - query_path=path.ColumnPath(["x"]), - return_example_indices=False, - ) - - def test_get_array_step_missing(self): - with self.assertRaisesRegex( - KeyError, r"query_path step \(ssf3\) not in struct.*" - ): - table_util.get_array( - _INPUT_RECORD_BATCH, - query_path=path.ColumnPath(["f2", "sf2", "ssf3"]), - return_example_indices=False, - ) - - def test_get_array_return_example_indices(self): - record_batch = pa.RecordBatch.from_arrays( - [ - pa.array([ - [{"sf": [{"ssf": [1]}, {"ssf": [2]}]}], - [{"sf": [{"ssf": [3, 4]}]}], - ]), - pa.array([["one"], ["two"]]), + array = pa.array(struct_column_as_list_dicts, type=test_data_type) + + batch = pa.RecordBatch.from_arrays([array], ["c"]) + slice_start, slice_end = 1, 3 + batch = pa.RecordBatch.from_arrays([array[slice_start:slice_end]], ["c"]) + + sliced_expected_results = { + path.ColumnPath(["c"]): pa.array([[{"f2": [2.0]}, None], [None]]), + path.ColumnPath(["c", "f2"]): pa.array([[2.0], None, None]), + } + # Test case 1: slicing the array. + yield "SlicedArray", batch, sliced_expected_results + + batch = pa.RecordBatch.from_arrays([array], ["c"])[slice_start:slice_end] + # Test case 2: slicing the RecordBatch. + yield "SlicedRecordBatch", batch, sliced_expected_results + + +def _make_enumerate_test_data_with_null_top_level() -> ( + Iterable[EnumerateStructNullValueTestData] +): + """Yields test data with a top level list element is missing.""" + test_data_type = pa.list_(pa.struct([("f2", pa.list_(pa.float64()))])) + struct_column_as_list_dicts = [ + [], # first element of 'c'; note this is not counted as missing. + None, # c is missing. + [ # third element of 'c' + None, # f2 is missing ], - ["f", "w"], + [], # fourth element of 'c'; note this is not counted as missing. + ] + array = pa.array(struct_column_as_list_dicts, type=test_data_type) + validity_buffer_with_null = array.buffers()[0] + array_with_null_indicator = pa.Array.from_buffers( + array.type, + len(array) + array.offset, + [validity_buffer_with_null, array.buffers()[1]], + offset=0, + children=[array.values], ) - feature = path.ColumnPath(["f", "sf", "ssf"]) - actual_arr, actual_indices = table_util.get_array( - record_batch, feature, return_example_indices=True + batch_with_missing_entry = pa.RecordBatch.from_arrays( + [array_with_null_indicator], ["c"] ) - expected_arr = pa.array([[1], [2], [3, 4]]) - expected_indices = np.array([0, 0, 1]) - self.assertTrue( - actual_arr.equals(expected_arr), - "\nfeature: {};\nexpected:\n{};\nactual:\n{}".format( - feature, expected_arr, actual_arr - ), + missing_expected_results = { + path.ColumnPath(["c"]): pa.array([[], None, [None], []], type=test_data_type), + path.ColumnPath(["c", "f2"]): pa.array([None], type=pa.list_(pa.float64())), + } + yield ( + "ValuesPresentWithNullIndicator", + batch_with_missing_entry, + missing_expected_results, ) - np.testing.assert_array_equal(expected_indices, actual_indices) - def test_get_array_subpath_missing(self): - with self.assertRaisesRegex( - KeyError, r"Cannot process .* \(sssf\) inside .* list.*" - ): - table_util.get_array( - _INPUT_RECORD_BATCH, - query_path=path.ColumnPath(["f2", "sf2", "ssf1", "sssf"]), - return_example_indices=False, - ) - - @parameterized.named_parameters( - ((str(f), f, expected) for (f, expected) in _FEATURES_TO_ARRAYS.items()) - ) - def test_get_array(self, feature, expected): - actual_arr, actual_indices = table_util.get_array( - _INPUT_RECORD_BATCH, - feature, - return_example_indices=True, - wrap_flat_struct_in_list=False, - ) - expected_arr, expected_indices = expected - self.assertTrue( - actual_arr.equals(expected_arr), - "\nfeature: {};\nexpected:\n{};\nactual:\n{}".format( - feature, expected_arr, actual_arr - ), - ) - np.testing.assert_array_equal(expected_indices, actual_indices) - - @parameterized.named_parameters( - ((str(f), f, expected) for (f, expected) in _FEATURES_TO_ARRAYS.items()) - ) - def test_get_array_no_broadcast(self, feature, expected): - actual_arr, actual_indices = table_util.get_array( - _INPUT_RECORD_BATCH, - feature, - return_example_indices=False, - wrap_flat_struct_in_list=False, + +def _make_enumerate_test_data_with_slices_at_different_offsets() -> ( + Iterable[EnumerateStructNullValueTestData] +): + """Yields a test cases constructed from array slices with different offsets. + + Slicing in pyarrow is zero copy, which can have subtle bugs, so ensure + the code works under more obscure situations. + """ + total_size = 10 + values_array = pa.array(range(total_size), type=pa.int64()) + # create 5 pyarrow.Array object each of size from the original array ([0,1], + # [2,3], etc + slices = [ + values_array[start:end] + for (start, end) in zip( + range(0, total_size + 1, 2), range(2, total_size + 1, 2) + ) + ] # pyformat: disable + validity = pa.array([True, False], type=pa.bool_()) + # Label fields from "0" to "5" + new_type = pa.struct([pa.field(str(sl[0].as_py() // 2), sl.type) for sl in slices]) + # Using the value buffer of validity as composed_struct's validity bitmap + # buffer. + composed_struct = pa.StructArray.from_buffers( + new_type, len(slices[0]), [validity.buffers()[1]], children=slices ) - expected_arr, _ = expected - self.assertTrue( - actual_arr.equals(expected_arr), - "\nfeature: {};\nexpected:\n{};\nactual:\n{}".format( - feature, expected_arr, actual_arr + sliced_batch = pa.RecordBatch.from_arrays([composed_struct], ["c"]) + sliced_expected_results = { + path.ColumnPath(["c"]): pa.array( + [ + [{"0": 0, "1": 2, "2": 4, "3": 6, "4": 8}], + None, + ] ), + path.ColumnPath(["c", "0"]): pa.array([0, None], type=pa.int64()), + path.ColumnPath(["c", "1"]): pa.array([2, None], type=pa.int64()), + path.ColumnPath(["c", "2"]): pa.array([4, None], type=pa.int64()), + path.ColumnPath(["c", "3"]): pa.array([6, None], type=pa.int64()), + path.ColumnPath(["c", "4"]): pa.array([8, None], type=pa.int64()), + } # pyformat: disable + yield ("SlicedArrayWithOffests", sliced_batch, sliced_expected_results) + + +def _normalize(array: pa.Array) -> pa.Array: + """Round trips array through python objects. + + Comparing nested arrays with slices is buggy in Arrow 2.0 this method + is useful comparing two such arrays for logical equality. The bugs + appears to be fixed as of Arrow 5.0 this should be removable once that + becomes the minimum version. + + Args: + ---- + array: The array to normalize. + + Returns: + ------- + An array that doesn't have any more zero copy slices in itself or + it's children. Note the schema might be slightly different for + all null arrays. + """ + return pa.array(array.to_pylist()) + + +class TableUtilTest(parameterized.TestCase): + def test_get_array_empty_path(self): + with self.assertRaisesRegex(KeyError, r"query_path must be non-empty.*"): + table_util.get_array( + pa.RecordBatch.from_arrays([pa.array([[1], [2, 3]])], ["v"]), + query_path=path.ColumnPath([]), + return_example_indices=False, + ) + + def test_get_array_column_missing(self): + with self.assertRaisesRegex( + KeyError, r"query_path step 0 \(x\) not in record batch.*" + ): + table_util.get_array( + pa.RecordBatch.from_arrays([pa.array([[1], [2]])], ["y"]), + query_path=path.ColumnPath(["x"]), + return_example_indices=False, + ) + + def test_get_array_step_missing(self): + with self.assertRaisesRegex( + KeyError, r"query_path step \(ssf3\) not in struct.*" + ): + table_util.get_array( + _INPUT_RECORD_BATCH, + query_path=path.ColumnPath(["f2", "sf2", "ssf3"]), + return_example_indices=False, + ) + + def test_get_array_return_example_indices(self): + record_batch = pa.RecordBatch.from_arrays( + [ + pa.array( + [ + [{"sf": [{"ssf": [1]}, {"ssf": [2]}]}], + [{"sf": [{"ssf": [3, 4]}]}], + ] + ), + pa.array([["one"], ["two"]]), + ], + ["f", "w"], + ) + feature = path.ColumnPath(["f", "sf", "ssf"]) + actual_arr, actual_indices = table_util.get_array( + record_batch, feature, return_example_indices=True + ) + expected_arr = pa.array([[1], [2], [3, 4]]) + expected_indices = np.array([0, 0, 1]) + self.assertTrue( + actual_arr.equals(expected_arr), + f"\nfeature: {feature};\nexpected:\n{expected_arr};\nactual:\n{actual_arr}", + ) + np.testing.assert_array_equal(expected_indices, actual_indices) + + def test_get_array_subpath_missing(self): + with self.assertRaisesRegex( + KeyError, r"Cannot process .* \(sssf\) inside .* list.*" + ): + table_util.get_array( + _INPUT_RECORD_BATCH, + query_path=path.ColumnPath(["f2", "sf2", "ssf1", "sssf"]), + return_example_indices=False, + ) + + @parameterized.named_parameters( + (str(f), f, expected) for (f, expected) in _FEATURES_TO_ARRAYS.items() ) - self.assertIsNone(actual_indices) - - @parameterized.named_parameters( - ((str(f), f, expected) for (f, expected) in _FEATURES_TO_ARRAYS.items()) - ) - def test_get_array_wrap_flat_struct_array(self, feature, expected): - actual_arr, actual_indices = table_util.get_array( - _INPUT_RECORD_BATCH, - feature, - return_example_indices=True, - wrap_flat_struct_in_list=True, - ) - expected_arr, expected_indices = expected - if pa.types.is_struct(expected_arr.type): - expected_arr = array_util.ToSingletonListArray(expected_arr) - self.assertTrue( - actual_arr.equals(expected_arr), - "\nfeature: {};\nexpected:\n{};\nactual:\n{}".format( - feature, expected_arr, actual_arr - ), + def test_get_array(self, feature, expected): + actual_arr, actual_indices = table_util.get_array( + _INPUT_RECORD_BATCH, + feature, + return_example_indices=True, + wrap_flat_struct_in_list=False, + ) + expected_arr, expected_indices = expected + self.assertTrue( + actual_arr.equals(expected_arr), + f"\nfeature: {feature};\nexpected:\n{expected_arr};\nactual:\n{actual_arr}", + ) + np.testing.assert_array_equal(expected_indices, actual_indices) + + @parameterized.named_parameters( + (str(f), f, expected) for (f, expected) in _FEATURES_TO_ARRAYS.items() ) - np.testing.assert_array_equal(expected_indices, actual_indices) + def test_get_array_no_broadcast(self, feature, expected): + actual_arr, actual_indices = table_util.get_array( + _INPUT_RECORD_BATCH, + feature, + return_example_indices=False, + wrap_flat_struct_in_list=False, + ) + expected_arr, _ = expected + self.assertTrue( + actual_arr.equals(expected_arr), + f"\nfeature: {feature};\nexpected:\n{expected_arr};\nactual:\n{actual_arr}", + ) + self.assertIsNone(actual_indices) - def test_enumerate_arrays(self): - for leaves_only, wrap_flat_struct_in_list in itertools.product( - [True, False], [True, False] - ): - actual_results = {} - for feature_path, feature_array in table_util.enumerate_arrays( - _INPUT_RECORD_BATCH, - leaves_only, - wrap_flat_struct_in_list, - ): - actual_results[feature_path] = feature_array - - expected_results = {} - # leaf fields - for p in [ - ["f1"], - ["f2", "sf1"], - ["f2", "sf2", "ssf1"], - ["f3", "sf1"], - ["f3", "sf2"], - ]: - feature_path = path.ColumnPath(p) - expected_results[feature_path] = ( - _FEATURES_TO_ARRAYS[feature_path].array + @parameterized.named_parameters( + (str(f), f, expected) for (f, expected) in _FEATURES_TO_ARRAYS.items() + ) + def test_get_array_wrap_flat_struct_array(self, feature, expected): + actual_arr, actual_indices = table_util.get_array( + _INPUT_RECORD_BATCH, + feature, + return_example_indices=True, + wrap_flat_struct_in_list=True, ) - if not leaves_only: - for p in [["f2"], ["f2", "sf2"], ["f3"]]: - feature_path = path.ColumnPath(p) - expected_array = _FEATURES_TO_ARRAYS[feature_path][0] - if wrap_flat_struct_in_list and pa.types.is_struct( - expected_array.type - ): - expected_array = array_util.ToSingletonListArray(expected_array) - expected_results[feature_path] = expected_array - - self.assertLen(actual_results, len(expected_results)) - for k, v in six.iteritems(expected_results): - self.assertIn(k, actual_results) - actual = actual_results[k] + expected_arr, expected_indices = expected + if pa.types.is_struct(expected_arr.type): + expected_arr = array_util.ToSingletonListArray(expected_arr) self.assertTrue( - actual[0].equals(v[0]), - "leaves_only={}; " - "wrap_flat_struct_in_list={} feature={}; expected: {}; actual: {}" - .format( - leaves_only, wrap_flat_struct_in_list, k, v, actual - ), + actual_arr.equals(expected_arr), + f"\nfeature: {feature};\nexpected:\n{expected_arr};\nactual:\n{actual_arr}", + ) + np.testing.assert_array_equal(expected_indices, actual_indices) + + def test_enumerate_arrays(self): + for leaves_only, wrap_flat_struct_in_list in itertools.product( + [True, False], [True, False] + ): + actual_results = {} + for feature_path, feature_array in table_util.enumerate_arrays( + _INPUT_RECORD_BATCH, + leaves_only, + wrap_flat_struct_in_list, + ): + actual_results[feature_path] = feature_array + + expected_results = {} + # leaf fields + for p in [ + ["f1"], + ["f2", "sf1"], + ["f2", "sf2", "ssf1"], + ["f3", "sf1"], + ["f3", "sf2"], + ]: + feature_path = path.ColumnPath(p) + expected_results[feature_path] = _FEATURES_TO_ARRAYS[feature_path].array + if not leaves_only: + for p in [["f2"], ["f2", "sf2"], ["f3"]]: + feature_path = path.ColumnPath(p) + expected_array = _FEATURES_TO_ARRAYS[feature_path][0] + if wrap_flat_struct_in_list and pa.types.is_struct( + expected_array.type + ): + expected_array = array_util.ToSingletonListArray(expected_array) + expected_results[feature_path] = expected_array + + self.assertLen(actual_results, len(expected_results)) + for k, v in six.iteritems(expected_results): + self.assertIn(k, actual_results) + actual = actual_results[k] + self.assertTrue( + actual[0].equals(v[0]), + f"leaves_only={leaves_only}; " + f"wrap_flat_struct_in_list={wrap_flat_struct_in_list} feature={k}; expected: {v}; actual: {actual}", + ) + np.testing.assert_array_equal(actual[1], v[1]) + + @parameterized.named_parameters( + itertools.chain( + _make_enumerate_data_with_missing_data_at_leaves(), + _make_enumerate_test_data_with_null_values_and_sliced_batches(), + _make_enumerate_test_data_with_null_top_level(), + _make_enumerate_test_data_with_slices_at_different_offsets(), ) - np.testing.assert_array_equal(actual[1], v[1]) - - @parameterized.named_parameters( - itertools.chain( - _make_enumerate_data_with_missing_data_at_leaves(), - _make_enumerate_test_data_with_null_values_and_sliced_batches(), - _make_enumerate_test_data_with_null_top_level(), - _make_enumerate_test_data_with_slices_at_different_offsets(), - ) - ) - def test_enumerate_missing_propogated_in_flattened_struct( - self, batch, expected_results - ): - actual_results = {} - for feature_path, feature_array in table_util.enumerate_arrays( - batch, enumerate_leaves_only=False + ) + def test_enumerate_missing_propogated_in_flattened_struct( + self, batch, expected_results ): - actual_results[feature_path] = feature_array - self.assertLen(actual_results, len(expected_results)) - for k, v in six.iteritems(expected_results): - assert k in actual_results, (k, list(actual_results.keys())) - self.assertIn(k, actual_results) - actual = _normalize(actual_results[k]) - v = _normalize(v) - self.assertTrue( - actual.equals(v), - "feature={}; expected: {}; actual: {}; diff: {}".format( - k, v, actual, actual.diff(v) - ), - ) + actual_results = {} + for feature_path, feature_array in table_util.enumerate_arrays( + batch, enumerate_leaves_only=False + ): + actual_results[feature_path] = feature_array + self.assertLen(actual_results, len(expected_results)) + for k, v in six.iteritems(expected_results): + assert k in actual_results, (k, list(actual_results.keys())) + self.assertIn(k, actual_results) + actual = _normalize(actual_results[k]) + v = _normalize(v) + self.assertTrue( + actual.equals(v), + f"feature={k}; expected: {v}; actual: {actual}; diff: {actual.diff(v)}", + ) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/tfx_bsl/beam/pickle_helpers.py b/tfx_bsl/beam/pickle_helpers.py index 6b64e41d..e572cf57 100644 --- a/tfx_bsl/beam/pickle_helpers.py +++ b/tfx_bsl/beam/pickle_helpers.py @@ -21,121 +21,121 @@ # TODO(b/281148738): Remove this once all supported Beam versions depend on dill # with updated pickling logic or this is fixed in Beam. def fix_code_type_pickling() -> None: - """Overrides `CodeType` pickling to prevent segfaults in Python 3.10.""" - # Based on the `save_code` from dill-0.3.6. - # https://github.com/uqfoundation/dill/blob/d5c4dccbe19fb27bfd757cb60abd2899fd9e59ba/dill/_dill.py#L1105 - # Author: Mike McKerns (mmckerns @caltech and @uqfoundation) - # Copyright (c) 2008-2015 California Institute of Technology. - # Copyright (c) 2016-2023 The Uncertainty Quantification Foundation. - # License: 3-clause BSD. The full license text is available at: - # - https://github.com/uqfoundation/dill/blob/master/LICENSE + """Overrides `CodeType` pickling to prevent segfaults in Python 3.10.""" + # Based on the `save_code` from dill-0.3.6. + # https://github.com/uqfoundation/dill/blob/d5c4dccbe19fb27bfd757cb60abd2899fd9e59ba/dill/_dill.py#L1105 + # Author: Mike McKerns (mmckerns @caltech and @uqfoundation) + # Copyright (c) 2008-2015 California Institute of Technology. + # Copyright (c) 2016-2023 The Uncertainty Quantification Foundation. + # License: 3-clause BSD. The full license text is available at: + # - https://github.com/uqfoundation/dill/blob/master/LICENSE - # The following function is also based on 'save_codeobject' from 'cloudpickle' - # Copyright (c) 2012, Regents of the University of California. - # Copyright (c) 2009 `PiCloud, Inc. `_. - # License: 3-clause BSD. The full license text is available at: - # - https://github.com/cloudpipe/cloudpickle/blob/master/LICENSE - @dill.register(types.CodeType) - def save_code(pickler, obj): # pylint: disable=unused-variable - if hasattr(obj, 'co_endlinetable'): # python 3.11a (20 args) - args = ( - obj.co_argcount, - obj.co_posonlyargcount, - obj.co_kwonlyargcount, - obj.co_nlocals, - obj.co_stacksize, - obj.co_flags, - obj.co_code, - obj.co_consts, - obj.co_names, - obj.co_varnames, - obj.co_filename, - obj.co_name, - obj.co_qualname, - obj.co_firstlineno, - obj.co_linetable, - obj.co_endlinetable, - obj.co_columntable, - obj.co_exceptiontable, - obj.co_freevars, - obj.co_cellvars, - ) - elif hasattr(obj, 'co_exceptiontable'): # python 3.11 (18 args) - args = ( - obj.co_argcount, - obj.co_posonlyargcount, - obj.co_kwonlyargcount, - obj.co_nlocals, - obj.co_stacksize, - obj.co_flags, - obj.co_code, - obj.co_consts, - obj.co_names, - obj.co_varnames, - obj.co_filename, - obj.co_name, - obj.co_qualname, - obj.co_firstlineno, - obj.co_linetable, - obj.co_exceptiontable, - obj.co_freevars, - obj.co_cellvars, - ) - elif hasattr(obj, 'co_linetable'): # python 3.10 (16 args) - args = ( - obj.co_argcount, - obj.co_posonlyargcount, - obj.co_kwonlyargcount, - obj.co_nlocals, - obj.co_stacksize, - obj.co_flags, - obj.co_code, - obj.co_consts, - obj.co_names, - obj.co_varnames, - obj.co_filename, - obj.co_name, - obj.co_firstlineno, - obj.co_linetable, - obj.co_freevars, - obj.co_cellvars, - ) - elif hasattr(obj, 'co_posonlyargcount'): # python 3.8 (16 args) - args = ( - obj.co_argcount, - obj.co_posonlyargcount, - obj.co_kwonlyargcount, - obj.co_nlocals, - obj.co_stacksize, - obj.co_flags, - obj.co_code, - obj.co_consts, - obj.co_names, - obj.co_varnames, - obj.co_filename, - obj.co_name, - obj.co_firstlineno, - obj.co_lnotab, - obj.co_freevars, - obj.co_cellvars, - ) - else: # python 3.7 (15 args) - args = ( - obj.co_argcount, - obj.co_kwonlyargcount, - obj.co_nlocals, - obj.co_stacksize, - obj.co_flags, - obj.co_code, - obj.co_consts, - obj.co_names, - obj.co_varnames, - obj.co_filename, - obj.co_name, - obj.co_firstlineno, - obj.co_lnotab, - obj.co_freevars, - obj.co_cellvars, - ) + # The following function is also based on 'save_codeobject' from 'cloudpickle' + # Copyright (c) 2012, Regents of the University of California. + # Copyright (c) 2009 `PiCloud, Inc. `_. + # License: 3-clause BSD. The full license text is available at: + # - https://github.com/cloudpipe/cloudpickle/blob/master/LICENSE + @dill.register(types.CodeType) + def save_code(pickler, obj): # pylint: disable=unused-variable + if hasattr(obj, "co_endlinetable"): # python 3.11a (20 args) + args = ( + obj.co_argcount, + obj.co_posonlyargcount, + obj.co_kwonlyargcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + obj.co_filename, + obj.co_name, + obj.co_qualname, + obj.co_firstlineno, + obj.co_linetable, + obj.co_endlinetable, + obj.co_columntable, + obj.co_exceptiontable, + obj.co_freevars, + obj.co_cellvars, + ) + elif hasattr(obj, "co_exceptiontable"): # python 3.11 (18 args) + args = ( + obj.co_argcount, + obj.co_posonlyargcount, + obj.co_kwonlyargcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + obj.co_filename, + obj.co_name, + obj.co_qualname, + obj.co_firstlineno, + obj.co_linetable, + obj.co_exceptiontable, + obj.co_freevars, + obj.co_cellvars, + ) + elif hasattr(obj, "co_linetable"): # python 3.10 (16 args) + args = ( + obj.co_argcount, + obj.co_posonlyargcount, + obj.co_kwonlyargcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + obj.co_filename, + obj.co_name, + obj.co_firstlineno, + obj.co_linetable, + obj.co_freevars, + obj.co_cellvars, + ) + elif hasattr(obj, "co_posonlyargcount"): # python 3.8 (16 args) + args = ( + obj.co_argcount, + obj.co_posonlyargcount, + obj.co_kwonlyargcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + obj.co_filename, + obj.co_name, + obj.co_firstlineno, + obj.co_lnotab, + obj.co_freevars, + obj.co_cellvars, + ) + else: # python 3.7 (15 args) + args = ( + obj.co_argcount, + obj.co_kwonlyargcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + obj.co_filename, + obj.co_name, + obj.co_firstlineno, + obj.co_lnotab, + obj.co_freevars, + obj.co_cellvars, + ) - pickler.save_reduce(types.CodeType, args, obj=obj) + pickler.save_reduce(types.CodeType, args, obj=obj) diff --git a/tfx_bsl/beam/run_inference.py b/tfx_bsl/beam/run_inference.py index 8c5d42b0..a95fdee3 100644 --- a/tfx_bsl/beam/run_inference.py +++ b/tfx_bsl/beam/run_inference.py @@ -15,46 +15,65 @@ import abc import base64 -from concurrent import futures import functools import importlib import os -from typing import Any, Callable, Dict, Iterable, List, Mapping, NamedTuple, Optional, Sequence, Text, Tuple, TypeVar, Union +from concurrent import futures +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + NamedTuple, + Optional, + Sequence, + Text, + Tuple, + TypeVar, + Union, +) -from absl import logging import apache_beam as beam +import googleapiclient +import numpy as np +import tensorflow as tf +from absl import logging from apache_beam.ml.inference import base from apache_beam.options.pipeline_options import GoogleCloudOptions from apache_beam.transforms import resources from apache_beam.utils import retry -import googleapiclient -from googleapiclient import discovery -from googleapiclient import http -import numpy as np -import tensorflow as tf -from tfx_bsl.public.proto import model_spec_pb2 -from tfx_bsl.telemetry import util +from googleapiclient import discovery, http # TODO(b/140306674): stop using the internal TF API. -from tensorflow.python.saved_model import loader_impl # pylint: disable=g-direct-tensorflow-import -from tensorflow_serving.apis import classification_pb2 -from tensorflow_serving.apis import prediction_log_pb2 -from tensorflow_serving.apis import regression_pb2 +from tensorflow.python.saved_model import ( + loader_impl, # pylint: disable=g-direct-tensorflow-import +) +from tensorflow_serving.apis import ( + classification_pb2, + prediction_log_pb2, + regression_pb2, +) +from tfx_bsl.public.proto import model_spec_pb2 +from tfx_bsl.telemetry import util # TODO(b/131873699): Remove once 1.x support is dropped. try: - # pylint: disable=g-import-not-at-top - # We need to import this in order to register all quantiles ops, even though - # it's not directly used. - from tensorflow.contrib.boosted_trees.python.ops import quantile_ops as _ # pylint: disable=unused-import + # pylint: disable=g-import-not-at-top + # We need to import this in order to register all quantiles ops, even though + # it's not directly used. + from tensorflow.contrib.boosted_trees.python.ops import ( + quantile_ops as _, # pylint: disable=unused-import + ) except ImportError: - pass + pass -_METRICS_DESCRIPTOR_INFERENCE = 'BulkInferrer' -_METRICS_DESCRIPTOR_IN_PROCESS = 'InProcess' -_METRICS_DESCRIPTOR_CLOUD_AI_PREDICTION = 'CloudAIPlatformPrediction' +_METRICS_DESCRIPTOR_INFERENCE = "BulkInferrer" +_METRICS_DESCRIPTOR_IN_PROCESS = "InProcess" +_METRICS_DESCRIPTOR_CLOUD_AI_PREDICTION = "CloudAIPlatformPrediction" _REMOTE_INFERENCE_NUM_RETRIES = 5 # We define the following aliases of Any because the actual types are not @@ -65,1131 +84,1275 @@ # TODO(b/151468119): Converts this into enum once we stop supporting Python 2.7 -class _OperationType(object): - CLASSIFICATION = 'CLASSIFICATION' - REGRESSION = 'REGRESSION' - MULTI_INFERENCE = 'MULTI_INFERENCE' - PREDICTION = 'PREDICTION' +class _OperationType: + CLASSIFICATION = "CLASSIFICATION" + REGRESSION = "REGRESSION" + MULTI_INFERENCE = "MULTI_INFERENCE" + PREDICTION = "PREDICTION" -_K = TypeVar('_K') +_K = TypeVar("_K") InputType = Union[tf.train.Example, tf.train.SequenceExample, bytes] LoadOverrideFnType = Callable[[str, Sequence[str]], Any] _OUTPUT_TYPE = prediction_log_pb2.PredictionLog def _is_list_type(input_type: beam.typehints.typehints.TypeConstraint) -> bool: - if hasattr(input_type, 'inner_type'): - return input_type == beam.typehints.List[input_type.inner_type] - return False + if hasattr(input_type, "inner_type"): + return input_type == beam.typehints.List[input_type.inner_type] + return False def _key_and_result_type(input_type: beam.typehints.typehints.TypeConstraint): - """Get typehints for key and result type given an input typehint.""" - tuple_types = getattr(input_type, 'tuple_types', None) - if tuple_types is not None and len(tuple_types) == 2: - key_type = tuple_types[0] - value_type = tuple_types[1] - else: - key_type = None - value_type = input_type - if _is_list_type(value_type): - result_type = beam.typehints.List[_OUTPUT_TYPE] - else: - result_type = _OUTPUT_TYPE - return key_type, result_type + """Get typehints for key and result type given an input typehint.""" + tuple_types = getattr(input_type, "tuple_types", None) + if tuple_types is not None and len(tuple_types) == 2: + key_type = tuple_types[0] + value_type = tuple_types[1] + else: + key_type = None + value_type = input_type + if _is_list_type(value_type): + result_type = beam.typehints.List[_OUTPUT_TYPE] + else: + result_type = _OUTPUT_TYPE + return key_type, result_type def _using_in_process_inference( - inference_spec_type: model_spec_pb2.InferenceSpecType) -> bool: - return inference_spec_type.WhichOneof('type') == 'saved_model_spec' + inference_spec_type: model_spec_pb2.InferenceSpecType, +) -> bool: + return inference_spec_type.WhichOneof("type") == "saved_model_spec" def create_model_handler( inference_spec_type: model_spec_pb2.InferenceSpecType, load_override_fn: Optional[LoadOverrideFnType], - options_project_id: Optional[str]) -> base.ModelHandler: - """Creates a ModelHandler based on the InferenceSpecType. + options_project_id: Optional[str], +) -> base.ModelHandler: + """Creates a ModelHandler based on the InferenceSpecType. - Args: - inference_spec_type: Model inference endpoint. - load_override_fn: An option function to load the model, only used with - saved models. - options_project_id: The project id from pipeline options, only used if - there was no project_id specified in the inference_spec_type proto. - - Returns: - A ModelHandler appropriate for the inference_spec_type. - """ - if _using_in_process_inference(inference_spec_type): - return _get_saved_model_handler(inference_spec_type, load_override_fn) - return _RemotePredictModelHandler(inference_spec_type, options_project_id) + Args: + ---- + inference_spec_type: Model inference endpoint. + load_override_fn: An option function to load the model, only used with + saved models. + options_project_id: The project id from pipeline options, only used if + there was no project_id specified in the inference_spec_type proto. + + Returns: + ------- + A ModelHandler appropriate for the inference_spec_type. + """ + if _using_in_process_inference(inference_spec_type): + return _get_saved_model_handler(inference_spec_type, load_override_fn) + return _RemotePredictModelHandler(inference_spec_type, options_project_id) # Output type is inferred from input. -@beam.typehints.with_input_types(Union[InputType, Tuple[_K, InputType], - Tuple[_K, List[InputType]]]) +@beam.typehints.with_input_types( + Union[InputType, Tuple[_K, InputType], Tuple[_K, List[InputType]]] +) class RunInferenceImpl(beam.PTransform): - """Implementation of RunInference API.""" - - def __init__(self, - inference_spec_type: model_spec_pb2.InferenceSpecType, - load_override_fn: Optional[LoadOverrideFnType] = None): - """Initializes transform. - - Args: - inference_spec_type: InferenceSpecType proto. - load_override_fn: If provided, overrides the model loader fn of the - underlying ModelHandler. This takes a model path and sequence of tags, - and should return a model with interface compatible with tf.SavedModel. - """ - self._inference_spec_type = inference_spec_type - self._load_override_fn = load_override_fn - - # LINT.IfChange(close_to_resources) - @staticmethod - def _model_size_bytes(path: str) -> int: - # We might be unable to compute the size of the model during pipeline - # construction, but the model might still be accessible during pipeline - # execution. In such cases we will provide a default value for the model - # size. In general, it is a lot more costly to underestimate the size of - # the model than to overestimate it. - default_model_size = 1 << 30 # 1 GB. - - def file_size(directory, file): - return max(tf.io.gfile.stat(os.path.join(directory, file)).length, 0) - - try: - result = 0 - with futures.ThreadPoolExecutor() as executor: - for directory, _, files in tf.io.gfile.walk(path): - result += sum( - executor.map(functools.partial(file_size, directory), files)) - if result == 0: - result = default_model_size - return result - except OSError: - return default_model_size - - @staticmethod - def _make_close_to_resources( - inference_spec_type: model_spec_pb2.InferenceSpecType) -> str: - """Proximity resources not otherwise known (or visible) to Beam.""" - - if _using_in_process_inference(inference_spec_type): - # The model is expected to be loaded once per worker (as opposed to - # once per thread), due to the use of beam.Shared in pertinent DoFns. - # - # The exact value of this constant is not important; it aims to signify - # that there might be a non-trivial number of model loads. - # - # TODO(katsiapis): Auto(tune) this. - estimated_num_workers = 100 - model_path = inference_spec_type.saved_model_spec.model_path - model_size_bytes = RunInferenceImpl._model_size_bytes(model_path) - return f'{model_path}[{model_size_bytes * estimated_num_workers}]' - else: - # The model is available remotely, so the size of the RPC traffic is - # proportional to the size of the input. - # - # The exact value of this constant is not important; it aims to signify - # that there might be a non-trivial amount of RPC traffic. - # - # TODO(katsiapis): Auto(tune) this. - estimated_rpc_traffic_size_bytes = 1 << 40 # 1 TB. - - # TODO(katsiapis): Is it possible to query the AI platform to see what - # zones the model is available in, so that we can instead provide a - # descriptor along the lines of: f'zone1|zone2|...|zoneN[size]'? - del estimated_rpc_traffic_size_bytes - return '' - # LINT.ThenChange(../../../../learning/serving/contrib/servables/tensorflow/flume/bulk-inference.cc:close_to_resources) - - def infer_output_type(self, input_type): - key_type, result_type = _key_and_result_type(input_type) - if key_type is not None: - return beam.typehints.Tuple[key_type, result_type] - return result_type - - def expand(self, examples: beam.PCollection) -> beam.PCollection: - logging.info('RunInference on model: %s', self._inference_spec_type) - output_type = self.infer_output_type(examples.element_type) - # TODO(b/217271822): Do this unconditionally after BEAM-13690 is resolved. - if resources.ResourceHint.is_registered('close_to_resources'): - examples |= ( - 'CloseToResources' >> beam.Map(lambda x: x).with_resource_hints( - close_to_resources=self._make_close_to_resources( - self._inference_spec_type))) - handler = create_model_handler( - self._inference_spec_type, self._load_override_fn, - examples.pipeline.options.view_as(GoogleCloudOptions).project) - handler = _ModelHandlerWrapper(handler) - return examples | 'BulkInference' >> base.RunInference( - handler).with_output_types(output_type) + """Implementation of RunInference API.""" + + def __init__( + self, + inference_spec_type: model_spec_pb2.InferenceSpecType, + load_override_fn: Optional[LoadOverrideFnType] = None, + ): + """Initializes transform. + + Args: + ---- + inference_spec_type: InferenceSpecType proto. + load_override_fn: If provided, overrides the model loader fn of the + underlying ModelHandler. This takes a model path and sequence of tags, + and should return a model with interface compatible with tf.SavedModel. + """ + self._inference_spec_type = inference_spec_type + self._load_override_fn = load_override_fn + + # LINT.IfChange(close_to_resources) + @staticmethod + def _model_size_bytes(path: str) -> int: + # We might be unable to compute the size of the model during pipeline + # construction, but the model might still be accessible during pipeline + # execution. In such cases we will provide a default value for the model + # size. In general, it is a lot more costly to underestimate the size of + # the model than to overestimate it. + default_model_size = 1 << 30 # 1 GB. + + def file_size(directory, file): + return max(tf.io.gfile.stat(os.path.join(directory, file)).length, 0) + + try: + result = 0 + with futures.ThreadPoolExecutor() as executor: + for directory, _, files in tf.io.gfile.walk(path): + result += sum( + executor.map(functools.partial(file_size, directory), files) + ) + if result == 0: + result = default_model_size + return result + except OSError: + return default_model_size + + @staticmethod + def _make_close_to_resources( + inference_spec_type: model_spec_pb2.InferenceSpecType, + ) -> str: + """Proximity resources not otherwise known (or visible) to Beam.""" + if _using_in_process_inference(inference_spec_type): + # The model is expected to be loaded once per worker (as opposed to + # once per thread), due to the use of beam.Shared in pertinent DoFns. + # + # The exact value of this constant is not important; it aims to signify + # that there might be a non-trivial number of model loads. + # + # TODO(katsiapis): Auto(tune) this. + estimated_num_workers = 100 + model_path = inference_spec_type.saved_model_spec.model_path + model_size_bytes = RunInferenceImpl._model_size_bytes(model_path) + return f"{model_path}[{model_size_bytes * estimated_num_workers}]" + else: + # The model is available remotely, so the size of the RPC traffic is + # proportional to the size of the input. + # + # The exact value of this constant is not important; it aims to signify + # that there might be a non-trivial amount of RPC traffic. + # + # TODO(katsiapis): Auto(tune) this. + estimated_rpc_traffic_size_bytes = 1 << 40 # 1 TB. + + # TODO(katsiapis): Is it possible to query the AI platform to see what + # zones the model is available in, so that we can instead provide a + # descriptor along the lines of: f'zone1|zone2|...|zoneN[size]'? + del estimated_rpc_traffic_size_bytes + return "" + + # LINT.ThenChange(../../../../learning/serving/contrib/servables/tensorflow/flume/bulk-inference.cc:close_to_resources) + + def infer_output_type(self, input_type): + key_type, result_type = _key_and_result_type(input_type) + if key_type is not None: + return beam.typehints.Tuple[key_type, result_type] + return result_type + + def expand(self, examples: beam.PCollection) -> beam.PCollection: + logging.info("RunInference on model: %s", self._inference_spec_type) + output_type = self.infer_output_type(examples.element_type) + # TODO(b/217271822): Do this unconditionally after BEAM-13690 is resolved. + if resources.ResourceHint.is_registered("close_to_resources"): + examples |= "CloseToResources" >> beam.Map(lambda x: x).with_resource_hints( + close_to_resources=self._make_close_to_resources( + self._inference_spec_type + ) + ) + handler = create_model_handler( + self._inference_spec_type, + self._load_override_fn, + examples.pipeline.options.view_as(GoogleCloudOptions).project, + ) + handler = _ModelHandlerWrapper(handler) + return examples | "BulkInference" >> base.RunInference( + handler + ).with_output_types(output_type) def _get_saved_model_handler( inference_spec_type: model_spec_pb2.InferenceSpecType, - load_override_fn: Optional[LoadOverrideFnType]) -> base.ModelHandler: - """Get an in-process ModelHandler.""" - operation_type = _get_operation_type(inference_spec_type) - if operation_type == _OperationType.CLASSIFICATION: - return _ClassifyModelHandler(inference_spec_type, load_override_fn) - elif operation_type == _OperationType.REGRESSION: - return _RegressModelHandler(inference_spec_type, load_override_fn) - elif operation_type == _OperationType.MULTI_INFERENCE: - return _MultiInferenceModelHandler(inference_spec_type, load_override_fn) - elif operation_type == _OperationType.PREDICTION: - return _PredictModelHandler(inference_spec_type, load_override_fn) - else: - raise ValueError('Unsupported operation_type %s' % operation_type) + load_override_fn: Optional[LoadOverrideFnType], +) -> base.ModelHandler: + """Get an in-process ModelHandler.""" + operation_type = _get_operation_type(inference_spec_type) + if operation_type == _OperationType.CLASSIFICATION: + return _ClassifyModelHandler(inference_spec_type, load_override_fn) + elif operation_type == _OperationType.REGRESSION: + return _RegressModelHandler(inference_spec_type, load_override_fn) + elif operation_type == _OperationType.MULTI_INFERENCE: + return _MultiInferenceModelHandler(inference_spec_type, load_override_fn) + elif operation_type == _OperationType.PREDICTION: + return _PredictModelHandler(inference_spec_type, load_override_fn) + else: + raise ValueError("Unsupported operation_type %s" % operation_type) # Output type is inferred from input. -@beam.typehints.with_input_types(Union[InputType, Tuple[_K, InputType], - Tuple[_K, List[InputType]]]) +@beam.typehints.with_input_types( + Union[InputType, Tuple[_K, InputType], Tuple[_K, List[InputType]]] +) class RunInferencePerModelImpl(beam.PTransform): - """Implementation of the vectorized variant of the RunInference API.""" - - def __init__(self, - inference_spec_types: Iterable[model_spec_pb2.InferenceSpecType], - load_override_fn: Optional[LoadOverrideFnType] = None): - """Initializes transform. - - Args: - inference_spec_types: InferenceSpecType proto. - load_override_fn: If provided, overrides the model loader fn of the - underlying ModelHandler. This takes a model path and sequence of tags, - and should return a model with interface compatible with tf.SavedModel. - """ - self._inference_spec_types = tuple(inference_spec_types) - self._load_override_fn = load_override_fn - - def infer_output_type(self, input_type): - key_type, result_type = _key_and_result_type(input_type) - result_type = beam.typehints.Tuple[(result_type,) * - len(self._inference_spec_types)] - if key_type is not None: - return beam.typehints.Tuple[key_type, result_type] - return result_type - - def expand(self, examples: beam.PCollection) -> beam.PCollection: - output_type = self.infer_output_type(examples.element_type) - - # TODO(b/217442215): Obviate the need for this block (and instead rely - # solely on the one within RunInferenceImpl::expand). - # TODO(b/217271822): Do this unconditionally after BEAM-13690 is resolved. - if resources.ResourceHint.is_registered('close_to_resources'): - examples |= ( - 'CloseToResources' >> beam.Map(lambda x: x).with_resource_hints( - close_to_resources=','.join([ - RunInferenceImpl._make_close_to_resources(s) # pylint: disable=protected-access - for s in self._inference_spec_types - ]))) - - tuple_types = getattr(examples.element_type, 'tuple_types', None) - if tuple_types is None or len(tuple_types) != 2: - # The input is not a KV, so pair with a dummy key, run the inferences, and - # drop the dummy key afterwards. - return (examples - | 'PairWithNone' >> beam.Map(lambda x: (None, x)) - | 'ApplyOnKeyedInput' >> RunInferencePerModelImpl( - self._inference_spec_types) - | 'DropNone' >> beam.Values().with_output_types(output_type)) - - def infer_iteration_output_type(input_type): - """Infers ouput typehint for Iteration Ptransform based on input_type.""" - tuple_types = getattr(input_type, 'tuple_types', None) - output_tuple_components = [] - if tuple_types is not None: - output_tuple_components.extend(tuple_types) - example_type = tuple_types[1] - else: - output_tuple_components.append(input_type) - example_type = input_type - - if _is_list_type(example_type): - inference_result_type = beam.typehints.List[_OUTPUT_TYPE] - else: - inference_result_type = _OUTPUT_TYPE - output_tuple_components.append(inference_result_type) - return beam.typehints.Tuple[output_tuple_components] - - @beam.ptransform_fn - def Iteration(pcoll, inference_spec_type): # pylint: disable=invalid-name - return (pcoll - | 'PairWithInput' >> beam.Map(lambda x: (x, x[1])) - | 'RunInferenceImpl' >> RunInferenceImpl(inference_spec_type, - self._load_override_fn) - | 'ExtendResults' >> - beam.MapTuple(lambda k, v: k + (v,)).with_output_types( - infer_iteration_output_type(pcoll.element_type))) - - result = examples - for i, inference_spec_type in enumerate(self._inference_spec_types): - result |= f'Model[{i}]' >> Iteration(inference_spec_type) # pylint: disable=no-value-for-parameter - result |= 'ExtractResults' >> beam.Map( - lambda tup: (tup[0], tuple(tup[2:]))).with_output_types(output_type) - return result - - -_IOTensorSpec = NamedTuple('_IOTensorSpec', - [('input_tensor_alias', Text), - ('input_tensor_name', Text), - ('output_alias_tensor_names', Dict[Text, Text])]) - -_Signature = NamedTuple('_Signature', [('name', Text), - ('signature_def', _SignatureDef)]) + """Implementation of the vectorized variant of the RunInference API.""" + + def __init__( + self, + inference_spec_types: Iterable[model_spec_pb2.InferenceSpecType], + load_override_fn: Optional[LoadOverrideFnType] = None, + ): + """Initializes transform. + + Args: + ---- + inference_spec_types: InferenceSpecType proto. + load_override_fn: If provided, overrides the model loader fn of the + underlying ModelHandler. This takes a model path and sequence of tags, + and should return a model with interface compatible with tf.SavedModel. + """ + self._inference_spec_types = tuple(inference_spec_types) + self._load_override_fn = load_override_fn + + def infer_output_type(self, input_type): + key_type, result_type = _key_and_result_type(input_type) + result_type = beam.typehints.Tuple[ + (result_type,) * len(self._inference_spec_types) + ] + if key_type is not None: + return beam.typehints.Tuple[key_type, result_type] + return result_type + + def expand(self, examples: beam.PCollection) -> beam.PCollection: + output_type = self.infer_output_type(examples.element_type) + + # TODO(b/217442215): Obviate the need for this block (and instead rely + # solely on the one within RunInferenceImpl::expand). + # TODO(b/217271822): Do this unconditionally after BEAM-13690 is resolved. + if resources.ResourceHint.is_registered("close_to_resources"): + examples |= "CloseToResources" >> beam.Map(lambda x: x).with_resource_hints( + close_to_resources=",".join( + [ + RunInferenceImpl._make_close_to_resources(s) # pylint: disable=protected-access + for s in self._inference_spec_types + ] + ) + ) + + tuple_types = getattr(examples.element_type, "tuple_types", None) + if tuple_types is None or len(tuple_types) != 2: + # The input is not a KV, so pair with a dummy key, run the inferences, and + # drop the dummy key afterwards. + return ( + examples + | "PairWithNone" >> beam.Map(lambda x: (None, x)) + | "ApplyOnKeyedInput" + >> RunInferencePerModelImpl(self._inference_spec_types) + | "DropNone" >> beam.Values().with_output_types(output_type) + ) + + def infer_iteration_output_type(input_type): + """Infers ouput typehint for Iteration Ptransform based on input_type.""" + tuple_types = getattr(input_type, "tuple_types", None) + output_tuple_components = [] + if tuple_types is not None: + output_tuple_components.extend(tuple_types) + example_type = tuple_types[1] + else: + output_tuple_components.append(input_type) + example_type = input_type + + if _is_list_type(example_type): + inference_result_type = beam.typehints.List[_OUTPUT_TYPE] + else: + inference_result_type = _OUTPUT_TYPE + output_tuple_components.append(inference_result_type) + return beam.typehints.Tuple[output_tuple_components] + + @beam.ptransform_fn + def Iteration(pcoll, inference_spec_type): # pylint: disable=invalid-name + return ( + pcoll + | "PairWithInput" >> beam.Map(lambda x: (x, x[1])) + | "RunInferenceImpl" + >> RunInferenceImpl(inference_spec_type, self._load_override_fn) + | "ExtendResults" + >> beam.MapTuple(lambda k, v: k + (v,)).with_output_types( + infer_iteration_output_type(pcoll.element_type) + ) + ) + + result = examples + for i, inference_spec_type in enumerate(self._inference_spec_types): + result |= f"Model[{i}]" >> Iteration(inference_spec_type) # pylint: disable=no-value-for-parameter + result |= "ExtractResults" >> beam.Map( + lambda tup: (tup[0], tuple(tup[2:])) + ).with_output_types(output_type) + return result + + +class _IOTensorSpec(NamedTuple): + input_tensor_alias: str + input_tensor_name: str + output_alias_tensor_names: Dict[str, str] + + +class _Signature(NamedTuple): + name: str + signature_def: _SignatureDef def _retry_on_unavailable_and_resource_error_filter(exception: Exception): - """Retries for HttpError. - - Retries if error is unavailable (503) or resource exhausted (429). - Resource exhausted may happen when qps or bandwidth exceeds quota. + """Retries for HttpError. - Args: - exception: Exception from inference http request execution. + Retries if error is unavailable (503) or resource exhausted (429). + Resource exhausted may happen when qps or bandwidth exceeds quota. - Returns: - A boolean of whether retry. - """ + Args: + ---- + exception: Exception from inference http request execution. - return (isinstance(exception, googleapiclient.errors.HttpError) and - exception.resp.status in (503, 429)) + Returns: + ------- + A boolean of whether retry. + """ + return isinstance( + exception, googleapiclient.errors.HttpError + ) and exception.resp.status in (503, 429) class _BaseModelHandler(base.ModelHandler, metaclass=abc.ABCMeta): - """A basic TFX implementation of ModelHandler.""" - - def __init__(self, inference_spec_type: model_spec_pb2.InferenceSpecType): - super().__init__() - operation_type = _get_operation_type(inference_spec_type) - proximity_descriptor = ( - _METRICS_DESCRIPTOR_IN_PROCESS - if _using_in_process_inference(inference_spec_type) else - _METRICS_DESCRIPTOR_CLOUD_AI_PREDICTION) - self._metrics_namespace = util.MakeTfxNamespace( - [_METRICS_DESCRIPTOR_INFERENCE, operation_type, proximity_descriptor]) - self._batch_elements_kwargs = {} - for desc, val in inference_spec_type.batch_parameters.ListFields(): - self._batch_elements_kwargs[desc.name] = val - - def run_inference( - self, - examples: List[InputType], - model: Any, - inference_args=None) -> Iterable[prediction_log_pb2.PredictionLog]: - serialized_examples = [ - e if isinstance(e, bytes) else e.SerializeToString() for e in examples - ] - self._check_examples(examples) - outputs = self._run_inference(examples, serialized_examples, model) - return self._post_process(examples, serialized_examples, outputs) - - def _check_examples(self, examples): - pass - - def get_num_bytes( - self, examples: Iterable[prediction_log_pb2.PredictionLog]) -> int: - serialized_examples = [ - e if isinstance(e, bytes) else e.SerializeToString() for e in examples - ] - return sum(len(se) for se in serialized_examples) + """A basic TFX implementation of ModelHandler.""" + + def __init__(self, inference_spec_type: model_spec_pb2.InferenceSpecType): + super().__init__() + operation_type = _get_operation_type(inference_spec_type) + proximity_descriptor = ( + _METRICS_DESCRIPTOR_IN_PROCESS + if _using_in_process_inference(inference_spec_type) + else _METRICS_DESCRIPTOR_CLOUD_AI_PREDICTION + ) + self._metrics_namespace = util.MakeTfxNamespace( + [_METRICS_DESCRIPTOR_INFERENCE, operation_type, proximity_descriptor] + ) + self._batch_elements_kwargs = {} + for desc, val in inference_spec_type.batch_parameters.ListFields(): + self._batch_elements_kwargs[desc.name] = val + + def run_inference( + self, examples: List[InputType], model: Any, inference_args=None + ) -> Iterable[prediction_log_pb2.PredictionLog]: + serialized_examples = [ + e if isinstance(e, bytes) else e.SerializeToString() for e in examples + ] + self._check_examples(examples) + outputs = self._run_inference(examples, serialized_examples, model) + return self._post_process(examples, serialized_examples, outputs) + + def _check_examples(self, examples): + pass + + def get_num_bytes( + self, examples: Iterable[prediction_log_pb2.PredictionLog] + ) -> int: + serialized_examples = [ + e if isinstance(e, bytes) else e.SerializeToString() for e in examples + ] + return sum(len(se) for se in serialized_examples) + + def get_metrics_namespace(self): + return self._metrics_namespace + + def batch_elements_kwargs(self) -> Mapping[str, Any]: + return self._batch_elements_kwargs + + @abc.abstractmethod + def _post_process( + self, + examples: List[InputType], + serialized_examples: List[bytes], + outputs: List[Mapping[str, Union[np.ndarray, Any]]], + ) -> List[prediction_log_pb2.PredictionLog]: + raise NotImplementedError + + @abc.abstractmethod + def _run_inference( + self, examples: List[InputType], serialized_examples: List[bytes], model + ) -> List[Mapping[str, Any]]: + raise NotImplementedError - def get_metrics_namespace(self): - return self._metrics_namespace - def batch_elements_kwargs(self) -> Mapping[str, Any]: - return self._batch_elements_kwargs +# TODO(b/151468119): Consider to re-batch with online serving request size +# limit, and re-batch with RPC failures(InvalidArgument) regarding request size. +class _RemotePredictModelHandler(_BaseModelHandler): + """Performs predictions from a cloud-hosted TensorFlow model. - @abc.abstractmethod - def _post_process( - self, examples: List[InputType], serialized_examples: List[bytes], - outputs: List[Mapping[Text, Union[np.ndarray, Any]]] - ) -> List[prediction_log_pb2.PredictionLog]: - raise NotImplementedError + Supports both batch and streaming processing modes. + NOTE: Does not work on DirectRunner for streaming jobs [BEAM-7885]. - @abc.abstractmethod - def _run_inference(self, examples: List[InputType], - serialized_examples: List[bytes], - model) -> List[Mapping[Text, Any]]: - raise NotImplementedError + In order to request predictions, you must deploy your trained model to AI + Platform Prediction in the TensorFlow SavedModel format. See + [Exporting a SavedModel for prediction] + (https://cloud.google.com/ai-platform/prediction/docs/exporting-savedmodel-for-prediction) + for more details. + To send binary data, you have to make sure that the name of an input ends in + `_bytes`. -# TODO(b/151468119): Consider to re-batch with online serving request size -# limit, and re-batch with RPC failures(InvalidArgument) regarding request size. -class _RemotePredictModelHandler(_BaseModelHandler): - """Performs predictions from a cloud-hosted TensorFlow model. - - Supports both batch and streaming processing modes. - NOTE: Does not work on DirectRunner for streaming jobs [BEAM-7885]. - - In order to request predictions, you must deploy your trained model to AI - Platform Prediction in the TensorFlow SavedModel format. See - [Exporting a SavedModel for prediction] - (https://cloud.google.com/ai-platform/prediction/docs/exporting-savedmodel-for-prediction) - for more details. - - To send binary data, you have to make sure that the name of an input ends in - `_bytes`. - - NOTE: The returned `PredictLog` instances do not have `PredictRequest` part - filled. The reason is that it is difficult to determine the input tensor name - without having access to cloud-hosted model's signatures. - """ - - def __init__(self, inference_spec_type: model_spec_pb2.InferenceSpecType, - pipeline_options_project_id: Optional[str]): - super().__init__(inference_spec_type) - self._ai_platform_prediction_model_spec = ( - inference_spec_type.ai_platform_prediction_model_spec) - self._api_client = None - project_id = ( - inference_spec_type.ai_platform_prediction_model_spec.project_id or - pipeline_options_project_id) - if not project_id: - raise ValueError('Either a non-empty project id or project flag in ' - ' beam pipeline options needs be provided.') - - model_name = ( - inference_spec_type.ai_platform_prediction_model_spec.model_name) - if not model_name: - raise ValueError('A non-empty model name must be provided.') - - version_name = ( - inference_spec_type.ai_platform_prediction_model_spec.version_name) - name_spec = 'projects/{}/models/{}' - # If version is not specified, the default version for a model is used. - if version_name: - name_spec += '/versions/{}' - self._full_model_name = name_spec.format(project_id, model_name, - version_name) - - # Retry _REMOTE_INFERENCE_NUM_RETRIES times with exponential backoff. - @retry.with_exponential_backoff( - initial_delay_secs=1.0, - num_retries=_REMOTE_INFERENCE_NUM_RETRIES, - retry_filter=_retry_on_unavailable_and_resource_error_filter) - def _execute_request( - self, - request: http.HttpRequest) -> Mapping[Text, Sequence[Mapping[Text, Any]]]: - result = request.execute() - if 'error' in result: - raise ValueError(result['error']) - return result + NOTE: The returned `PredictLog` instances do not have `PredictRequest` part + filled. The reason is that it is difficult to determine the input tensor name + without having access to cloud-hosted model's signatures. + """ - def _make_instances( - self, - examples: List[Union[tf.train.Example, tf.train.SequenceExample]], - serialized_examples: List[bytes] - )-> List[Mapping[Text, Any]]: - if self._ai_platform_prediction_model_spec.use_serialization_config: - return [{'b64': base64.b64encode(se).decode()} - for se in serialized_examples] - else: - result = [] - for example in examples: - instance = {} - for name, feature in example.features.feature.items(): - attribute_kind = feature.WhichOneof('kind') - if attribute_kind is None: - continue - values = self._make_values(name, feature, attribute_kind) - instance[name] = values[0] if len(values) == 1 else values - result.append(instance) - return result - - @staticmethod - def _make_values(name: Text, feature: Any, attribute_kind: Text) -> List[Any]: - values = getattr(feature, attribute_kind).value - if name.endswith('_bytes'): - return [{'b64': base64.b64encode(x).decode()} for x in values] - elif attribute_kind == 'bytes_list': - return [x.decode() for x in values] - else: - # Converts proto RepeatedScalarContainer to list so it is - # JSON-serializable. - return list(values) - - def load_model(self): - # TODO(b/151468119): Add tfx_bsl_version and tfx_bsl_py_version to - # user agent once custom header is supported in googleapiclient. - self._api_client = discovery.build('ml', 'v1') - # load_model returns a locally hosted model. Since all these inferences - # are run on vertexAI, no local model is present. - return None - - def _check_examples(self, examples: List[InputType]): - # TODO(b/131873699): Add support for tf.train.SequenceExample even when - # use_serialization_config is not enabled (by appropriately modifying - # _make_instances). - allowed_types = ( - (tf.train.Example, tf.train.SequenceExample, bytes) - if self._ai_platform_prediction_model_spec.use_serialization_config - else tf.train.Example) - if not all(isinstance(e, allowed_types) for e in examples): - raise NotImplementedError( - 'RemotePredict supports raw and serialized tf.train.Example, raw and ' - 'serialized tf.SequenceExample and raw bytes (the ' - 'latter three only when use_serialization_config is true)') - - def _run_inference(self, examples: List[InputType], - serialized_examples: List[bytes], - model) -> List[Mapping[Text, Any]]: - self._check_examples(examples) - body = {'instances': self._make_instances(examples, serialized_examples)} - if self._api_client is None: - raise ValueError( - 'API client is not initialized. Call load_model() first.' - ) - request = self._api_client.projects().predict( - name=self._full_model_name, body=body) - response = self._execute_request(request) - return response['predictions'] - - def _post_process( - self, examples: List[InputType], serialized_examples: List[bytes], - outputs: List[Mapping[Text, - Any]]) -> List[prediction_log_pb2.PredictionLog]: - del examples - result = [] - for i, serialized_example in enumerate(serialized_examples): - prediction_log = prediction_log_pb2.PredictionLog() - predict_log = prediction_log.predict_log - input_tensor_proto = predict_log.request.inputs[ - tf.saved_model.PREDICT_INPUTS] - input_tensor_proto.dtype = tf.string.as_datatype_enum - input_tensor_proto.tensor_shape.dim.add().size = 1 - input_tensor_proto.string_val.append(serialized_example) - for output_alias, values in outputs[i].items(): - values = np.array(values) - tensor_proto = tf.make_tensor_proto( - values=values, - dtype=tf.as_dtype(values.dtype).as_datatype_enum, - shape=np.expand_dims(values, axis=0).shape) - predict_log.response.outputs[output_alias].CopyFrom(tensor_proto) - result.append(prediction_log) - return result + def __init__( + self, + inference_spec_type: model_spec_pb2.InferenceSpecType, + pipeline_options_project_id: Optional[str], + ): + super().__init__(inference_spec_type) + self._ai_platform_prediction_model_spec = ( + inference_spec_type.ai_platform_prediction_model_spec + ) + self._api_client = None + project_id = ( + inference_spec_type.ai_platform_prediction_model_spec.project_id + or pipeline_options_project_id + ) + if not project_id: + raise ValueError( + "Either a non-empty project id or project flag in " + " beam pipeline options needs be provided." + ) + + model_name = inference_spec_type.ai_platform_prediction_model_spec.model_name + if not model_name: + raise ValueError("A non-empty model name must be provided.") + + version_name = ( + inference_spec_type.ai_platform_prediction_model_spec.version_name + ) + name_spec = "projects/{}/models/{}" + # If version is not specified, the default version for a model is used. + if version_name: + name_spec += "/versions/{}" + self._full_model_name = name_spec.format(project_id, model_name, version_name) + + # Retry _REMOTE_INFERENCE_NUM_RETRIES times with exponential backoff. + @retry.with_exponential_backoff( + initial_delay_secs=1.0, + num_retries=_REMOTE_INFERENCE_NUM_RETRIES, + retry_filter=_retry_on_unavailable_and_resource_error_filter, + ) + def _execute_request( + self, request: http.HttpRequest + ) -> Mapping[str, Sequence[Mapping[str, Any]]]: + result = request.execute() + if "error" in result: + raise ValueError(result["error"]) + return result + + def _make_instances( + self, + examples: List[Union[tf.train.Example, tf.train.SequenceExample]], + serialized_examples: List[bytes], + ) -> List[Mapping[str, Any]]: + if self._ai_platform_prediction_model_spec.use_serialization_config: + return [ + {"b64": base64.b64encode(se).decode()} for se in serialized_examples + ] + else: + result = [] + for example in examples: + instance = {} + for name, feature in example.features.feature.items(): + attribute_kind = feature.WhichOneof("kind") + if attribute_kind is None: + continue + values = self._make_values(name, feature, attribute_kind) + instance[name] = values[0] if len(values) == 1 else values + result.append(instance) + return result + + @staticmethod + def _make_values(name: str, feature: Any, attribute_kind: str) -> List[Any]: + values = getattr(feature, attribute_kind).value + if name.endswith("_bytes"): + return [{"b64": base64.b64encode(x).decode()} for x in values] + elif attribute_kind == "bytes_list": + return [x.decode() for x in values] + else: + # Converts proto RepeatedScalarContainer to list so it is + # JSON-serializable. + return list(values) + + def load_model(self): + # TODO(b/151468119): Add tfx_bsl_version and tfx_bsl_py_version to + # user agent once custom header is supported in googleapiclient. + self._api_client = discovery.build("ml", "v1") + # load_model returns a locally hosted model. Since all these inferences + # are run on vertexAI, no local model is present. + return + + def _check_examples(self, examples: List[InputType]): + # TODO(b/131873699): Add support for tf.train.SequenceExample even when + # use_serialization_config is not enabled (by appropriately modifying + # _make_instances). + allowed_types = ( + (tf.train.Example, tf.train.SequenceExample, bytes) + if self._ai_platform_prediction_model_spec.use_serialization_config + else tf.train.Example + ) + if not all(isinstance(e, allowed_types) for e in examples): + raise NotImplementedError( + "RemotePredict supports raw and serialized tf.train.Example, raw and " + "serialized tf.SequenceExample and raw bytes (the " + "latter three only when use_serialization_config is true)" + ) + + def _run_inference( + self, examples: List[InputType], serialized_examples: List[bytes], model + ) -> List[Mapping[str, Any]]: + self._check_examples(examples) + body = {"instances": self._make_instances(examples, serialized_examples)} + if self._api_client is None: + raise ValueError("API client is not initialized. Call load_model() first.") + request = self._api_client.projects().predict( + name=self._full_model_name, body=body + ) + response = self._execute_request(request) + return response["predictions"] + + def _post_process( + self, + examples: List[InputType], + serialized_examples: List[bytes], + outputs: List[Mapping[str, Any]], + ) -> List[prediction_log_pb2.PredictionLog]: + del examples + result = [] + for i, serialized_example in enumerate(serialized_examples): + prediction_log = prediction_log_pb2.PredictionLog() + predict_log = prediction_log.predict_log + input_tensor_proto = predict_log.request.inputs[ + tf.saved_model.PREDICT_INPUTS + ] + input_tensor_proto.dtype = tf.string.as_datatype_enum + input_tensor_proto.tensor_shape.dim.add().size = 1 + input_tensor_proto.string_val.append(serialized_example) + for output_alias, values in outputs[i].items(): + values = np.array(values) + tensor_proto = tf.make_tensor_proto( + values=values, + dtype=tf.as_dtype(values.dtype).as_datatype_enum, + shape=np.expand_dims(values, axis=0).shape, + ) + predict_log.response.outputs[output_alias].CopyFrom(tensor_proto) + result.append(prediction_log) + return result class _BaseSavedModelHandler(_BaseModelHandler): - """A spec that runs in-process batch inference with a model. + """A spec that runs in-process batch inference with a model. Models need to have the required serving signature as mentioned in [Tensorflow Serving](https://www.tensorflow.org/tfx/serving/signature_defs) This function will check model signatures first. Then it will load and run model inference in batch. - """ - - def __init__(self, inference_spec_type: model_spec_pb2.InferenceSpecType, - load_override_fn: Optional[LoadOverrideFnType]): - super().__init__(inference_spec_type) - self._inference_spec_type = inference_spec_type - self._model_path = inference_spec_type.saved_model_spec.model_path - if not self._model_path: - raise ValueError('Model path is not valid.') - self._tags = _get_tags(inference_spec_type) - self._signatures = _get_signatures( - inference_spec_type.saved_model_spec.model_path, - inference_spec_type.saved_model_spec.signature_name, self._tags) - self._io_tensor_spec = self._make_io_tensor_spec() - if self._has_tpu_tag(): - # TODO(b/161563144): Support TPU inference. - raise NotImplementedError('TPU inference is not supported yet.') - self._load_override_fn = load_override_fn - - def _has_tpu_tag(self) -> bool: - return (len(self._tags) == 2 and tf.saved_model.SERVING in self._tags and - tf.saved_model.TPU in self._tags) - - # TODO(b/159982957): Replace this with a mechinism that registers any custom - # op. - def _maybe_register_addon_ops(self): - - def _try_import(name): - try: - importlib.import_module(name) - except (ImportError, tf.errors.NotFoundError): - logging.info('%s is not available.', name) - - _try_import('tensorflow_text') - _try_import('tensorflow_decision_forests') - _try_import('struct2tensor') - - def load_model(self): - if self._load_override_fn: - return self._load_override_fn(self._model_path, self._tags) - self._maybe_register_addon_ops() - result = tf.compat.v1.Session(graph=tf.compat.v1.Graph()) - tf.compat.v1.saved_model.loader.load(result, self._tags, self._model_path) - return result + """ - def _make_io_tensor_spec(self) -> _IOTensorSpec: - # Pre process functions will validate for each signature. - io_tensor_specs = [] - for signature in self._signatures: - if len(signature.signature_def.inputs) != 1: - raise ValueError('Signature should have 1 and only 1 inputs') - if (list(signature.signature_def.inputs.values())[0].dtype != - tf.string.as_datatype_enum): - raise ValueError( - 'Input dtype is expected to be %s, got %s' % - (tf.string.as_datatype_enum, - list(signature.signature_def.inputs.values())[0].dtype)) - io_tensor_specs.append(_signature_pre_process(signature.signature_def)) - input_tensor_name = '' - input_tensor_alias = '' - output_alias_tensor_names = {} - for io_tensor_spec in io_tensor_specs: - if not input_tensor_name: - input_tensor_name = io_tensor_spec.input_tensor_name - input_tensor_alias = io_tensor_spec.input_tensor_alias - elif input_tensor_name != io_tensor_spec.input_tensor_name: - raise ValueError('Input tensor must be the same for all Signatures.') - for alias, tensor_name in io_tensor_spec.output_alias_tensor_names.items( - ): - output_alias_tensor_names[alias] = tensor_name - if (not output_alias_tensor_names or not input_tensor_name or - not input_tensor_alias): - raise ValueError('No valid fetch tensors or feed tensors.') - return _IOTensorSpec(input_tensor_alias, input_tensor_name, - output_alias_tensor_names) - - def _run_inference(self, examples: List[InputType], # pytype: disable=signature-mismatch # overriding-return-type-checks - serialized_examples: List[bytes], - model: Any) -> Mapping[Text, np.ndarray]: - result = model.run( - self._io_tensor_spec.output_alias_tensor_names, - feed_dict={self._io_tensor_spec.input_tensor_name: serialized_examples}) - if len(result) != len(self._io_tensor_spec.output_alias_tensor_names): - raise RuntimeError('Output length does not match fetches') - return result + def __init__( + self, + inference_spec_type: model_spec_pb2.InferenceSpecType, + load_override_fn: Optional[LoadOverrideFnType], + ): + super().__init__(inference_spec_type) + self._inference_spec_type = inference_spec_type + self._model_path = inference_spec_type.saved_model_spec.model_path + if not self._model_path: + raise ValueError("Model path is not valid.") + self._tags = _get_tags(inference_spec_type) + self._signatures = _get_signatures( + inference_spec_type.saved_model_spec.model_path, + inference_spec_type.saved_model_spec.signature_name, + self._tags, + ) + self._io_tensor_spec = self._make_io_tensor_spec() + if self._has_tpu_tag(): + # TODO(b/161563144): Support TPU inference. + raise NotImplementedError("TPU inference is not supported yet.") + self._load_override_fn = load_override_fn + + def _has_tpu_tag(self) -> bool: + return ( + len(self._tags) == 2 + and tf.saved_model.SERVING in self._tags + and tf.saved_model.TPU in self._tags + ) + + # TODO(b/159982957): Replace this with a mechinism that registers any custom + # op. + def _maybe_register_addon_ops(self): + def _try_import(name): + try: + importlib.import_module(name) + except (ImportError, tf.errors.NotFoundError): + logging.info("%s is not available.", name) + + _try_import("tensorflow_text") + _try_import("tensorflow_decision_forests") + _try_import("struct2tensor") + + def load_model(self): + if self._load_override_fn: + return self._load_override_fn(self._model_path, self._tags) + self._maybe_register_addon_ops() + result = tf.compat.v1.Session(graph=tf.compat.v1.Graph()) + tf.compat.v1.saved_model.loader.load(result, self._tags, self._model_path) + return result + + def _make_io_tensor_spec(self) -> _IOTensorSpec: + # Pre process functions will validate for each signature. + io_tensor_specs = [] + for signature in self._signatures: + if len(signature.signature_def.inputs) != 1: + raise ValueError("Signature should have 1 and only 1 inputs") + if ( + list(signature.signature_def.inputs.values())[0].dtype + != tf.string.as_datatype_enum + ): + raise ValueError( + "Input dtype is expected to be %s, got %s" + % ( + tf.string.as_datatype_enum, + list(signature.signature_def.inputs.values())[0].dtype, + ) + ) + io_tensor_specs.append(_signature_pre_process(signature.signature_def)) + input_tensor_name = "" + input_tensor_alias = "" + output_alias_tensor_names = {} + for io_tensor_spec in io_tensor_specs: + if not input_tensor_name: + input_tensor_name = io_tensor_spec.input_tensor_name + input_tensor_alias = io_tensor_spec.input_tensor_alias + elif input_tensor_name != io_tensor_spec.input_tensor_name: + raise ValueError("Input tensor must be the same for all Signatures.") + for alias, tensor_name in io_tensor_spec.output_alias_tensor_names.items(): + output_alias_tensor_names[alias] = tensor_name + if ( + not output_alias_tensor_names + or not input_tensor_name + or not input_tensor_alias + ): + raise ValueError("No valid fetch tensors or feed tensors.") + return _IOTensorSpec( + input_tensor_alias, input_tensor_name, output_alias_tensor_names + ) + + def _run_inference( + self, + examples: List[ + InputType + ], # pytype: disable=signature-mismatch # overriding-return-type-checks + serialized_examples: List[bytes], + model: Any, + ) -> Mapping[str, np.ndarray]: + result = model.run( + self._io_tensor_spec.output_alias_tensor_names, + feed_dict={self._io_tensor_spec.input_tensor_name: serialized_examples}, + ) + if len(result) != len(self._io_tensor_spec.output_alias_tensor_names): + raise RuntimeError("Output length does not match fetches") + return result class _ClassifyModelHandler(_BaseSavedModelHandler): - """Implements a spec for classification.""" - - def _check_examples(self, examples: List[InputType]): - if not all(isinstance(e, (tf.train.Example, bytes)) for e in examples): - raise ValueError( - 'Classify only supports raw or serialized tf.train.Example') - - def _post_process( # pytype: disable=signature-mismatch # overriding-parameter-type-checks - self, examples: List[Union[tf.train.Example, - bytes]], serialized_examples: List[bytes], - outputs: Mapping[Text, - np.ndarray]) -> List[prediction_log_pb2.PredictionLog]: - del serialized_examples - # TODO(b/131873699): Can we fold prediction_log_pb2.PredictionLog building - # into _post_process_classify? - classifications = _post_process_classify( - self._io_tensor_spec.output_alias_tensor_names, examples, outputs) - result = [] - for example, classification in zip(examples, classifications): - prediction_log = prediction_log_pb2.PredictionLog() - input_example = (prediction_log.classify_log.request.input.example_list - .examples.add()) - (input_example.ParseFromString - if isinstance(example, bytes) - else input_example.CopyFrom)(example) - (prediction_log.classify_log.response.result.classifications.add() - .CopyFrom(classification)) - result.append(prediction_log) - return result + """Implements a spec for classification.""" + + def _check_examples(self, examples: List[InputType]): + if not all(isinstance(e, (tf.train.Example, bytes)) for e in examples): + raise ValueError( + "Classify only supports raw or serialized tf.train.Example" + ) + + def _post_process( # pytype: disable=signature-mismatch # overriding-parameter-type-checks + self, + examples: List[Union[tf.train.Example, bytes]], + serialized_examples: List[bytes], + outputs: Mapping[str, np.ndarray], + ) -> List[prediction_log_pb2.PredictionLog]: + del serialized_examples + # TODO(b/131873699): Can we fold prediction_log_pb2.PredictionLog building + # into _post_process_classify? + classifications = _post_process_classify( + self._io_tensor_spec.output_alias_tensor_names, examples, outputs + ) + result = [] + for example, classification in zip(examples, classifications): + prediction_log = prediction_log_pb2.PredictionLog() + input_example = ( + prediction_log.classify_log.request.input.example_list.examples.add() + ) + ( + input_example.ParseFromString + if isinstance(example, bytes) + else input_example.CopyFrom + )(example) + ( + prediction_log.classify_log.response.result.classifications.add().CopyFrom( + classification + ) + ) + result.append(prediction_log) + return result class _RegressModelHandler(_BaseSavedModelHandler): - """A DoFn that run inference on regression model.""" - - def _check_examples(self, examples: List[InputType]): - if not all(isinstance(e, (tf.train.Example, bytes)) for e in examples): - raise ValueError( - 'Regress only supports raw or serialized tf.train.Example') - - def _post_process( # pytype: disable=signature-mismatch # overriding-parameter-type-checks - self, examples: List[Union[tf.train.Example, - bytes]], serialized_examples: List[bytes], - outputs: Mapping[Text, - np.ndarray]) -> List[prediction_log_pb2.PredictionLog]: - del serialized_examples - # TODO(b/131873699): Can we fold prediction_log_pb2.PredictionLog building - # into _post_process_regress? - regressions = _post_process_regress(examples, outputs) - result = [] - for example, regression in zip(examples, regressions): - prediction_log = prediction_log_pb2.PredictionLog() - input_example = (prediction_log.regress_log.request.input.example_list - .examples.add()) - (input_example.ParseFromString - if isinstance(example, bytes) - else input_example.CopyFrom)(example) - prediction_log.regress_log.response.result.regressions.add().CopyFrom( - regression) - result.append(prediction_log) - return result + """A DoFn that run inference on regression model.""" + + def _check_examples(self, examples: List[InputType]): + if not all(isinstance(e, (tf.train.Example, bytes)) for e in examples): + raise ValueError("Regress only supports raw or serialized tf.train.Example") + + def _post_process( # pytype: disable=signature-mismatch # overriding-parameter-type-checks + self, + examples: List[Union[tf.train.Example, bytes]], + serialized_examples: List[bytes], + outputs: Mapping[str, np.ndarray], + ) -> List[prediction_log_pb2.PredictionLog]: + del serialized_examples + # TODO(b/131873699): Can we fold prediction_log_pb2.PredictionLog building + # into _post_process_regress? + regressions = _post_process_regress(examples, outputs) + result = [] + for example, regression in zip(examples, regressions): + prediction_log = prediction_log_pb2.PredictionLog() + input_example = ( + prediction_log.regress_log.request.input.example_list.examples.add() + ) + ( + input_example.ParseFromString + if isinstance(example, bytes) + else input_example.CopyFrom + )(example) + prediction_log.regress_log.response.result.regressions.add().CopyFrom( + regression + ) + result.append(prediction_log) + return result class _MultiInferenceModelHandler(_BaseSavedModelHandler): - """A DoFn that runs inference on multi-head model.""" - - def _check_examples(self, examples: List[InputType]): - if not all(isinstance(e, (tf.train.Example, bytes)) for e in examples): - raise ValueError( - 'Multi inference only supports raw or serialized tf.train.Example') - - def _post_process( # pytype: disable=signature-mismatch # overriding-parameter-type-checks - self, examples: List[Union[tf.train.Example, - bytes]], serialized_examples: List[bytes], - outputs: Mapping[Text, - np.ndarray]) -> List[prediction_log_pb2.PredictionLog]: - del serialized_examples - classifications = None - regressions = None - for signature in self._signatures: - signature_def = signature.signature_def - if signature_def.method_name == tf.saved_model.CLASSIFY_METHOD_NAME: - classifications = _post_process_classify( - self._io_tensor_spec.output_alias_tensor_names, examples, outputs) - elif signature_def.method_name == tf.saved_model.REGRESS_METHOD_NAME: - regressions = _post_process_regress(examples, outputs) - else: - raise ValueError('Signature method %s is not supported for ' - 'multi inference' % signature_def.method_name) - result = [] - for i, example in enumerate(examples): - prediction_log = prediction_log_pb2.PredictionLog() - input_example = (prediction_log.multi_inference_log.request.input - .example_list.examples.add()) - (input_example.ParseFromString - if isinstance(example, bytes) - else input_example.CopyFrom)(example) - response = prediction_log.multi_inference_log.response - for signature in self._signatures: - signature_def = signature.signature_def - inference_result = response.results.add() - if (signature_def.method_name == tf.saved_model.CLASSIFY_METHOD_NAME and - classifications): - inference_result.classification_result.classifications.add().CopyFrom( - classifications[i]) - elif ( - signature_def.method_name == tf.saved_model.REGRESS_METHOD_NAME and - regressions): - inference_result.regression_result.regressions.add().CopyFrom( - regressions[i]) - else: - raise ValueError('Signature method %s is not supported for ' - 'multi inference' % signature_def.method_name) - inference_result.model_spec.signature_name = signature.name - if len(response.results) != len(self._signatures): - raise RuntimeError('Multi inference response result length does not ' - 'match the number of signatures') - result.append(prediction_log) - return result + """A DoFn that runs inference on multi-head model.""" + + def _check_examples(self, examples: List[InputType]): + if not all(isinstance(e, (tf.train.Example, bytes)) for e in examples): + raise ValueError( + "Multi inference only supports raw or serialized tf.train.Example" + ) + + def _post_process( # pytype: disable=signature-mismatch # overriding-parameter-type-checks + self, + examples: List[Union[tf.train.Example, bytes]], + serialized_examples: List[bytes], + outputs: Mapping[str, np.ndarray], + ) -> List[prediction_log_pb2.PredictionLog]: + del serialized_examples + classifications = None + regressions = None + for signature in self._signatures: + signature_def = signature.signature_def + if signature_def.method_name == tf.saved_model.CLASSIFY_METHOD_NAME: + classifications = _post_process_classify( + self._io_tensor_spec.output_alias_tensor_names, examples, outputs + ) + elif signature_def.method_name == tf.saved_model.REGRESS_METHOD_NAME: + regressions = _post_process_regress(examples, outputs) + else: + raise ValueError( + "Signature method %s is not supported for " + "multi inference" % signature_def.method_name + ) + result = [] + for i, example in enumerate(examples): + prediction_log = prediction_log_pb2.PredictionLog() + input_example = prediction_log.multi_inference_log.request.input.example_list.examples.add() + ( + input_example.ParseFromString + if isinstance(example, bytes) + else input_example.CopyFrom + )(example) + response = prediction_log.multi_inference_log.response + for signature in self._signatures: + signature_def = signature.signature_def + inference_result = response.results.add() + if ( + signature_def.method_name == tf.saved_model.CLASSIFY_METHOD_NAME + and classifications + ): + inference_result.classification_result.classifications.add().CopyFrom( + classifications[i] + ) + elif ( + signature_def.method_name == tf.saved_model.REGRESS_METHOD_NAME + and regressions + ): + inference_result.regression_result.regressions.add().CopyFrom( + regressions[i] + ) + else: + raise ValueError( + "Signature method %s is not supported for " + "multi inference" % signature_def.method_name + ) + inference_result.model_spec.signature_name = signature.name + if len(response.results) != len(self._signatures): + raise RuntimeError( + "Multi inference response result length does not " + "match the number of signatures" + ) + result.append(prediction_log) + return result class _PredictModelHandler(_BaseSavedModelHandler): - """A DoFn that runs inference on predict model.""" - - def _check_examples(self, examples: List[InputType]): - pass - - def _post_process( # pytype: disable=signature-mismatch # overriding-parameter-type-checks - self, examples: List[InputType], serialized_examples: List[bytes], - outputs: Mapping[Text, - np.ndarray]) -> List[prediction_log_pb2.PredictionLog]: - del examples - input_tensor_alias = self._io_tensor_spec.input_tensor_alias - signature_name = self._signatures[0].name - batch_size = len(serialized_examples) - for output_alias, output in outputs.items(): - if len(output.shape) < 1 or output.shape[0] != batch_size: - raise ValueError( - 'Expected output tensor %s to have at least one ' - 'dimension, with the first having a size equal to the input batch ' - 'size %s. Instead found %s' % - (output_alias, batch_size, output.shape)) - result = [] - for i, serialized_example in enumerate(serialized_examples): - prediction_log = prediction_log_pb2.PredictionLog() - predict_log = prediction_log.predict_log - input_tensor_proto = predict_log.request.inputs[input_tensor_alias] - input_tensor_proto.dtype = tf.string.as_datatype_enum - input_tensor_proto.tensor_shape.dim.add().size = 1 - input_tensor_proto.string_val.append(serialized_example) - predict_log.request.model_spec.signature_name = signature_name - predict_log.response.model_spec.signature_name = signature_name - for output_alias, output in outputs.items(): - # Mimic tensor::Split - values = output[i] - tensor_proto = tf.make_tensor_proto( - values=values, - dtype=tf.as_dtype(values.dtype).as_datatype_enum, - shape=np.expand_dims(values, axis=0).shape) - predict_log.response.outputs[output_alias].CopyFrom(tensor_proto) - result.append(prediction_log) - return result + """A DoFn that runs inference on predict model.""" + + def _check_examples(self, examples: List[InputType]): + pass + + def _post_process( # pytype: disable=signature-mismatch # overriding-parameter-type-checks + self, + examples: List[InputType], + serialized_examples: List[bytes], + outputs: Mapping[str, np.ndarray], + ) -> List[prediction_log_pb2.PredictionLog]: + del examples + input_tensor_alias = self._io_tensor_spec.input_tensor_alias + signature_name = self._signatures[0].name + batch_size = len(serialized_examples) + for output_alias, output in outputs.items(): + if len(output.shape) < 1 or output.shape[0] != batch_size: + raise ValueError( + "Expected output tensor %s to have at least one " + "dimension, with the first having a size equal to the input batch " + "size %s. Instead found %s" + % (output_alias, batch_size, output.shape) + ) + result = [] + for i, serialized_example in enumerate(serialized_examples): + prediction_log = prediction_log_pb2.PredictionLog() + predict_log = prediction_log.predict_log + input_tensor_proto = predict_log.request.inputs[input_tensor_alias] + input_tensor_proto.dtype = tf.string.as_datatype_enum + input_tensor_proto.tensor_shape.dim.add().size = 1 + input_tensor_proto.string_val.append(serialized_example) + predict_log.request.model_spec.signature_name = signature_name + predict_log.response.model_spec.signature_name = signature_name + for output_alias, output in outputs.items(): + # Mimic tensor::Split + values = output[i] + tensor_proto = tf.make_tensor_proto( + values=values, + dtype=tf.as_dtype(values.dtype).as_datatype_enum, + shape=np.expand_dims(values, axis=0).shape, + ) + predict_log.response.outputs[output_alias].CopyFrom(tensor_proto) + result.append(prediction_log) + return result def _post_process_classify( - output_alias_tensor_names: Mapping[Text, Text], - examples: List[tf.train.Example], outputs: Mapping[Text, np.ndarray] + output_alias_tensor_names: Mapping[str, str], + examples: List[tf.train.Example], + outputs: Mapping[str, np.ndarray], ) -> List[classification_pb2.Classifications]: - """Returns classifications from inference output.""" - - # This is to avoid error "The truth value of an array with - # more than one example is ambiguous." - has_classes = False - has_scores = False - if tf.saved_model.CLASSIFY_OUTPUT_CLASSES in output_alias_tensor_names: - classes = outputs[tf.saved_model.CLASSIFY_OUTPUT_CLASSES] - has_classes = True - if tf.saved_model.CLASSIFY_OUTPUT_SCORES in output_alias_tensor_names: - scores = outputs[tf.saved_model.CLASSIFY_OUTPUT_SCORES] - has_scores = True - if has_classes: - if classes.ndim != 2: - raise ValueError('Expected Tensor shape: [batch_size num_classes] but ' - 'got %s' % classes.shape) - if classes.dtype != tf.string.as_numpy_dtype: - raise ValueError('Expected classes Tensor of %s. Got: %s' % - (tf.string.as_numpy_dtype, classes.dtype)) - if classes.shape[0] != len(examples): - raise ValueError('Expected classes output batch size of %s, got %s' % - (len(examples), classes.shape[0])) - if has_scores: - if scores.ndim != 2: - raise ValueError("""Expected Tensor shape: [batch_size num_classes] but - got %s""" % scores.shape) - if scores.dtype != tf.float32.as_numpy_dtype: - raise ValueError('Expected classes Tensor of %s. Got: %s' % - (tf.float32.as_numpy_dtype, scores.dtype)) - if scores.shape[0] != len(examples): - raise ValueError('Expected classes output batch size of %s, got %s' % - (len(examples), scores.shape[0])) - num_classes = 0 - if has_classes and has_scores: - if scores.shape[1] != classes.shape[1]: - raise ValueError('Tensors class and score should match in shape[1]. ' - 'Got %s vs %s' % (classes.shape[1], scores.shape[1])) - num_classes = classes.shape[1] - elif has_classes: - num_classes = classes.shape[1] - elif has_scores: - num_classes = scores.shape[1] - - result = [] - for i in range(len(examples)): - classifications = classification_pb2.Classifications() - for c in range(num_classes): - klass = classifications.classes.add() - if has_classes: - klass.label = classes[i][c] - if has_scores: - klass.score = scores[i][c] - result.append(classifications) - return result + """Returns classifications from inference output.""" + # This is to avoid error "The truth value of an array with + # more than one example is ambiguous." + has_classes = False + has_scores = False + if tf.saved_model.CLASSIFY_OUTPUT_CLASSES in output_alias_tensor_names: + classes = outputs[tf.saved_model.CLASSIFY_OUTPUT_CLASSES] + has_classes = True + if tf.saved_model.CLASSIFY_OUTPUT_SCORES in output_alias_tensor_names: + scores = outputs[tf.saved_model.CLASSIFY_OUTPUT_SCORES] + has_scores = True + if has_classes: + if classes.ndim != 2: + raise ValueError( + "Expected Tensor shape: [batch_size num_classes] but " + "got %s" % classes.shape + ) + if classes.dtype != tf.string.as_numpy_dtype: + raise ValueError( + "Expected classes Tensor of %s. Got: %s" + % (tf.string.as_numpy_dtype, classes.dtype) + ) + if classes.shape[0] != len(examples): + raise ValueError( + "Expected classes output batch size of %s, got %s" + % (len(examples), classes.shape[0]) + ) + if has_scores: + if scores.ndim != 2: + raise ValueError( + """Expected Tensor shape: [batch_size num_classes] but + got %s""" + % scores.shape + ) + if scores.dtype != tf.float32.as_numpy_dtype: + raise ValueError( + "Expected classes Tensor of %s. Got: %s" + % (tf.float32.as_numpy_dtype, scores.dtype) + ) + if scores.shape[0] != len(examples): + raise ValueError( + "Expected classes output batch size of %s, got %s" + % (len(examples), scores.shape[0]) + ) + num_classes = 0 + if has_classes and has_scores: + if scores.shape[1] != classes.shape[1]: + raise ValueError( + "Tensors class and score should match in shape[1]. " + "Got %s vs %s" % (classes.shape[1], scores.shape[1]) + ) + num_classes = classes.shape[1] + elif has_classes: + num_classes = classes.shape[1] + elif has_scores: + num_classes = scores.shape[1] + + result = [] + for i in range(len(examples)): + classifications = classification_pb2.Classifications() + for c in range(num_classes): + klass = classifications.classes.add() + if has_classes: + klass.label = classes[i][c] + if has_scores: + klass.score = scores[i][c] + result.append(classifications) + return result def _post_process_regress( - examples: List[tf.train.Example], - outputs: Mapping[Text, np.ndarray]) -> List[regression_pb2.Regression]: - """Returns regressions from inference output.""" - - if tf.saved_model.REGRESS_OUTPUTS not in outputs: - raise ValueError('No regression outputs found in outputs: %s' % - outputs.keys()) - output = outputs[tf.saved_model.REGRESS_OUTPUTS] - batch_size = len(examples) - if not (output.ndim == 1 or (output.ndim == 2 and output.shape[1] == 1)): - raise ValueError("""Expected output Tensor shape to be either [batch_size] - or [batch_size, 1] but got %s""" % output.shape) - if batch_size != output.shape[0]: - raise ValueError( - 'Input batch size did not match output batch size: %s vs %s' % - (batch_size, output.shape[0])) - if output.dtype != tf.float32.as_numpy_dtype: - raise ValueError('Expected output Tensor of %s. Got: %s' % - (tf.float32.as_numpy_dtype, output.dtype)) - if output.size != batch_size: - raise ValueError('Expected output batch size to be %s. Got: %s' % - (batch_size, output.size)) - flatten_output = output.flatten() - result = [] - for value in flatten_output: - regression = regression_pb2.Regression() - regression.value = value - result.append(regression) - # Add additional check to save downstream consumer checks. - if len(result) != len(examples): - raise RuntimeError('Regression length does not match examples') - return result + examples: List[tf.train.Example], outputs: Mapping[str, np.ndarray] +) -> List[regression_pb2.Regression]: + """Returns regressions from inference output.""" + if tf.saved_model.REGRESS_OUTPUTS not in outputs: + raise ValueError("No regression outputs found in outputs: %s" % outputs.keys()) + output = outputs[tf.saved_model.REGRESS_OUTPUTS] + batch_size = len(examples) + if not (output.ndim == 1 or (output.ndim == 2 and output.shape[1] == 1)): + raise ValueError( + """Expected output Tensor shape to be either [batch_size] + or [batch_size, 1] but got %s""" + % output.shape + ) + if batch_size != output.shape[0]: + raise ValueError( + "Input batch size did not match output batch size: %s vs %s" + % (batch_size, output.shape[0]) + ) + if output.dtype != tf.float32.as_numpy_dtype: + raise ValueError( + "Expected output Tensor of %s. Got: %s" + % (tf.float32.as_numpy_dtype, output.dtype) + ) + if output.size != batch_size: + raise ValueError( + "Expected output batch size to be %s. Got: %s" % (batch_size, output.size) + ) + flatten_output = output.flatten() + result = [] + for value in flatten_output: + regression = regression_pb2.Regression() + regression.value = value + result.append(regression) + # Add additional check to save downstream consumer checks. + if len(result) != len(examples): + raise RuntimeError("Regression length does not match examples") + return result def _signature_pre_process(signature: _SignatureDef) -> _IOTensorSpec: - """Returns IOTensorSpec from signature.""" - - if len(signature.inputs) != 1: - raise ValueError('Signature should have 1 and only 1 inputs') - input_tensor_alias = list(signature.inputs.keys())[0] - if list(signature.inputs.values())[0].dtype != tf.string.as_datatype_enum: - raise ValueError( - 'Input dtype is expected to be %s, got %s' % tf.string.as_datatype_enum, - list(signature.inputs.values())[0].dtype) - if signature.method_name == tf.saved_model.CLASSIFY_METHOD_NAME: - input_tensor_name, output_alias_tensor_names = ( - _signature_pre_process_classify(signature)) - elif signature.method_name == tf.saved_model.REGRESS_METHOD_NAME: - input_tensor_name, output_alias_tensor_names = ( - _signature_pre_process_regress(signature)) - elif signature.method_name == tf.saved_model.PREDICT_METHOD_NAME: - input_tensor_name, output_alias_tensor_names = ( - _signature_pre_process_predict(signature)) - else: - raise ValueError('Signature method %s is not supported' % - signature.method_name) - return _IOTensorSpec(input_tensor_alias, input_tensor_name, - output_alias_tensor_names) + """Returns IOTensorSpec from signature.""" + if len(signature.inputs) != 1: + raise ValueError("Signature should have 1 and only 1 inputs") + input_tensor_alias = list(signature.inputs.keys())[0] + if list(signature.inputs.values())[0].dtype != tf.string.as_datatype_enum: + raise ValueError( + "Input dtype is expected to be %s, got %s" % tf.string.as_datatype_enum, + list(signature.inputs.values())[0].dtype, + ) + if signature.method_name == tf.saved_model.CLASSIFY_METHOD_NAME: + input_tensor_name, output_alias_tensor_names = _signature_pre_process_classify( + signature + ) + elif signature.method_name == tf.saved_model.REGRESS_METHOD_NAME: + input_tensor_name, output_alias_tensor_names = _signature_pre_process_regress( + signature + ) + elif signature.method_name == tf.saved_model.PREDICT_METHOD_NAME: + input_tensor_name, output_alias_tensor_names = _signature_pre_process_predict( + signature + ) + else: + raise ValueError("Signature method %s is not supported" % signature.method_name) + return _IOTensorSpec( + input_tensor_alias, input_tensor_name, output_alias_tensor_names + ) def _signature_pre_process_classify( - signature: _SignatureDef) -> Tuple[Text, Dict[Text, Text]]: - """Returns input tensor name and output alias tensor names from signature. - - Args: - signature: SignatureDef - - Returns: - A tuple of input tensor name and output alias tensor names. - """ - - if len(signature.outputs) != 1 and len(signature.outputs) != 2: - raise ValueError('Classify signature should have 1 or 2 outputs') - if tf.saved_model.CLASSIFY_INPUTS not in signature.inputs: - raise ValueError('No classification inputs found in SignatureDef: %s' % - signature.inputs) - input_tensor_name = signature.inputs[tf.saved_model.CLASSIFY_INPUTS].name - output_alias_tensor_names = {} - if (tf.saved_model.CLASSIFY_OUTPUT_CLASSES not in signature.outputs and - tf.saved_model.CLASSIFY_OUTPUT_SCORES not in signature.outputs): - raise ValueError( - """Expected classification signature outputs to contain at - least one of %s or %s. Signature was: %s""" % - tf.saved_model.CLASSIFY_OUTPUT_CLASSES, - tf.saved_model.CLASSIFY_OUTPUT_SCORES, signature) - if tf.saved_model.CLASSIFY_OUTPUT_CLASSES in signature.outputs: - output_alias_tensor_names[tf.saved_model.CLASSIFY_OUTPUT_CLASSES] = ( - signature.outputs[tf.saved_model.CLASSIFY_OUTPUT_CLASSES].name) - if tf.saved_model.CLASSIFY_OUTPUT_SCORES in signature.outputs: - output_alias_tensor_names[tf.saved_model.CLASSIFY_OUTPUT_SCORES] = ( - signature.outputs[tf.saved_model.CLASSIFY_OUTPUT_SCORES].name) - return input_tensor_name, output_alias_tensor_names + signature: _SignatureDef, +) -> Tuple[str, Dict[str, str]]: + """Returns input tensor name and output alias tensor names from signature. + + Args: + ---- + signature: SignatureDef + + Returns: + ------- + A tuple of input tensor name and output alias tensor names. + """ + if len(signature.outputs) != 1 and len(signature.outputs) != 2: + raise ValueError("Classify signature should have 1 or 2 outputs") + if tf.saved_model.CLASSIFY_INPUTS not in signature.inputs: + raise ValueError( + "No classification inputs found in SignatureDef: %s" % signature.inputs + ) + input_tensor_name = signature.inputs[tf.saved_model.CLASSIFY_INPUTS].name + output_alias_tensor_names = {} + if ( + tf.saved_model.CLASSIFY_OUTPUT_CLASSES not in signature.outputs + and tf.saved_model.CLASSIFY_OUTPUT_SCORES not in signature.outputs + ): + raise ValueError( + """Expected classification signature outputs to contain at + least one of %s or %s. Signature was: %s""" + % tf.saved_model.CLASSIFY_OUTPUT_CLASSES, + tf.saved_model.CLASSIFY_OUTPUT_SCORES, + signature, + ) + if tf.saved_model.CLASSIFY_OUTPUT_CLASSES in signature.outputs: + output_alias_tensor_names[tf.saved_model.CLASSIFY_OUTPUT_CLASSES] = ( + signature.outputs[tf.saved_model.CLASSIFY_OUTPUT_CLASSES].name + ) + if tf.saved_model.CLASSIFY_OUTPUT_SCORES in signature.outputs: + output_alias_tensor_names[tf.saved_model.CLASSIFY_OUTPUT_SCORES] = ( + signature.outputs[tf.saved_model.CLASSIFY_OUTPUT_SCORES].name + ) + return input_tensor_name, output_alias_tensor_names def _signature_pre_process_regress( - signature: _SignatureDef) -> Tuple[Text, Dict[Text, Text]]: - """Returns input tensor name and output alias tensor names from signature. - - Args: - signature: SignatureDef - - Returns: - A tuple of input tensor name and output alias tensor names. - """ - - if len(signature.outputs) != 1: - raise ValueError('Regress signature should have 1 output') - if tf.saved_model.REGRESS_INPUTS not in signature.inputs: - raise ValueError('No regression inputs found in SignatureDef: %s' % - signature.inputs) - input_tensor_name = signature.inputs[tf.saved_model.REGRESS_INPUTS].name - if tf.saved_model.REGRESS_OUTPUTS not in signature.outputs: - raise ValueError('No regression outputs found in SignatureDef: %s' % - signature.outputs) - output_alias_tensor_names = { - tf.saved_model.REGRESS_OUTPUTS: - signature.outputs[tf.saved_model.REGRESS_OUTPUTS].name - } - return input_tensor_name, output_alias_tensor_names + signature: _SignatureDef, +) -> Tuple[str, Dict[str, str]]: + """Returns input tensor name and output alias tensor names from signature. + + Args: + ---- + signature: SignatureDef + + Returns: + ------- + A tuple of input tensor name and output alias tensor names. + """ + if len(signature.outputs) != 1: + raise ValueError("Regress signature should have 1 output") + if tf.saved_model.REGRESS_INPUTS not in signature.inputs: + raise ValueError( + "No regression inputs found in SignatureDef: %s" % signature.inputs + ) + input_tensor_name = signature.inputs[tf.saved_model.REGRESS_INPUTS].name + if tf.saved_model.REGRESS_OUTPUTS not in signature.outputs: + raise ValueError( + "No regression outputs found in SignatureDef: %s" % signature.outputs + ) + output_alias_tensor_names = { + tf.saved_model.REGRESS_OUTPUTS: signature.outputs[ + tf.saved_model.REGRESS_OUTPUTS + ].name + } + return input_tensor_name, output_alias_tensor_names def _signature_pre_process_predict( - signature: _SignatureDef) -> Tuple[Text, Dict[Text, Text]]: - """Returns input tensor name and output alias tensor names from signature. - - Args: - signature: SignatureDef - - Returns: - A tuple of input tensor name and output alias tensor names. - """ - - input_tensor_name = list(signature.inputs.values())[0].name - output_alias_tensor_names = dict([ - (key, output.name) for key, output in signature.outputs.items() - ]) - return input_tensor_name, output_alias_tensor_names - - -def _get_signatures(model_path: Text, signatures: Sequence[Text], - tags: Sequence[Text]) -> Sequence[_Signature]: - """Returns a sequence of {model_signature_name: signature}.""" - - if signatures: - signature_names = signatures - else: - signature_names = [tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] - - saved_model_pb = loader_impl.parse_saved_model(model_path) - meta_graph_def = _get_meta_graph_def(saved_model_pb, tags) - result = [] - for signature_name in signature_names: - if signature_name in meta_graph_def.signature_def: - result.append( - _Signature(signature_name, - meta_graph_def.signature_def[signature_name])) - else: - raise RuntimeError('Signature %s could not be found in SavedModel' % - signature_name) - return result - - -def _get_operation_type( - inference_spec_type: model_spec_pb2.InferenceSpecType) -> Text: - if _using_in_process_inference(inference_spec_type): - signatures = _get_signatures( - inference_spec_type.saved_model_spec.model_path, - inference_spec_type.saved_model_spec.signature_name, - _get_tags(inference_spec_type)) - if not signatures: - raise ValueError('Model does not have valid signature to use') - - if len(signatures) == 1: - method_name = signatures[0].signature_def.method_name - if method_name == tf.saved_model.CLASSIFY_METHOD_NAME: - return _OperationType.CLASSIFICATION - elif method_name == tf.saved_model.REGRESS_METHOD_NAME: - return _OperationType.REGRESSION - elif method_name == tf.saved_model.PREDICT_METHOD_NAME: - return _OperationType.PREDICTION - else: - raise ValueError('Unsupported signature method_name %s' % method_name) + signature: _SignatureDef, +) -> Tuple[str, Dict[str, str]]: + """Returns input tensor name and output alias tensor names from signature. + + Args: + ---- + signature: SignatureDef + + Returns: + ------- + A tuple of input tensor name and output alias tensor names. + """ + input_tensor_name = list(signature.inputs.values())[0].name + output_alias_tensor_names = dict( + [(key, output.name) for key, output in signature.outputs.items()] + ) + return input_tensor_name, output_alias_tensor_names + + +def _get_signatures( + model_path: str, signatures: Sequence[str], tags: Sequence[str] +) -> Sequence[_Signature]: + """Returns a sequence of {model_signature_name: signature}.""" + if signatures: + signature_names = signatures else: - for signature in signatures: - method_name = signature.signature_def.method_name - if (method_name != tf.saved_model.CLASSIFY_METHOD_NAME and - method_name != tf.saved_model.REGRESS_METHOD_NAME): - raise ValueError('Unsupported signature method_name for multi-head ' - 'model inference: %s' % method_name) - return _OperationType.MULTI_INFERENCE - else: - # Remote inference supports predictions only. - return _OperationType.PREDICTION + signature_names = [tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] + + saved_model_pb = loader_impl.parse_saved_model(model_path) + meta_graph_def = _get_meta_graph_def(saved_model_pb, tags) + result = [] + for signature_name in signature_names: + if signature_name in meta_graph_def.signature_def: + result.append( + _Signature(signature_name, meta_graph_def.signature_def[signature_name]) + ) + else: + raise RuntimeError( + "Signature %s could not be found in SavedModel" % signature_name + ) + return result -def _get_meta_graph_def(saved_model_pb: _SavedModel, - tags: Sequence[Text]) -> _MetaGraphDef: - """Returns MetaGraphDef from SavedModel.""" +def _get_operation_type(inference_spec_type: model_spec_pb2.InferenceSpecType) -> str: + if _using_in_process_inference(inference_spec_type): + signatures = _get_signatures( + inference_spec_type.saved_model_spec.model_path, + inference_spec_type.saved_model_spec.signature_name, + _get_tags(inference_spec_type), + ) + if not signatures: + raise ValueError("Model does not have valid signature to use") + + if len(signatures) == 1: + method_name = signatures[0].signature_def.method_name + if method_name == tf.saved_model.CLASSIFY_METHOD_NAME: + return _OperationType.CLASSIFICATION + elif method_name == tf.saved_model.REGRESS_METHOD_NAME: + return _OperationType.REGRESSION + elif method_name == tf.saved_model.PREDICT_METHOD_NAME: + return _OperationType.PREDICTION + else: + raise ValueError("Unsupported signature method_name %s" % method_name) + else: + for signature in signatures: + method_name = signature.signature_def.method_name + if ( + method_name != tf.saved_model.CLASSIFY_METHOD_NAME + and method_name != tf.saved_model.REGRESS_METHOD_NAME + ): + raise ValueError( + "Unsupported signature method_name for multi-head " + "model inference: %s" % method_name + ) + return _OperationType.MULTI_INFERENCE + else: + # Remote inference supports predictions only. + return _OperationType.PREDICTION - for meta_graph_def in saved_model_pb.meta_graphs: - if set(meta_graph_def.meta_info_def.tags) == set(tags): - return meta_graph_def - raise RuntimeError('MetaGraphDef associated with tags %s could not be ' - 'found in SavedModel' % tags) +def _get_meta_graph_def( + saved_model_pb: _SavedModel, tags: Sequence[str] +) -> _MetaGraphDef: + """Returns MetaGraphDef from SavedModel.""" + for meta_graph_def in saved_model_pb.meta_graphs: + if set(meta_graph_def.meta_info_def.tags) == set(tags): + return meta_graph_def + raise RuntimeError( + "MetaGraphDef associated with tags %s could not be " + "found in SavedModel" % tags + ) -def _get_tags( - inference_spec_type: model_spec_pb2.InferenceSpecType) -> Sequence[Text]: - """Returns tags from ModelSpec.""" - if inference_spec_type.saved_model_spec.tag: - return list(inference_spec_type.saved_model_spec.tag) - else: - return [tf.saved_model.SERVING] +def _get_tags(inference_spec_type: model_spec_pb2.InferenceSpecType) -> Sequence[str]: + """Returns tags from ModelSpec.""" + if inference_spec_type.saved_model_spec.tag: + return list(inference_spec_type.saved_model_spec.tag) + else: + return [tf.saved_model.SERVING] -_T = TypeVar('_T') +_T = TypeVar("_T") def _flatten_examples( - maybe_nested_examples: List[Union[_T, List[_T]]] + maybe_nested_examples: List[Union[_T, List[_T]]], ) -> Tuple[List[_T], Optional[List[int]], Optional[int]]: - """Flattens nested examples, and returns corresponding nested list indices.""" - if (not maybe_nested_examples or - not isinstance(maybe_nested_examples[0], list)): - return maybe_nested_examples, None, None - idx = [] - flattened = [] - for i in range(len(maybe_nested_examples)): - for ex in maybe_nested_examples[i]: - idx.append(i) - flattened.append(ex) - return flattened, idx, len(maybe_nested_examples) - - -def _nest_results(flat_results: Iterable[_T], idx: Optional[List[int]], - max_idx: Optional[int]) -> List[Union[_T, List[_T]]]: - """Reverses operation of _flatten_examples if indices are provided.""" - if idx is None: - return list(flat_results) - nested_results = [] - for _ in range(max_idx): - nested_results.append([]) - for result, i in zip(flat_results, idx): - nested_results[i].append(result) - return nested_results + """Flattens nested examples, and returns corresponding nested list indices.""" + if not maybe_nested_examples or not isinstance(maybe_nested_examples[0], list): + return maybe_nested_examples, None, None + idx = [] + flattened = [] + for i in range(len(maybe_nested_examples)): + for ex in maybe_nested_examples[i]: + idx.append(i) + flattened.append(ex) + return flattened, idx, len(maybe_nested_examples) + + +def _nest_results( + flat_results: Iterable[_T], idx: Optional[List[int]], max_idx: Optional[int] +) -> List[Union[_T, List[_T]]]: + """Reverses operation of _flatten_examples if indices are provided.""" + if idx is None: + return list(flat_results) + nested_results = [] + for _ in range(max_idx): + nested_results.append([]) + for result, i in zip(flat_results, idx): + nested_results[i].append(result) + return nested_results # TODO(b/231328769): Overload batch args when available. class _ModelHandlerWrapper(base.ModelHandler): - """Wrapper that handles key forwarding and pre-batching of inputs. + """Wrapper that handles key forwarding and pre-batching of inputs. - This wrapper accepts mapping ExampleType -> PredictType, - and itself maps either + This wrapper accepts mapping ExampleType -> PredictType, + and itself maps either - * ExampleType -> PredictType + * ExampleType -> PredictType - * Tuple[K, ExampleType] -> Tuple[K, PredictType] + * Tuple[K, ExampleType] -> Tuple[K, PredictType] - * Tuple[K, List[ExampleType]] -> Tuple[K, List[PredictType]] + * Tuple[K, List[ExampleType]] -> Tuple[K, List[PredictType]] - The second mode can support forwarding metadata with a one-to-one relationship - to examples, while the third supports forwarding metadata with a many-to-one - relationship. + The second mode can support forwarding metadata with a one-to-one relationship + to examples, while the third supports forwarding metadata with a many-to-one + relationship. - Note that ExampleType can not be a Tuple or a List. - """ + Note that ExampleType can not be a Tuple or a List. + """ - def __init__(self, model_handler: base.ModelHandler): - super().__init__() - self._model_handler = model_handler + def __init__(self, model_handler: base.ModelHandler): + super().__init__() + self._model_handler = model_handler - def load_model(self) -> Any: - return self._model_handler.load_model() + def load_model(self) -> Any: + return self._model_handler.load_model() - def run_inference(self, - batch: Sequence[Any], - model: Any, - inference_args=None) -> Sequence[Any]: - if not batch: - return [] - if isinstance(batch[0], tuple): - keys, examples = zip(*batch) - else: - keys, examples = None, batch - examples, nested_batch_idx, max_idx = _flatten_examples(examples) - predictions = self._model_handler.run_inference(examples, model) - predictions = _nest_results(predictions, nested_batch_idx, max_idx) - if keys: - return list(zip(keys, predictions)) - return predictions - - def get_num_bytes(self, batch: Any) -> int: - if isinstance(batch[0], tuple): - _, batch = zip(*batch) - batch, _, _ = _flatten_examples(batch) - return self._model_handler.get_num_bytes(batch) - - def get_metrics_namespace(self) -> str: - return self._model_handler.get_metrics_namespace() + def run_inference( + self, batch: Sequence[Any], model: Any, inference_args=None + ) -> Sequence[Any]: + if not batch: + return [] + if isinstance(batch[0], tuple): + keys, examples = zip(*batch) + else: + keys, examples = None, batch + examples, nested_batch_idx, max_idx = _flatten_examples(examples) + predictions = self._model_handler.run_inference(examples, model) + predictions = _nest_results(predictions, nested_batch_idx, max_idx) + if keys: + return list(zip(keys, predictions)) + return predictions + + def get_num_bytes(self, batch: Any) -> int: + if isinstance(batch[0], tuple): + _, batch = zip(*batch) + batch, _, _ = _flatten_examples(batch) + return self._model_handler.get_num_bytes(batch) + + def get_metrics_namespace(self) -> str: + return self._model_handler.get_metrics_namespace() diff --git a/tfx_bsl/beam/test_helpers.py b/tfx_bsl/beam/test_helpers.py index 6e3f7c71..ab258206 100644 --- a/tfx_bsl/beam/test_helpers.py +++ b/tfx_bsl/beam/test_helpers.py @@ -1,4 +1,3 @@ -# coding=utf-8 # # Copyright 2022 Google Inc. All Rights Reserved. # @@ -17,6 +16,6 @@ def make_test_beam_pipeline_kwargs(): - # This is kwargs for apache_beam.Pipeline's __init__, using the default runner - # here. - return {} + # This is kwargs for apache_beam.Pipeline's __init__, using the default runner + # here. + return {} diff --git a/tfx_bsl/coders/batch_util.py b/tfx_bsl/coders/batch_util.py index 08a07582..cfbcf5bb 100644 --- a/tfx_bsl/coders/batch_util.py +++ b/tfx_bsl/coders/batch_util.py @@ -16,9 +16,10 @@ import inspect import math from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar -from absl import flags import apache_beam as beam +from absl import flags + from tfx_bsl.telemetry import util as telemetry_util # Beam might grow the batch size too large for Arrow BinaryArray / ListArray @@ -53,49 +54,49 @@ def _UseByteSizeBatching() -> bool: - """Cautious access to `tfxio_use_byte_size_batching` flag value.""" - return ( - _USE_BYTE_SIZE_BATCHING.value - if flags.FLAGS.is_parsed() - else _USE_BYTE_SIZE_BATCHING.default - ) + """Cautious access to `tfxio_use_byte_size_batching` flag value.""" + return ( + _USE_BYTE_SIZE_BATCHING.value + if flags.FLAGS.is_parsed() + else _USE_BYTE_SIZE_BATCHING.default + ) def GetBatchElementsKwargs( batch_size: Optional[int], element_size_fn: Callable[[Any], int] = len ) -> Dict[str, Any]: - """Returns the kwargs to pass to beam.BatchElements().""" - if batch_size is not None: - return { - "min_batch_size": batch_size, - "max_batch_size": batch_size, + """Returns the kwargs to pass to beam.BatchElements().""" + if batch_size is not None: + return { + "min_batch_size": batch_size, + "max_batch_size": batch_size, + } + if _UseByteSizeBatching(): + min_element_size = int( + math.ceil(_TARGET_BATCH_BYTES_SIZE / _BATCH_SIZE_CAP_WITH_BYTE_TARGET) + ) + return { + "min_batch_size": _TARGET_BATCH_BYTES_SIZE, + "max_batch_size": _TARGET_BATCH_BYTES_SIZE, + "element_size_fn": lambda e: max(element_size_fn(e), min_element_size), + } + # Allow `BatchElements` to tune the values with the given parameters. + # We fix the tuning parameters here to prevent Beam changes from immediately + # affecting all dependencies. + result = { + "min_batch_size": 1, + "max_batch_size": _BATCH_SIZE_CAP, + "target_batch_overhead": 0.05, + "target_batch_duration_secs": 1, + "variance": 0.25, } - if _UseByteSizeBatching(): - min_element_size = int( - math.ceil(_TARGET_BATCH_BYTES_SIZE / _BATCH_SIZE_CAP_WITH_BYTE_TARGET) - ) - return { - "min_batch_size": _TARGET_BATCH_BYTES_SIZE, - "max_batch_size": _TARGET_BATCH_BYTES_SIZE, - "element_size_fn": lambda e: max(element_size_fn(e), min_element_size), - } - # Allow `BatchElements` to tune the values with the given parameters. - # We fix the tuning parameters here to prevent Beam changes from immediately - # affecting all dependencies. - result = { - "min_batch_size": 1, - "max_batch_size": _BATCH_SIZE_CAP, - "target_batch_overhead": 0.05, - "target_batch_duration_secs": 1, - "variance": 0.25, - } - batch_elements_signature = inspect.signature(beam.BatchElements) - if ( - "target_batch_duration_secs_including_fixed_cost" - in batch_elements_signature.parameters - ): - result["target_batch_duration_secs_including_fixed_cost"] = 1 - return result + batch_elements_signature = inspect.signature(beam.BatchElements) + if ( + "target_batch_duration_secs_including_fixed_cost" + in batch_elements_signature.parameters + ): + result["target_batch_duration_secs_including_fixed_cost"] = 1 + return result def _MakeAndIncrementBatchingMetrics( @@ -103,16 +104,12 @@ def _MakeAndIncrementBatchingMetrics( batch_size: Optional[int], telemetry_descriptors: Optional[Sequence[str]], ) -> None: - """Increments metrics relevant to batching.""" - namespace = telemetry_util.MakeTfxNamespace( - telemetry_descriptors or ["Unknown"] - ) - beam.metrics.Metrics.counter(namespace, "tfxio_use_byte_size_batching").inc( - int(_UseByteSizeBatching()) - ) - beam.metrics.Metrics.counter(namespace, "desired_batch_size").inc( - batch_size or 0 - ) + """Increments metrics relevant to batching.""" + namespace = telemetry_util.MakeTfxNamespace(telemetry_descriptors or ["Unknown"]) + beam.metrics.Metrics.counter(namespace, "tfxio_use_byte_size_batching").inc( + int(_UseByteSizeBatching()) + ) + beam.metrics.Metrics.counter(namespace, "desired_batch_size").inc(batch_size or 0) T = TypeVar("T") @@ -127,30 +124,32 @@ def BatchRecords( telemetry_descriptors: Optional[Sequence[str]], record_size_fn: Callable[[T], int] = len, ) -> beam.PCollection: - """Batches collection of records tuning the batch size if not provided. - - Args: - records: A PCollection of records to batch. - batch_size: Desired batch size. If None, will be tuned for optimal - performance. - telemetry_descriptors: Descriptors to use for batching metrics. - record_size_fn: Function used to determine size of each record in bytes. - Only used if byte size-based batching is enabled. Defaults to `len` - function suitable for bytes records. - - Returns: - A PCollection of batched records. - """ - _ = ( - records.pipeline - | "CreateSole" >> beam.Create([None]) - | "IncrementMetrics" - >> beam.Map( - _MakeAndIncrementBatchingMetrics, - batch_size=batch_size, - telemetry_descriptors=telemetry_descriptors, - ) - ) - return records | "BatchElements" >> beam.BatchElements( - **GetBatchElementsKwargs(batch_size, record_size_fn) - ) + """Batches collection of records tuning the batch size if not provided. + + Args: + ---- + records: A PCollection of records to batch. + batch_size: Desired batch size. If None, will be tuned for optimal + performance. + telemetry_descriptors: Descriptors to use for batching metrics. + record_size_fn: Function used to determine size of each record in bytes. + Only used if byte size-based batching is enabled. Defaults to `len` + function suitable for bytes records. + + Returns: + ------- + A PCollection of batched records. + """ + _ = ( + records.pipeline + | "CreateSole" >> beam.Create([None]) + | "IncrementMetrics" + >> beam.Map( + _MakeAndIncrementBatchingMetrics, + batch_size=batch_size, + telemetry_descriptors=telemetry_descriptors, + ) + ) + return records | "BatchElements" >> beam.BatchElements( + **GetBatchElementsKwargs(batch_size, record_size_fn) + ) diff --git a/tfx_bsl/coders/batch_util_test.py b/tfx_bsl/coders/batch_util_test.py index c186be45..ee9f6f1b 100644 --- a/tfx_bsl/coders/batch_util_test.py +++ b/tfx_bsl/coders/batch_util_test.py @@ -13,16 +13,12 @@ # limitations under the License. """Tests for tfx_bsl.coders.batch_util.""" -import pytest - -from absl.testing import flagsaver - import apache_beam as beam +import pytest +from absl.testing import absltest, flagsaver, parameterized from apache_beam.testing import util as beam_testing_util from tfx_bsl.coders import batch_util -from absl.testing import absltest -from absl.testing import parameterized _BATCH_RECORDS_TEST_CASES = ( dict( @@ -82,115 +78,112 @@ class BatchUtilTest(parameterized.TestCase): - - @parameterized.named_parameters(*_BATCH_RECORDS_TEST_CASES) - def testGetBatchElementsKwargs( - self, - batch_size, - tfxio_use_byte_size_batching, - expected_kwargs, - element_size_fn=len, - expected_element_contributions=None, - ): - - if self._testMethodName in [ - "testGetBatchElementsKwargsbyte_size_batching", - "testGetBatchElementsKwargsbyte_size_batching_with_element_size_fn", - ]: - pytest.xfail(reason="Test fails and needs to be fixed. ") - - with flagsaver.flagsaver( - tfxio_use_byte_size_batching=tfxio_use_byte_size_batching + @parameterized.named_parameters(*_BATCH_RECORDS_TEST_CASES) + def testGetBatchElementsKwargs( + self, + batch_size, + tfxio_use_byte_size_batching, + expected_kwargs, + element_size_fn=len, + expected_element_contributions=None, ): - kwargs = batch_util.GetBatchElementsKwargs( - batch_size, element_size_fn=element_size_fn - ) - # This parameter may not be present in some Beam versions that we support. - target_batch_duration_secs_including_fixed_cost = kwargs.pop( - "target_batch_duration_secs_including_fixed_cost", None - ) - self.assertIn(target_batch_duration_secs_including_fixed_cost, {1, None}) - if expected_kwargs.pop("element_size_fn", None) is not None: - self.assertIn("element_size_fn", kwargs) - element_size_fn = kwargs.pop("element_size_fn") - for ( - element, - expected_contribution, - ) in expected_element_contributions.items(): - self.assertEqual( - element_size_fn(element), - expected_contribution, - msg=f"Unexpected contribution of element {element}", - ) - self.assertDictEqual(kwargs, expected_kwargs) - - @parameterized.named_parameters(*_BATCH_RECORDS_TEST_CASES) - def testBatchRecords( - self, - batch_size, - tfxio_use_byte_size_batching, - expected_kwargs, - element_size_fn=len, - expected_element_contributions=None, - ): - - if self._testMethodName in [ - "testBatchRecordsbatch_size_none", - "testBatchRecordsbyte_size_batching", - "testBatchRecordsbyte_size_batching_with_element_size_fn", - "testBatchRecordsfixed_batch_size", - "testBatchRecordsfixed_batch_size_byte_size_batching", - ]: - pytest.xfail(reason="PR 260 81 test fails and needs to be fixed. ") - - del expected_kwargs - telemetry_descriptors = ["TestComponent"] - input_records = ( - [b"asd", b"asds", b"123", b"gdgd" * 1000] - if expected_element_contributions is None - else expected_element_contributions.keys() - ) - - def AssertFn(batched_records): - # We can't validate the actual sizes since they depend on test - # environment. - self.assertNotEmpty(batched_records) - for batch in batched_records: - self.assertIsInstance(batch, list) - self.assertNotEmpty(batch) - - with flagsaver.flagsaver( - tfxio_use_byte_size_batching=tfxio_use_byte_size_batching + if self._testMethodName in [ + "testGetBatchElementsKwargsbyte_size_batching", + "testGetBatchElementsKwargsbyte_size_batching_with_element_size_fn", + ]: + pytest.xfail(reason="Test fails and needs to be fixed. ") + + with flagsaver.flagsaver( + tfxio_use_byte_size_batching=tfxio_use_byte_size_batching + ): + kwargs = batch_util.GetBatchElementsKwargs( + batch_size, element_size_fn=element_size_fn + ) + # This parameter may not be present in some Beam versions that we support. + target_batch_duration_secs_including_fixed_cost = kwargs.pop( + "target_batch_duration_secs_including_fixed_cost", None + ) + self.assertIn(target_batch_duration_secs_including_fixed_cost, {1, None}) + if expected_kwargs.pop("element_size_fn", None) is not None: + self.assertIn("element_size_fn", kwargs) + element_size_fn = kwargs.pop("element_size_fn") + for ( + element, + expected_contribution, + ) in expected_element_contributions.items(): + self.assertEqual( + element_size_fn(element), + expected_contribution, + msg=f"Unexpected contribution of element {element}", + ) + self.assertDictEqual(kwargs, expected_kwargs) + + @parameterized.named_parameters(*_BATCH_RECORDS_TEST_CASES) + def testBatchRecords( + self, + batch_size, + tfxio_use_byte_size_batching, + expected_kwargs, + element_size_fn=len, + expected_element_contributions=None, ): - p = beam.Pipeline() - batched_records_pcoll = ( - p - | beam.Create(input_records) - | batch_util.BatchRecords( - batch_size, telemetry_descriptors, record_size_fn=element_size_fn - ) - ) - beam_testing_util.assert_that(batched_records_pcoll, AssertFn) - pipeline_result = p.run() - pipeline_result.wait_until_finish() - all_metrics = pipeline_result.metrics() - maintained_metrics = all_metrics.query( - beam.metrics.metric.MetricsFilter().with_namespace( - "tfx." + ".".join(telemetry_descriptors) - ) - ) - self.assertIsNotNone(maintained_metrics) - counters = maintained_metrics[beam.metrics.metric.MetricResults.COUNTERS] - self.assertLen(counters, 2) - expected_counters = { - "tfxio_use_byte_size_batching": int(tfxio_use_byte_size_batching), - "desired_batch_size": batch_size or 0, - } - for counter in counters: - self.assertEqual( - counter.result, expected_counters[counter.key.metric.name] + if self._testMethodName in [ + "testBatchRecordsbatch_size_none", + "testBatchRecordsbyte_size_batching", + "testBatchRecordsbyte_size_batching_with_element_size_fn", + "testBatchRecordsfixed_batch_size", + "testBatchRecordsfixed_batch_size_byte_size_batching", + ]: + pytest.xfail(reason="PR 260 81 test fails and needs to be fixed. ") + + del expected_kwargs + telemetry_descriptors = ["TestComponent"] + input_records = ( + [b"asd", b"asds", b"123", b"gdgd" * 1000] + if expected_element_contributions is None + else expected_element_contributions.keys() ) + def AssertFn(batched_records): + # We can't validate the actual sizes since they depend on test + # environment. + self.assertNotEmpty(batched_records) + for batch in batched_records: + self.assertIsInstance(batch, list) + self.assertNotEmpty(batch) + + with flagsaver.flagsaver( + tfxio_use_byte_size_batching=tfxio_use_byte_size_batching + ): + p = beam.Pipeline() + batched_records_pcoll = ( + p + | beam.Create(input_records) + | batch_util.BatchRecords( + batch_size, telemetry_descriptors, record_size_fn=element_size_fn + ) + ) + beam_testing_util.assert_that(batched_records_pcoll, AssertFn) + pipeline_result = p.run() + pipeline_result.wait_until_finish() + all_metrics = pipeline_result.metrics() + maintained_metrics = all_metrics.query( + beam.metrics.metric.MetricsFilter().with_namespace( + "tfx." + ".".join(telemetry_descriptors) + ) + ) + self.assertIsNotNone(maintained_metrics) + counters = maintained_metrics[beam.metrics.metric.MetricResults.COUNTERS] + self.assertLen(counters, 2) + expected_counters = { + "tfxio_use_byte_size_batching": int(tfxio_use_byte_size_batching), + "desired_batch_size": batch_size or 0, + } + for counter in counters: + self.assertEqual( + counter.result, expected_counters[counter.key.metric.name] + ) + if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/tfx_bsl/coders/csv_decoder.py b/tfx_bsl/coders/csv_decoder.py index efd6becf..d6fc089d 100644 --- a/tfx_bsl/coders/csv_decoder.py +++ b/tfx_bsl/coders/csv_decoder.py @@ -16,375 +16,422 @@ import csv import enum -from typing import Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Set, Text, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + NamedTuple, + Optional, + Set, + Text, + Tuple, + Union, +) import apache_beam as beam import numpy as np import pyarrow as pa -from tfx_bsl.coders import batch_util +from tensorflow_metadata.proto.v0 import schema_pb2, statistics_pb2 -from tensorflow_metadata.proto.v0 import schema_pb2 -from tensorflow_metadata.proto.v0 import statistics_pb2 +from tfx_bsl.coders import batch_util PARSE_CSV_LINE_YIELDS_RAW_RECORDS = True CSVCell = bytes -CSVLine = Text -ColumnName = Text +CSVLine = str +ColumnName = str class ColumnType(enum.IntEnum): - """Enum for the type of a CSV column.""" - # column will not be in the result RecordBatch - IGNORE = -2 - # column will be in the result RecordBatch but will be of Null type (which - # means this column contains only empty value). - UNKNOWN = -1 - INT = statistics_pb2.FeatureNameStatistics.INT - FLOAT = statistics_pb2.FeatureNameStatistics.FLOAT - STRING = statistics_pb2.FeatureNameStatistics.STRING - - # We need the following to hold for type inference to work. - assert UNKNOWN < INT - assert INT < FLOAT - assert FLOAT < STRING - - -ColumnInfo = NamedTuple("ColumnInfo", [ - ("name", ColumnName), - ("type", ColumnType), -]) + """Enum for the type of a CSV column.""" + + # column will not be in the result RecordBatch + IGNORE = -2 + # column will be in the result RecordBatch but will be of Null type (which + # means this column contains only empty value). + UNKNOWN = -1 + INT = statistics_pb2.FeatureNameStatistics.INT + FLOAT = statistics_pb2.FeatureNameStatistics.FLOAT + STRING = statistics_pb2.FeatureNameStatistics.STRING + + # We need the following to hold for type inference to work. + assert UNKNOWN < INT + assert INT < FLOAT + assert FLOAT < STRING + + +class ColumnInfo(NamedTuple): + name: ColumnName + type: ColumnType + _SCHEMA_TYPE_TO_COLUMN_TYPE = { schema_pb2.INT: ColumnType.INT, schema_pb2.FLOAT: ColumnType.FLOAT, - schema_pb2.BYTES: ColumnType.STRING + schema_pb2.BYTES: ColumnType.STRING, } _FEATURE_TYPE_TO_ARROW_TYPE = { - ColumnType.UNKNOWN: pa.null(), - ColumnType.INT: pa.large_list(pa.int64()), - ColumnType.FLOAT: pa.large_list(pa.float32()), - ColumnType.STRING: pa.large_list(pa.large_binary()) + ColumnType.UNKNOWN: pa.null(), + ColumnType.INT: pa.large_list(pa.int64()), + ColumnType.FLOAT: pa.large_list(pa.float32()), + ColumnType.STRING: pa.large_list(pa.large_binary()), } @beam.ptransform_fn @beam.typehints.with_input_types(CSVLine) @beam.typehints.with_output_types(pa.RecordBatch) -def CSVToRecordBatch(lines: beam.pvalue.PCollection, - column_names: List[Text], - desired_batch_size: Optional[int], - delimiter: Text = ",", - skip_blank_lines: bool = True, - schema: Optional[schema_pb2.Schema] = None, - multivalent_columns: Optional[List[Text]] = None, - secondary_delimiter: Optional[Text] = None, - raw_record_column_name: Optional[Text] = None): - """Decodes CSV records into Arrow RecordBatches. - - Args: - lines: The pcollection of raw records (csv lines). - column_names: List of feature names. Order must match the order in the CSV - file. - desired_batch_size: Batch size. The output Arrow RecordBatches will have as - many rows as the `desired_batch_size`. If None, the batch size is auto - tuned by beam. - delimiter: A one-character string used to separate fields. - skip_blank_lines: A boolean to indicate whether to skip over blank lines - rather than interpreting them as missing values. - schema: An optional schema of the input data. If this is provided, it must - contain a subset of columns in `column_names`. If a feature is in - `column_names` but not in the schema, it won't be in the result - RecordBatch. - multivalent_columns: Columns that can contain multiple values. If - secondary_delimiter is provided, this must also be provided. - secondary_delimiter: Delimiter used for parsing multivalent columns. If - multivalent_columns is provided, this must also be provided. - raw_record_column_name: Optional name for a column containing the raw csv - lines. If this is None, then this column will not be produced. This will - always be the last column in the record batch. - - Returns: - RecordBatches of the CSV lines. - - Raises: - ValueError: - * If the columns do not match the specified csv headers. - * If the schema has invalid feature types. - * If the schema does not contain all columns. - * If raw_record_column_name exists in column_names - """ - if (raw_record_column_name is not None and - raw_record_column_name in column_names): - raise ValueError( - "raw_record_column_name: {} is already an existing column name. " - "Please choose a different name.".format(raw_record_column_name)) - - csv_lines_and_raw_records = ( - lines | "ParseCSVLines" >> beam.ParDo(ParseCSVLine(delimiter))) - - if schema is not None: - column_infos = _GetColumnInfosFromSchema(schema, column_names) - else: - # TODO(b/72746442): Consider using a DeepCopy optimization similar to TFT. - # Do first pass to infer the feature types. - column_infos = beam.pvalue.AsSingleton( +def CSVToRecordBatch( + lines: beam.pvalue.PCollection, + column_names: List[str], + desired_batch_size: Optional[int], + delimiter: str = ",", + skip_blank_lines: bool = True, + schema: Optional[schema_pb2.Schema] = None, + multivalent_columns: Optional[List[str]] = None, + secondary_delimiter: Optional[str] = None, + raw_record_column_name: Optional[str] = None, +): + """Decodes CSV records into Arrow RecordBatches. + + Args: + ---- + lines: The pcollection of raw records (csv lines). + column_names: List of feature names. Order must match the order in the CSV + file. + desired_batch_size: Batch size. The output Arrow RecordBatches will have as + many rows as the `desired_batch_size`. If None, the batch size is auto + tuned by beam. + delimiter: A one-character string used to separate fields. + skip_blank_lines: A boolean to indicate whether to skip over blank lines + rather than interpreting them as missing values. + schema: An optional schema of the input data. If this is provided, it must + contain a subset of columns in `column_names`. If a feature is in + `column_names` but not in the schema, it won't be in the result + RecordBatch. + multivalent_columns: Columns that can contain multiple values. If + secondary_delimiter is provided, this must also be provided. + secondary_delimiter: Delimiter used for parsing multivalent columns. If + multivalent_columns is provided, this must also be provided. + raw_record_column_name: Optional name for a column containing the raw csv + lines. If this is None, then this column will not be produced. This will + always be the last column in the record batch. + + Returns: + ------- + RecordBatches of the CSV lines. + + Raises: + ------ + ValueError: + * If the columns do not match the specified csv headers. + * If the schema has invalid feature types. + * If the schema does not contain all columns. + * If raw_record_column_name exists in column_names + """ + if raw_record_column_name is not None and raw_record_column_name in column_names: + raise ValueError( + f"raw_record_column_name: {raw_record_column_name} is already an existing column name. " + "Please choose a different name." + ) + + csv_lines_and_raw_records = lines | "ParseCSVLines" >> beam.ParDo( + ParseCSVLine(delimiter) + ) + + if schema is not None: + column_infos = _GetColumnInfosFromSchema(schema, column_names) + else: + # TODO(b/72746442): Consider using a DeepCopy optimization similar to TFT. + # Do first pass to infer the feature types. + column_infos = beam.pvalue.AsSingleton( + csv_lines_and_raw_records + | "ExtractParsedCSVLines" >> beam.Keys() + | "InferColumnTypes" + >> beam.CombineGlobally( + ColumnTypeInferrer( + column_names=column_names, + skip_blank_lines=skip_blank_lines, + multivalent_columns=multivalent_columns, + secondary_delimiter=secondary_delimiter, + ) + ) + ) + + # Do second pass to generate the RecordBatches. + return ( csv_lines_and_raw_records - | "ExtractParsedCSVLines" >> beam.Keys() - | "InferColumnTypes" >> beam.CombineGlobally( - ColumnTypeInferrer( - column_names=column_names, + | "BatchCSVLines" + >> batch_util.BatchRecords( + desired_batch_size, + telemetry_descriptors=["CSVToRecordBatch"], + # The elements are tuples of parsed and unparsed CSVlines. + record_size_fn=lambda kv: len(kv[1]) << 1, + ) + | "BatchedCSVRowsToArrow" + >> beam.ParDo( + BatchedCSVRowsToRecordBatch( skip_blank_lines=skip_blank_lines, multivalent_columns=multivalent_columns, - secondary_delimiter=secondary_delimiter))) - - # Do second pass to generate the RecordBatches. - return ( - csv_lines_and_raw_records - | "BatchCSVLines" - >> batch_util.BatchRecords( - desired_batch_size, - telemetry_descriptors=["CSVToRecordBatch"], - # The elements are tuples of parsed and unparsed CSVlines. - record_size_fn=lambda kv: len(kv[1]) << 1, - ) - | "BatchedCSVRowsToArrow" - >> beam.ParDo( - BatchedCSVRowsToRecordBatch( - skip_blank_lines=skip_blank_lines, - multivalent_columns=multivalent_columns, - secondary_delimiter=secondary_delimiter, - raw_record_column_name=raw_record_column_name, - ), - column_infos, - ) - ) + secondary_delimiter=secondary_delimiter, + raw_record_column_name=raw_record_column_name, + ), + column_infos, + ) + ) @beam.typehints.with_input_types(CSVLine) @beam.typehints.with_output_types(Tuple[List[CSVCell], CSVLine]) class ParseCSVLine(beam.DoFn): - """A beam.DoFn to parse CSVLines into Tuple(List[CSVCell], CSVLine). + """A beam.DoFn to parse CSVLines into Tuple(List[CSVCell], CSVLine). - The CSVLine is the raw csv row. The raw csv row will always be output. - """ + The CSVLine is the raw csv row. The raw csv row will always be output. + """ - def __init__(self, delimiter: Text): - self._delimiter = delimiter - self._reader = None + def __init__(self, delimiter: str): + self._delimiter = delimiter + self._reader = None - def setup(self): - self._reader = _CSVRecordReader(self._delimiter) + def setup(self): + self._reader = _CSVRecordReader(self._delimiter) - def process(self, - csv_line: CSVLine) -> Iterable[Tuple[List[CSVCell], CSVLine]]: - assert self._reader is not None, "Reader uninitialized. Call setup() first." - line = self._reader.ReadLine(csv_line) - yield (line, csv_line) + def process(self, csv_line: CSVLine) -> Iterable[Tuple[List[CSVCell], CSVLine]]: + assert self._reader is not None, "Reader uninitialized. Call setup() first." + line = self._reader.ReadLine(csv_line) + yield (line, csv_line) @beam.typehints.with_input_types(List[CSVCell]) @beam.typehints.with_output_types(List[ColumnInfo]) class ColumnTypeInferrer(beam.CombineFn): - """A beam.CombineFn to infer CSV Column types. - - Its input can be produced by ParseCSVLine(). - """ - - def __init__( - self, - column_names: List[ColumnName], - skip_blank_lines: bool, - multivalent_columns: Optional[Set[ColumnName]] = None, - secondary_delimiter: Optional[Text] = None) -> None: - """Initializes a feature type inferrer combiner.""" - self._column_names = column_names - self._skip_blank_lines = skip_blank_lines - self._multivalent_columns = ( - multivalent_columns if multivalent_columns is not None else set()) - if multivalent_columns: - assert secondary_delimiter, ("secondary_delimiter must be specified if " - "there are multivalent columns") - self._multivalent_reader = _CSVRecordReader(secondary_delimiter) - - def create_accumulator(self) -> Dict[ColumnName, ColumnType]: - """Creates an empty accumulator to keep track of the feature types.""" - return {} - - def add_input(self, accumulator: Dict[ColumnName, ColumnType], - cells: List[CSVCell]) -> Dict[ColumnName, ColumnType]: - """Updates the feature types in the accumulator using the input row. + """A beam.CombineFn to infer CSV Column types. - Args: - accumulator: A dict containing the already inferred feature types. - cells: A list containing feature values of a CSV record. - - Returns: - A dict containing the updated feature types based on input row. - - Raises: - ValueError: If the columns do not match the specified csv headers. + Its input can be produced by ParseCSVLine(). """ - # If the row is empty and we don't want to skip blank lines, - # add an empty string to each column. - if not cells and not self._skip_blank_lines: - cells = ["" for _ in range(len(self._column_names))] - elif cells and len(cells) != len(self._column_names): - raise ValueError("Columns do not match specified csv headers: %s -> %s" % - (self._column_names, cells)) - - # Iterate over each feature value and update the type. - for column_name, cell in zip(self._column_names, cells): - - # Get the already inferred type of the feature. - previous_type = accumulator.get(column_name, None) - if column_name in self._multivalent_columns: - # the reader only accepts str but v is bytes. - values = self._multivalent_reader.ReadLine(cell.decode()) # pytype: disable=attribute-error # trace-all-classes - current_type = max([_InferValueType(value) for value in values - ]) if values else ColumnType.UNKNOWN - else: - current_type = _InferValueType(cell) - - # If the type inferred from the current value is higher in the type - # hierarchy compared to the already inferred type, we update the type. - # The type hierarchy is, - # INT (level 0) --> FLOAT (level 1) --> STRING (level 2) - if previous_type is None or current_type > previous_type: - accumulator[column_name] = current_type - return accumulator - - def merge_accumulators( - self, accumulators: List[Dict[ColumnName, ColumnType]] - ) -> Dict[ColumnName, ColumnType]: - """Merge the feature types inferred from the different partitions. - Args: - accumulators: A list of dicts containing the feature types inferred from - the different partitions of the data. - - Returns: - A dict containing the merged feature types. - """ - result = {} - for shard_types in accumulators: - # Merge the types inferred in each partition using the type hierarchy. - # Specifically, whenever we observe a type higher in the type hierarchy - # we update the type. - for feature_name, feature_type in shard_types.items(): - if feature_name not in result or feature_type > result[feature_name]: - result[feature_name] = feature_type - return result - - def extract_output( - self, accumulator: Dict[ColumnName, ColumnType]) -> List[ColumnInfo]: - """Return a list of tuples containing the column info.""" - return [ - ColumnInfo(col_name, accumulator.get(col_name, ColumnType.UNKNOWN)) - for col_name in self._column_names - ] - - -@beam.typehints.with_input_types( - List[Tuple[List[CSVCell], CSVLine]], - List[ColumnInfo]) + def __init__( + self, + column_names: List[ColumnName], + skip_blank_lines: bool, + multivalent_columns: Optional[Set[ColumnName]] = None, + secondary_delimiter: Optional[str] = None, + ) -> None: + """Initializes a feature type inferrer combiner.""" + self._column_names = column_names + self._skip_blank_lines = skip_blank_lines + self._multivalent_columns = ( + multivalent_columns if multivalent_columns is not None else set() + ) + if multivalent_columns: + assert secondary_delimiter, ( + "secondary_delimiter must be specified if " + "there are multivalent columns" + ) + self._multivalent_reader = _CSVRecordReader(secondary_delimiter) + + def create_accumulator(self) -> Dict[ColumnName, ColumnType]: + """Creates an empty accumulator to keep track of the feature types.""" + return {} + + def add_input( + self, accumulator: Dict[ColumnName, ColumnType], cells: List[CSVCell] + ) -> Dict[ColumnName, ColumnType]: + """Updates the feature types in the accumulator using the input row. + + Args: + ---- + accumulator: A dict containing the already inferred feature types. + cells: A list containing feature values of a CSV record. + + Returns: + ------- + A dict containing the updated feature types based on input row. + + Raises: + ------ + ValueError: If the columns do not match the specified csv headers. + """ + # If the row is empty and we don't want to skip blank lines, + # add an empty string to each column. + if not cells and not self._skip_blank_lines: + cells = ["" for _ in range(len(self._column_names))] + elif cells and len(cells) != len(self._column_names): + raise ValueError( + "Columns do not match specified csv headers: %s -> %s" + % (self._column_names, cells) + ) + + # Iterate over each feature value and update the type. + for column_name, cell in zip(self._column_names, cells): + # Get the already inferred type of the feature. + previous_type = accumulator.get(column_name) + if column_name in self._multivalent_columns: + # the reader only accepts str but v is bytes. + values = self._multivalent_reader.ReadLine( + cell.decode() + ) # pytype: disable=attribute-error # trace-all-classes + current_type = ( + max([_InferValueType(value) for value in values]) + if values + else ColumnType.UNKNOWN + ) + else: + current_type = _InferValueType(cell) + + # If the type inferred from the current value is higher in the type + # hierarchy compared to the already inferred type, we update the type. + # The type hierarchy is, + # INT (level 0) --> FLOAT (level 1) --> STRING (level 2) + if previous_type is None or current_type > previous_type: + accumulator[column_name] = current_type + return accumulator + + def merge_accumulators( + self, accumulators: List[Dict[ColumnName, ColumnType]] + ) -> Dict[ColumnName, ColumnType]: + """Merge the feature types inferred from the different partitions. + + Args: + ---- + accumulators: A list of dicts containing the feature types inferred from + the different partitions of the data. + + Returns: + ------- + A dict containing the merged feature types. + """ + result = {} + for shard_types in accumulators: + # Merge the types inferred in each partition using the type hierarchy. + # Specifically, whenever we observe a type higher in the type hierarchy + # we update the type. + for feature_name, feature_type in shard_types.items(): + if feature_name not in result or feature_type > result[feature_name]: + result[feature_name] = feature_type + return result + + def extract_output( + self, accumulator: Dict[ColumnName, ColumnType] + ) -> List[ColumnInfo]: + """Return a list of tuples containing the column info.""" + return [ + ColumnInfo(col_name, accumulator.get(col_name, ColumnType.UNKNOWN)) + for col_name in self._column_names + ] + + +@beam.typehints.with_input_types(List[Tuple[List[CSVCell], CSVLine]], List[ColumnInfo]) @beam.typehints.with_output_types(pa.RecordBatch) class BatchedCSVRowsToRecordBatch(beam.DoFn): - """DoFn to convert a batch of csv rows to a RecordBatch.""" - - def __init__(self, - skip_blank_lines: bool, - multivalent_columns: Optional[Set[ColumnName]] = None, - secondary_delimiter: Optional[Text] = None, - raw_record_column_name: Optional[Text] = None): - self._skip_blank_lines = skip_blank_lines - self._multivalent_columns = ( - multivalent_columns if multivalent_columns is not None else set()) - if multivalent_columns: - assert secondary_delimiter, ("secondary_delimiter must be specified if " - "there are multivalent columns") - self._multivalent_reader = _CSVRecordReader(secondary_delimiter) - self._raw_record_column_name = raw_record_column_name - self._raw_record_column_type = _FEATURE_TYPE_TO_ARROW_TYPE.get( - ColumnType.STRING) - - # Note that len(_column_handlers) == len(column_infos) but - # len(_column_names) and len(_column_arrow_types) may not equal to that, - # because columns of type IGNORE are not there. - self._column_handlers = None - self._column_names = [] - self._column_arrow_types = None - - def _get_column_handler( - self, column_info: ColumnInfo - ) -> Optional[Callable[[CSVCell], Optional[Iterable[Union[int, float, - bytes]]]]]: - if column_info.type == ColumnType.IGNORE: - return None - value_converter = _VALUE_CONVERTER_MAP.get(column_info.type) - assert value_converter is not None - if column_info.name in self._multivalent_columns: - # If the column is multivalent and unknown, we treat it as a univalent - # column. This will result in a null array instead of a list", as - # TFDV does not support list. - if column_info.type is ColumnType.UNKNOWN: - return lambda v: None - return lambda v: [ # pylint: disable=g-long-lambda - value_converter(sub_v) - # the reader only accepts str but v is bytes. - for sub_v in self._multivalent_reader.ReadLine(v.decode()) - ] - else: - return lambda v: (value_converter(v),) - - def _process_column_infos(self, column_infos: List[ColumnInfo]): - self._column_handlers = [self._get_column_handler(c) for c in column_infos] - self._column_arrow_types = [ - _FEATURE_TYPE_TO_ARROW_TYPE.get(c.type) - for c in column_infos - if c.type != ColumnType.IGNORE - ] - self._column_names = [ - c.name for c in column_infos if c.type != ColumnType.IGNORE] - - def process(self, batch_of_tuple: List[Tuple[List[CSVCell], CSVLine]], - column_infos: List[ColumnInfo]) -> Iterable[pa.RecordBatch]: - if self._column_handlers is None: - self._process_column_infos(column_infos) - - raw_records = [] - values_list_by_column = [[] for _ in self._column_names] - for (csv_row, raw_record) in batch_of_tuple: - if not csv_row: - if not self._skip_blank_lines: - for l in values_list_by_column: - l.append(None) - continue - if len(csv_row) != len(self._column_handlers): - raise ValueError( - "Encountered a row of unexpected number of columns: {} vs. {}" - .format(len(csv_row), len(self._column_handlers))) - column_idx = 0 - for csv_cell, handler in zip(csv_row, self._column_handlers): - if handler is None: - continue - values_list_by_column[column_idx].append( - handler(csv_cell) if csv_cell else None) - column_idx += 1 - if self._raw_record_column_name is not None: - raw_records.append([raw_record]) - - arrow_arrays = [ - pa.array(l, type=t) - for l, t in zip(values_list_by_column, self._column_arrow_types) - ] - - if self._raw_record_column_name is not None: - arrow_arrays.append( - pa.array(raw_records, type=self._raw_record_column_type)) - self._column_names.append(self._raw_record_column_name) - yield pa.RecordBatch.from_arrays(arrow_arrays, self._column_names) + """DoFn to convert a batch of csv rows to a RecordBatch.""" + + def __init__( + self, + skip_blank_lines: bool, + multivalent_columns: Optional[Set[ColumnName]] = None, + secondary_delimiter: Optional[str] = None, + raw_record_column_name: Optional[str] = None, + ): + self._skip_blank_lines = skip_blank_lines + self._multivalent_columns = ( + multivalent_columns if multivalent_columns is not None else set() + ) + if multivalent_columns: + assert secondary_delimiter, ( + "secondary_delimiter must be specified if " + "there are multivalent columns" + ) + self._multivalent_reader = _CSVRecordReader(secondary_delimiter) + self._raw_record_column_name = raw_record_column_name + self._raw_record_column_type = _FEATURE_TYPE_TO_ARROW_TYPE.get( + ColumnType.STRING + ) + + # Note that len(_column_handlers) == len(column_infos) but + # len(_column_names) and len(_column_arrow_types) may not equal to that, + # because columns of type IGNORE are not there. + self._column_handlers = None + self._column_names = [] + self._column_arrow_types = None + + def _get_column_handler( + self, column_info: ColumnInfo + ) -> Optional[Callable[[CSVCell], Optional[Iterable[Union[int, float, bytes]]]]]: + if column_info.type == ColumnType.IGNORE: + return None + value_converter = _VALUE_CONVERTER_MAP.get(column_info.type) + assert value_converter is not None + if column_info.name in self._multivalent_columns: + # If the column is multivalent and unknown, we treat it as a univalent + # column. This will result in a null array instead of a list", as + # TFDV does not support list. + if column_info.type is ColumnType.UNKNOWN: + return lambda v: None + return lambda v: [ # pylint: disable=g-long-lambda + value_converter(sub_v) + # the reader only accepts str but v is bytes. + for sub_v in self._multivalent_reader.ReadLine(v.decode()) + ] + else: + return lambda v: (value_converter(v),) + + def _process_column_infos(self, column_infos: List[ColumnInfo]): + self._column_handlers = [self._get_column_handler(c) for c in column_infos] + self._column_arrow_types = [ + _FEATURE_TYPE_TO_ARROW_TYPE.get(c.type) + for c in column_infos + if c.type != ColumnType.IGNORE + ] + self._column_names = [ + c.name for c in column_infos if c.type != ColumnType.IGNORE + ] + + def process( + self, + batch_of_tuple: List[Tuple[List[CSVCell], CSVLine]], + column_infos: List[ColumnInfo], + ) -> Iterable[pa.RecordBatch]: + if self._column_handlers is None: + self._process_column_infos(column_infos) + + raw_records = [] + values_list_by_column = [[] for _ in self._column_names] + for csv_row, raw_record in batch_of_tuple: + if not csv_row: + if not self._skip_blank_lines: + for l in values_list_by_column: + l.append(None) + continue + if len(csv_row) != len(self._column_handlers): + raise ValueError( + f"Encountered a row of unexpected number of columns: {len(csv_row)} vs. {len(self._column_handlers)}" + ) + column_idx = 0 + for csv_cell, handler in zip(csv_row, self._column_handlers): + if handler is None: + continue + values_list_by_column[column_idx].append( + handler(csv_cell) if csv_cell else None + ) + column_idx += 1 + if self._raw_record_column_name is not None: + raw_records.append([raw_record]) + + arrow_arrays = [ + pa.array(l, type=t) + for l, t in zip(values_list_by_column, self._column_arrow_types) + ] + + if self._raw_record_column_name is not None: + arrow_arrays.append( + pa.array(raw_records, type=self._raw_record_column_type) + ) + self._column_names.append(self._raw_record_column_name) + yield pa.RecordBatch.from_arrays(arrow_arrays, self._column_names) _VALUE_CONVERTER_MAP = { @@ -395,93 +442,100 @@ def process(self, batch_of_tuple: List[Tuple[List[CSVCell], CSVLine]], } -def GetArrowSchema(column_names: List[Text], - schema: schema_pb2.Schema, - raw_record_column_name: Optional[Text] = None) -> pa.Schema: - """Returns the arrow schema given columns and a TFMD schema. +def GetArrowSchema( + column_names: List[str], + schema: schema_pb2.Schema, + raw_record_column_name: Optional[str] = None, +) -> pa.Schema: + """Returns the arrow schema given columns and a TFMD schema. - Args: - column_names: List of feature names. This must match the features in schema. - schema: The schema proto to base the arrow schema from. - raw_record_column_name: An optional name for the column containing raw - records. If this is not set, the arrow schema will not contain a raw - records column. + Args: + ---- + column_names: List of feature names. This must match the features in schema. + schema: The schema proto to base the arrow schema from. + raw_record_column_name: An optional name for the column containing raw + records. If this is not set, the arrow schema will not contain a raw + records column. - Returns: - Arrow Schema based on the provided schema proto. + Returns: + ------- + Arrow Schema based on the provided schema proto. - Raises: - ValueError: - * If the schema contains a feature that does not exist in `column_names`. - * If the feature type does not map to an arrow type. - * If raw_record_column_name exists in column_names - """ - schema_feature_names = [f.name for f in schema.feature] - if not set(schema_feature_names).issubset(set(column_names)): - raise ValueError( - "Schema features are not a subset of column names: {} vs {}".format( - schema_feature_names, column_names)) + Raises: + ------ + ValueError: + * If the schema contains a feature that does not exist in `column_names`. + * If the feature type does not map to an arrow type. + * If raw_record_column_name exists in column_names + """ + schema_feature_names = [f.name for f in schema.feature] + if not set(schema_feature_names).issubset(set(column_names)): + raise ValueError( + f"Schema features are not a subset of column names: {schema_feature_names} vs {column_names}" + ) - fields = [] - column_name_to_schema_feature_map = {f.name: f for f in schema.feature} - for col in column_names: - feature = column_name_to_schema_feature_map.get(col) - if feature is None: - continue - arrow_type = _FEATURE_TYPE_TO_ARROW_TYPE.get( - _SCHEMA_TYPE_TO_COLUMN_TYPE.get(feature.type), None) - if arrow_type is None: - raise ValueError("Feature {} has unsupport type {}".format( - feature.name, feature.type)) - fields.append(pa.field(feature.name, arrow_type)) + fields = [] + column_name_to_schema_feature_map = {f.name: f for f in schema.feature} + for col in column_names: + feature = column_name_to_schema_feature_map.get(col) + if feature is None: + continue + arrow_type = _FEATURE_TYPE_TO_ARROW_TYPE.get( + _SCHEMA_TYPE_TO_COLUMN_TYPE.get(feature.type), None + ) + if arrow_type is None: + raise ValueError( + f"Feature {feature.name} has unsupport type {feature.type}" + ) + fields.append(pa.field(feature.name, arrow_type)) - if raw_record_column_name is not None: - if raw_record_column_name in column_names: - raise ValueError( - "raw_record_column_name: {} is already an existing column name. " - "Please choose a different name.".format(raw_record_column_name)) - raw_record_type = _FEATURE_TYPE_TO_ARROW_TYPE.get( - ColumnType.STRING) - fields.append(pa.field(raw_record_column_name, raw_record_type)) - return pa.schema(fields) + if raw_record_column_name is not None: + if raw_record_column_name in column_names: + raise ValueError( + f"raw_record_column_name: {raw_record_column_name} is already an existing column name. " + "Please choose a different name." + ) + raw_record_type = _FEATURE_TYPE_TO_ARROW_TYPE.get(ColumnType.STRING) + fields.append(pa.field(raw_record_column_name, raw_record_type)) + return pa.schema(fields) -class _CSVRecordReader(object): - """A picklable wrapper for csv.reader that can parse one record at a time.""" +class _CSVRecordReader: + """A picklable wrapper for csv.reader that can parse one record at a time.""" - def __init__(self, delimiter: Text): - self._delimiter = delimiter - self._line_iterator = _MutableRepeat() - self._reader = csv.reader(self._line_iterator, delimiter=delimiter) + def __init__(self, delimiter: str): + self._delimiter = delimiter + self._line_iterator = _MutableRepeat() + self._reader = csv.reader(self._line_iterator, delimiter=delimiter) - def ReadLine(self, csv_line: CSVLine) -> List[CSVCell]: - """Reads out bytes for PY2 and Unicode for PY3.""" - self._line_iterator.SetItem(csv_line) - return [cell.encode() for cell in next(self._reader)] + def ReadLine(self, csv_line: CSVLine) -> List[CSVCell]: + """Reads out bytes for PY2 and Unicode for PY3.""" + self._line_iterator.SetItem(csv_line) + return [cell.encode() for cell in next(self._reader)] - def __getstate__(self): - return (self._delimiter,) + def __getstate__(self): + return (self._delimiter,) - def __setstate__(self, state): - self.__init__(*state) + def __setstate__(self, state): + self.__init__(*state) -class _MutableRepeat(object): - """Similar to itertools.repeat, but the item can be set on the fly.""" +class _MutableRepeat: + """Similar to itertools.repeat, but the item can be set on the fly.""" - def __init__(self): - self._item = None + def __init__(self): + self._item = None - def SetItem(self, item: Any): - self._item = item + def SetItem(self, item: Any): + self._item = item - def __iter__(self) -> Any: - return self + def __iter__(self) -> Any: + return self - def __next__(self) -> Any: - return self._item + def __next__(self) -> Any: + return self._item - next = __next__ + next = __next__ _INT64_MIN = np.iinfo(np.int64).min @@ -489,41 +543,42 @@ def __next__(self) -> Any: def _InferValueType(value: CSVCell) -> ColumnType: - """Infer column type from the input value.""" - if not value: - return ColumnType.UNKNOWN - - # Check if the value is of type INT. - try: - if _INT64_MIN <= int(value) <= _INT64_MAX: - return ColumnType.INT - # We infer STRING type when we have long integer values. - return ColumnType.STRING - except ValueError: - # If the type is not INT, we next check for FLOAT type (according to our - # type hierarchy). If we can convert the string to a float value, we - # fix the type to be FLOAT. Else we resort to STRING type. + """Infer column type from the input value.""" + if not value: + return ColumnType.UNKNOWN + + # Check if the value is of type INT. try: - float(value) + if _INT64_MIN <= int(value) <= _INT64_MAX: + return ColumnType.INT + # We infer STRING type when we have long integer values. + return ColumnType.STRING except ValueError: - return ColumnType.STRING - return ColumnType.FLOAT + # If the type is not INT, we next check for FLOAT type (according to our + # type hierarchy). If we can convert the string to a float value, we + # fix the type to be FLOAT. Else we resort to STRING type. + try: + float(value) + except ValueError: + return ColumnType.STRING + return ColumnType.FLOAT def _GetColumnInfosFromSchema( - schema: schema_pb2.Schema, - column_names: List[Text]) -> List[ColumnInfo]: - """Get column name and type from the input schema.""" - feature_type_map = {} - for feature in schema.feature: - feature_type = _SCHEMA_TYPE_TO_COLUMN_TYPE.get(feature.type, None) - if feature_type is None: - raise ValueError("Schema contains invalid type: {}.".format( - schema_pb2.FeatureType.Name(feature.type))) - feature_type_map[feature.name] = feature_type - - column_infos = [] - for col_name in column_names: - feature_type = feature_type_map.get(col_name, ColumnType.IGNORE) - column_infos.append(ColumnInfo(col_name, feature_type)) - return column_infos + schema: schema_pb2.Schema, column_names: List[str] +) -> List[ColumnInfo]: + """Get column name and type from the input schema.""" + feature_type_map = {} + for feature in schema.feature: + feature_type = _SCHEMA_TYPE_TO_COLUMN_TYPE.get(feature.type, None) + if feature_type is None: + raise ValueError( + f"Schema contains invalid type: {schema_pb2.FeatureType.Name(feature.type)}." + ) + feature_type_map[feature.name] = feature_type + + column_infos = [] + for col_name in column_names: + feature_type = feature_type_map.get(col_name, ColumnType.IGNORE) + column_infos.append(ColumnInfo(col_name, feature_type)) + return column_infos diff --git a/tfx_bsl/coders/csv_decoder_test.py b/tfx_bsl/coders/csv_decoder_test.py index ab49fa01..17b2ef05 100644 --- a/tfx_bsl/coders/csv_decoder_test.py +++ b/tfx_bsl/coders/csv_decoder_test.py @@ -1,4 +1,3 @@ -# coding=utf-8 # # Copyright 2018 Google LLC # @@ -16,393 +15,470 @@ """Tests for CSV decoder.""" -import pytest import apache_beam as beam -from apache_beam.testing import util as beam_test_util import numpy as np import pyarrow as pa -from tfx_bsl.coders import csv_decoder +import pytest +from absl.testing import absltest, parameterized +from apache_beam.testing import util as beam_test_util from google.protobuf import text_format -from absl.testing import absltest -from absl.testing import parameterized from tensorflow_metadata.proto.v0 import schema_pb2 +from tfx_bsl.coders import csv_decoder + _TEST_CASES = [ dict( - testcase_name='simple', - input_lines=['1,2.0,hello', '5,12.34,world'], - column_names=['int_feature', 'float_feature', 'str_feature'], + testcase_name="simple", + input_lines=["1,2.0,hello", "5,12.34,world"], + column_names=["int_feature", "float_feature", "str_feature"], expected_csv_cells=[ - [b'1', b'2.0', b'hello'], - [b'5', b'12.34', b'world'], + [b"1", b"2.0", b"hello"], + [b"5", b"12.34", b"world"], ], expected_types=[ csv_decoder.ColumnType.INT, csv_decoder.ColumnType.FLOAT, csv_decoder.ColumnType.STRING, ], - expected_record_batch=pa.RecordBatch.from_arrays([ - pa.array([[1], [5]], pa.large_list(pa.int64())), - pa.array([[2.0], [12.34]], pa.large_list(pa.float32())), - pa.array([[b'hello'], [b'world']], pa.large_list(pa.large_binary())) - ], ['int_feature', 'float_feature', 'str_feature'])), + expected_record_batch=pa.RecordBatch.from_arrays( + [ + pa.array([[1], [5]], pa.large_list(pa.int64())), + pa.array([[2.0], [12.34]], pa.large_list(pa.float32())), + pa.array([[b"hello"], [b"world"]], pa.large_list(pa.large_binary())), + ], + ["int_feature", "float_feature", "str_feature"], + ), + ), dict( - testcase_name='missing_values', - input_lines=[',,', '1,,hello', ',12.34,'], - column_names=['f1', 'f2', 'f3'], + testcase_name="missing_values", + input_lines=[",,", "1,,hello", ",12.34,"], + column_names=["f1", "f2", "f3"], expected_csv_cells=[ - [b'', b'', b''], - [b'1', b'', b'hello'], - [b'', b'12.34', b''], + [b"", b"", b""], + [b"1", b"", b"hello"], + [b"", b"12.34", b""], ], expected_types=[ csv_decoder.ColumnType.INT, csv_decoder.ColumnType.FLOAT, csv_decoder.ColumnType.STRING, ], - expected_record_batch=pa.RecordBatch.from_arrays([ - pa.array([None, [1], None], pa.large_list(pa.int64())), - pa.array([None, None, [12.34]], pa.large_list(pa.float32())), - pa.array([None, [b'hello'], None], pa.large_list( - pa.large_binary())), - ], ['f1', 'f2', 'f3'])), + expected_record_batch=pa.RecordBatch.from_arrays( + [ + pa.array([None, [1], None], pa.large_list(pa.int64())), + pa.array([None, None, [12.34]], pa.large_list(pa.float32())), + pa.array([None, [b"hello"], None], pa.large_list(pa.large_binary())), + ], + ["f1", "f2", "f3"], + ), + ), dict( - testcase_name='mixed_int_and_float', - input_lines=['2,1.5', '1.5,2'], - column_names=['f1', 'f2'], - expected_csv_cells=[[b'2', b'1.5'], [b'1.5', b'2']], + testcase_name="mixed_int_and_float", + input_lines=["2,1.5", "1.5,2"], + column_names=["f1", "f2"], + expected_csv_cells=[[b"2", b"1.5"], [b"1.5", b"2"]], expected_types=[ csv_decoder.ColumnType.FLOAT, csv_decoder.ColumnType.FLOAT, ], - expected_record_batch=pa.RecordBatch.from_arrays([ - pa.array([[2], [1.5]], pa.large_list(pa.float32())), - pa.array([[1.5], [2]], pa.large_list(pa.float32())) - ], ['f1', 'f2'])), + expected_record_batch=pa.RecordBatch.from_arrays( + [ + pa.array([[2], [1.5]], pa.large_list(pa.float32())), + pa.array([[1.5], [2]], pa.large_list(pa.float32())), + ], + ["f1", "f2"], + ), + ), dict( - testcase_name='mixed_int_and_string', - input_lines=['2,abc', 'abc,2'], - column_names=['f1', 'f2'], - expected_csv_cells=[[b'2', b'abc'], [b'abc', b'2']], + testcase_name="mixed_int_and_string", + input_lines=["2,abc", "abc,2"], + column_names=["f1", "f2"], + expected_csv_cells=[[b"2", b"abc"], [b"abc", b"2"]], expected_types=[ csv_decoder.ColumnType.STRING, csv_decoder.ColumnType.STRING, ], - expected_record_batch=pa.RecordBatch.from_arrays([ - pa.array([[b'2'], [b'abc']], pa.large_list(pa.large_binary())), - pa.array([[b'abc'], [b'2']], pa.large_list(pa.large_binary())) - ], ['f1', 'f2'])), + expected_record_batch=pa.RecordBatch.from_arrays( + [ + pa.array([[b"2"], [b"abc"]], pa.large_list(pa.large_binary())), + pa.array([[b"abc"], [b"2"]], pa.large_list(pa.large_binary())), + ], + ["f1", "f2"], + ), + ), dict( - testcase_name='mixed_float_and_string', - input_lines=['2.3,abc', 'abc,2.3'], - column_names=['f1', 'f2'], - expected_csv_cells=[[b'2.3', b'abc'], [b'abc', b'2.3']], + testcase_name="mixed_float_and_string", + input_lines=["2.3,abc", "abc,2.3"], + column_names=["f1", "f2"], + expected_csv_cells=[[b"2.3", b"abc"], [b"abc", b"2.3"]], expected_types=[ csv_decoder.ColumnType.STRING, csv_decoder.ColumnType.STRING, ], - expected_record_batch=pa.RecordBatch.from_arrays([ - pa.array([[b'2.3'], [b'abc']], pa.large_list(pa.large_binary())), - pa.array([[b'abc'], [b'2.3']], pa.large_list(pa.large_binary())) - ], ['f1', 'f2'])), + expected_record_batch=pa.RecordBatch.from_arrays( + [ + pa.array([[b"2.3"], [b"abc"]], pa.large_list(pa.large_binary())), + pa.array([[b"abc"], [b"2.3"]], pa.large_list(pa.large_binary())), + ], + ["f1", "f2"], + ), + ), dict( - testcase_name='unicode', - input_lines=[u'\U0001f951'], - column_names=['f1'], - expected_csv_cells=[[u'\U0001f951'.encode('utf-8')]], + testcase_name="unicode", + input_lines=["\U0001f951"], + column_names=["f1"], + expected_csv_cells=[["\U0001f951".encode()]], expected_types=[ csv_decoder.ColumnType.STRING, ], - expected_record_batch=pa.RecordBatch.from_arrays([ - pa.array([[u'\U0001f951'.encode('utf-8')]], - pa.large_list(pa.large_binary())) - ], ['f1'])), + expected_record_batch=pa.RecordBatch.from_arrays( + [pa.array([["\U0001f951".encode()]], pa.large_list(pa.large_binary()))], + ["f1"], + ), + ), dict( - testcase_name='quotes', + testcase_name="quotes", input_lines=['1,"ab,cd,ef"', '5,"wx,xy,yz"'], - column_names=['f1', 'f2'], - expected_csv_cells=[[b'1', b'ab,cd,ef'], [b'5', b'wx,xy,yz']], + column_names=["f1", "f2"], + expected_csv_cells=[[b"1", b"ab,cd,ef"], [b"5", b"wx,xy,yz"]], expected_types=[ csv_decoder.ColumnType.INT, csv_decoder.ColumnType.STRING, ], - expected_record_batch=pa.RecordBatch.from_arrays([ - pa.array([[1], [5]], pa.large_list(pa.int64())), - pa.array([[b'ab,cd,ef'], [b'wx,xy,yz']], - pa.large_list(pa.large_binary())) - ], ['f1', 'f2'])), + expected_record_batch=pa.RecordBatch.from_arrays( + [ + pa.array([[1], [5]], pa.large_list(pa.int64())), + pa.array( + [[b"ab,cd,ef"], [b"wx,xy,yz"]], pa.large_list(pa.large_binary()) + ), + ], + ["f1", "f2"], + ), + ), dict( - testcase_name='space_delimiter', + testcase_name="space_delimiter", input_lines=['1 "ab,cd,ef"', '5 "wx,xy,yz"'], - column_names=['f1', 'f2'], - expected_csv_cells=[[b'1', b'ab,cd,ef'], [b'5', b'wx,xy,yz']], + column_names=["f1", "f2"], + expected_csv_cells=[[b"1", b"ab,cd,ef"], [b"5", b"wx,xy,yz"]], expected_types=[ csv_decoder.ColumnType.INT, csv_decoder.ColumnType.STRING, ], - expected_record_batch=pa.RecordBatch.from_arrays([ - pa.array([[1], [5]], pa.large_list(pa.int64())), - pa.array([[b'ab,cd,ef'], [b'wx,xy,yz']], - pa.large_list(pa.large_binary())) - ], ['f1', 'f2']), - delimiter=' '), + expected_record_batch=pa.RecordBatch.from_arrays( + [ + pa.array([[1], [5]], pa.large_list(pa.int64())), + pa.array( + [[b"ab,cd,ef"], [b"wx,xy,yz"]], pa.large_list(pa.large_binary()) + ), + ], + ["f1", "f2"], + ), + delimiter=" ", + ), dict( - testcase_name='tab_delimiter', - input_lines=['1\t"this is a \ttext"', '5\t'], - column_names=['f1', 'f2'], - expected_csv_cells=[[b'1', b'this is a \ttext'], [b'5', b'']], + testcase_name="tab_delimiter", + input_lines=['1\t"this is a \ttext"', "5\t"], + column_names=["f1", "f2"], + expected_csv_cells=[[b"1", b"this is a \ttext"], [b"5", b""]], expected_types=[ csv_decoder.ColumnType.INT, csv_decoder.ColumnType.STRING, ], - expected_record_batch=pa.RecordBatch.from_arrays([ - pa.array([[1], [5]], pa.large_list(pa.int64())), - pa.array([[b'this is a \ttext'], None], - pa.large_list(pa.large_binary())) - ], ['f1', 'f2']), - delimiter='\t'), + expected_record_batch=pa.RecordBatch.from_arrays( + [ + pa.array([[1], [5]], pa.large_list(pa.int64())), + pa.array( + [[b"this is a \ttext"], None], pa.large_list(pa.large_binary()) + ), + ], + ["f1", "f2"], + ), + delimiter="\t", + ), dict( - testcase_name='negative_values', - input_lines=['-1,-2.5'], - column_names=['f1', 'f2'], - expected_csv_cells=[[b'-1', b'-2.5']], + testcase_name="negative_values", + input_lines=["-1,-2.5"], + column_names=["f1", "f2"], + expected_csv_cells=[[b"-1", b"-2.5"]], expected_types=[ csv_decoder.ColumnType.INT, csv_decoder.ColumnType.FLOAT, ], - expected_record_batch=pa.RecordBatch.from_arrays([ - pa.array([[-1]], pa.large_list(pa.int64())), - pa.array([[-2.5]], pa.large_list(pa.float32())) - ], ['f1', 'f2'])), + expected_record_batch=pa.RecordBatch.from_arrays( + [ + pa.array([[-1]], pa.large_list(pa.int64())), + pa.array([[-2.5]], pa.large_list(pa.float32())), + ], + ["f1", "f2"], + ), + ), dict( - testcase_name='int64_boundary', + testcase_name="int64_boundary", input_lines=[ - '%s,%s,%s,%s' % ( + "%s,%s,%s,%s" + % ( str(np.iinfo(np.int64).min), str(np.iinfo(np.int64).max), str(np.iinfo(np.int64).min - 1), str(np.iinfo(np.int64).max + 1), ) ], - column_names=['int64min', 'int64max', 'int64min-1', 'int64max+1'], - expected_csv_cells=[[ - b'-9223372036854775808', b'9223372036854775807', - b'-9223372036854775809', b'9223372036854775808' - ]], + column_names=["int64min", "int64max", "int64min-1", "int64max+1"], + expected_csv_cells=[ + [ + b"-9223372036854775808", + b"9223372036854775807", + b"-9223372036854775809", + b"9223372036854775808", + ] + ], expected_types=[ csv_decoder.ColumnType.INT, csv_decoder.ColumnType.INT, csv_decoder.ColumnType.STRING, csv_decoder.ColumnType.STRING, ], - expected_record_batch=pa.RecordBatch.from_arrays([ - pa.array([[-9223372036854775808]], pa.large_list(pa.int64())), - pa.array([[9223372036854775807]], pa.large_list(pa.int64())), - pa.array([[b'-9223372036854775809']], - pa.large_list(pa.large_binary())), - pa.array([[b'9223372036854775808']], pa.large_list( - pa.large_binary())) - ], ['int64min', 'int64max', 'int64min-1', 'int64max+1'])), + expected_record_batch=pa.RecordBatch.from_arrays( + [ + pa.array([[-9223372036854775808]], pa.large_list(pa.int64())), + pa.array([[9223372036854775807]], pa.large_list(pa.int64())), + pa.array([[b"-9223372036854775809"]], pa.large_list(pa.large_binary())), + pa.array([[b"9223372036854775808"]], pa.large_list(pa.large_binary())), + ], + ["int64min", "int64max", "int64min-1", "int64max+1"], + ), + ), dict( - testcase_name='skip_blank_lines', - input_lines=['', '1,2'], + testcase_name="skip_blank_lines", + input_lines=["", "1,2"], skip_blank_lines=True, - column_names=['f1', 'f2'], - expected_csv_cells=[[], [b'1', b'2']], + column_names=["f1", "f2"], + expected_csv_cells=[[], [b"1", b"2"]], expected_types=[ csv_decoder.ColumnType.INT, csv_decoder.ColumnType.INT, ], - expected_record_batch=pa.RecordBatch.from_arrays([ - pa.array([[1]], pa.large_list(pa.int64())), - pa.array([[2]], pa.large_list(pa.int64())) - ], ['f1', 'f2'])), + expected_record_batch=pa.RecordBatch.from_arrays( + [ + pa.array([[1]], pa.large_list(pa.int64())), + pa.array([[2]], pa.large_list(pa.int64())), + ], + ["f1", "f2"], + ), + ), dict( - testcase_name='consider_blank_lines', - input_lines=['', '1,2'], + testcase_name="consider_blank_lines", + input_lines=["", "1,2"], skip_blank_lines=False, - column_names=['f1', 'f2'], - expected_csv_cells=[[], [b'1', b'2']], + column_names=["f1", "f2"], + expected_csv_cells=[[], [b"1", b"2"]], expected_types=[ csv_decoder.ColumnType.INT, csv_decoder.ColumnType.INT, ], - expected_record_batch=pa.RecordBatch.from_arrays([ - pa.array([None, [1]], pa.large_list(pa.int64())), - pa.array([None, [2]], pa.large_list(pa.int64())) - ], ['f1', 'f2'])), + expected_record_batch=pa.RecordBatch.from_arrays( + [ + pa.array([None, [1]], pa.large_list(pa.int64())), + pa.array([None, [2]], pa.large_list(pa.int64())), + ], + ["f1", "f2"], + ), + ), dict( - testcase_name='skip_blank_lines_single_column', - input_lines=['', '1'], + testcase_name="skip_blank_lines_single_column", + input_lines=["", "1"], skip_blank_lines=True, - column_names=['f1'], - expected_csv_cells=[[], [b'1']], + column_names=["f1"], + expected_csv_cells=[[], [b"1"]], expected_types=[ csv_decoder.ColumnType.INT, ], expected_record_batch=pa.RecordBatch.from_arrays( - [pa.array([[1]], pa.large_list(pa.int64()))], ['f1'])), + [pa.array([[1]], pa.large_list(pa.int64()))], ["f1"] + ), + ), dict( - testcase_name='consider_blank_lines_single_column', - input_lines=['', '1'], + testcase_name="consider_blank_lines_single_column", + input_lines=["", "1"], skip_blank_lines=False, - column_names=['f1'], - expected_csv_cells=[[], [b'1']], + column_names=["f1"], + expected_csv_cells=[[], [b"1"]], expected_types=[ csv_decoder.ColumnType.INT, ], expected_record_batch=pa.RecordBatch.from_arrays( - [pa.array([None, [1]], pa.large_list(pa.int64()))], ['f1'])), + [pa.array([None, [1]], pa.large_list(pa.int64()))], ["f1"] + ), + ), dict( - testcase_name='empty_csv', + testcase_name="empty_csv", input_lines=[], - column_names=['f1'], + column_names=["f1"], expected_csv_cells=[], expected_types=[csv_decoder.ColumnType.UNKNOWN], expected_record_batch=[], ), dict( - testcase_name='null_column', - input_lines=['', ''], - column_names=['f1'], + testcase_name="null_column", + input_lines=["", ""], + column_names=["f1"], expected_csv_cells=[[], []], expected_types=[csv_decoder.ColumnType.UNKNOWN], expected_record_batch=pa.RecordBatch.from_arrays( - [pa.array([None, None], pa.null())], ['f1'])), + [pa.array([None, None], pa.null())], ["f1"] + ), + ), dict( - testcase_name='size_2_vector_int_multivalent', - input_lines=['12|14'], - column_names=['x'], - expected_csv_cells=[[b'12|14']], + testcase_name="size_2_vector_int_multivalent", + input_lines=["12|14"], + column_names=["x"], + expected_csv_cells=[[b"12|14"]], expected_types=[csv_decoder.ColumnType.INT], expected_record_batch=pa.RecordBatch.from_arrays( - [pa.array([[12, 14]], pa.large_list(pa.int64()))], ['x']), - delimiter=' ', - multivalent_columns=['x'], - secondary_delimiter='|'), + [pa.array([[12, 14]], pa.large_list(pa.int64()))], ["x"] + ), + delimiter=" ", + multivalent_columns=["x"], + secondary_delimiter="|", + ), dict( - testcase_name='space_and_comma_delimiter', + testcase_name="space_and_comma_delimiter", input_lines=['1,2 "abcdef"', '5,1 "wxxyyz"'], - column_names=['f1', 'f2'], - expected_csv_cells=[[b'1,2', b'abcdef'], [b'5,1', b'wxxyyz']], - expected_types=[ - csv_decoder.ColumnType.INT, csv_decoder.ColumnType.STRING - ], - expected_record_batch=pa.RecordBatch.from_arrays([ - pa.array([[1, 2], [5, 1]], pa.large_list(pa.int64())), - pa.array([[b'abcdef'], [b'wxxyyz']], pa.large_list( - pa.large_binary())) - ], ['f1', 'f2']), - delimiter=' ', - multivalent_columns=['f1'], - secondary_delimiter=','), + column_names=["f1", "f2"], + expected_csv_cells=[[b"1,2", b"abcdef"], [b"5,1", b"wxxyyz"]], + expected_types=[csv_decoder.ColumnType.INT, csv_decoder.ColumnType.STRING], + expected_record_batch=pa.RecordBatch.from_arrays( + [ + pa.array([[1, 2], [5, 1]], pa.large_list(pa.int64())), + pa.array([[b"abcdef"], [b"wxxyyz"]], pa.large_list(pa.large_binary())), + ], + ["f1", "f2"], + ), + delimiter=" ", + multivalent_columns=["f1"], + secondary_delimiter=",", + ), dict( - testcase_name='empty_multivalent_column', - input_lines=[',test'], - column_names=['f1', 'f2'], - expected_csv_cells=[[b'', b'test']], - expected_types=[ - csv_decoder.ColumnType.UNKNOWN, csv_decoder.ColumnType.STRING - ], - expected_record_batch=pa.RecordBatch.from_arrays([ - pa.array([None], pa.null()), - pa.array([[b'test']], pa.large_list(pa.large_binary())) - ], ['f1', 'f2']), - multivalent_columns=['f1'], - secondary_delimiter='|'), + testcase_name="empty_multivalent_column", + input_lines=[",test"], + column_names=["f1", "f2"], + expected_csv_cells=[[b"", b"test"]], + expected_types=[csv_decoder.ColumnType.UNKNOWN, csv_decoder.ColumnType.STRING], + expected_record_batch=pa.RecordBatch.from_arrays( + [ + pa.array([None], pa.null()), + pa.array([[b"test"]], pa.large_list(pa.large_binary())), + ], + ["f1", "f2"], + ), + multivalent_columns=["f1"], + secondary_delimiter="|", + ), dict( - testcase_name='empty_values_multivalent_column', - input_lines=['|,test'], - column_names=['f1', 'f2'], - expected_csv_cells=[[b'|', b'test']], - expected_types=[ - csv_decoder.ColumnType.UNKNOWN, csv_decoder.ColumnType.STRING - ], - expected_record_batch=pa.RecordBatch.from_arrays([ - pa.array([None], pa.null()), - pa.array([[b'test']], pa.large_list(pa.large_binary())) - ], ['f1', 'f2']), - multivalent_columns=['f1'], - secondary_delimiter='|'), + testcase_name="empty_values_multivalent_column", + input_lines=["|,test"], + column_names=["f1", "f2"], + expected_csv_cells=[[b"|", b"test"]], + expected_types=[csv_decoder.ColumnType.UNKNOWN, csv_decoder.ColumnType.STRING], + expected_record_batch=pa.RecordBatch.from_arrays( + [ + pa.array([None], pa.null()), + pa.array([[b"test"]], pa.large_list(pa.large_binary())), + ], + ["f1", "f2"], + ), + multivalent_columns=["f1"], + secondary_delimiter="|", + ), dict( - testcase_name='empty_string_multivalent_column', - input_lines=['|,test', 'a|b,test'], - column_names=['f1', 'f2'], - expected_csv_cells=[[b'|', b'test'], [b'a|b', b'test']], - expected_types=[ - csv_decoder.ColumnType.STRING, csv_decoder.ColumnType.STRING - ], - expected_record_batch=pa.RecordBatch.from_arrays([ - pa.array([[b'', b''], [b'a', b'b']], pa.large_list( - pa.large_binary())), - pa.array([[b'test'], [b'test']], pa.large_list(pa.large_binary())) - ], ['f1', 'f2']), - multivalent_columns=['f1'], - secondary_delimiter='|'), + testcase_name="empty_string_multivalent_column", + input_lines=["|,test", "a|b,test"], + column_names=["f1", "f2"], + expected_csv_cells=[[b"|", b"test"], [b"a|b", b"test"]], + expected_types=[csv_decoder.ColumnType.STRING, csv_decoder.ColumnType.STRING], + expected_record_batch=pa.RecordBatch.from_arrays( + [ + pa.array([[b"", b""], [b"a", b"b"]], pa.large_list(pa.large_binary())), + pa.array([[b"test"], [b"test"]], pa.large_list(pa.large_binary())), + ], + ["f1", "f2"], + ), + multivalent_columns=["f1"], + secondary_delimiter="|", + ), dict( - testcase_name='int_and_float_multivalent_column', - input_lines=['1|2.3,test'], - column_names=['f1', 'f2'], - expected_csv_cells=[[b'1|2.3', b'test']], - expected_types=[ - csv_decoder.ColumnType.FLOAT, csv_decoder.ColumnType.STRING - ], - expected_record_batch=pa.RecordBatch.from_arrays([ - pa.array([[1, 2.3]], pa.large_list(pa.float32())), - pa.array([[b'test']], pa.large_list(pa.large_binary())) - ], ['f1', 'f2']), - multivalent_columns=['f1'], - secondary_delimiter='|'), + testcase_name="int_and_float_multivalent_column", + input_lines=["1|2.3,test"], + column_names=["f1", "f2"], + expected_csv_cells=[[b"1|2.3", b"test"]], + expected_types=[csv_decoder.ColumnType.FLOAT, csv_decoder.ColumnType.STRING], + expected_record_batch=pa.RecordBatch.from_arrays( + [ + pa.array([[1, 2.3]], pa.large_list(pa.float32())), + pa.array([[b"test"]], pa.large_list(pa.large_binary())), + ], + ["f1", "f2"], + ), + multivalent_columns=["f1"], + secondary_delimiter="|", + ), dict( - testcase_name='float_and_string_multivalent_column', - input_lines=['2.3|abc,test'], - column_names=['f1', 'f2'], - expected_csv_cells=[[b'2.3|abc', b'test']], - expected_types=[ - csv_decoder.ColumnType.STRING, csv_decoder.ColumnType.STRING - ], - expected_record_batch=pa.RecordBatch.from_arrays([ - pa.array([[b'2.3', b'abc']], pa.large_list(pa.large_binary())), - pa.array([[b'test']], pa.large_list(pa.large_binary())) - ], ['f1', 'f2']), - multivalent_columns=['f1'], - secondary_delimiter='|'), + testcase_name="float_and_string_multivalent_column", + input_lines=["2.3|abc,test"], + column_names=["f1", "f2"], + expected_csv_cells=[[b"2.3|abc", b"test"]], + expected_types=[csv_decoder.ColumnType.STRING, csv_decoder.ColumnType.STRING], + expected_record_batch=pa.RecordBatch.from_arrays( + [ + pa.array([[b"2.3", b"abc"]], pa.large_list(pa.large_binary())), + pa.array([[b"test"]], pa.large_list(pa.large_binary())), + ], + ["f1", "f2"], + ), + multivalent_columns=["f1"], + secondary_delimiter="|", + ), dict( - testcase_name='int_and_string_multivalent_column', - input_lines=['1|abc,test'], - column_names=['f1', 'f2'], - expected_csv_cells=[[b'1|abc', b'test']], - expected_types=[ - csv_decoder.ColumnType.STRING, csv_decoder.ColumnType.STRING - ], - expected_record_batch=pa.RecordBatch.from_arrays([ - pa.array([[b'1', b'abc']], pa.large_list(pa.large_binary())), - pa.array([[b'test']], pa.large_list(pa.large_binary())) - ], ['f1', 'f2']), - multivalent_columns=['f1'], - secondary_delimiter='|'), + testcase_name="int_and_string_multivalent_column", + input_lines=["1|abc,test"], + column_names=["f1", "f2"], + expected_csv_cells=[[b"1|abc", b"test"]], + expected_types=[csv_decoder.ColumnType.STRING, csv_decoder.ColumnType.STRING], + expected_record_batch=pa.RecordBatch.from_arrays( + [ + pa.array([[b"1", b"abc"]], pa.large_list(pa.large_binary())), + pa.array([[b"test"]], pa.large_list(pa.large_binary())), + ], + ["f1", "f2"], + ), + multivalent_columns=["f1"], + secondary_delimiter="|", + ), dict( - testcase_name='int_and_string_multivalent_column_multiple_lines', - input_lines=['1|abc,test', '2|2,test'], - column_names=['f1', 'f2'], - expected_csv_cells=[[b'1|abc', b'test'], [b'2|2', b'test']], - expected_types=[ - csv_decoder.ColumnType.STRING, csv_decoder.ColumnType.STRING - ], - expected_record_batch=pa.RecordBatch.from_arrays([ - pa.array([[b'1', b'abc'], [b'2', b'2']], - pa.large_list(pa.large_binary())), - pa.array([[b'test'], [b'test']], pa.large_list(pa.large_binary())) - ], ['f1', 'f2']), - multivalent_columns=['f1'], - secondary_delimiter='|'), + testcase_name="int_and_string_multivalent_column_multiple_lines", + input_lines=["1|abc,test", "2|2,test"], + column_names=["f1", "f2"], + expected_csv_cells=[[b"1|abc", b"test"], [b"2|2", b"test"]], + expected_types=[csv_decoder.ColumnType.STRING, csv_decoder.ColumnType.STRING], + expected_record_batch=pa.RecordBatch.from_arrays( + [ + pa.array( + [[b"1", b"abc"], [b"2", b"2"]], pa.large_list(pa.large_binary()) + ), + pa.array([[b"test"], [b"test"]], pa.large_list(pa.large_binary())), + ], + ["f1", "f2"], + ), + multivalent_columns=["f1"], + secondary_delimiter="|", + ), dict( - testcase_name='with_schema', - input_lines=['1,2.0,hello', '5,12.34,world'], - column_names=['int_feature', 'float_feature', 'str_feature'], + testcase_name="with_schema", + input_lines=["1,2.0,hello", "5,12.34,world"], + column_names=["int_feature", "float_feature", "str_feature"], expected_csv_cells=[ - [b'1', b'2.0', b'hello'], - [b'5', b'12.34', b'world'], + [b"1", b"2.0", b"hello"], + [b"5", b"12.34", b"world"], ], expected_types=[ csv_decoder.ColumnType.INT, @@ -422,42 +498,52 @@ name: "str_feature" type: BYTES } - """, schema_pb2.Schema()), - expected_record_batch=pa.RecordBatch.from_arrays([ - pa.array([[1], [5]], pa.large_list(pa.int64())), - pa.array([[2.0], [12.34]], pa.large_list(pa.float32())), - pa.array([[b'hello'], [b'world']], - pa.large_list(pa.large_binary())) - ], ['int_feature', 'float_feature', 'str_feature'])), + """, + schema_pb2.Schema(), + ), + expected_record_batch=pa.RecordBatch.from_arrays( + [ + pa.array([[1], [5]], pa.large_list(pa.int64())), + pa.array([[2.0], [12.34]], pa.large_list(pa.float32())), + pa.array([[b"hello"], [b"world"]], pa.large_list(pa.large_binary())), + ], + ["int_feature", "float_feature", "str_feature"], + ), + ), dict( - testcase_name='attach_raw_records', - input_lines=['1,2.0,hello', '5,12.34,world'], - column_names=['int_feature', 'float_feature', 'str_feature'], + testcase_name="attach_raw_records", + input_lines=["1,2.0,hello", "5,12.34,world"], + column_names=["int_feature", "float_feature", "str_feature"], expected_csv_cells=[ - [b'1', b'2.0', b'hello'], - [b'5', b'12.34', b'world'], + [b"1", b"2.0", b"hello"], + [b"5", b"12.34", b"world"], ], expected_types=[ csv_decoder.ColumnType.INT, csv_decoder.ColumnType.FLOAT, csv_decoder.ColumnType.STRING, ], - expected_record_batch=pa.RecordBatch.from_arrays([ - pa.array([[1], [5]], pa.large_list(pa.int64())), - pa.array([[2.0], [12.34]], pa.large_list(pa.float32())), - pa.array([[b'hello'], [b'world']], pa.large_list( - pa.large_binary())), - pa.array([[b'1,2.0,hello'], [b'5,12.34,world']], - pa.large_list(pa.large_binary())) - ], ['int_feature', 'float_feature', 'str_feature', 'raw_records']), - raw_record_column_name='raw_records'), + expected_record_batch=pa.RecordBatch.from_arrays( + [ + pa.array([[1], [5]], pa.large_list(pa.int64())), + pa.array([[2.0], [12.34]], pa.large_list(pa.float32())), + pa.array([[b"hello"], [b"world"]], pa.large_list(pa.large_binary())), + pa.array( + [[b"1,2.0,hello"], [b"5,12.34,world"]], + pa.large_list(pa.large_binary()), + ), + ], + ["int_feature", "float_feature", "str_feature", "raw_records"], + ), + raw_record_column_name="raw_records", + ), dict( - testcase_name='with_schema_attach_raw_record', - input_lines=['1,2.0,hello', '5,12.34,world'], - column_names=['int_feature', 'float_feature', 'str_feature'], + testcase_name="with_schema_attach_raw_record", + input_lines=["1,2.0,hello", "5,12.34,world"], + column_names=["int_feature", "float_feature", "str_feature"], expected_csv_cells=[ - [b'1', b'2.0', b'hello'], - [b'5', b'12.34', b'world'], + [b"1", b"2.0", b"hello"], + [b"5", b"12.34", b"world"], ], expected_types=[ csv_decoder.ColumnType.INT, @@ -477,236 +563,293 @@ name: "str_feature" type: BYTES } - """, schema_pb2.Schema()), - expected_record_batch=pa.RecordBatch.from_arrays([ - pa.array([[1], [5]], pa.large_list(pa.int64())), - pa.array([[2.0], [12.34]], pa.large_list(pa.float32())), - pa.array([[b'hello'], [b'world']], - pa.large_list(pa.large_binary())), - pa.array([[b'1,2.0,hello'], [b'5,12.34,world']], - pa.large_list(pa.large_binary())) - ], ['int_feature', 'float_feature', 'str_feature', 'raw_records' - ]), - raw_record_column_name='raw_records'), + """, + schema_pb2.Schema(), + ), + expected_record_batch=pa.RecordBatch.from_arrays( + [ + pa.array([[1], [5]], pa.large_list(pa.int64())), + pa.array([[2.0], [12.34]], pa.large_list(pa.float32())), + pa.array([[b"hello"], [b"world"]], pa.large_list(pa.large_binary())), + pa.array( + [[b"1,2.0,hello"], [b"5,12.34,world"]], + pa.large_list(pa.large_binary()), + ), + ], + ["int_feature", "float_feature", "str_feature", "raw_records"], + ), + raw_record_column_name="raw_records", + ), dict( - testcase_name='multivalent_attach_raw_records', + testcase_name="multivalent_attach_raw_records", input_lines=['1,2 "abcdef"', '5,1 "wxxyyz"'], - column_names=['f1', 'f2'], - expected_csv_cells=[[b'1,2', b'abcdef'], [b'5,1', b'wxxyyz']], + column_names=["f1", "f2"], + expected_csv_cells=[[b"1,2", b"abcdef"], [b"5,1", b"wxxyyz"]], expected_types=[ csv_decoder.ColumnType.INT, csv_decoder.ColumnType.STRING, ], - expected_record_batch=pa.RecordBatch.from_arrays([ - pa.array([[1, 2], [5, 1]], pa.large_list(pa.int64())), - pa.array([[b'abcdef'], [b'wxxyyz']], pa.large_list( - pa.large_binary())), - pa.array([[b'1,2 "abcdef"'], [b'5,1 "wxxyyz"']], - pa.large_list(pa.large_binary())) - ], ['f1', 'f2', 'raw_records']), - delimiter=' ', - multivalent_columns=['f1'], - secondary_delimiter=',', - raw_record_column_name='raw_records'), + expected_record_batch=pa.RecordBatch.from_arrays( + [ + pa.array([[1, 2], [5, 1]], pa.large_list(pa.int64())), + pa.array([[b"abcdef"], [b"wxxyyz"]], pa.large_list(pa.large_binary())), + pa.array( + [[b'1,2 "abcdef"'], [b'5,1 "wxxyyz"']], + pa.large_list(pa.large_binary()), + ), + ], + ["f1", "f2", "raw_records"], + ), + delimiter=" ", + multivalent_columns=["f1"], + secondary_delimiter=",", + raw_record_column_name="raw_records", + ), ] class CSVDecoderTest(parameterized.TestCase): - """Tests for CSV decoder.""" - - @parameterized.named_parameters(_TEST_CASES) - def test_parse_csv_lines(self, - input_lines, - column_names, - expected_csv_cells, - expected_types, - expected_record_batch, - skip_blank_lines=False, - schema=None, - delimiter=',', - multivalent_columns=None, - secondary_delimiter=None, - raw_record_column_name=None): + """Tests for CSV decoder.""" - if self._testMethodName in [ - "test_parse_csv_lines_attach_raw_records", - "test_parse_csv_lines_consider_blank_lines", - "test_parse_csv_lines_consider_blank_lines_single_column", - "test_parse_csv_lines_empty_csv", - "test_parse_csv_lines_empty_multivalent_column", - "test_parse_csv_lines_empty_string_multivalent_column", - "test_parse_csv_lines_empty_values_multivalent_column", - "test_parse_csv_lines_float_and_string_multivalent_column", - "test_parse_csv_lines_int64_boundary", - "test_parse_csv_lines_int_and_float_multivalent_column", - "test_parse_csv_lines_int_and_string_multivalent_column", - "test_parse_csv_lines_int_and_string_multivalent_column_multiple_lines", - "test_parse_csv_lines_missing_values", - "test_parse_csv_lines_mixed_float_and_string", - "test_parse_csv_lines_mixed_int_and_float", - "test_parse_csv_lines_mixed_int_and_string", - "test_parse_csv_lines_multivalent_attach_raw_records", - "test_parse_csv_lines_negative_values", - "test_parse_csv_lines_null_column", - "test_parse_csv_lines_quotes", - "test_parse_csv_lines_simple", - "test_parse_csv_lines_size_2_vector_int_multivalent", - "test_parse_csv_lines_skip_blank_lines", - "test_parse_csv_lines_skip_blank_lines_single_column", - "test_parse_csv_lines_space_and_comma_delimiter", - "test_parse_csv_lines_space_delimiter", - "test_parse_csv_lines_tab_delimiter", - "test_parse_csv_lines_unicode", - "test_parse_csv_lines_with_schema", - "test_parse_csv_lines_with_schema_attach_raw_record", - ]: - pytest.xfail(reason="Test fails and needs to be fixed. ") + @parameterized.named_parameters(_TEST_CASES) + def test_parse_csv_lines( + self, + input_lines, + column_names, + expected_csv_cells, + expected_types, + expected_record_batch, + skip_blank_lines=False, + schema=None, + delimiter=",", + multivalent_columns=None, + secondary_delimiter=None, + raw_record_column_name=None, + ): + if self._testMethodName in [ + "test_parse_csv_lines_attach_raw_records", + "test_parse_csv_lines_consider_blank_lines", + "test_parse_csv_lines_consider_blank_lines_single_column", + "test_parse_csv_lines_empty_csv", + "test_parse_csv_lines_empty_multivalent_column", + "test_parse_csv_lines_empty_string_multivalent_column", + "test_parse_csv_lines_empty_values_multivalent_column", + "test_parse_csv_lines_float_and_string_multivalent_column", + "test_parse_csv_lines_int64_boundary", + "test_parse_csv_lines_int_and_float_multivalent_column", + "test_parse_csv_lines_int_and_string_multivalent_column", + "test_parse_csv_lines_int_and_string_multivalent_column_multiple_lines", + "test_parse_csv_lines_missing_values", + "test_parse_csv_lines_mixed_float_and_string", + "test_parse_csv_lines_mixed_int_and_float", + "test_parse_csv_lines_mixed_int_and_string", + "test_parse_csv_lines_multivalent_attach_raw_records", + "test_parse_csv_lines_negative_values", + "test_parse_csv_lines_null_column", + "test_parse_csv_lines_quotes", + "test_parse_csv_lines_simple", + "test_parse_csv_lines_size_2_vector_int_multivalent", + "test_parse_csv_lines_skip_blank_lines", + "test_parse_csv_lines_skip_blank_lines_single_column", + "test_parse_csv_lines_space_and_comma_delimiter", + "test_parse_csv_lines_space_delimiter", + "test_parse_csv_lines_tab_delimiter", + "test_parse_csv_lines_unicode", + "test_parse_csv_lines_with_schema", + "test_parse_csv_lines_with_schema_attach_raw_record", + ]: + pytest.xfail(reason="Test fails and needs to be fixed. ") - def _check_csv_cells(actual): - for i in range(len(actual)): - self.assertEqual(expected_csv_cells[i], actual[i][0]) - self.assertEqual(input_lines[i], actual[i][1]) + def _check_csv_cells(actual): + for i in range(len(actual)): + self.assertEqual(expected_csv_cells[i], actual[i][0]) + self.assertEqual(input_lines[i], actual[i][1]) - def _check_types(actual): - self.assertLen(actual, 1) - self.assertCountEqual([ - csv_decoder.ColumnInfo(n, t) - for n, t in zip(column_names, expected_types) - ], actual[0]) + def _check_types(actual): + self.assertLen(actual, 1) + self.assertCountEqual( + [ + csv_decoder.ColumnInfo(n, t) + for n, t in zip(column_names, expected_types) + ], + actual[0], + ) - def _check_record_batches(actual): - """Compares a list of pa.RecordBatch.""" - if actual: - self.assertTrue(actual[0].equals(expected_record_batch)) - else: - self.assertEqual(expected_record_batch, actual) + def _check_record_batches(actual): + """Compares a list of pa.RecordBatch.""" + if actual: + self.assertTrue(actual[0].equals(expected_record_batch)) + else: + self.assertEqual(expected_record_batch, actual) - def _check_arrow_schema(actual): - for record_batch in actual: - expected_arrow_schema = csv_decoder.GetArrowSchema( - column_names, schema, raw_record_column_name) - self.assertEqual(record_batch.schema, expected_arrow_schema) + def _check_arrow_schema(actual): + for record_batch in actual: + expected_arrow_schema = csv_decoder.GetArrowSchema( + column_names, schema, raw_record_column_name + ) + self.assertEqual(record_batch.schema, expected_arrow_schema) - with beam.Pipeline() as p: - parsed_csv_cells_and_raw_records = ( - p | beam.Create(input_lines, reshuffle=False) - | beam.ParDo(csv_decoder.ParseCSVLine(delimiter=delimiter))) - inferred_types = ( - parsed_csv_cells_and_raw_records - | beam.Keys() - | beam.CombineGlobally( - csv_decoder.ColumnTypeInferrer( - column_names, - skip_blank_lines=skip_blank_lines, - multivalent_columns=multivalent_columns, - secondary_delimiter=secondary_delimiter))) + with beam.Pipeline() as p: + parsed_csv_cells_and_raw_records = ( + p + | beam.Create(input_lines, reshuffle=False) + | beam.ParDo(csv_decoder.ParseCSVLine(delimiter=delimiter)) + ) + inferred_types = ( + parsed_csv_cells_and_raw_records + | beam.Keys() + | beam.CombineGlobally( + csv_decoder.ColumnTypeInferrer( + column_names, + skip_blank_lines=skip_blank_lines, + multivalent_columns=multivalent_columns, + secondary_delimiter=secondary_delimiter, + ) + ) + ) - beam_test_util.assert_that( - parsed_csv_cells_and_raw_records, - _check_csv_cells, - label='check_parsed_csv_cells') - beam_test_util.assert_that( - inferred_types, _check_types, label='check_types') + beam_test_util.assert_that( + parsed_csv_cells_and_raw_records, + _check_csv_cells, + label="check_parsed_csv_cells", + ) + beam_test_util.assert_that( + inferred_types, _check_types, label="check_types" + ) - record_batches = ( - parsed_csv_cells_and_raw_records - | beam.BatchElements(min_batch_size=1000) | beam.ParDo( - csv_decoder.BatchedCSVRowsToRecordBatch( - skip_blank_lines=skip_blank_lines, - multivalent_columns=multivalent_columns, - secondary_delimiter=secondary_delimiter, - raw_record_column_name=raw_record_column_name), - beam.pvalue.AsSingleton(inferred_types))) - beam_test_util.assert_that( - record_batches, _check_record_batches, label='check_record_batches') - if schema: - beam_test_util.assert_that( - record_batches, _check_arrow_schema, label='check_arrow_schema') + record_batches = ( + parsed_csv_cells_and_raw_records + | beam.BatchElements(min_batch_size=1000) + | beam.ParDo( + csv_decoder.BatchedCSVRowsToRecordBatch( + skip_blank_lines=skip_blank_lines, + multivalent_columns=multivalent_columns, + secondary_delimiter=secondary_delimiter, + raw_record_column_name=raw_record_column_name, + ), + beam.pvalue.AsSingleton(inferred_types), + ) + ) + beam_test_util.assert_that( + record_batches, _check_record_batches, label="check_record_batches" + ) + if schema: + beam_test_util.assert_that( + record_batches, _check_arrow_schema, label="check_arrow_schema" + ) - # Testing CSVToRecordBatch - with beam.Pipeline() as p: - record_batches = ( - p | 'CreatingPColl' >> beam.Create(input_lines, reshuffle=False) - | 'CSVToRecordBatch' >> csv_decoder.CSVToRecordBatch( - column_names=column_names, - delimiter=delimiter, - skip_blank_lines=skip_blank_lines, - desired_batch_size=1000, - schema=schema, - multivalent_columns=multivalent_columns, - secondary_delimiter=secondary_delimiter, - raw_record_column_name=raw_record_column_name)) - beam_test_util.assert_that( - record_batches, _check_record_batches, label='check_record_batches') + # Testing CSVToRecordBatch + with beam.Pipeline() as p: + record_batches = ( + p + | "CreatingPColl" >> beam.Create(input_lines, reshuffle=False) + | "CSVToRecordBatch" + >> csv_decoder.CSVToRecordBatch( + column_names=column_names, + delimiter=delimiter, + skip_blank_lines=skip_blank_lines, + desired_batch_size=1000, + schema=schema, + multivalent_columns=multivalent_columns, + secondary_delimiter=secondary_delimiter, + raw_record_column_name=raw_record_column_name, + ) + ) + beam_test_util.assert_that( + record_batches, _check_record_batches, label="check_record_batches" + ) - @pytest.mark.xfail(run=False, reason="This test fails and needs to be fixed.") - def test_csv_to_recordbatch_schema_features_subset_of_column_names(self): - input_lines = ['1,2.0,hello', '5,12.34,world'] - column_names = ['int_feature', 'float_feature', 'str_feature'] - schema = text_format.Parse("""feature { name: "int_feature" type: INT }""", - schema_pb2.Schema()) - self.assertEqual( - csv_decoder.GetArrowSchema(column_names, schema), - pa.schema([pa.field('int_feature', pa.large_list(pa.int64()))])) + @pytest.mark.xfail(run=False, reason="This test fails and needs to be fixed.") + def test_csv_to_recordbatch_schema_features_subset_of_column_names(self): + input_lines = ["1,2.0,hello", "5,12.34,world"] + column_names = ["int_feature", "float_feature", "str_feature"] + schema = text_format.Parse( + """feature { name: "int_feature" type: INT }""", schema_pb2.Schema() + ) + self.assertEqual( + csv_decoder.GetArrowSchema(column_names, schema), + pa.schema([pa.field("int_feature", pa.large_list(pa.int64()))]), + ) - def _check_record_batches(record_batches): - self.assertLen(record_batches, 1) - self.assertTrue(record_batches[0].equals( - pa.RecordBatch.from_arrays( - [pa.array([[1], [5]], pa.large_list(pa.int64()))], - ['int_feature']))) + def _check_record_batches(record_batches): + self.assertLen(record_batches, 1) + self.assertTrue( + record_batches[0].equals( + pa.RecordBatch.from_arrays( + [pa.array([[1], [5]], pa.large_list(pa.int64()))], + ["int_feature"], + ) + ) + ) - with beam.Pipeline() as p: - record_batches = ( - p | 'CreatingPColl' >> beam.Create(input_lines, reshuffle=False) - | 'CSVToRecordBatch' >> csv_decoder.CSVToRecordBatch( - column_names=column_names, - delimiter=',', - desired_batch_size=1000, - schema=schema)) - beam_test_util.assert_that( - record_batches, _check_record_batches, label='check_record_batches') + with beam.Pipeline() as p: + record_batches = ( + p + | "CreatingPColl" >> beam.Create(input_lines, reshuffle=False) + | "CSVToRecordBatch" + >> csv_decoder.CSVToRecordBatch( + column_names=column_names, + delimiter=",", + desired_batch_size=1000, + schema=schema, + ) + ) + beam_test_util.assert_that( + record_batches, _check_record_batches, label="check_record_batches" + ) - def test_invalid_row(self): - input_lines = ['1,2.0,hello', '5,12.34'] - column_names = ['int_feature', 'float_feature', 'str_feature'] - with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises - ValueError, '.*Columns do not match specified csv headers.*'): - with beam.Pipeline() as p: - result = ( - p | beam.Create(input_lines, reshuffle=False) - | beam.ParDo(csv_decoder.ParseCSVLine(delimiter=',')) - | beam.Keys() - | beam.CombineGlobally( - csv_decoder.ColumnTypeInferrer( - column_names, skip_blank_lines=False))) - beam_test_util.assert_that(result, lambda _: None) + def test_invalid_row(self): + input_lines = ["1,2.0,hello", "5,12.34"] + column_names = ["int_feature", "float_feature", "str_feature"] + with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises + ValueError, ".*Columns do not match specified csv headers.*" + ): + with beam.Pipeline() as p: + result = ( + p + | beam.Create(input_lines, reshuffle=False) + | beam.ParDo(csv_decoder.ParseCSVLine(delimiter=",")) + | beam.Keys() + | beam.CombineGlobally( + csv_decoder.ColumnTypeInferrer( + column_names, skip_blank_lines=False + ) + ) + ) + beam_test_util.assert_that(result, lambda _: None) - def test_invalid_schema_type(self): - input_lines = ['1'] - column_names = ['f1'] - schema = text_format.Parse( - """ + def test_invalid_schema_type(self): + input_lines = ["1"] + column_names = ["f1"] + schema = text_format.Parse( + """ feature { name: "struct_feature" type: STRUCT } - """, schema_pb2.Schema()) - with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises - ValueError, '.*Schema contains invalid type: STRUCT.*'): - with beam.Pipeline() as p: - result = ( - p | beam.Create(input_lines, reshuffle=False) - | 'CSVToRecordBatch' >> csv_decoder.CSVToRecordBatch( - column_names=column_names, - schema=schema, - desired_batch_size=1000)) - beam_test_util.assert_that(result, lambda _: None) + """, + schema_pb2.Schema(), + ) + with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises + ValueError, ".*Schema contains invalid type: STRUCT.*" + ): + with beam.Pipeline() as p: + result = ( + p + | beam.Create(input_lines, reshuffle=False) + | "CSVToRecordBatch" + >> csv_decoder.CSVToRecordBatch( + column_names=column_names, + schema=schema, + desired_batch_size=1000, + ) + ) + beam_test_util.assert_that(result, lambda _: None) - def test_invalid_raw_record_column_name(self): - input_lines = ['1,2.0,hello', '5,12.34'] - schema = text_format.Parse( - """ + def test_invalid_raw_record_column_name(self): + input_lines = ["1,2.0,hello", "5,12.34"] + schema = text_format.Parse( + """ feature { name: "int_feature" type: INT @@ -719,26 +862,35 @@ def test_invalid_raw_record_column_name(self): name: "str_feature" type: BYTES } - """, schema_pb2.Schema()) - column_names = ['int_feature', 'float_feature', 'str_feature'] - with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises - ValueError, 'raw_record_column_name.* is already an existing column.*'): - with beam.Pipeline() as p: - result = ( - p | beam.Create(input_lines, reshuffle=False) - | 'CSVToRecordBatch' >> csv_decoder.CSVToRecordBatch( - column_names=column_names, - desired_batch_size=1000, - raw_record_column_name='int_feature')) - beam_test_util.assert_that(result, lambda _: None) - with self.assertRaisesRegex( - ValueError, 'raw_record_column_name.* is already an existing column.*'): - csv_decoder.GetArrowSchema( - column_names, schema, raw_record_column_name='int_feature') + """, + schema_pb2.Schema(), + ) + column_names = ["int_feature", "float_feature", "str_feature"] + with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises + ValueError, "raw_record_column_name.* is already an existing column.*" + ): + with beam.Pipeline() as p: + result = ( + p + | beam.Create(input_lines, reshuffle=False) + | "CSVToRecordBatch" + >> csv_decoder.CSVToRecordBatch( + column_names=column_names, + desired_batch_size=1000, + raw_record_column_name="int_feature", + ) + ) + beam_test_util.assert_that(result, lambda _: None) + with self.assertRaisesRegex( + ValueError, "raw_record_column_name.* is already an existing column.*" + ): + csv_decoder.GetArrowSchema( + column_names, schema, raw_record_column_name="int_feature" + ) - def test_get_arrow_schema_schema_feature_not_subset_of_column_names(self): - schema = text_format.Parse( - """ + def test_get_arrow_schema_schema_feature_not_subset_of_column_names(self): + schema = text_format.Parse( + """ feature { name: "f1" type: INT @@ -747,12 +899,15 @@ def test_get_arrow_schema_schema_feature_not_subset_of_column_names(self): name: "f2" type: INT } - """, schema_pb2.Schema()) - column_names = ['f1'] - with self.assertRaisesRegex( - ValueError, 'Schema features are not a subset of column names'): - csv_decoder.GetArrowSchema(column_names, schema) + """, + schema_pb2.Schema(), + ) + column_names = ["f1"] + with self.assertRaisesRegex( + ValueError, "Schema features are not a subset of column names" + ): + csv_decoder.GetArrowSchema(column_names, schema) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tfx_bsl/coders/example_coder.py b/tfx_bsl/coders/example_coder.py index 5e41ad92..41973efa 100644 --- a/tfx_bsl/coders/example_coder.py +++ b/tfx_bsl/coders/example_coder.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Example coders.""" -from typing import List, Optional, Type, Tuple -import pyarrow as pa +from typing import List, Optional, Tuple, Type +import pyarrow as pa from tensorflow_metadata.proto.v0 import schema_pb2 # pylint: disable=unused-import @@ -23,85 +23,93 @@ # pylint: disable=g-import-not-at-top # See b/148667210 for why the ImportError is ignored. try: - from tfx_bsl.cc.tfx_bsl_extension.coders import ExamplesToRecordBatchDecoder as ExamplesToRecordBatchDecoderCpp - from tfx_bsl.cc.tfx_bsl_extension.coders import ExampleToNumpyDict - from tfx_bsl.cc.tfx_bsl_extension.coders import RecordBatchToExamplesEncoder as RecordBatchToExamplesEncoderCpp + from tfx_bsl.cc.tfx_bsl_extension.coders import ( + ExamplesToRecordBatchDecoder as ExamplesToRecordBatchDecoderCpp, + ) + from tfx_bsl.cc.tfx_bsl_extension.coders import ExampleToNumpyDict + from tfx_bsl.cc.tfx_bsl_extension.coders import ( + RecordBatchToExamplesEncoder as RecordBatchToExamplesEncoderCpp, + ) except ImportError: - import sys - sys.stderr.write("Error importing tfx_bsl_extension.coders. " - "Some tfx_bsl functionalities are not available") + import sys + + sys.stderr.write( + "Error importing tfx_bsl_extension.coders. " + "Some tfx_bsl functionalities are not available" + ) # pylint: enable=g-import-not-at-top # pytype: enable=import-error # pylint: enable=unused-import class RecordBatchToExamplesEncoder: - """Encodes `pa.RecordBatch` as a list of serialized `tf.Example`s. + """Encodes `pa.RecordBatch` as a list of serialized `tf.Example`s. - Requires TFMD schema only if RecordBatches contains nested lists with - depth > 2 that represent TensorFlow's RaggedFeatures. - """ + Requires TFMD schema only if RecordBatches contains nested lists with + depth > 2 that represent TensorFlow's RaggedFeatures. + """ - __slots__ = ["_schema", "_coder"] + __slots__ = ["_schema", "_coder"] - def __init__(self, schema: Optional[schema_pb2.Schema] = None): - self._schema = schema - self._coder = RecordBatchToExamplesEncoderCpp( - None if schema is None else schema.SerializeToString() - ) + def __init__(self, schema: Optional[schema_pb2.Schema] = None): + self._schema = schema + self._coder = RecordBatchToExamplesEncoderCpp( + None if schema is None else schema.SerializeToString() + ) - def __reduce__( - self, - ) -> Tuple[ - Type["RecordBatchToExamplesEncoder"], Tuple[Optional[schema_pb2.Schema]] - ]: - return (self.__class__, (self._schema,)) + def __reduce__( + self, + ) -> Tuple[ + Type["RecordBatchToExamplesEncoder"], Tuple[Optional[schema_pb2.Schema]] + ]: + return (self.__class__, (self._schema,)) - def encode(self, record_batch: pa.RecordBatch) -> List[bytes]: # pylint: disable=invalid-name - return self._coder.Encode(record_batch) + def encode(self, record_batch: pa.RecordBatch) -> List[bytes]: # pylint: disable=invalid-name + return self._coder.Encode(record_batch) # TODO(b/271883540) Deprecate this. def RecordBatchToExamples(record_batch: pa.RecordBatch) -> List[bytes]: - """Stateless version of the encoder above.""" - return RecordBatchToExamplesEncoder().encode(record_batch) + """Stateless version of the encoder above.""" + return RecordBatchToExamplesEncoder().encode(record_batch) class ExamplesToRecordBatchDecoder: - """Decodes a list of serialized `tf.Example`s into `pa.RecordBatch`. + """Decodes a list of serialized `tf.Example`s into `pa.RecordBatch`. - If a schema is provided then the record batch will contain only the fields - from the schema, in the same order as the Schema. The data type of the - schema to determine the field types, with INT, BYTES and FLOAT fields in the - schema corresponding to the Arrow data types large_list[int64], - large_list[large_binary] and large_list[float32]. + If a schema is provided then the record batch will contain only the fields + from the schema, in the same order as the Schema. The data type of the + schema to determine the field types, with INT, BYTES and FLOAT fields in the + schema corresponding to the Arrow data types large_list[int64], + large_list[large_binary] and large_list[float32]. - If a schema is not provided then the data type will be inferred, and chosen - from list_type[int64], list_type[binary_type] and list_type[float32]. In the - case where no data type can be inferred the arrow null type will be inferred. + If a schema is not provided then the data type will be inferred, and chosen + from list_type[int64], list_type[binary_type] and list_type[float32]. In the + case where no data type can be inferred the arrow null type will be inferred. - This class wraps pybind11 class `ExamplesToRecordBatchDecoder` to make the - class and its member functions picklable. - """ + This class wraps pybind11 class `ExamplesToRecordBatchDecoder` to make the + class and its member functions picklable. + """ - __slots__ = ["_schema", "_coder"] + __slots__ = ["_schema", "_coder"] - def __init__(self, serialized_schema: Optional[bytes] = None): - """Initializes ExamplesToRecordBatchDecoder. + def __init__(self, serialized_schema: Optional[bytes] = None): + """Initializes ExamplesToRecordBatchDecoder. - Args: - serialized_schema: A serialized TFMD schema. - """ - self._schema = serialized_schema - self._coder = ExamplesToRecordBatchDecoderCpp(serialized_schema) + Args: + ---- + serialized_schema: A serialized TFMD schema. + """ + self._schema = serialized_schema + self._coder = ExamplesToRecordBatchDecoderCpp(serialized_schema) - def __reduce__( - self - ) -> Tuple[Type["ExamplesToRecordBatchDecoder"], Tuple[Optional[bytes]]]: - return (self.__class__, (self._schema,)) + def __reduce__( + self, + ) -> Tuple[Type["ExamplesToRecordBatchDecoder"], Tuple[Optional[bytes]]]: + return (self.__class__, (self._schema,)) - def DecodeBatch(self, examples: List[bytes]) -> pa.RecordBatch: - return self._coder.DecodeBatch(examples) + def DecodeBatch(self, examples: List[bytes]) -> pa.RecordBatch: + return self._coder.DecodeBatch(examples) - def ArrowSchema(self) -> pa.Schema: - return self._coder.ArrowSchema() + def ArrowSchema(self) -> pa.Schema: + return self._coder.ArrowSchema() diff --git a/tfx_bsl/coders/example_coder_test.py b/tfx_bsl/coders/example_coder_test.py index 8c874fd8..c77c2b43 100644 --- a/tfx_bsl/coders/example_coder_test.py +++ b/tfx_bsl/coders/example_coder_test.py @@ -12,17 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for tfx_bsl.coders.example_coder.""" + import pickle + import pyarrow as pa import tensorflow as tf -from tfx_bsl.coders import example_coder -from tfx_bsl.tfxio import tensor_representation_util - +from absl.testing import absltest, parameterized from google.protobuf import text_format -from absl.testing import absltest -from absl.testing import parameterized from tensorflow_metadata.proto.v0 import schema_pb2 +from tfx_bsl.coders import example_coder +from tfx_bsl.tfxio import tensor_representation_util + _TEST_EXAMPLES = [ """ features { @@ -64,16 +65,22 @@ testcase_name="without_schema_simple", schema_text_proto=None, examples_text_proto=_TEST_EXAMPLES, - expected=pa.RecordBatch.from_arrays([ - pa.array([None, None, [1.0], None], - type=pa.large_list(pa.float32())), - pa.array([None, None, None, None], type=pa.null()), - pa.array([[b"a", b"b"], None, None, []], - type=pa.large_list(pa.large_binary())), - pa.array([[1.0, 2.0], None, None, []], - type=pa.large_list(pa.float32())), - pa.array([[4, 5], None, None, []], type=pa.large_list(pa.int64())) - ], ["v", "w", "x", "y", "z"])), + expected=pa.RecordBatch.from_arrays( + [ + pa.array([None, None, [1.0], None], type=pa.large_list(pa.float32())), + pa.array([None, None, None, None], type=pa.null()), + pa.array( + [[b"a", b"b"], None, None, []], + type=pa.large_list(pa.large_binary()), + ), + pa.array( + [[1.0, 2.0], None, None, []], type=pa.large_list(pa.float32()) + ), + pa.array([[4, 5], None, None, []], type=pa.large_list(pa.int64())), + ], + ["v", "w", "x", "y", "z"], + ), + ), dict( testcase_name="with_schema_simple", schema_text_proto=""" @@ -90,13 +97,20 @@ type: INT }""", examples_text_proto=_TEST_EXAMPLES, - expected=pa.RecordBatch.from_arrays([ - pa.array([[b"a", b"b"], None, None, []], - type=pa.large_list(pa.large_binary())), - pa.array([[1.0, 2.0], None, None, []], - type=pa.large_list(pa.float32())), - pa.array([[4, 5], None, None, []], type=pa.large_list(pa.int64())) - ], ["x", "y", "z"])), + expected=pa.RecordBatch.from_arrays( + [ + pa.array( + [[b"a", b"b"], None, None, []], + type=pa.large_list(pa.large_binary()), + ), + pa.array( + [[1.0, 2.0], None, None, []], type=pa.large_list(pa.float32()) + ), + pa.array([[4, 5], None, None, []], type=pa.large_list(pa.int64())), + ], + ["x", "y", "z"], + ), + ), dict( testcase_name="ignore_features_not_in_schema", schema_text_proto=""" @@ -110,12 +124,19 @@ } """, examples_text_proto=_TEST_EXAMPLES, - expected=pa.RecordBatch.from_arrays([ - pa.array([[b"a", b"b"], None, None, []], - type=pa.large_list(pa.large_binary())), - pa.array([[1.0, 2.0], None, None, []], - type=pa.large_list(pa.float32())), - ], ["x", "y"])), + expected=pa.RecordBatch.from_arrays( + [ + pa.array( + [[b"a", b"b"], None, None, []], + type=pa.large_list(pa.large_binary()), + ), + pa.array( + [[1.0, 2.0], None, None, []], type=pa.large_list(pa.float32()) + ), + ], + ["x", "y"], + ), + ), dict( testcase_name="build_nulls_for_unseen_feature", schema_text_proto=""" @@ -125,10 +146,15 @@ } """, examples_text_proto=_TEST_EXAMPLES, - expected=pa.RecordBatch.from_arrays([ - pa.array([None, None, None, None], - type=pa.large_list(pa.large_binary())), - ], ["a"])), + expected=pa.RecordBatch.from_arrays( + [ + pa.array( + [None, None, None, None], type=pa.large_list(pa.large_binary()) + ), + ], + ["a"], + ), + ), dict( testcase_name="build_null_for_unset_kind", schema_text_proto=""" @@ -142,9 +168,13 @@ features { feature { key: "a" value { } } } """ ], - expected=pa.RecordBatch.from_arrays([ - pa.array([None], type=pa.large_list(pa.large_binary())), - ], ["a"])), + expected=pa.RecordBatch.from_arrays( + [ + pa.array([None], type=pa.large_list(pa.large_binary())), + ], + ["a"], + ), + ), dict( testcase_name="duplicate_feature_names_in_schema", schema_text_proto=""" @@ -163,9 +193,13 @@ features { feature { key: "a" value { } } } """ ], - expected=pa.RecordBatch.from_arrays([ - pa.array([None], type=pa.large_list(pa.large_binary())), - ], ["a"])), + expected=pa.RecordBatch.from_arrays( + [ + pa.array([None], type=pa.large_list(pa.large_binary())), + ], + ["a"], + ), + ), ] _INVALID_INPUT_CASES = [ @@ -185,7 +219,8 @@ error=RuntimeError, error_msg_regex=( "Feature had wrong type, expected bytes_list, found float_list " - "for feature \"a\""), + 'for feature "a"' + ), ), dict( testcase_name="no_schema_mixed_type", @@ -193,94 +228,99 @@ examples_text_proto=[ """ features { feature { key: "a" value { float_list { value: [] } } } } - """, """ + """, + """ features { feature { key: "a" value { int64_list { value: [] } } } } - """ + """, ], error=RuntimeError, error_msg_regex=( "Feature had wrong type, expected float_list, found int64_list" - " for feature \"a\""), + ' for feature "a"' + ), ), ] class ExamplesToRecordBatchDecoderTest(parameterized.TestCase): - - @parameterized.named_parameters(*_DECODE_CASES) - def test_decode(self, schema_text_proto, examples_text_proto, expected): - serialized_examples = [ - text_format.Parse(pbtxt, tf.train.Example()).SerializeToString() - for pbtxt in examples_text_proto - ] - serialized_schema = None - if schema_text_proto is not None: - serialized_schema = text_format.Parse( - schema_text_proto, schema_pb2.Schema()).SerializeToString() - - coder = example_coder.ExamplesToRecordBatchDecoder(serialized_schema) - - result = coder.DecodeBatch(serialized_examples) - self.assertIsInstance(result, pa.RecordBatch) - self.assertTrue( - result.equals(expected), - ( - f"\nactual: {result.to_pydict()}\nactual schema:" - f" {result.schema}\nexpected:{expected.to_pydict()}\nexpected" - f" schema: {expected.schema}\nencoded: {serialized_examples}" - ), - ) - if serialized_schema: - self.assertTrue(expected.schema.equals(coder.ArrowSchema())) - - # Verify that coder and DecodeBatch can be properly pickled and unpickled. - # This is necessary for using them in beam.Map. - coder = pickle.loads(pickle.dumps(coder)) - decode = pickle.loads(pickle.dumps(coder.DecodeBatch)) - result = decode(serialized_examples) - self.assertIsInstance(result, pa.RecordBatch) - self.assertTrue( - result.equals(expected), - "actual: {}\n expected:{}".format(result, expected)) - if serialized_schema: - self.assertTrue(expected.schema.equals(coder.ArrowSchema())) - - @parameterized.named_parameters(*_INVALID_INPUT_CASES) - def test_invalid_input(self, schema_text_proto, examples_text_proto, error, - error_msg_regex): - serialized_examples = [ - text_format.Parse(pbtxt, tf.train.Example()).SerializeToString() - for pbtxt in examples_text_proto - ] - serialized_schema = None - if schema_text_proto is not None: - serialized_schema = text_format.Parse( - schema_text_proto, schema_pb2.Schema()).SerializeToString() - - if serialized_schema: - coder = example_coder.ExamplesToRecordBatchDecoder(serialized_schema) - else: - coder = example_coder.ExamplesToRecordBatchDecoder() - - with self.assertRaisesRegex(error, error_msg_regex): - coder.DecodeBatch(serialized_examples) - - def test_arrow_schema_not_available_if_tfmd_schema_not_available(self): - coder = example_coder.ExamplesToRecordBatchDecoder() - with self.assertRaisesRegex(RuntimeError, "Unable to get the arrow schema"): - _ = coder.ArrowSchema() - - def test_invalid_feature_type(self): - serialized_schema = text_format.Parse( - """ + @parameterized.named_parameters(*_DECODE_CASES) + def test_decode(self, schema_text_proto, examples_text_proto, expected): + serialized_examples = [ + text_format.Parse(pbtxt, tf.train.Example()).SerializeToString() + for pbtxt in examples_text_proto + ] + serialized_schema = None + if schema_text_proto is not None: + serialized_schema = text_format.Parse( + schema_text_proto, schema_pb2.Schema() + ).SerializeToString() + + coder = example_coder.ExamplesToRecordBatchDecoder(serialized_schema) + + result = coder.DecodeBatch(serialized_examples) + self.assertIsInstance(result, pa.RecordBatch) + self.assertTrue( + result.equals(expected), + ( + f"\nactual: {result.to_pydict()}\nactual schema:" + f" {result.schema}\nexpected:{expected.to_pydict()}\nexpected" + f" schema: {expected.schema}\nencoded: {serialized_examples}" + ), + ) + if serialized_schema: + self.assertTrue(expected.schema.equals(coder.ArrowSchema())) + + # Verify that coder and DecodeBatch can be properly pickled and unpickled. + # This is necessary for using them in beam.Map. + coder = pickle.loads(pickle.dumps(coder)) + decode = pickle.loads(pickle.dumps(coder.DecodeBatch)) + result = decode(serialized_examples) + self.assertIsInstance(result, pa.RecordBatch) + self.assertTrue( + result.equals(expected), f"actual: {result}\n expected:{expected}" + ) + if serialized_schema: + self.assertTrue(expected.schema.equals(coder.ArrowSchema())) + + @parameterized.named_parameters(*_INVALID_INPUT_CASES) + def test_invalid_input( + self, schema_text_proto, examples_text_proto, error, error_msg_regex + ): + serialized_examples = [ + text_format.Parse(pbtxt, tf.train.Example()).SerializeToString() + for pbtxt in examples_text_proto + ] + serialized_schema = None + if schema_text_proto is not None: + serialized_schema = text_format.Parse( + schema_text_proto, schema_pb2.Schema() + ).SerializeToString() + + if serialized_schema: + coder = example_coder.ExamplesToRecordBatchDecoder(serialized_schema) + else: + coder = example_coder.ExamplesToRecordBatchDecoder() + + with self.assertRaisesRegex(error, error_msg_regex): + coder.DecodeBatch(serialized_examples) + + def test_arrow_schema_not_available_if_tfmd_schema_not_available(self): + coder = example_coder.ExamplesToRecordBatchDecoder() + with self.assertRaisesRegex(RuntimeError, "Unable to get the arrow schema"): + _ = coder.ArrowSchema() + + def test_invalid_feature_type(self): + serialized_schema = text_format.Parse( + """ feature { name: "a" type: STRUCT } - """, schema_pb2.Schema()).SerializeToString() - with self.assertRaisesRegex(RuntimeError, - "Bad field type for feature: a.*"): - _ = example_coder.ExamplesToRecordBatchDecoder(serialized_schema) + """, + schema_pb2.Schema(), + ).SerializeToString() + with self.assertRaisesRegex(RuntimeError, "Bad field type for feature: a.*"): + _ = example_coder.ExamplesToRecordBatchDecoder(serialized_schema) _ENCODE_TEST_EXAMPLES = [ @@ -316,67 +356,83 @@ def test_invalid_feature_type(self): _ENCODE_CASES = [ dict( - record_batch=pa.RecordBatch.from_arrays([ - pa.array([[b"a", b"b"], None, None, []], - type=pa.large_list(pa.large_binary())), - pa.array([[1.0, 2.0], None, None, []], type=pa.list_(pa.float32())), - pa.array([[4, 5], None, None, []], type=pa.large_list(pa.int64())) - ], ["x", "y", "z"]), - examples_text_proto=_ENCODE_TEST_EXAMPLES), + record_batch=pa.RecordBatch.from_arrays( + [ + pa.array( + [[b"a", b"b"], None, None, []], + type=pa.large_list(pa.large_binary()), + ), + pa.array([[1.0, 2.0], None, None, []], type=pa.list_(pa.float32())), + pa.array([[4, 5], None, None, []], type=pa.large_list(pa.int64())), + ], + ["x", "y", "z"], + ), + examples_text_proto=_ENCODE_TEST_EXAMPLES, + ), dict( - record_batch=pa.RecordBatch.from_arrays([ - pa.array([None, None, [b"a", b"b"]], - type=pa.large_list(pa.binary())), - pa.array([None, None, [1.0, 2.0]], type=pa.large_list( - pa.float32())), - pa.array([None, None, [4, 5]], type=pa.list_(pa.int64())) - ], ["x", "y", "z"]), - examples_text_proto=list(reversed(_ENCODE_TEST_EXAMPLES[:-1]))), + record_batch=pa.RecordBatch.from_arrays( + [ + pa.array([None, None, [b"a", b"b"]], type=pa.large_list(pa.binary())), + pa.array([None, None, [1.0, 2.0]], type=pa.large_list(pa.float32())), + pa.array([None, None, [4, 5]], type=pa.list_(pa.int64())), + ], + ["x", "y", "z"], + ), + examples_text_proto=list(reversed(_ENCODE_TEST_EXAMPLES[:-1])), + ), ] _INVALID_ENCODE_TYPE_CASES = [ dict( record_batch=pa.RecordBatch.from_arrays([pa.array([1, 2, 3])], ["a"]), error=RuntimeError, - error_msg_regex="Expected ListArray or LargeListArray"), + error_msg_regex="Expected ListArray or LargeListArray", + ), dict( record_batch=pa.RecordBatch.from_arrays( - [pa.array([[True], [False]], type=pa.large_list(pa.bool_()))], - ["a"]), + [pa.array([[True], [False]], type=pa.large_list(pa.bool_()))], ["a"] + ), error=RuntimeError, - error_msg_regex="Bad field type"), + error_msg_regex="Bad field type", + ), dict( - record_batch=pa.RecordBatch.from_arrays([ - pa.array([[b"a", b"b"], None, None, []], - type=pa.large_list(pa.large_binary())), - pa.array([[1.0, 2.0], None, None, []], - type=pa.large_list(pa.float32())), - ], ["x", "x"]), + record_batch=pa.RecordBatch.from_arrays( + [ + pa.array( + [[b"a", b"b"], None, None, []], + type=pa.large_list(pa.large_binary()), + ), + pa.array( + [[1.0, 2.0], None, None, []], type=pa.large_list(pa.float32()) + ), + ], + ["x", "x"], + ), error=RuntimeError, - error_msg_regex="RecordBatch contains duplicate column names") + error_msg_regex="RecordBatch contains duplicate column names", + ), ] class RecordBatchToExamplesTest(parameterized.TestCase): - - @parameterized.parameters(*_ENCODE_CASES) - def test_encode(self, record_batch, examples_text_proto): - expected_examples = [ - text_format.Parse(pbtxt, tf.train.Example()) - for pbtxt in examples_text_proto - ] - coder = example_coder.RecordBatchToExamplesEncoder() - actual_examples = [ - tf.train.Example.FromString(encoded) - for encoded in coder.encode(record_batch) - ] - - self.assertEqual(actual_examples, expected_examples) - - @parameterized.parameters(*_INVALID_ENCODE_TYPE_CASES) - def test_invalid_input(self, record_batch, error, error_msg_regex): - with self.assertRaisesRegex(error, error_msg_regex): - example_coder.RecordBatchToExamplesEncoder().encode(record_batch) + @parameterized.parameters(*_ENCODE_CASES) + def test_encode(self, record_batch, examples_text_proto): + expected_examples = [ + text_format.Parse(pbtxt, tf.train.Example()) + for pbtxt in examples_text_proto + ] + coder = example_coder.RecordBatchToExamplesEncoder() + actual_examples = [ + tf.train.Example.FromString(encoded) + for encoded in coder.encode(record_batch) + ] + + self.assertEqual(actual_examples, expected_examples) + + @parameterized.parameters(*_INVALID_ENCODE_TYPE_CASES) + def test_invalid_input(self, record_batch, error, error_msg_regex): + with self.assertRaisesRegex(error, error_msg_regex): + example_coder.RecordBatchToExamplesEncoder().encode(record_batch) _ENCODE_NESTED_TEST_EXAMPLES = [ @@ -581,9 +637,7 @@ def test_invalid_input(self, record_batch, error, error_msg_regex): ["x", "y"], ), error=RuntimeError, - error_msg_regex=( - "conflicts with another source column in the same batch." - ), + error_msg_regex=("conflicts with another source column in the same batch."), schema=text_format.Parse( """ tensor_representation_group { @@ -668,74 +722,68 @@ def test_invalid_input(self, record_batch, error, error_msg_regex): ] -class RecordBatchToExamplesEncoderTest( - parameterized.TestCase, tf.test.TestCase -): - - @parameterized.parameters(*(_ENCODE_CASES + _ENCODE_NESTED_CASES)) - def test_encode(self, record_batch, examples_text_proto, schema=None): - expected_examples = [ - text_format.Parse(pbtxt, tf.train.Example()) - for pbtxt in examples_text_proto - ] - coder = example_coder.RecordBatchToExamplesEncoder(schema) - # Verify that coder can be properly pickled and unpickled. - coder = pickle.loads(pickle.dumps(coder)) - encoded = coder.encode(record_batch) - self.assertLen(encoded, len(expected_examples)) - for idx, (expected, actual) in enumerate(zip(expected_examples, encoded)): - self.assertProtoEquals( - expected, - tf.train.Example.FromString(actual), - msg=f" at position {idx}", - ) - - @parameterized.parameters(*(_INVALID_ENCODE_TYPE_CASES + - _INVALID_ENCODE_NESTED_TYPE_CASES)) - def test_invalid_input(self, - record_batch, - error, - error_msg_regex, - schema=None): - schema = (schema or schema_pb2.Schema()) - coder = example_coder.RecordBatchToExamplesEncoder(schema) - with self.assertRaisesRegex(error, error_msg_regex): - coder.encode(record_batch) - - def test_encode_is_consistent_with_parse_example(self): - coder = example_coder.RecordBatchToExamplesEncoder(_ENCODE_NESTED_SCHEMA) - encoded = tf.constant(coder.encode(_ENCODE_NESTED_RECORD_BATCH)) - tensor_representations = ( - tensor_representation_util.GetTensorRepresentationsFromSchema( - _ENCODE_NESTED_SCHEMA - ) +class RecordBatchToExamplesEncoderTest(parameterized.TestCase, tf.test.TestCase): + @parameterized.parameters(*(_ENCODE_CASES + _ENCODE_NESTED_CASES)) + def test_encode(self, record_batch, examples_text_proto, schema=None): + expected_examples = [ + text_format.Parse(pbtxt, tf.train.Example()) + for pbtxt in examples_text_proto + ] + coder = example_coder.RecordBatchToExamplesEncoder(schema) + # Verify that coder can be properly pickled and unpickled. + coder = pickle.loads(pickle.dumps(coder)) + encoded = coder.encode(record_batch) + self.assertLen(encoded, len(expected_examples)) + for idx, (expected, actual) in enumerate(zip(expected_examples, encoded)): + self.assertProtoEquals( + expected, + tf.train.Example.FromString(actual), + msg=f" at position {idx}", + ) + + @parameterized.parameters( + *(_INVALID_ENCODE_TYPE_CASES + _INVALID_ENCODE_NESTED_TYPE_CASES) ) - dtypes = { - "x": schema_pb2.FeatureType.BYTES, - "y": schema_pb2.FeatureType.FLOAT, - "z": schema_pb2.FeatureType.INT, - } - feature_spec = { - name: tensor_representation_util.CreateTfExampleParserConfig( - representation, dtypes[name] + def test_invalid_input(self, record_batch, error, error_msg_regex, schema=None): + schema = schema or schema_pb2.Schema() + coder = example_coder.RecordBatchToExamplesEncoder(schema) + with self.assertRaisesRegex(error, error_msg_regex): + coder.encode(record_batch) + + def test_encode_is_consistent_with_parse_example(self): + coder = example_coder.RecordBatchToExamplesEncoder(_ENCODE_NESTED_SCHEMA) + encoded = tf.constant(coder.encode(_ENCODE_NESTED_RECORD_BATCH)) + tensor_representations = ( + tensor_representation_util.GetTensorRepresentationsFromSchema( + _ENCODE_NESTED_SCHEMA + ) ) - for name, representation in tensor_representations.items() - } - decoded = tf.io.parse_example(encoded, feature_spec) - expected_values = { - "x": [[[b"a", b"b"]], [], [], []], - "y": [[[[1.0, 2.0]]], [[[3.0, 4.0]]], [], [[]]], - "z": [[[[[4], [5]]]], [], [[[[6], []]]], [[[[], []]]]], - } - expected_ragged_ranks = {"x": 1, "y": 2, "z": 4} - self.assertLen(decoded, len(expected_values)) - for name, expected in expected_values.items(): - actual = decoded[name] - self.assertEqual(actual.to_list(), expected, msg=f"For {name}") - self.assertEqual( - actual.ragged_rank, expected_ragged_ranks[name], msg=f"For {name}" - ) + dtypes = { + "x": schema_pb2.FeatureType.BYTES, + "y": schema_pb2.FeatureType.FLOAT, + "z": schema_pb2.FeatureType.INT, + } + feature_spec = { + name: tensor_representation_util.CreateTfExampleParserConfig( + representation, dtypes[name] + ) + for name, representation in tensor_representations.items() + } + decoded = tf.io.parse_example(encoded, feature_spec) + expected_values = { + "x": [[[b"a", b"b"]], [], [], []], + "y": [[[[1.0, 2.0]]], [[[3.0, 4.0]]], [], [[]]], + "z": [[[[[4], [5]]]], [], [[[[6], []]]], [[[[], []]]]], + } + expected_ragged_ranks = {"x": 1, "y": 2, "z": 4} + self.assertLen(decoded, len(expected_values)) + for name, expected in expected_values.items(): + actual = decoded[name] + self.assertEqual(actual.to_list(), expected, msg=f"For {name}") + self.assertEqual( + actual.ragged_rank, expected_ragged_ranks[name], msg=f"For {name}" + ) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/tfx_bsl/coders/example_numpy_decoder_test.py b/tfx_bsl/coders/example_numpy_decoder_test.py index d60d0d15..bfd6b2cd 100644 --- a/tfx_bsl/coders/example_numpy_decoder_test.py +++ b/tfx_bsl/coders/example_numpy_decoder_test.py @@ -15,128 +15,126 @@ import numpy as np import tensorflow as tf -from tfx_bsl.coders import example_coder - +from absl.testing import absltest, parameterized from google.protobuf import text_format -from absl.testing import absltest -from absl.testing import parameterized + +from tfx_bsl.coders import example_coder _TF_EXAMPLE_DECODER_TESTS = [ { - 'testcase_name': 'empty_input', - 'example_proto_text': '''features {}''', - 'decoded_example': {} + "testcase_name": "empty_input", + "example_proto_text": """features {}""", + "decoded_example": {}, }, { - 'testcase_name': 'int_feature_non_empty', - 'example_proto_text': ''' + "testcase_name": "int_feature_non_empty", + "example_proto_text": """ features { feature { key: 'x' value { int64_list { value: [ 1, 2, 3 ] } } } } - ''', - 'decoded_example': {'x': np.array([1, 2, 3], dtype=np.int64)} + """, + "decoded_example": {"x": np.array([1, 2, 3], dtype=np.int64)}, }, { - 'testcase_name': 'float_feature_non_empty', - 'example_proto_text': ''' + "testcase_name": "float_feature_non_empty", + "example_proto_text": """ features { feature { key: 'x' value { float_list { value: [ 4.0, 5.0 ] } } } } - ''', - 'decoded_example': {'x': np.array([4.0, 5.0], dtype=np.float32)} + """, + "decoded_example": {"x": np.array([4.0, 5.0], dtype=np.float32)}, }, { - 'testcase_name': 'str_feature_non_empty', - 'example_proto_text': ''' + "testcase_name": "str_feature_non_empty", + "example_proto_text": """ features { feature { key: 'x' value { bytes_list { value: [ 'string', 'list' ] } } } } - ''', - 'decoded_example': {'x': np.array([b'string', b'list'], - dtype=object)} + """, + "decoded_example": {"x": np.array([b"string", b"list"], dtype=object)}, }, { - 'testcase_name': 'int_feature_empty', - 'example_proto_text': ''' + "testcase_name": "int_feature_empty", + "example_proto_text": """ features { feature { key: 'x' value { int64_list { } } } } - ''', - 'decoded_example': {'x': np.array([], dtype=np.int64)} + """, + "decoded_example": {"x": np.array([], dtype=np.int64)}, }, { - 'testcase_name': 'float_feature_empty', - 'example_proto_text': ''' + "testcase_name": "float_feature_empty", + "example_proto_text": """ features { feature { key: 'x' value { float_list { } } } } - ''', - 'decoded_example': {'x': np.array([], dtype=np.float32)} + """, + "decoded_example": {"x": np.array([], dtype=np.float32)}, }, { - 'testcase_name': 'str_feature_empty', - 'example_proto_text': ''' + "testcase_name": "str_feature_empty", + "example_proto_text": """ features { feature { key: 'x' value { bytes_list { } } } } - ''', - 'decoded_example': {'x': np.array([], dtype=object)} + """, + "decoded_example": {"x": np.array([], dtype=object)}, }, { - 'testcase_name': 'feature_missing', - 'example_proto_text': ''' + "testcase_name": "feature_missing", + "example_proto_text": """ features { feature { key: 'x' value { } } } - ''', - 'decoded_example': {'x': None} + """, + "decoded_example": {"x": None}, }, ] class TFExampleDecoderTest(parameterized.TestCase): - """Tests for TFExampleDecoder.""" + """Tests for TFExampleDecoder.""" - def _check_decoding_results(self, actual, expected): - # Check that the numpy array dtypes match. - self.assertEqual(len(actual), len(expected)) - for key in actual: - if expected[key] is None: - self.assertEqual(actual[key], None) - else: - self.assertEqual(actual[key].dtype, expected[key].dtype) - np.testing.assert_equal(actual, expected) + def _check_decoding_results(self, actual, expected): + # Check that the numpy array dtypes match. + self.assertEqual(len(actual), len(expected)) + for key in actual: + if expected[key] is None: + self.assertEqual(actual[key], None) + else: + self.assertEqual(actual[key].dtype, expected[key].dtype) + np.testing.assert_equal(actual, expected) - @parameterized.named_parameters( - *_TF_EXAMPLE_DECODER_TESTS) - def test_decode_example(self, example_proto_text, decoded_example): - example = tf.train.Example() - text_format.Merge(example_proto_text, example) - self._check_decoding_results( - example_coder.ExampleToNumpyDict(example.SerializeToString()), - decoded_example) + @parameterized.named_parameters(*_TF_EXAMPLE_DECODER_TESTS) + def test_decode_example(self, example_proto_text, decoded_example): + example = tf.train.Example() + text_format.Merge(example_proto_text, example) + self._check_decoding_results( + example_coder.ExampleToNumpyDict(example.SerializeToString()), + decoded_example, + ) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tfx_bsl/coders/sequence_example_coder.py b/tfx_bsl/coders/sequence_example_coder.py index 3c6bc22c..962fc538 100644 --- a/tfx_bsl/coders/sequence_example_coder.py +++ b/tfx_bsl/coders/sequence_example_coder.py @@ -18,11 +18,14 @@ # pylint: disable=g-import-not-at-top # See b/148667210 for why the ImportError is ignored. try: - from tfx_bsl.cc.tfx_bsl_extension.coders import SequenceExamplesToRecordBatchDecoder + from tfx_bsl.cc.tfx_bsl_extension.coders import SequenceExamplesToRecordBatchDecoder except ImportError: - import sys - sys.stderr.write("Error importing tfx_bsl_extension.coders. " - "Some tfx_bsl functionalities are not available.") + import sys + + sys.stderr.write( + "Error importing tfx_bsl_extension.coders. " + "Some tfx_bsl functionalities are not available." + ) # pylint: enable=g-import-not-at-top # pytype: enable=import-error # pylint: enable=unused-import diff --git a/tfx_bsl/coders/sequence_example_coder_test.py b/tfx_bsl/coders/sequence_example_coder_test.py index 45f56243..993a25d1 100644 --- a/tfx_bsl/coders/sequence_example_coder_test.py +++ b/tfx_bsl/coders/sequence_example_coder_test.py @@ -15,13 +15,12 @@ import pyarrow as pa import tensorflow as tf -from tfx_bsl.coders import sequence_example_coder - +from absl.testing import absltest, parameterized from google.protobuf import text_format -from absl.testing import absltest -from absl.testing import parameterized from tensorflow_metadata.proto.v0 import schema_pb2 +from tfx_bsl.coders import sequence_example_coder + _TEST_SEQUENCE_COLUMN_NAME = "##SEQUENCE##" _TYPED_SEQUENCE_EXAMPLE = """ context { @@ -325,37 +324,55 @@ testcase_name="without_schema_first_example_typed", schema_text_proto=None, sequence_examples_text_proto=[ - _TYPED_SEQUENCE_EXAMPLE, _UNTYPED_SEQUENCE_EXAMPLE, + _TYPED_SEQUENCE_EXAMPLE, + _UNTYPED_SEQUENCE_EXAMPLE, _SOME_FEATURES_TYPED_SEQUENCE_EXAMPLE, - _EMPTY_VALUES_LIST_SEQUENCE_EXAMPLE + _EMPTY_VALUES_LIST_SEQUENCE_EXAMPLE, ], - expected=pa.RecordBatch.from_arrays([ - pa.array([[1], None, None, []], type=pa.large_list(pa.int64())), - pa.array([[1.0, 2.0], None, None, []], - type=pa.large_list(pa.float32())), - pa.array([[b"a", b"b", b"c"], None, None, []], - type=pa.large_list(pa.large_binary())), - pa.array([None, None, None, None], pa.null()), - pa.array([None, None, [1.0], None], - type=pa.large_list(pa.float32())), - pa.StructArray.from_arrays([ - pa.array([None, None, [[1.0]], None], - type=pa.large_list(pa.large_list(pa.float32()))), - pa.array([[[1, 2], [3]], [], [None, None, None], [[], []]], - type=pa.large_list(pa.large_list(pa.int64()))), - pa.array([[[3.0, 4.0], [1.0, 2.0]], [], [None], [[]]], - type=pa.large_list(pa.large_list(pa.float32()))), - pa.array([[[b"a", b"b"], [b"c"]], [], [None], [[]]], - type=pa.large_list(pa.large_list(pa.large_binary()))) + expected=pa.RecordBatch.from_arrays( + [ + pa.array([[1], None, None, []], type=pa.large_list(pa.int64())), + pa.array( + [[1.0, 2.0], None, None, []], type=pa.large_list(pa.float32()) + ), + pa.array( + [[b"a", b"b", b"c"], None, None, []], + type=pa.large_list(pa.large_binary()), + ), + pa.array([None, None, None, None], pa.null()), + pa.array([None, None, [1.0], None], type=pa.large_list(pa.float32())), + pa.StructArray.from_arrays( + [ + pa.array( + [None, None, [[1.0]], None], + type=pa.large_list(pa.large_list(pa.float32())), + ), + pa.array( + [[[1, 2], [3]], [], [None, None, None], [[], []]], + type=pa.large_list(pa.large_list(pa.int64())), + ), + pa.array( + [[[3.0, 4.0], [1.0, 2.0]], [], [None], [[]]], + type=pa.large_list(pa.large_list(pa.float32())), + ), + pa.array( + [[[b"a", b"b"], [b"c"]], [], [None], [[]]], + type=pa.large_list(pa.large_list(pa.large_binary())), + ), + ], + names=["sequence_v", "sequence_x", "sequence_y", "sequence_z"], + ), + ], + [ + "context_a", + "context_b", + "context_c", + "context_d", + "context_e", + _TEST_SEQUENCE_COLUMN_NAME, ], - names=[ - "sequence_v", "sequence_x", - "sequence_y", "sequence_z" - ]) - ], [ - "context_a", "context_b", "context_c", "context_d", "context_e", - _TEST_SEQUENCE_COLUMN_NAME - ])), + ), + ), dict( testcase_name="with_schema_first_example_typed", schema_text_proto=""" @@ -390,64 +407,95 @@ } }""", sequence_examples_text_proto=[ - _TYPED_SEQUENCE_EXAMPLE, _UNTYPED_SEQUENCE_EXAMPLE, + _TYPED_SEQUENCE_EXAMPLE, + _UNTYPED_SEQUENCE_EXAMPLE, _SOME_FEATURES_TYPED_SEQUENCE_EXAMPLE, - _EMPTY_VALUES_LIST_SEQUENCE_EXAMPLE + _EMPTY_VALUES_LIST_SEQUENCE_EXAMPLE, ], - expected=pa.RecordBatch.from_arrays([ - pa.array([[1], None, None, []], type=pa.large_list(pa.int64())), - pa.array([[1.0, 2.0], None, None, []], - type=pa.large_list(pa.float32())), - pa.array([[b"a", b"b", b"c"], None, None, []], - type=pa.large_list(pa.large_binary())), - pa.StructArray.from_arrays([ - pa.array([[[1, 2], [3]], [], [None, None, None], [[], []]], - type=pa.large_list(pa.large_list(pa.int64()))), - pa.array([[[3.0, 4.0], [1.0, 2.0]], [], [None], [[]]], - type=pa.large_list(pa.large_list(pa.float32()))), - pa.array([[[b"a", b"b"], [b"c"]], [], [None], [[]]], - type=pa.large_list(pa.large_list(pa.large_binary()))) + expected=pa.RecordBatch.from_arrays( + [ + pa.array([[1], None, None, []], type=pa.large_list(pa.int64())), + pa.array( + [[1.0, 2.0], None, None, []], type=pa.large_list(pa.float32()) + ), + pa.array( + [[b"a", b"b", b"c"], None, None, []], + type=pa.large_list(pa.large_binary()), + ), + pa.StructArray.from_arrays( + [ + pa.array( + [[[1, 2], [3]], [], [None, None, None], [[], []]], + type=pa.large_list(pa.large_list(pa.int64())), + ), + pa.array( + [[[3.0, 4.0], [1.0, 2.0]], [], [None], [[]]], + type=pa.large_list(pa.large_list(pa.float32())), + ), + pa.array( + [[[b"a", b"b"], [b"c"]], [], [None], [[]]], + type=pa.large_list(pa.large_list(pa.large_binary())), + ), + ], + names=["sequence_x", "sequence_y", "sequence_z"], + ), ], - names=[ - "sequence_x", "sequence_y", - "sequence_z" - ]) - ], ["context_a", "context_b", "context_c", _TEST_SEQUENCE_COLUMN_NAME - ])), + ["context_a", "context_b", "context_c", _TEST_SEQUENCE_COLUMN_NAME], + ), + ), dict( testcase_name="without_schema_untyped_then_typed_examples", schema_text_proto=None, sequence_examples_text_proto=[ - _UNTYPED_SEQUENCE_EXAMPLE, _SOME_FEATURES_TYPED_SEQUENCE_EXAMPLE, - _EMPTY_VALUES_LIST_SEQUENCE_EXAMPLE, _TYPED_SEQUENCE_EXAMPLE + _UNTYPED_SEQUENCE_EXAMPLE, + _SOME_FEATURES_TYPED_SEQUENCE_EXAMPLE, + _EMPTY_VALUES_LIST_SEQUENCE_EXAMPLE, + _TYPED_SEQUENCE_EXAMPLE, ], - expected=pa.RecordBatch.from_arrays([ - pa.array([None, None, [], [1]], type=pa.large_list(pa.int64())), - pa.array([None, None, [], [1.0, 2.0]], - type=pa.large_list(pa.float32())), - pa.array([None, None, [], [b"a", b"b", b"c"]], - type=pa.large_list(pa.large_binary())), - pa.array([None, None, None, None], pa.null()), - pa.array([None, [1.0], None, None], - type=pa.large_list(pa.float32())), - pa.StructArray.from_arrays([ - pa.array([None, [[1.0]], None, None], - type=pa.large_list(pa.large_list(pa.float32()))), - pa.array([[], [None, None, None], [[], []], [[1, 2], [3]]], - type=pa.large_list(pa.large_list(pa.int64()))), - pa.array([[], [None], [[]], [[3.0, 4.0], [1.0, 2.0]]], - type=pa.large_list(pa.large_list(pa.float32()))), - pa.array([[], [None], [[]], [[b"a", b"b"], [b"c"]]], - type=pa.large_list(pa.large_list(pa.large_binary()))) + expected=pa.RecordBatch.from_arrays( + [ + pa.array([None, None, [], [1]], type=pa.large_list(pa.int64())), + pa.array( + [None, None, [], [1.0, 2.0]], type=pa.large_list(pa.float32()) + ), + pa.array( + [None, None, [], [b"a", b"b", b"c"]], + type=pa.large_list(pa.large_binary()), + ), + pa.array([None, None, None, None], pa.null()), + pa.array([None, [1.0], None, None], type=pa.large_list(pa.float32())), + pa.StructArray.from_arrays( + [ + pa.array( + [None, [[1.0]], None, None], + type=pa.large_list(pa.large_list(pa.float32())), + ), + pa.array( + [[], [None, None, None], [[], []], [[1, 2], [3]]], + type=pa.large_list(pa.large_list(pa.int64())), + ), + pa.array( + [[], [None], [[]], [[3.0, 4.0], [1.0, 2.0]]], + type=pa.large_list(pa.large_list(pa.float32())), + ), + pa.array( + [[], [None], [[]], [[b"a", b"b"], [b"c"]]], + type=pa.large_list(pa.large_list(pa.large_binary())), + ), + ], + names=["sequence_v", "sequence_x", "sequence_y", "sequence_z"], + ), + ], + [ + "context_a", + "context_b", + "context_c", + "context_d", + "context_e", + _TEST_SEQUENCE_COLUMN_NAME, ], - names=[ - "sequence_v", "sequence_x", - "sequence_y", "sequence_z" - ]) - ], [ - "context_a", "context_b", "context_c", "context_d", "context_e", - _TEST_SEQUENCE_COLUMN_NAME - ])), + ), + ), dict( testcase_name="with_schema_untyped_then_typed_examples", schema_text_proto=""" @@ -482,50 +530,72 @@ } }""", sequence_examples_text_proto=[ - _UNTYPED_SEQUENCE_EXAMPLE, _SOME_FEATURES_TYPED_SEQUENCE_EXAMPLE, - _EMPTY_VALUES_LIST_SEQUENCE_EXAMPLE, _TYPED_SEQUENCE_EXAMPLE + _UNTYPED_SEQUENCE_EXAMPLE, + _SOME_FEATURES_TYPED_SEQUENCE_EXAMPLE, + _EMPTY_VALUES_LIST_SEQUENCE_EXAMPLE, + _TYPED_SEQUENCE_EXAMPLE, ], - expected=pa.RecordBatch.from_arrays([ - pa.array([None, None, [], [1]], type=pa.large_list(pa.int64())), - pa.array([None, None, [], [1.0, 2.0]], - type=pa.large_list(pa.float32())), - pa.array([None, None, [], [b"a", b"b", b"c"]], - type=pa.large_list(pa.large_binary())), - pa.StructArray.from_arrays([ - pa.array([[], [None, None, None], [[], []], [[1, 2], [3]]], - type=pa.large_list(pa.large_list(pa.int64()))), - pa.array([[], [None], [[]], [[3.0, 4.0], [1.0, 2.0]]], - type=pa.large_list(pa.large_list(pa.float32()))), - pa.array([[], [None], [[]], [[b"a", b"b"], [b"c"]]], - type=pa.large_list(pa.large_list(pa.large_binary()))) + expected=pa.RecordBatch.from_arrays( + [ + pa.array([None, None, [], [1]], type=pa.large_list(pa.int64())), + pa.array( + [None, None, [], [1.0, 2.0]], type=pa.large_list(pa.float32()) + ), + pa.array( + [None, None, [], [b"a", b"b", b"c"]], + type=pa.large_list(pa.large_binary()), + ), + pa.StructArray.from_arrays( + [ + pa.array( + [[], [None, None, None], [[], []], [[1, 2], [3]]], + type=pa.large_list(pa.large_list(pa.int64())), + ), + pa.array( + [[], [None], [[]], [[3.0, 4.0], [1.0, 2.0]]], + type=pa.large_list(pa.large_list(pa.float32())), + ), + pa.array( + [[], [None], [[]], [[b"a", b"b"], [b"c"]]], + type=pa.large_list(pa.large_list(pa.large_binary())), + ), + ], + names=["sequence_x", "sequence_y", "sequence_z"], + ), ], - names=[ - "sequence_x", "sequence_y", - "sequence_z" - ]) - ], ["context_a", "context_b", "context_c", _TEST_SEQUENCE_COLUMN_NAME - ])), + ["context_a", "context_b", "context_c", _TEST_SEQUENCE_COLUMN_NAME], + ), + ), dict( testcase_name="without_schema_no_typed_examples", schema_text_proto=None, sequence_examples_text_proto=_TEST_SEQUENCE_EXAMPLES_NONE_TYPED, - expected=pa.RecordBatch.from_arrays([ - pa.array([None, None], type=pa.null()), - pa.array([None, None], type=pa.null()), - pa.array([None, None], type=pa.null()), - pa.array([None, None], type=pa.null()), - pa.StructArray.from_arrays([ - pa.array([None, [None]], type=pa.large_list(pa.null())), - pa.array([[], [None]], type=pa.large_list(pa.null())), + expected=pa.RecordBatch.from_arrays( + [ + pa.array([None, None], type=pa.null()), + pa.array([None, None], type=pa.null()), + pa.array([None, None], type=pa.null()), + pa.array([None, None], type=pa.null()), + pa.StructArray.from_arrays( + [ + pa.array([None, [None]], type=pa.large_list(pa.null())), + pa.array([[], [None]], type=pa.large_list(pa.null())), + ], + names=[ + "sequence_w", + "sequence_x", + ], + ), + ], + [ + "context_a", + "context_b", + "context_c", + "context_d", + _TEST_SEQUENCE_COLUMN_NAME, ], - names=[ - "sequence_w", - "sequence_x", - ]) - ], [ - "context_a", "context_b", "context_c", "context_d", - _TEST_SEQUENCE_COLUMN_NAME - ])), + ), + ), dict( testcase_name="with_schema_no_typed_examples", schema_text_proto=""" @@ -560,24 +630,31 @@ } }""", sequence_examples_text_proto=_TEST_SEQUENCE_EXAMPLES_NONE_TYPED, - expected=pa.RecordBatch.from_arrays([ - pa.array([None, None], type=pa.large_list(pa.int64())), - pa.array([None, None], type=pa.large_list(pa.float32())), - pa.array([None, None], type=pa.large_list(pa.large_binary())), - pa.StructArray.from_arrays([ - pa.array([[], [None]], - type=pa.large_list(pa.large_list(pa.int64()))), - pa.array([None, None], - type=pa.large_list(pa.large_list(pa.float32()))), - pa.array([None, None], - type=pa.large_list(pa.large_list(pa.large_binary()))) + expected=pa.RecordBatch.from_arrays( + [ + pa.array([None, None], type=pa.large_list(pa.int64())), + pa.array([None, None], type=pa.large_list(pa.float32())), + pa.array([None, None], type=pa.large_list(pa.large_binary())), + pa.StructArray.from_arrays( + [ + pa.array( + [[], [None]], type=pa.large_list(pa.large_list(pa.int64())) + ), + pa.array( + [None, None], + type=pa.large_list(pa.large_list(pa.float32())), + ), + pa.array( + [None, None], + type=pa.large_list(pa.large_list(pa.large_binary())), + ), + ], + names=["sequence_x", "sequence_y", "sequence_z"], + ), ], - names=[ - "sequence_x", "sequence_y", - "sequence_z" - ]) - ], ["context_a", "context_b", "context_c", _TEST_SEQUENCE_COLUMN_NAME - ])), + ["context_a", "context_b", "context_c", _TEST_SEQUENCE_COLUMN_NAME], + ), + ), dict( testcase_name="build_nulls_for_unseen_feature", schema_text_proto=""" @@ -597,19 +674,29 @@ } """, sequence_examples_text_proto=[ - _TYPED_SEQUENCE_EXAMPLE, _UNTYPED_SEQUENCE_EXAMPLE, + _TYPED_SEQUENCE_EXAMPLE, + _UNTYPED_SEQUENCE_EXAMPLE, _SOME_FEATURES_TYPED_SEQUENCE_EXAMPLE, - _EMPTY_VALUES_LIST_SEQUENCE_EXAMPLE + _EMPTY_VALUES_LIST_SEQUENCE_EXAMPLE, ], - expected=pa.RecordBatch.from_arrays([ - pa.array([None, None, None, None], - type=pa.large_list(pa.large_binary())), - pa.StructArray.from_arrays([ - pa.array([None, None, None, None], - type=pa.large_list(pa.large_list(pa.int64()))) + expected=pa.RecordBatch.from_arrays( + [ + pa.array( + [None, None, None, None], type=pa.large_list(pa.large_binary()) + ), + pa.StructArray.from_arrays( + [ + pa.array( + [None, None, None, None], + type=pa.large_list(pa.large_list(pa.int64())), + ) + ], + names=["sequence_u"], + ), ], - names=["sequence_u"]), - ], ["context_u", _TEST_SEQUENCE_COLUMN_NAME])), + ["context_u", _TEST_SEQUENCE_COLUMN_NAME], + ), + ), dict( testcase_name="build_null_for_unset_kind", schema_text_proto=""" @@ -636,12 +723,17 @@ } """ ], - expected=pa.RecordBatch.from_arrays([ - pa.array([None], type=pa.large_list(pa.large_binary())), - pa.StructArray.from_arrays( - [pa.array([[]], type=pa.large_list(pa.large_list(pa.int64())))], - names=["sequence_a"]), - ], ["context_a", _TEST_SEQUENCE_COLUMN_NAME])), + expected=pa.RecordBatch.from_arrays( + [ + pa.array([None], type=pa.large_list(pa.large_binary())), + pa.StructArray.from_arrays( + [pa.array([[]], type=pa.large_list(pa.large_list(pa.int64())))], + names=["sequence_a"], + ), + ], + ["context_a", _TEST_SEQUENCE_COLUMN_NAME], + ), + ), dict( testcase_name="schema_does_not_contain_sequence_feature", schema_text_proto=""" @@ -658,9 +750,13 @@ } """ ], - expected=pa.RecordBatch.from_arrays([ - pa.array([None], type=pa.large_list(pa.large_binary())), - ], ["context_a"])), + expected=pa.RecordBatch.from_arrays( + [ + pa.array([None], type=pa.large_list(pa.large_binary())), + ], + ["context_a"], + ), + ), dict( testcase_name="duplicate_context_feature_names_in_schema", schema_text_proto=""" @@ -682,9 +778,13 @@ } """ ], - expected=pa.RecordBatch.from_arrays([ - pa.array([None], type=pa.large_list(pa.large_binary())), - ], ["context_a"])), + expected=pa.RecordBatch.from_arrays( + [ + pa.array([None], type=pa.large_list(pa.large_binary())), + ], + ["context_a"], + ), + ), dict( testcase_name="duplicate_sequence_feature_names_in_schema", schema_text_proto=""" @@ -711,20 +811,31 @@ } """ ], - expected=pa.RecordBatch.from_arrays([ - pa.StructArray.from_arrays( - [pa.array([[]], type=pa.large_list(pa.large_list(pa.int64())))], - names=["sequence_a"]), - ], [_TEST_SEQUENCE_COLUMN_NAME])), + expected=pa.RecordBatch.from_arrays( + [ + pa.StructArray.from_arrays( + [pa.array([[]], type=pa.large_list(pa.large_list(pa.int64())))], + names=["sequence_a"], + ), + ], + [_TEST_SEQUENCE_COLUMN_NAME], + ), + ), dict( testcase_name="feature_lists_with_no_sequence_features", schema_text_proto=None, - sequence_examples_text_proto=[""" + sequence_examples_text_proto=[ + """ feature_lists {} - """], - expected=pa.RecordBatch.from_arrays([ - pa.StructArray.from_buffers(pa.struct([]), 1, [None]), - ], [_TEST_SEQUENCE_COLUMN_NAME])), + """ + ], + expected=pa.RecordBatch.from_arrays( + [ + pa.StructArray.from_buffers(pa.struct([]), 1, [None]), + ], + [_TEST_SEQUENCE_COLUMN_NAME], + ), + ), dict( testcase_name="without_schema_only_context_features", schema_text_proto=None, @@ -742,9 +853,13 @@ } """ ], - expected=pa.RecordBatch.from_arrays([ - pa.array([[1, 2]], type=pa.large_list(pa.int64())), - ], ["context_a"])), + expected=pa.RecordBatch.from_arrays( + [ + pa.array([[1, 2]], type=pa.large_list(pa.int64())), + ], + ["context_a"], + ), + ), dict( testcase_name="without_schema_only_sequence_features", schema_text_proto=None, @@ -764,13 +879,20 @@ } """ ], - expected=pa.RecordBatch.from_arrays([ - pa.StructArray.from_arrays([ - pa.array([[[1, 2]]], - type=pa.large_list(pa.large_list(pa.int64()))), + expected=pa.RecordBatch.from_arrays( + [ + pa.StructArray.from_arrays( + [ + pa.array( + [[[1, 2]]], type=pa.large_list(pa.large_list(pa.int64())) + ), + ], + names=["sequence_x"], + ) ], - names=["sequence_x"]) - ], [_TEST_SEQUENCE_COLUMN_NAME])), + [_TEST_SEQUENCE_COLUMN_NAME], + ), + ), ] _INVALID_INPUT_CASES = [ @@ -790,7 +912,8 @@ error=RuntimeError, error_msg_regex=( "Feature had wrong type, expected bytes_list, found float_list " - "for feature \"a\""), + 'for feature "a"' + ), ), dict( testcase_name="sequence_feature_actual_type_mismatches_schema_type", @@ -821,7 +944,8 @@ error=RuntimeError, error_msg_regex=( "Feature had wrong type, expected bytes_list, found float_list " - "for sequence feature \"a\""), + 'for sequence feature "a"' + ), ), dict( testcase_name="context_feature_no_schema_mixed_type", @@ -829,14 +953,16 @@ sequence_examples_text_proto=[ """ context { feature { key: "a" value { float_list { value: [] } } } } - """, """ + """, + """ context { feature { key: "a" value { int64_list { value: [] } } } } - """ + """, ], error=RuntimeError, error_msg_regex=( "Feature had wrong type, expected float_list, found int64_list" - " for feature \"a\""), + ' for feature "a"' + ), ), dict( testcase_name="sequence_feature_no_schema_mixed_type", @@ -851,7 +977,8 @@ } } } - """, """ + """, + """ feature_lists { feature_list { key: 'a' @@ -860,89 +987,95 @@ } } } - """ + """, ], error=RuntimeError, error_msg_regex=( "Feature had wrong type, expected float_list, found int64_list" - " for sequence feature \"a\""), + ' for sequence feature "a"' + ), ), ] class SequenceExamplesToRecordBatchDecoderTest(parameterized.TestCase): + @parameterized.named_parameters(*_DECODE_CASES) + def test_decode(self, schema_text_proto, sequence_examples_text_proto, expected): + serialized_sequence_examples = [ + text_format.Parse(pbtxt, tf.train.SequenceExample()).SerializeToString() + for pbtxt in sequence_examples_text_proto + ] + serialized_schema = None + if schema_text_proto is not None: + serialized_schema = text_format.Parse( + schema_text_proto, schema_pb2.Schema() + ).SerializeToString() - @parameterized.named_parameters(*_DECODE_CASES) - def test_decode(self, schema_text_proto, sequence_examples_text_proto, - expected): - serialized_sequence_examples = [ - text_format.Parse(pbtxt, - tf.train.SequenceExample()).SerializeToString() - for pbtxt in sequence_examples_text_proto - ] - serialized_schema = None - if schema_text_proto is not None: - serialized_schema = text_format.Parse( - schema_text_proto, schema_pb2.Schema()).SerializeToString() - - if serialized_schema: - coder = sequence_example_coder.SequenceExamplesToRecordBatchDecoder( - _TEST_SEQUENCE_COLUMN_NAME, - serialized_schema) - else: - coder = sequence_example_coder.SequenceExamplesToRecordBatchDecoder( - _TEST_SEQUENCE_COLUMN_NAME) + if serialized_schema: + coder = sequence_example_coder.SequenceExamplesToRecordBatchDecoder( + _TEST_SEQUENCE_COLUMN_NAME, serialized_schema + ) + else: + coder = sequence_example_coder.SequenceExamplesToRecordBatchDecoder( + _TEST_SEQUENCE_COLUMN_NAME + ) - result = coder.DecodeBatch(serialized_sequence_examples) - self.assertIsInstance(result, pa.RecordBatch) - self.assertTrue( - result.equals(expected), - "actual: {}\n expected:{}".format(result, expected)) + result = coder.DecodeBatch(serialized_sequence_examples) + self.assertIsInstance(result, pa.RecordBatch) + self.assertTrue( + result.equals(expected), f"actual: {result}\n expected:{expected}" + ) - if serialized_schema is not None: - self.assertTrue(coder.ArrowSchema().equals(result.schema)) + if serialized_schema is not None: + self.assertTrue(coder.ArrowSchema().equals(result.schema)) - @parameterized.named_parameters(*_INVALID_INPUT_CASES) - def test_invalid_input(self, schema_text_proto, sequence_examples_text_proto, - error, error_msg_regex): - serialized_sequence_examples = [ - text_format.Parse(pbtxt, - tf.train.SequenceExample()).SerializeToString() - for pbtxt in sequence_examples_text_proto - ] - serialized_schema = None - if schema_text_proto is not None: - serialized_schema = text_format.Parse( - schema_text_proto, schema_pb2.Schema()).SerializeToString() + @parameterized.named_parameters(*_INVALID_INPUT_CASES) + def test_invalid_input( + self, schema_text_proto, sequence_examples_text_proto, error, error_msg_regex + ): + serialized_sequence_examples = [ + text_format.Parse(pbtxt, tf.train.SequenceExample()).SerializeToString() + for pbtxt in sequence_examples_text_proto + ] + serialized_schema = None + if schema_text_proto is not None: + serialized_schema = text_format.Parse( + schema_text_proto, schema_pb2.Schema() + ).SerializeToString() - if serialized_schema: - coder = sequence_example_coder.SequenceExamplesToRecordBatchDecoder( - _TEST_SEQUENCE_COLUMN_NAME, serialized_schema) - else: - coder = sequence_example_coder.SequenceExamplesToRecordBatchDecoder( - _TEST_SEQUENCE_COLUMN_NAME) + if serialized_schema: + coder = sequence_example_coder.SequenceExamplesToRecordBatchDecoder( + _TEST_SEQUENCE_COLUMN_NAME, serialized_schema + ) + else: + coder = sequence_example_coder.SequenceExamplesToRecordBatchDecoder( + _TEST_SEQUENCE_COLUMN_NAME + ) - with self.assertRaisesRegex(error, error_msg_regex): - coder.DecodeBatch(serialized_sequence_examples) + with self.assertRaisesRegex(error, error_msg_regex): + coder.DecodeBatch(serialized_sequence_examples) - def test_sequence_feature_column_name_not_struct_in_schema(self): - schema_text_proto = """ + def test_sequence_feature_column_name_not_struct_in_schema(self): + schema_text_proto = """ feature { name: "##SEQUENCE##" type: INT } """ - serialized_schema = text_format.Parse( - schema_text_proto, schema_pb2.Schema()).SerializeToString() + serialized_schema = text_format.Parse( + schema_text_proto, schema_pb2.Schema() + ).SerializeToString() - error_msg_regex = ( - "Found a feature in the schema with the sequence_feature_column_name " - r"\(i.e., ##SEQUENCE##\) that is not a struct.*") + error_msg_regex = ( + "Found a feature in the schema with the sequence_feature_column_name " + r"\(i.e., ##SEQUENCE##\) that is not a struct.*" + ) - with self.assertRaisesRegex(RuntimeError, error_msg_regex): - sequence_example_coder.SequenceExamplesToRecordBatchDecoder( - _TEST_SEQUENCE_COLUMN_NAME, serialized_schema) + with self.assertRaisesRegex(RuntimeError, error_msg_regex): + sequence_example_coder.SequenceExamplesToRecordBatchDecoder( + _TEST_SEQUENCE_COLUMN_NAME, serialized_schema + ) if __name__ == "__main__": - absltest.main() + absltest.main() diff --git a/tfx_bsl/coders/tf_graph_record_decoder.py b/tfx_bsl/coders/tf_graph_record_decoder.py index 8de4ddd0..8d78ebeb 100644 --- a/tfx_bsl/coders/tf_graph_record_decoder.py +++ b/tfx_bsl/coders/tf_graph_record_decoder.py @@ -17,9 +17,9 @@ from typing import Dict, Optional, Union import tensorflow as tf - -from tensorflow.python.framework import composite_tensor # pylint: disable=g-direct-tensorflow-import - +from tensorflow.python.framework import ( + composite_tensor, # pylint: disable=g-direct-tensorflow-import +) TensorAlike = Union[tf.Tensor, composite_tensor.CompositeTensor] @@ -28,167 +28,179 @@ class TFGraphRecordDecoder(metaclass=abc.ABCMeta): - """Base class for decoders that turns a list of bytes to (composite) tensors. - - Sub-classes must implement `decode_record()` (see its docstring - for requirements). - - Decoder instances can be saved as a SavedModel by `save_decoder()`. - The SavedModel can be loaded back by `load_decoder()`. However, the loaded - decoder will always be of the type `LoadedDecoder` and only have the public - interfaces listed in this base class available. - """ - - def output_type_specs(self) -> Dict[str, tf.TypeSpec]: - """Returns the tf.TypeSpecs of the decoded tensors. - - Returns: - A dict whose keys are the same as keys of the dict returned by - `decode_record()` and values are the tf.TypeSpec of the corresponding - (composite) tensor. - """ - return { - k: tf.type_spec_from_value(v) for k, v in - self._make_concrete_decode_function().structured_outputs.items() - } - - @abc.abstractmethod - def decode_record(self, records: tf.Tensor) -> Dict[str, TensorAlike]: - """Sub-classes should implement this. - - Implementations must use TF ops to derive the result (composite) tensors, as - this function will be traced and become a tf.function (thus a TF Graph). - Note that autograph is not enabled in such tracing, which means any python - control flow / loops will not be converted to TF cond / loops automatically. - - The returned tensors must be batch-aligned (i.e. they should be at least - of rank 1, and their outer-most dimensions must be of the same size). They - do not have to be batch-aligned with the input tensor, but if that's the - case, an additional tensor must be provided among the results, to indicate - which input record a "row" in the output batch comes from. See - `record_index_tensor_name` for more details. - - Args: - records: a 1-D string tensor that contains the records to be decoded. - - Returns: - A dict of (composite) tensors. - """ - - @property - def record_index_tensor_name(self) -> Optional[str]: - """The name of the tensor indicating which record a slice is from. - - The decoded tensors are batch-aligned among themselves, but they don't - necessarily have to be batch-aligned with the input records. If not, - sub-classes should implement this method to tie the batch dimension - with the input record. - - The record index tensor must be a SparseTensor or a RaggedTensor of integral - type, and must be 2-D and must not contain "missing" values. + """Base class for decoders that turns a list of bytes to (composite) tensors. - A record index tensor like the following: - [[0], [0], [2]] - means that of 3 "rows" in the output "batch", the first two rows came - from the first record, and the 3rd row came from the third record. + Sub-classes must implement `decode_record()` (see its docstring + for requirements). - The name must not be an empty string. - - Returns: - The name of the record index tensor. + Decoder instances can be saved as a SavedModel by `save_decoder()`. + The SavedModel can be loaded back by `load_decoder()`. However, the loaded + decoder will always be of the type `LoadedDecoder` and only have the public + interfaces listed in this base class available. """ - return None - def _make_concrete_decode_function(self): - return ( - tf.function( + def output_type_specs(self) -> Dict[str, tf.TypeSpec]: + """Returns the tf.TypeSpecs of the decoded tensors. + + Returns + ------- + A dict whose keys are the same as keys of the dict returned by + `decode_record()` and values are the tf.TypeSpec of the corresponding + (composite) tensor. + """ + return { + k: tf.type_spec_from_value(v) + for k, v in self._make_concrete_decode_function().structured_outputs.items() + } + + @abc.abstractmethod + def decode_record(self, records: tf.Tensor) -> Dict[str, TensorAlike]: + """Sub-classes should implement this. + + Implementations must use TF ops to derive the result (composite) tensors, as + this function will be traced and become a tf.function (thus a TF Graph). + Note that autograph is not enabled in such tracing, which means any python + control flow / loops will not be converted to TF cond / loops automatically. + + The returned tensors must be batch-aligned (i.e. they should be at least + of rank 1, and their outer-most dimensions must be of the same size). They + do not have to be batch-aligned with the input tensor, but if that's the + case, an additional tensor must be provided among the results, to indicate + which input record a "row" in the output batch comes from. See + `record_index_tensor_name` for more details. + + Args: + ---- + records: a 1-D string tensor that contains the records to be decoded. + + Returns: + ------- + A dict of (composite) tensors. + """ + + @property + def record_index_tensor_name(self) -> Optional[str]: + """The name of the tensor indicating which record a slice is from. + + The decoded tensors are batch-aligned among themselves, but they don't + necessarily have to be batch-aligned with the input records. If not, + sub-classes should implement this method to tie the batch dimension + with the input record. + + The record index tensor must be a SparseTensor or a RaggedTensor of integral + type, and must be 2-D and must not contain "missing" values. + + A record index tensor like the following: + [[0], [0], [2]] + means that of 3 "rows" in the output "batch", the first two rows came + from the first record, and the 3rd row came from the third record. + + The name must not be an empty string. + + Returns + ------- + The name of the record index tensor. + """ + return None + + def _make_concrete_decode_function(self): + return tf.function( self.decode_record, input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)], - autograph=False) - .get_concrete_function()) + autograph=False, + ).get_concrete_function() - def save(self, path: str) -> None: - """Saves this TFGraphRecordDecoder to a SavedModel at `path`. + def save(self, path: str) -> None: + """Saves this TFGraphRecordDecoder to a SavedModel at `path`. - This functions the same as `tf_graph_record_decoder.save_decoder()`. This is - provided purely for convenience, and should not impact the actual saved - model, since only the `tf.function` from `_make_concrete_decode_function` is - saved. + This functions the same as `tf_graph_record_decoder.save_decoder()`. This is + provided purely for convenience, and should not impact the actual saved + model, since only the `tf.function` from `_make_concrete_decode_function` is + saved. - Args: - path: The path to where the saved_model is saved. - """ - save_decoder(self, path) + Args: + ---- + path: The path to where the saved_model is saved. + """ + save_decoder(self, path) -class LoadedDecoder(object): - """A decoder recovered from a SavedModel. +class LoadedDecoder: + """A decoder recovered from a SavedModel. - It has all the public interfaces of a TFGraphRecordDecoder. - """ + It has all the public interfaces of a TFGraphRecordDecoder. + """ - def __init__(self, loaded_module: tf.Module): - self._decode_fun = loaded_module.decode_fun - self._record_index_tensor_name = None + def __init__(self, loaded_module: tf.Module): + self._decode_fun = loaded_module.decode_fun + self._record_index_tensor_name = None - if hasattr(loaded_module, "signatures"): - for signature_name in loaded_module.signatures.keys(): - if signature_name.startswith( - _RECORD_INDEX_TENSOR_NAME_SIGNATURE_PREFIX): - record_index_tensor_name = signature_name[ - len(_RECORD_INDEX_TENSOR_NAME_SIGNATURE_PREFIX):] - assert record_index_tensor_name, ( - "Invalid (empty) record_index_tensor_name") - self._record_index_tensor_name = record_index_tensor_name + if hasattr(loaded_module, "signatures"): + for signature_name in loaded_module.signatures.keys(): + if signature_name.startswith( + _RECORD_INDEX_TENSOR_NAME_SIGNATURE_PREFIX + ): + record_index_tensor_name = signature_name[ + len(_RECORD_INDEX_TENSOR_NAME_SIGNATURE_PREFIX) : + ] + assert ( + record_index_tensor_name + ), "Invalid (empty) record_index_tensor_name" + self._record_index_tensor_name = record_index_tensor_name - assert isinstance(self._decode_fun.structured_outputs, dict) - # Note that a loaded concrete function's structured_outputs are already - # TensorSpecs (instead of TensorAlikes). - self._output_type_specs = self._decode_fun.structured_outputs.copy() + assert isinstance(self._decode_fun.structured_outputs, dict) + # Note that a loaded concrete function's structured_outputs are already + # TensorSpecs (instead of TensorAlikes). + self._output_type_specs = self._decode_fun.structured_outputs.copy() - def decode_record(self, record: tf.Tensor) -> Dict[str, TensorAlike]: - return self._decode_fun(record) + def decode_record(self, record: tf.Tensor) -> Dict[str, TensorAlike]: + return self._decode_fun(record) - def output_type_specs(self) -> Dict[str, tf.TypeSpec]: - return self._output_type_specs + def output_type_specs(self) -> Dict[str, tf.TypeSpec]: + return self._output_type_specs - @property - def record_index_tensor_name(self) -> Optional[str]: - return self._record_index_tensor_name + @property + def record_index_tensor_name(self) -> Optional[str]: + return self._record_index_tensor_name def save_decoder(decoder: TFGraphRecordDecoder, path: str) -> None: - """Saves a TFGraphRecordDecoder to a SavedModel.""" - m = tf.Module() - m.decode_fun = decoder._make_concrete_decode_function() # pylint:disable=protected-access - - signatures = dict() - if decoder.record_index_tensor_name is not None: - assert decoder.record_index_tensor_name, ( - "Invalid (empty) record_index_tensor_name") - assert decoder.record_index_tensor_name in decoder.output_type_specs(), ( - "Invalid decoder: record_index_tensor_name: {} not in output " - "tensors: {}".format(decoder.record_index_tensor_name, - decoder.output_type_specs().keys())) - - @tf.function(input_signature=[]) - def record_index_tensor_name_fun(): - return decoder.record_index_tensor_name - # We also encode the record index tensor name in the name of a signature. - # This way, we do not need to evaluate a tensor or a TF Function in order - # to know the name when loading a decoder back. - signatures = { - "%s%s" % (_RECORD_INDEX_TENSOR_NAME_SIGNATURE_PREFIX, - decoder.record_index_tensor_name): - record_index_tensor_name_fun.get_concrete_function() - } - - tf.saved_model.save(m, path, signatures=signatures) + """Saves a TFGraphRecordDecoder to a SavedModel.""" + m = tf.Module() + m.decode_fun = decoder._make_concrete_decode_function() # pylint:disable=protected-access + + signatures = dict() + if decoder.record_index_tensor_name is not None: + assert ( + decoder.record_index_tensor_name + ), "Invalid (empty) record_index_tensor_name" + assert decoder.record_index_tensor_name in decoder.output_type_specs(), ( + f"Invalid decoder: record_index_tensor_name: {decoder.record_index_tensor_name} not in output " + f"tensors: {decoder.output_type_specs().keys()}" + ) + + @tf.function(input_signature=[]) + def record_index_tensor_name_fun(): + return decoder.record_index_tensor_name + + # We also encode the record index tensor name in the name of a signature. + # This way, we do not need to evaluate a tensor or a TF Function in order + # to know the name when loading a decoder back. + signatures = { + "%s%s" + % ( + _RECORD_INDEX_TENSOR_NAME_SIGNATURE_PREFIX, + decoder.record_index_tensor_name, + ): record_index_tensor_name_fun.get_concrete_function() + } + + tf.saved_model.save(m, path, signatures=signatures) def load_decoder(path: str) -> LoadedDecoder: - """Loads a TFGraphRecordDecoder from a SavedModel.""" - loaded_module = tf.saved_model.load(path) - assert hasattr(loaded_module, "decode_fun"), ( - "the SavedModel is not a TFGraphRecordDecoder") - return LoadedDecoder(loaded_module) + """Loads a TFGraphRecordDecoder from a SavedModel.""" + loaded_module = tf.saved_model.load(path) + assert hasattr( + loaded_module, "decode_fun" + ), "the SavedModel is not a TFGraphRecordDecoder" + return LoadedDecoder(loaded_module) diff --git a/tfx_bsl/coders/tf_graph_record_decoder_test.py b/tfx_bsl/coders/tf_graph_record_decoder_test.py index 82250378..fdc39852 100644 --- a/tfx_bsl/coders/tf_graph_record_decoder_test.py +++ b/tfx_bsl/coders/tf_graph_record_decoder_test.py @@ -13,154 +13,159 @@ # limitations under the License. """Tests for tfx_bsl.coders.tf_graph_record_decoder.""" -import pytest import os import tempfile -from absl import flags +import pytest import tensorflow as tf -from tfx_bsl.coders import tf_graph_record_decoder +from absl import flags +from tfx_bsl.coders import tf_graph_record_decoder FLAGS = flags.FLAGS class _DecoderForTesting(tf_graph_record_decoder.TFGraphRecordDecoder): - - def decode_record(self, record): - indices = tf.transpose(tf.stack([ - tf.range(tf.size(record), dtype=tf.int64), - tf.zeros(tf.size(record), dtype=tf.int64) - ])) - sparse = tf.SparseTensor( - values=record, - indices=indices, - dense_shape=[tf.size(record), 1]) - return { - "sparse_tensor": sparse, - "ragged_tensor": tf.RaggedTensor.from_sparse(sparse), - "record_index": tf.RaggedTensor.from_row_splits( - values=tf.range(tf.size(record), dtype=tf.int64), - row_splits=tf.range(tf.size(record) + 1, dtype=tf.int64)), - "dense_tensor": record, - } + def decode_record(self, record): + indices = tf.transpose( + tf.stack( + [ + tf.range(tf.size(record), dtype=tf.int64), + tf.zeros(tf.size(record), dtype=tf.int64), + ] + ) + ) + sparse = tf.SparseTensor( + values=record, indices=indices, dense_shape=[tf.size(record), 1] + ) + return { + "sparse_tensor": sparse, + "ragged_tensor": tf.RaggedTensor.from_sparse(sparse), + "record_index": tf.RaggedTensor.from_row_splits( + values=tf.range(tf.size(record), dtype=tf.int64), + row_splits=tf.range(tf.size(record) + 1, dtype=tf.int64), + ), + "dense_tensor": record, + } class _DecoderForTestWithRecordIndexTensorName(_DecoderForTesting): - - @property - def record_index_tensor_name(self): - return "record_index" + @property + def record_index_tensor_name(self): + return "record_index" class _DecoderForTestWithInvalidRecordIndexTensorName(_DecoderForTesting): - - @property - def record_index_tensor_name(self): - return "does_not_exist" + @property + def record_index_tensor_name(self): + return "does_not_exist" class TfGraphRecordDecoderTest(tf.test.TestCase): - - def setUp(self): - super().setUp() - self._tmp_dir = tempfile.mkdtemp(dir=FLAGS.test_tmpdir) - - def _assert_type_specs_equal(self, lhs, rhs): - self.assertLen(lhs, len(rhs)) - for k, spec in lhs.items(): - self.assertIn(k, rhs) - # special handling for tf.TensorSpec to ignore the difference in .name. - if isinstance(spec, tf.TensorSpec): - self.assertIsInstance(rhs[k], tf.TensorSpec) - self.assertEqual(spec.shape.as_list(), rhs[k].shape.as_list()) - self.assertEqual(spec.dtype, rhs[k].dtype) - continue - self.assertEqual(spec, rhs[k]) - - @pytest.mark.xfail(run=False, reason="This test fails and needs to be fixed.") - def test_save_load_decode(self): - decoder = _DecoderForTestWithRecordIndexTensorName() - actual_type_specs = decoder.output_type_specs() - actual_sparse_tensor_spec = actual_type_specs.pop("sparse_tensor") - # The expected shape is [None, 1], but due to a TensorFlow bug, it could - # be [None, None] in older TF versions. - self.assertTrue(actual_sparse_tensor_spec == - tf.SparseTensorSpec(shape=[None, None], dtype=tf.string) or - actual_sparse_tensor_spec == tf.SparseTensorSpec( - shape=[None, 1], dtype=tf.string)) - self.assertEqual( - actual_type_specs, { - "ragged_tensor": - tf.RaggedTensorSpec( - shape=[None, None], dtype=tf.string, ragged_rank=1), - "record_index": - tf.RaggedTensorSpec( - shape=[None, None], dtype=tf.int64, ragged_rank=1), - "dense_tensor": - tf.TensorSpec(shape=[None], dtype=tf.string) - }) - self.assertEqual(decoder.record_index_tensor_name, "record_index") - tf_graph_record_decoder.save_decoder(decoder, self._tmp_dir) - loaded = tf_graph_record_decoder.load_decoder(self._tmp_dir) - self.assertEqual(loaded.record_index_tensor_name, "record_index") - - self._assert_type_specs_equal(decoder.output_type_specs(), - loaded.output_type_specs()) - - records = [b"abc", b"def"] - got = loaded.decode_record(records) - self.assertLen(got, len(loaded.output_type_specs())) - self.assertIn("sparse_tensor", got) - st = got["sparse_tensor"] - self.assertAllEqual(st.values, records) - self.assertAllEqual(st.indices, [[0, 0], [1, 0]]) - self.assertAllEqual(st.dense_shape, [2, 1]) - - rt = got["ragged_tensor"] - self.assertAllEqual(rt, tf.ragged.constant([[b"abc"], [b"def"]])) - - rt = got["record_index"] - self.assertAllEqual(rt, tf.ragged.constant([[0], [1]])) - - dt = got["dense_tensor"] - self.assertAllEqual(dt, records) - - # Also test that .record_index_tensor_name can be accessed in graph - # mode. - with tf.compat.v1.Graph().as_default(): - self.assertFalse(tf.executing_eagerly()) - loaded = tf_graph_record_decoder.load_decoder(self._tmp_dir) - self.assertEqual(loaded.record_index_tensor_name, "record_index") - - # Also test that the decoder's class method `save_decoder` works. - new_decoder_path = (os.path.join(self._tmp_dir, "decoder_2")) - decoder.save(new_decoder_path) - loaded = tf_graph_record_decoder.load_decoder(new_decoder_path) - self.assertEqual(loaded.record_index_tensor_name, "record_index") - - @pytest.mark.xfail(run=False, reason="This test fails and needs to be fixed.") - def test_no_record_index_tensor_name(self): - decoder = _DecoderForTesting() - self.assertIsNone(decoder.record_index_tensor_name) - - tf_graph_record_decoder.save_decoder(decoder, self._tmp_dir) - loaded = tf_graph_record_decoder.load_decoder(self._tmp_dir) - self._assert_type_specs_equal(decoder.output_type_specs(), - loaded.output_type_specs()) - self.assertIsNone(loaded.record_index_tensor_name) - - with tf.compat.v1.Graph().as_default(): - self.assertFalse(tf.executing_eagerly()) - loaded = tf_graph_record_decoder.load_decoder(self._tmp_dir) - self.assertIsNone(loaded.record_index_tensor_name) - - @pytest.mark.xfail(run=False, reason="This test fails and needs to be fixed.") - def test_do_not_save_if_record_index_tensor_name_invalid(self): - decoder = _DecoderForTestWithInvalidRecordIndexTensorName() - with self.assertRaisesRegex(AssertionError, "record_index_tensor_name"): - tf_graph_record_decoder.save_decoder(decoder, self._tmp_dir) + def setUp(self): + super().setUp() + self._tmp_dir = tempfile.mkdtemp(dir=FLAGS.test_tmpdir) + + def _assert_type_specs_equal(self, lhs, rhs): + self.assertLen(lhs, len(rhs)) + for k, spec in lhs.items(): + self.assertIn(k, rhs) + # special handling for tf.TensorSpec to ignore the difference in .name. + if isinstance(spec, tf.TensorSpec): + self.assertIsInstance(rhs[k], tf.TensorSpec) + self.assertEqual(spec.shape.as_list(), rhs[k].shape.as_list()) + self.assertEqual(spec.dtype, rhs[k].dtype) + continue + self.assertEqual(spec, rhs[k]) + + @pytest.mark.xfail(run=False, reason="This test fails and needs to be fixed.") + def test_save_load_decode(self): + decoder = _DecoderForTestWithRecordIndexTensorName() + actual_type_specs = decoder.output_type_specs() + actual_sparse_tensor_spec = actual_type_specs.pop("sparse_tensor") + # The expected shape is [None, 1], but due to a TensorFlow bug, it could + # be [None, None] in older TF versions. + self.assertTrue( + actual_sparse_tensor_spec + == tf.SparseTensorSpec(shape=[None, None], dtype=tf.string) + or actual_sparse_tensor_spec + == tf.SparseTensorSpec(shape=[None, 1], dtype=tf.string) + ) + self.assertEqual( + actual_type_specs, + { + "ragged_tensor": tf.RaggedTensorSpec( + shape=[None, None], dtype=tf.string, ragged_rank=1 + ), + "record_index": tf.RaggedTensorSpec( + shape=[None, None], dtype=tf.int64, ragged_rank=1 + ), + "dense_tensor": tf.TensorSpec(shape=[None], dtype=tf.string), + }, + ) + self.assertEqual(decoder.record_index_tensor_name, "record_index") + tf_graph_record_decoder.save_decoder(decoder, self._tmp_dir) + loaded = tf_graph_record_decoder.load_decoder(self._tmp_dir) + self.assertEqual(loaded.record_index_tensor_name, "record_index") + + self._assert_type_specs_equal( + decoder.output_type_specs(), loaded.output_type_specs() + ) + + records = [b"abc", b"def"] + got = loaded.decode_record(records) + self.assertLen(got, len(loaded.output_type_specs())) + self.assertIn("sparse_tensor", got) + st = got["sparse_tensor"] + self.assertAllEqual(st.values, records) + self.assertAllEqual(st.indices, [[0, 0], [1, 0]]) + self.assertAllEqual(st.dense_shape, [2, 1]) + + rt = got["ragged_tensor"] + self.assertAllEqual(rt, tf.ragged.constant([[b"abc"], [b"def"]])) + + rt = got["record_index"] + self.assertAllEqual(rt, tf.ragged.constant([[0], [1]])) + + dt = got["dense_tensor"] + self.assertAllEqual(dt, records) + + # Also test that .record_index_tensor_name can be accessed in graph + # mode. + with tf.compat.v1.Graph().as_default(): + self.assertFalse(tf.executing_eagerly()) + loaded = tf_graph_record_decoder.load_decoder(self._tmp_dir) + self.assertEqual(loaded.record_index_tensor_name, "record_index") + + # Also test that the decoder's class method `save_decoder` works. + new_decoder_path = os.path.join(self._tmp_dir, "decoder_2") + decoder.save(new_decoder_path) + loaded = tf_graph_record_decoder.load_decoder(new_decoder_path) + self.assertEqual(loaded.record_index_tensor_name, "record_index") + + @pytest.mark.xfail(run=False, reason="This test fails and needs to be fixed.") + def test_no_record_index_tensor_name(self): + decoder = _DecoderForTesting() + self.assertIsNone(decoder.record_index_tensor_name) + + tf_graph_record_decoder.save_decoder(decoder, self._tmp_dir) + loaded = tf_graph_record_decoder.load_decoder(self._tmp_dir) + self._assert_type_specs_equal( + decoder.output_type_specs(), loaded.output_type_specs() + ) + self.assertIsNone(loaded.record_index_tensor_name) + + with tf.compat.v1.Graph().as_default(): + self.assertFalse(tf.executing_eagerly()) + loaded = tf_graph_record_decoder.load_decoder(self._tmp_dir) + self.assertIsNone(loaded.record_index_tensor_name) + + @pytest.mark.xfail(run=False, reason="This test fails and needs to be fixed.") + def test_do_not_save_if_record_index_tensor_name_invalid(self): + decoder = _DecoderForTestWithInvalidRecordIndexTensorName() + with self.assertRaisesRegex(AssertionError, "record_index_tensor_name"): + tf_graph_record_decoder.save_decoder(decoder, self._tmp_dir) if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/tfx_bsl/docs/schema_interpretation.md b/tfx_bsl/docs/schema_interpretation.md index 82c2acd3..192b9048 100644 --- a/tfx_bsl/docs/schema_interpretation.md +++ b/tfx_bsl/docs/schema_interpretation.md @@ -43,7 +43,7 @@ give some examples of advanced usage. ### Primitive types -