Skip to content

Commit 88dd402

Browse files
Revert "Add resource initializer support (#6826)" (#6900)
This is causing regression e2e tests to fail: 1) saved_model_v1_with_hashtable. #REGRESSION convert_predict webgl {"WEBGL_VERSION":2,"WEBGL_CPU_FORWARD":false,"WEBGL_SIZE_UPLOAD_UNIFORM":0} Error: Arrays differ: actual[0] = -1, expected[0] = 3. To reproduce this, use node 16 in e2e/ and run `NIGHTLY=true ./scripts/test-ci.sh`, or, after running that to generate the required files, run `yarn karma start --tags '#REGRESSION'`. This reverts commit 42dee16.
1 parent c5eb1d3 commit 88dd402

18 files changed

Lines changed: 256 additions & 1092 deletions

e2e/integration_tests/constants.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ export const CONVERT_PREDICT_MODELS = {
3737
'saved_model_v1', 'saved_model_v2', 'saved_model_v2_with_control_flow',
3838
'saved_model_with_conv2d', 'saved_model_with_prelu',
3939
'saved_model_v2_complex64', 'saved_model_v2_with_control_flow_v2',
40-
'saved_model_v2_with_tensorlist_ops', 'saved_model_v1_with_hashtable',
41-
'saved_model_v2_with_hashtable'
40+
'saved_model_v2_with_tensorlist_ops', 'saved_model_v1_with_hashtable'
4241
],
4342
layers_model: ['mobilenet']
4443
};

e2e/integration_tests/convert_predict.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -427,47 +427,6 @@ def _create_saved_model_v1_with_hashtable(save_dir):
427427
}
428428
}
429429

430-
def _create_saved_model_v2_with_hashtable(save_dir):
431-
"""Test a TF V2 model with HashTable Ops.
432-
433-
Args:
434-
save_dir: directory name of where the saved model will be stored.
435-
"""
436-
class Table(tf.Module):
437-
def __init__(self):
438-
super(Table, self).__init__()
439-
keys = tf.constant(['a', 'b'])
440-
vals= tf.constant([0, 1])
441-
init = tf.lookup.KeyValueTensorInitializer(keys, vals)
442-
self.table = tf.lookup.StaticHashTable(init, -1)
443-
444-
def initializeTable(self):
445-
@tf.function
446-
def lookup(input):
447-
return self.table.lookup(input)
448-
449-
return lookup
450-
451-
model = Table()
452-
concrete_fn = model.initializeTable().get_concrete_function(
453-
input=tf.TensorSpec([None], tf.string))
454-
455-
tf.saved_model.save(model, save_dir, signatures={"serving_default": concrete_fn})
456-
457-
return {
458-
"async": False,
459-
"inputs": {
460-
"Placeholder:0": {
461-
"value": ["a", "b", "c"], "shape": [3], "dtype": "string"
462-
}
463-
},
464-
"outputs": {
465-
"StatefulPartitionedCall/None_Lookup/LookupTableFindV2:0": {
466-
"value": [0, 1, -1], "shape": [3], "dtype": "int32"
467-
}
468-
}
469-
}
470-
471430
def _layers_mobilenet():
472431
model = tf.keras.applications.MobileNetV2()
473432
model_path = 'mobilenet'
@@ -512,8 +471,6 @@ def main():
512471
'saved_model_v2_with_tensorlist_ops', control_flow_v2=True)
513472
_save_and_convert_model(_create_saved_model_v1_with_hashtable,
514473
'saved_model_v1_with_hashtable')
515-
_save_and_convert_model(_create_saved_model_v2_with_hashtable,
516-
'saved_model_v2_with_hashtable')
517474

518475
_layers_mobilenet()
519476
if __name__ == '__main__':

e2e/yarn.lock

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,6 +1012,11 @@
10121012
dependencies:
10131013
detect-browser "*"
10141014

1015+
"@types/emscripten@~0.0.34":
1016+
version "0.0.34"
1017+
resolved "https://registry.yarnpkg.com/@types/emscripten/-/emscripten-0.0.34.tgz#12b4a344274fb102ff2f6c877b37587bc3e46008"
1018+
integrity sha512-QSb9ojDincskc+uKMI0KXp8e1NALFINCrMlp8VGKGcTSxeEyRTTKyjWw75NYrCZHUsVEEEpr1tYHpbtaC++/sQ==
1019+
10151020
"@types/jasmine@~3.0.0":
10161021
version "3.0.0"
10171022
resolved "https://registry.yarnpkg.com/@types/jasmine/-/jasmine-3.0.0.tgz#9a6b6755a02fcd6baa088a767557709c79728f98"

tfjs-converter/python/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ flax>=0.5.3
22
jax>=0.3.16
33
importlib_resources>=5.9.0
44
protobuf<3.20,>=3.9.2
5-
tensorflow>=2.10.0,<3
5+
tensorflow>=2.1.0,<3
66
six>=1.12.0,<2
77
tensorflow-hub>=0.7.0,<0.13; python_version >= "3"
88
packaging~=20.9

tfjs-converter/python/tensorflowjs/converters/common.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,8 @@
3131
CONVERTED_BY_KEY = 'convertedBy'
3232

3333
SIGNATURE_KEY = 'signature'
34-
INITIALIZER_SIGNATURE_KEY = 'initializerSignature'
3534
USER_DEFINED_METADATA_KEY = 'userDefinedMetadata'
3635
STRUCTURED_OUTPUTS_KEYS_KEY = 'structuredOutputKeys'
37-
RESOURCE_ID_KEY = 'resourceId'
3836

3937
# Model formats.
4038
KERAS_SAVED_MODEL = 'keras_saved_model'

tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py

Lines changed: 3 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from tensorflow.core.protobuf import config_pb2
2828
from tensorflow.core.protobuf import device_properties_pb2
2929
from tensorflow.core.protobuf import meta_graph_pb2
30-
from tensorflow.python.checkpoint.trackable_view import TrackableView
3130
from tensorflow.python.eager import context
3231
from tensorflow.python.framework import convert_to_constants
3332
from tensorflow.python.grappler import cluster as gcluster
@@ -39,7 +38,6 @@
3938
from tensorflow.python.saved_model import loader
4039
from tensorflow.python.training.saver import export_meta_graph
4140
from tensorflow.python.tools.saved_model_cli import get_signature_def_map
42-
from tensorflow.saved_model.experimental import TrackableResource
4341
from google.protobuf.json_format import MessageToDict
4442
import tensorflow_hub as hub
4543
from packaging import version
@@ -127,7 +125,6 @@ def optimize_graph(graph, signature_def, output_graph,
127125
weight_shard_size_bytes=1024 * 1024 * 4,
128126
experiments=False,
129127
initializer_graph=None,
130-
resource_ids_maps=None,
131128
metadata=None):
132129
"""Takes a Python Graph object and optimizes the graph.
133130
@@ -144,9 +141,6 @@ def optimize_graph(graph, signature_def, output_graph,
144141
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
145142
The size of each weight file will be <= this value.
146143
initializer_graph: The frozen graph for initializers.
147-
resource_ids_maps: Tuple of two dictionaries, one
148-
mapping inference input names to resource id, and the other
149-
mapping initializer output names to resource id.
150144
metadata: User defined metadata map.
151145
"""
152146

@@ -217,17 +211,13 @@ def optimize_graph(graph, signature_def, output_graph,
217211
', '.join(unsupported))
218212

219213
initializer_graph_def = None
220-
initializer_signature_def = None
221214
if initializer_graph:
222215
initializer_graph_def = initializer_graph.as_graph_def()
223-
if hasattr(initializer_graph, 'outputs'):
224-
initializer_signature_def = _build_signature_def(initializer_graph, [], initializer_graph.outputs)
225216

226217
extract_weights(
227218
optimized_graph, output_graph, tf_version,
228219
signature_def, quantization_dtype_map, weight_shard_size_bytes,
229-
initializer_graph_def, initializer_signature_def,
230-
resource_ids_maps=resource_ids_maps, metadata=metadata)
220+
initializer_graph_def, metadata=metadata)
231221

232222
def extract_const_nodes(nodes):
233223
"""Takes a list of nodes and extract the weights. Return weight manifest
@@ -266,8 +256,6 @@ def extract_weights(graph_def,
266256
quantization_dtype_map=None,
267257
weight_shard_size_bytes=1024 * 1024 * 4,
268258
initializer_graph_def=None,
269-
initializer_signature_def=None,
270-
resource_ids_maps=None,
271259
metadata=None):
272260
"""Takes a Python GraphDef object and extract the weights.
273261
@@ -283,10 +271,6 @@ def extract_weights(graph_def,
283271
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
284272
The size of each weight file will be <= this value.
285273
initializer_graph_def: tf.GraphDef proto object for initializer graph.
286-
initializer_signature_def: the SignatureDef of the initializer graph.
287-
resource_ids_maps: Tuple of two dictionaries, one
288-
mapping inference input names to resource id, and the other
289-
mapping initializer output names to resource id.
290274
metadata: User defined metadata map.
291275
"""
292276
global_manifest = extract_const_nodes(graph_def.node)
@@ -314,8 +298,6 @@ def extract_weights(graph_def,
314298
quantization_dtype_map=quantization_dtype_map,
315299
weight_shard_size_bytes=weight_shard_size_bytes,
316300
initializer_graph_def=initializer_graph_def,
317-
initializer_signature_def=initializer_signature_def,
318-
resource_ids_maps=resource_ids_maps,
319301
metadata=metadata)
320302

321303
def write_artifacts(topology,
@@ -326,8 +308,6 @@ def write_artifacts(topology,
326308
quantization_dtype_map=None,
327309
weight_shard_size_bytes=1024 * 1024 * 4,
328310
initializer_graph_def=None,
329-
initializer_signature_def=None,
330-
resource_ids_maps=None,
331311
metadata=None):
332312
"""Writes weights and topology to the output_dir.
333313
@@ -346,10 +326,6 @@ def write_artifacts(topology,
346326
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
347327
The size of each weight file will be <= this value.
348328
initializer_graph_def: tf.GraphDef proto object for initializer graph.
349-
initializer_signature_def: the SignatureDef of the initializer graph.
350-
resource_ids_maps: Tuple of two dictionaries, one
351-
mapping inference input names to resource id, and the other
352-
mapping initializer output names to resource id.
353329
metadata: User defined metadata map.
354330
"""
355331
model_json = {
@@ -367,30 +343,6 @@ def write_artifacts(topology,
367343
if initializer_graph_def and initializer_graph_def.node:
368344
model_json[common.ARTIFACT_MODEL_INITIALIZER] = MessageToDict(
369345
initializer_graph_def)
370-
if initializer_signature_def:
371-
model_json[common.INITIALIZER_SIGNATURE_KEY] = MessageToDict(
372-
initializer_signature_def)
373-
374-
# Assign resource ids to inference inputs and initializer outputs. In
375-
# TensorFlow, both inference and initializer graphs have a reference
376-
# to the common resource (so initializer runs on reference, and then inference
377-
# graph uses it). We are doing something similar but instead of assigning
378-
# a reference to the resource in the serialized graph, we assign the id
379-
# of the resource, and then we can recreate the common reference in javascript
380-
# by matching resource ids.
381-
if resource_ids_maps is not None:
382-
model_input_to_resource_id, init_output_to_resource_id = resource_ids_maps
383-
signature_inputs = model_json[common.SIGNATURE_KEY]['inputs']
384-
initializer_signature_outputs = model_json[common.INITIALIZER_SIGNATURE_KEY]['outputs']
385-
386-
for (input, resource_id) in model_input_to_resource_id.items():
387-
if input in signature_inputs:
388-
signature_inputs[input][common.RESOURCE_ID_KEY] = resource_id
389-
390-
for (output, resource_id) in init_output_to_resource_id.items():
391-
if output in initializer_signature_outputs:
392-
initializer_signature_outputs[output][common.RESOURCE_ID_KEY] = resource_id
393-
394346

395347
weights_manifest = write_weights.write_weights(
396348
weights, os.path.dirname(output_graph), write_manifest=False,
@@ -598,108 +550,6 @@ def _find_signature(saved_model_dir, saved_model_tags, signature_def):
598550

599551
return signature_def_map[signature_def]
600552

601-
def _get_resource_initializer_concrete_function(model):
602-
"""Create a tf.function that creates and initializes all the resources used by the model.
603-
For more information on resources, please see the TensorFlow code:
604-
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/trackable/resource.py#L232
605-
Args:
606-
model: Loaded saved model.
607-
608-
Returns:
609-
Nullable. A concrete function.
610-
"""
611-
trackable_view = TrackableView(model)
612-
model_resources = [obj for obj in trackable_view.descendants() if isinstance(obj, TrackableResource)]
613-
614-
if not model_resources:
615-
return None
616-
617-
# A list holding tuples of (TrackableResource, captured_input_index) where
618-
# TrackableResource represents one resource in the model
619-
# (a hash table for example), and captured_input_index is the resource
620-
# initialization function's captured input index corresponding
621-
# to the TrackableResource. Captured inputs are simply inputs not provided
622-
# directly be user, but by the model.
623-
model_resources_with_captured_input_index = []
624-
for model_resource in model_resources:
625-
# A runtime id that is unique across different resources, and constant
626-
# across graphs.
627-
resource_handle_id = model_resource.resource_handle._id
628-
# the _initialize function initializes the resource, so one of its captured
629-
# inputs must be the resource, so search for that input.
630-
captured_inputs = model_resource._initialize.get_concrete_function()._captured_inputs
631-
for captured_input_index in range(len(captured_inputs)):
632-
if captured_inputs[captured_input_index]._id == resource_handle_id:
633-
model_resources_with_captured_input_index.append((model_resource, captured_input_index))
634-
635-
@tf.function()
636-
def resource_initializer():
637-
# Recreate resources to capture them in this tf.function.
638-
new_resources = []
639-
for (model_resource, captured_input_index) in model_resources_with_captured_input_index:
640-
# Make a new resource (that is identical to the old, but captured in
641-
# this functon only).
642-
new_resource = model_resource._create_resource()
643-
new_resources.append(new_resource)
644-
645-
# Since we precomputed the captured input corresponding to this resource,
646-
# we can directly replace it with the copy new_resource. If we don't do
647-
# this, then _initialize will not get capture in this graph since the
648-
# old resource was already initialized in TF model load.
649-
model_resource._initialize.get_concrete_function()._captured_inputs[captured_input_index] = new_resource
650-
model_resource._initialize()
651-
652-
return new_resources
653-
654-
# Add resource_initializer to the output graph.
655-
return resource_initializer.get_concrete_function()
656-
657-
def _get_resource_ids_maps(model, concrete_func, resource_init_concrete_func):
658-
"""Generates dictionaries that map tensor names to the loaded saved model resource id,
659-
allowing for matching of initializer outputs to inference inputs.
660-
661-
Args:
662-
model: Loaded saved model.
663-
concrete_func: Concrete function of the inference graph.
664-
resource_init_concrete_func: Concrete function of the initializer graph.
665-
666-
Returns:
667-
A dictionary mapping inference input names to resource id.
668-
A dictionary mapping initializer output names to resource id.
669-
"""
670-
trackable_view = TrackableView(model)
671-
model_resources = [obj for obj in trackable_view.descendants() if isinstance(obj, TrackableResource)]
672-
673-
674-
# Each resource has a unique runtime resource id associated with it which
675-
# can be used across graphs, so we extract it here from inference
676-
# graph for use later.
677-
resource_id_to_captured_input_index = {
678-
captured_input._id : captured_input_index for \
679-
captured_input_index, captured_input in \
680-
enumerate(concrete_func._captured_inputs)
681-
}
682-
# Captured inputs always come after user provided inputs.
683-
captured_input_index_offset = len(concrete_func.inputs) - len(concrete_func._captured_inputs)
684-
685-
model_input_to_resource_id = {}
686-
init_output_to_resource_id = {}
687-
for i, resource in enumerate(model_resources):
688-
_id = resource.resource_handle._id
689-
# Get input from inference graph corresponding to this resource.
690-
captured_input_index = resource_id_to_captured_input_index[_id]
691-
model_input = concrete_func.inputs[captured_input_index + captured_input_index_offset]
692-
693-
# Get output from initializer graph corresponding to this resource.
694-
init_output = resource_init_concrete_func.outputs[i]
695-
696-
# Match both with the same id (initializer output will be passed in to
697-
# corresponding input in inference input).
698-
model_input_to_resource_id[model_input.name] = _id
699-
init_output_to_resource_id[init_output.name] = _id
700-
701-
return (model_input_to_resource_id, init_output_to_resource_id)
702-
703553
def _convert_tf_saved_model(output_dir,
704554
saved_model_dir=None,
705555
keras_model=None,
@@ -813,15 +663,8 @@ def _convert_tf_saved_model(output_dir,
813663
# reliable way. Try to freeze the graph using V2 utils. If that fails, freeze
814664
# the graph using V1 utils.
815665
frozen_initializer_graph = None
816-
resource_ids_maps = None
817666
try:
818667
frozen_graph = _freeze_saved_model_v2(concrete_func, control_flow_v2)
819-
resource_initializer_concrete_func = _get_resource_initializer_concrete_function(model)
820-
821-
if resource_initializer_concrete_func:
822-
frozen_initializer_graph = _freeze_saved_model_v2(resource_initializer_concrete_func, control_flow_v2)
823-
resource_ids_maps = _get_resource_ids_maps(model, concrete_func, resource_initializer_concrete_func)
824-
825668
except BaseException:
826669
if saved_model_dir:
827670
(frozen_graph,
@@ -839,8 +682,9 @@ def _convert_tf_saved_model(output_dir,
839682
with tf.compat.v1.gfile.GFile(frozen_file, 'wb') as f:
840683
f.write(frozen_graph.as_graph_def().SerializeToString())
841684

685+
inputs = [x for x in concrete_func.inputs if not x.dtype == 'resource']
842686
signature = _build_signature_def(
843-
frozen_graph, concrete_func.inputs, concrete_func.outputs, saved_model_sigature)
687+
frozen_graph, inputs, concrete_func.outputs, saved_model_sigature)
844688

845689
define_transform_graph_func()
846690

@@ -860,7 +704,6 @@ def _convert_tf_saved_model(output_dir,
860704
weight_shard_size_bytes=weight_shard_size_bytes,
861705
experiments=experiments,
862706
initializer_graph=frozen_initializer_graph,
863-
resource_ids_maps=resource_ids_maps,
864707
metadata=metadata)
865708

866709
def define_transform_graph_func():

0 commit comments

Comments
 (0)