Skip to content

Commit 10f93f6

Browse files
ybaturinatf-text-github-robot
authored andcommitted
Create tf_text wheel build rule.
PiperOrigin-RevId: 816530468
1 parent aa839b1 commit 10f93f6

File tree

8 files changed

+407
-22
lines changed

8 files changed

+407
-22
lines changed

WORKSPACE

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,28 @@ workspace(name = "org_tensorflow_text")
22

33
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
44

5+
# Toolchains for ML projects hermetic builds.
6+
# Details: https://github.com/google-ml-infra/rules_ml_toolchain
7+
http_archive(
8+
name = "rules_ml_toolchain",
9+
sha256 = "de3b14418657eeacd8afc2aa89608be6ec8d66cd6a5de81c4f693e77bc41bee1",
10+
strip_prefix = "rules_ml_toolchain-5653e5a0ca87c1272069b4b24864e55ce7f129a1",
11+
urls = [
12+
"https://github.com/google-ml-infra/rules_ml_toolchain/archive/5653e5a0ca87c1272069b4b24864e55ce7f129a1.tar.gz",
13+
],
14+
)
15+
16+
load(
17+
"@rules_ml_toolchain//cc_toolchain/deps:cc_toolchain_deps.bzl",
18+
"cc_toolchain_deps",
19+
)
20+
21+
cc_toolchain_deps()
22+
23+
register_toolchains("@rules_ml_toolchain//cc_toolchain:lx64_lx64")
24+
25+
register_toolchains("@rules_ml_toolchain//cc_toolchain:lx64_lx64_cuda")
26+
527
http_archive(
628
name = "icu",
729
strip_prefix = "icu-release-64-2",
@@ -56,10 +78,10 @@ http_archive(
5678

5779
http_archive(
5880
name = "org_tensorflow",
59-
strip_prefix = "tensorflow-40998f44c0c500ce0f6e3b1658dfbc54f838a82a",
60-
sha256 = "5a5bc4599964c71277dcac0d687435291e5810d2ac2f6283cc96736febf73aaf",
81+
sha256 = "1a25308b15036bf8006ada5c9955ddc9a217792e6fc24deee04626ec07013f2c",
82+
strip_prefix = "tensorflow-72fbba3d20f4616d7312b5e2b7f79daf6e82f2fa",
6183
urls = [
62-
"https://github.com/tensorflow/tensorflow/archive/40998f44c0c500ce0f6e3b1658dfbc54f838a82a.zip"
84+
"https://github.com/tensorflow/tensorflow/archive/72fbba3d20f4616d7312b5e2b7f79daf6e82f2fa.zip",
6385
],
6486
)
6587

@@ -134,6 +156,14 @@ load("@pypi//:requirements.bzl", "install_deps")
134156

135157
install_deps()
136158

159+
load("//oss_scripts/pip_package:tensorflow_text_python_wheel.bzl", "tensorflow_text_python_wheel_repository")
160+
161+
tensorflow_text_python_wheel_repository(
162+
name = "tensorflow_text_wheel",
163+
version_key = "__version__",
164+
version_source = "//tensorflow_text:__init__.py",
165+
)
166+
137167
# Initialize TensorFlow dependencies.
138168
load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3")
139169
tf_workspace3()
@@ -151,14 +181,16 @@ load("@local_config_android//:android.bzl", "android_workspace")
151181
android_workspace()
152182

153183
load(
154-
"@local_xla//third_party/py:python_wheel.bzl",
184+
"@org_tensorflow//third_party/xla/third_party/py:python_wheel.bzl",
155185
"python_wheel_version_suffix_repository",
156186
)
157187

158-
python_wheel_version_suffix_repository(name = "tf_wheel_version_suffix")
188+
python_wheel_version_suffix_repository(
189+
name = "tf_wheel_version_suffix",
190+
)
159191

160192
load(
161-
"@local_xla//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
193+
"@rules_ml_toolchain//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
162194
"cuda_json_init_repository",
163195
)
164196

@@ -170,7 +202,7 @@ load(
170202
"CUDNN_REDISTRIBUTIONS",
171203
)
172204
load(
173-
"@local_xla//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
205+
"@rules_ml_toolchain//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
174206
"cuda_redist_init_repositories",
175207
"cudnn_redist_init_repository",
176208
)
@@ -184,21 +216,21 @@ cudnn_redist_init_repository(
184216
)
185217

186218
load(
187-
"@local_xla//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
219+
"@rules_ml_toolchain//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
188220
"cuda_configure",
189221
)
190222

191223
cuda_configure(name = "local_config_cuda")
192224

193225
load(
194-
"@local_xla//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
226+
"@rules_ml_toolchain//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
195227
"nccl_redist_init_repository",
196228
)
197229

198230
nccl_redist_init_repository()
199231

200232
load(
201-
"@local_xla//third_party/nccl/hermetic:nccl_configure.bzl",
233+
"@rules_ml_toolchain//third_party/nccl/hermetic:nccl_configure.bzl",
202234
"nccl_configure",
203235
)
204236

oss_scripts/configure.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ else
4141
if [[ "$IS_NIGHTLY" == "nightly" ]]; then
4242
pip install tf-nightly
4343
else
44-
pip install tensorflow==2.18.0
44+
pip install tensorflow==2.20.0
4545
fi
4646
fi
4747

oss_scripts/pip_package/BUILD

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
2+
load("@org_tensorflow//third_party/xla/third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps")
3+
14
# Tools for building the TF.Text pip package.
25
load("@python//:defs.bzl", "compile_pip_requirements")
36
load("@python_version_repo//:py_version.bzl", "REQUIREMENTS")
7+
load("//oss_scripts/pip_package:wheel.bzl", "tensorflow_text_source_package", "tensorflow_text_wheel")
48

59
package(default_visibility = ["//visibility:private"])
610

@@ -27,14 +31,62 @@ py_binary(
2731
],
2832
)
2933

30-
sh_binary(
31-
name = "build_pip_package",
32-
srcs = ["build_pip_package.sh"],
33-
data = [
34+
string_flag(
35+
name = "output_path",
36+
build_setting_default = "dist",
37+
)
38+
39+
py_binary(
40+
name = "build_wheel_py",
41+
srcs = ["build_wheel.py"],
42+
main = "build_wheel.py",
43+
deps = [
44+
#":build_utils",
45+
#"@bazel_tools//tools/python/runfiles",
46+
#"@pypi//build",
47+
#"@pypi//setuptools",
48+
#"@pypi//wheel",
49+
],
50+
)
51+
52+
filegroup(
53+
name = "wheel_sources",
54+
srcs = [
3455
"LICENSE",
3556
"MANIFEST.in",
3657
"setup.nightly.py",
37-
"setup.py",
38-
"//tensorflow_text",
58+
":transitive_data_deps",
59+
":transitive_py_deps",
3960
],
4061
)
62+
63+
transitive_py_deps(
64+
name = "transitive_py_deps",
65+
deps = ["//tensorflow_text"],
66+
)
67+
68+
collect_data_files(
69+
name = "transitive_data_deps",
70+
deps = ["//tensorflow_text"],
71+
)
72+
73+
tensorflow_text_wheel(
74+
name = "tensorflow_text_wheel",
75+
srcs = [":wheel_sources"],
76+
)
77+
78+
tensorflow_text_source_package(
79+
name = "tensorflow_text_source_package",
80+
srcs = [":wheel_sources"],
81+
)
82+
#sh_binary(
83+
# name = "build_pip_package",
84+
# srcs = ["build_pip_package.sh"],
85+
# data = [
86+
# "LICENSE",
87+
# "MANIFEST.in",
88+
# "setup.nightly.py",
89+
# "setup.py",
90+
# "//tensorflow_text",
91+
# ],
92+
#)
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# coding=utf-8
2+
# Copyright 2025 TF.Text Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
17+
#
18+
# Licensed under the Apache License, Version 2.0 (the "License");
19+
# you may not use this file except in compliance with the License.
20+
# You may obtain a copy of the License at
21+
#
22+
# http://www.apache.org/licenses/LICENSE-2.0
23+
#
24+
# Unless required by applicable law or agreed to in writing, software
25+
# distributed under the License is distributed on an "AS IS" BASIS,
26+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27+
# See the License for the specific language governing permissions and
28+
# limitations under the License.
29+
# ==============================================================================
30+
"""Script that builds a tf text wheel, intended to be run via bazel."""
31+
32+
import argparse
33+
import os
34+
import pathlib
35+
import shutil
36+
import subprocess
37+
import sys
38+
import tempfile
39+
40+
parser = argparse.ArgumentParser(fromfile_prefix_chars="@")
41+
parser.add_argument(
42+
"--output_path",
43+
default=None,
44+
required=True,
45+
help="Path to which the output wheel should be written. Required.",
46+
)
47+
parser.add_argument(
48+
"--srcs", help="source files for the wheel", action="append"
49+
)
50+
parser.add_argument(
51+
"--build-wheel-only",
52+
default=False,
53+
help="Whether to build the wheel only. Optional.",
54+
)
55+
parser.add_argument(
56+
"--build-source-package-only",
57+
default=False,
58+
help="Whether to build the source package only. Optional.",
59+
)
60+
parser.add_argument(
61+
"--platform",
62+
default="",
63+
required=False,
64+
help="Platform name to be passed to setup.py",
65+
)
66+
args = parser.parse_args()
67+
68+
69+
def copy_file(
70+
src_file: str,
71+
dst_dir: str,
72+
) -> None:
73+
"""Copy a file to the destination directory.
74+
75+
Args:
76+
src_file: file to be copied
77+
dst_dir: destination directory
78+
"""
79+
80+
dest_dir_path = os.path.join(dst_dir, os.path.dirname(src_file))
81+
os.makedirs(dest_dir_path, exist_ok=True)
82+
shutil.copy(src_file, dest_dir_path)
83+
os.chmod(os.path.join(dst_dir, src_file), 0o644)
84+
85+
86+
def prepare_srcs(deps: list[str], srcs_dir: str) -> None:
87+
"""Filter the sources and copy them to the destination directory.
88+
89+
Args:
90+
deps: a list of paths to files.
91+
srcs_dir: target directory where files are copied to.
92+
"""
93+
94+
for file in deps:
95+
print(file)
96+
if not (file.startswith("bazel-out") or file.startswith("external")):
97+
copy_file(file, srcs_dir)
98+
99+
100+
def build_wheel(
101+
dir_path: str,
102+
cwd: str,
103+
platform: str,
104+
) -> None:
105+
"""Build the wheel in the target directory.
106+
107+
Args:
108+
dir_path: directory where the wheel will be stored
109+
cwd: path to directory with wheel source files
110+
platform: platform name to pass to setup.py.
111+
"""
112+
113+
subprocess.run(
114+
[
115+
sys.executable,
116+
"setup.nightly.py",
117+
"bdist_wheel",
118+
f"--dist-dir={dir_path}",
119+
f"--plat-name={platform}",
120+
],
121+
check=True,
122+
cwd=cwd,
123+
)
124+
125+
126+
tmpdir = tempfile.TemporaryDirectory(prefix="tensorflow_text")
127+
sources_path = tmpdir.name
128+
129+
try:
130+
os.makedirs(args.output_path, exist_ok=True)
131+
prepare_srcs(args.srcs, pathlib.Path(sources_path))
132+
build_wheel(
133+
os.path.join(os.getcwd(), args.output_path),
134+
tmpdir.path,
135+
args.platform,
136+
)
137+
finally:
138+
if tmpdir:
139+
tmpdir.cleanup()
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
setuptools==70.0.0
22
dm-tree==0.1.8 # Limit for macos support.
33
numpy
4-
protobuf==4.25.3 # b/397977335 - Fix crash on python 3.9, 3.10.
5-
tensorflow
4+
#protobuf==4.25.3 # b/397977335 - Fix crash on python 3.9, 3.10.
5+
tensorflow==2.20.0
66
tf-keras
7-
tensorflow-datasets
8-
tensorflow-metadata
7+
#tensorflow-datasets
8+
#tensorflow-metadata
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
# Repository rule to generate a file with TF text wheel version.
16+
def _tensorflow_text_python_wheel_repository_impl(repository_ctx):
17+
version_source = repository_ctx.attr.version_source
18+
version_key = repository_ctx.attr.version_key
19+
version_file_content = repository_ctx.read(
20+
repository_ctx.path(version_source),
21+
)
22+
version_start_index = version_file_content.find(version_key)
23+
version_end_index = version_start_index + version_file_content[version_start_index:].find("\n")
24+
wheel_version = version_file_content[version_start_index:version_end_index].replace(
25+
version_key,
26+
"WHEEL_VERSION",
27+
)
28+
repository_ctx.file(
29+
"wheel.bzl",
30+
wheel_version,
31+
)
32+
repository_ctx.file("BUILD", "")
33+
34+
tensorflow_text_python_wheel_repository = repository_rule(
35+
implementation = _tensorflow_text_python_wheel_repository_impl,
36+
attrs = {
37+
"version_source": attr.label(mandatory = True, allow_single_file = True),
38+
"version_key": attr.string(mandatory = True),
39+
},
40+
)

0 commit comments

Comments
 (0)