Skip to content

Commit 567754e

Browse files
authored
Add asset zipping functionality to TFJS converter (#6915)
* Add asset zipping functionality to TFJS converter * Add TFDF to converter requirements * Add TFDF dependency * Fix assets overwrite bug * Make copy assets conditional on TFDF input
1 parent 2cc528b commit 567754e

File tree

7 files changed

+117
-5
lines changed

7 files changed

+117
-5
lines changed

tfjs-converter/python/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ py_wheel(
5252
"importlib_resources>=5.9.0",
5353
"jax>=0.3.16",
5454
"protobuf<3.20,>=3.9.2",
55-
"tensorflow>=2.1.0,<3",
55+
"tensorflow>=2.10.0,<3",
56+
"tensorflow-decision-forests>=1.0.1",
5657
"six>=1.12.0,<2",
5758
"tensorflow-hub>=0.7.0,<0.13",
5859
"packaging~=20.9",

tfjs-converter/python/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ jax>=0.3.16
33
importlib_resources>=5.9.0
44
protobuf<3.20,>=3.9.2
55
tensorflow>=2.10.0,<3
6+
tensorflow-decision-forests>=1.0.1
67
six>=1.12.0,<2
78
tensorflow-hub>=0.7.0,<0.13; python_version >= "3"
89
packaging~=20.9

tfjs-converter/python/tensorflowjs/BUILD

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,15 @@ py_library(
8080
deps = [requirement("tensorflow")],
8181
)
8282

83+
py_library(
84+
name = "expect_tensorflow_decision_forests_installed",
85+
# This is a dummy rule used as a tensorflow dependency in open-source.
86+
# We expect tensorflow-decision-forests to already be installed on
87+
# the system, e.g. via
88+
# `pip install tensorflow-decision-forests`.
89+
deps = [requirement("tensorflow-decision-forests")],
90+
)
91+
8392
py_library(
8493
name = "expect_tensorflow_hub_installed",
8594
# This is a dummy rule used as a tensorflow_hub dependency in open-source.

tfjs-converter/python/tensorflowjs/converters/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ py_library(
188188
":graph_rewrite_util",
189189
"//tfjs-converter/python/tensorflowjs:expect_numpy_installed",
190190
"//tfjs-converter/python/tensorflowjs:expect_packaging_installed",
191+
"//tfjs-converter/python/tensorflowjs:expect_tensorflow_decision_forests_installed",
191192
"//tfjs-converter/python/tensorflowjs:expect_tensorflow_hub_installed",
192193
"//tfjs-converter/python/tensorflowjs:expect_tensorflow_installed",
193194
"//tfjs-converter/python/tensorflowjs:resource_loader",

tfjs-converter/python/tensorflowjs/converters/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
# File name for the indexing JSON file in an artifact directory.
1919
ARTIFACT_MODEL_JSON_FILE_NAME = 'model.json'
20+
ASSETS_DIRECTORY_NAME = 'assets'
2021

2122
# JSON string keys for fields of the indexing JSON.
2223
ARTIFACT_MODEL_TOPOLOGY_KEY = 'modelTopology'

tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,19 @@
2020

2121
import json
2222
import os
23+
import shutil
24+
import tempfile
25+
from zipfile import ZipFile
2326

27+
# Required to load saved models that use TFDF.
28+
import tensorflow_decision_forests
2429
import tensorflow as tf
2530
from tensorflow.core.framework import graph_pb2
2631
from tensorflow.core.framework import node_def_pb2
2732
from tensorflow.core.protobuf import config_pb2
2833
from tensorflow.core.protobuf import device_properties_pb2
2934
from tensorflow.core.protobuf import meta_graph_pb2
35+
from tensorflow.io import gfile
3036
from tensorflow.python.checkpoint.trackable_view import TrackableView
3137
from tensorflow.python.eager import context
3238
from tensorflow.python.framework import convert_to_constants
@@ -399,7 +405,7 @@ def write_artifacts(topology,
399405
assert isinstance(weights_manifest, list)
400406
model_json[common.ARTIFACT_WEIGHTS_MANIFEST_KEY] = weights_manifest
401407

402-
with tf.io.gfile.GFile(output_graph, 'w') as f:
408+
with gfile.GFile(output_graph, 'w') as f:
403409
json.dump(model_json, f)
404410

405411
def _remove_unused_control_flow_inputs(input_graph_def):
@@ -421,6 +427,49 @@ def _check_signature_in_model(saved_model, signature_name):
421427
"are available: %s" % (signature_name,
422428
saved_model.signatures.keys()))
423429

430+
def _copy_assets(saved_model_dir, output_dir):
431+
input_assets_path = os.path.join(saved_model_dir, common.ASSETS_DIRECTORY_NAME)
432+
433+
if gfile.exists(input_assets_path) and gfile.isdir(input_assets_path):
434+
435+
tmp_dir = tempfile.mkdtemp()
436+
zip_path = gfile.join(tmp_dir, common.ASSETS_DIRECTORY_NAME + '.zip')
437+
438+
with ZipFile(zip_path, 'w') as archive:
439+
for (input_dir_path, _, file_names) in gfile.walk(input_assets_path):
440+
441+
relative_dir_path = os.path.relpath(input_dir_path, input_assets_path)
442+
443+
for file_name in file_names:
444+
445+
input_file_path = gfile.join(input_dir_path, file_name)
446+
relative_file_path = gfile.join(relative_dir_path, file_name)
447+
448+
with gfile.GFile(input_file_path, 'rb') as input_file:
449+
with archive.open(relative_file_path, 'w') as relative_file:
450+
shutil.copyfileobj(input_file, relative_file)
451+
452+
output_assets_path = gfile.join(output_dir, common.ASSETS_DIRECTORY_NAME + '.zip')
453+
gfile.copy(zip_path, output_assets_path, overwrite=True)
454+
455+
if gfile.isdir(tmp_dir):
456+
gfile.rmtree(tmp_dir)
457+
458+
# TFDF stores the necessary files for its binary in the assets folder.
459+
ASSET_REQUIRING_OPS = set([
460+
'SimpleMLCreateModelResource'
461+
'SimpleMLLoadModelFromPathWithHandle',
462+
'SimpleMLInferenceOpWithHandle',
463+
])
464+
465+
def _is_assets_required(model_ops):
466+
return not ASSET_REQUIRING_OPS.isdisjoint(model_ops)
467+
468+
def _get_frozen_graph_ops(frozen_graph):
469+
if frozen_graph is None:
470+
return []
471+
return [node.op for node in frozen_graph.as_graph_def().node]
472+
424473

425474
def _freeze_saved_model_v1(saved_model_dir, saved_model_tags,
426475
output_node_names):
@@ -745,8 +794,8 @@ def _convert_tf_saved_model(output_dir,
745794
if signature_def is None:
746795
signature_def = 'serving_default'
747796

748-
if not tf.io.gfile.exists(output_dir):
749-
tf.io.gfile.makedirs(output_dir)
797+
if not gfile.exists(output_dir):
798+
gfile.makedirs(output_dir)
750799
output_graph = os.path.join(
751800
output_dir, common.ARTIFACT_MODEL_JSON_FILE_NAME)
752801

@@ -852,6 +901,12 @@ def _convert_tf_saved_model(output_dir,
852901
# tensorflow version.
853902
tf_version = tf.__version__
854903

904+
if saved_model_dir:
905+
model_ops = set(_get_frozen_graph_ops(frozen_graph)) |\
906+
set(_get_frozen_graph_ops(frozen_initializer_graph))
907+
if _is_assets_required(model_ops):
908+
_copy_assets(saved_model_dir, output_dir)
909+
855910
optimize_graph(frozen_graph, signature,
856911
output_graph, tf_version,
857912
quantization_dtype_map=quantization_dtype_map,
@@ -1137,7 +1192,7 @@ def convert_tf_hub_module(module_handle, output_dir,
11371192
# TODO(vbardiovskyg): We can remove this v1 code path once loading of all v1
11381193
# modules is fixed on the TF side, or once the modules we cannot load become
11391194
# replaced with newer versions.
1140-
if tf.io.gfile.exists(os.path.join(module_path, _HUB_V1_MODULE_PB)):
1195+
if gfile.exists(os.path.join(module_path, _HUB_V1_MODULE_PB)):
11411196
print("Loading the module using TF 1.X interface from %s." % module_path)
11421197
convert_tf_hub_module_v1(module_path, output_dir, signature,
11431198
quantization_dtype_map,

tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
import shutil
2222
import tempfile
2323
import unittest
24+
import numpy as np
2425

2526
import tensorflow.compat.v2 as tf
27+
from tensorflow_decision_forests.keras import GradientBoostedTreesModel
2628
from tensorflow.python.eager import def_function
2729
from tensorflow.python.framework import constant_op
2830
from tensorflow.python.framework import dtypes
@@ -35,6 +37,7 @@
3537
from tensorflowjs import version
3638
from tensorflowjs.converters import graph_rewrite_util
3739
from tensorflowjs.converters import tf_saved_model_conversion_v2
40+
from tensorflowjs.converters.common import ASSETS_DIRECTORY_NAME
3841

3942
SAVED_MODEL_DIR = 'saved_model'
4043
HUB_MODULE_DIR = 'hub_module'
@@ -246,6 +249,22 @@ def find_next_odd(v):
246249
save_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
247250
save(root, save_dir, to_save)
248251

252+
def _create_saved_model_with_tfdf(self):
253+
"""Test a basic TFDF model."""
254+
P = 5
255+
NUM_EXAMPLES = 10
256+
NUM_FEATURES = 4
257+
258+
x_train = np.random.uniform(size=(NUM_EXAMPLES, NUM_FEATURES))
259+
y_train = np.random.uniform(size=NUM_EXAMPLES) > 0.5
260+
w_train = y_train * (P - 1) + 1 # 1 or p depending on the class.
261+
262+
model = GradientBoostedTreesModel()
263+
model.fit(x=x_train, y=y_train, sample_weight=w_train)
264+
265+
save_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
266+
model.save(save_dir)
267+
249268
def _create_unsupported_saved_model(self):
250269
root = tracking.AutoTrackable()
251270
root.w = variables.Variable(tf.random.uniform([2, 2]))
@@ -936,6 +955,31 @@ def test_convert_saved_model_with_control_flow_v2(self):
936955
glob.glob(
937956
os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'group*-*')))
938957

958+
def test_convert_saved_model_with_tfdf(self):
959+
self._create_saved_model_with_tfdf()
960+
961+
tfjs_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
962+
tf_saved_model_conversion_v2.convert_tf_saved_model(
963+
tfjs_path, tfjs_path, skip_op_check=True
964+
)
965+
966+
# Check model.json and weights manifest.
967+
with open(os.path.join(tfjs_path, 'model.json'), 'rt') as f:
968+
model_json = json.load(f)
969+
970+
# Check TFDF ops are present.
971+
model_ops = [node['op'] for node in model_json['modelTopology']['node']]
972+
self.assertTrue('SimpleMLInferenceOpWithHandle' in model_ops)
973+
974+
initializer_ops = [node['op'] for node in model_json['modelInitializer']['node']]
975+
self.assertTrue('SimpleMLCreateModelResource' in initializer_ops)
976+
self.assertTrue('SimpleMLLoadModelFromPathWithHandle' in initializer_ops)
977+
978+
# Check assets containing TFDF files were copied over.
979+
self.assertTrue(
980+
os.path.exists(
981+
os.path.join(tfjs_path, ASSETS_DIRECTORY_NAME + '.zip')))
982+
939983
def test_convert_saved_model_sharded(self):
940984
self._create_saved_model()
941985
model_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)

0 commit comments

Comments
 (0)