Skip to content

Commit 5b3f106

Browse files
ClarkChin08ftian1
authored andcommitted
while not use ilit_tune, do not parse bn input to avoid tf1.15/2.1 bn crash
1 parent 0ee25e6 commit 5b3f106

File tree

4 files changed

+10
-6
lines changed

4 files changed

+10
-6
lines changed
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
1-
intel_tensorflow==2.2.0
21
scikit-image
32
Pillow

examples/tensorflow/style_transfer/run_benchmark.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ function run_benchmark {
5353
--style_images_paths "${style_images}" \
5454
--content_images_paths "${content_images}" \
5555
--config "./conf.yaml" \
56-
--precision "int8" \
56+
--tune=False \
5757
--output_model "${output_model}"
5858

5959
}

examples/tensorflow/style_transfer/run_tuning.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ function run_tuning {
4444
--style_images_paths "${style_images}" \
4545
--content_images_paths "${content_images}" \
4646
--config "./conf.yaml" \
47-
--precision "fp32" \
47+
--tune=True \
4848
--output_model "${output_model}"
4949
}
5050

examples/tensorflow/style_transfer/style_tune.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040

4141
flags.DEFINE_string('precision', 'fp32', 'precision')
4242

43+
flags.DEFINE_bool('tune', False, 'if use tune')
44+
4345
flags.DEFINE_string('config', None, 'yaml configuration for tuning')
4446

4547
FLAGS = flags.FLAGS
@@ -124,7 +126,6 @@ def main(args=None):
124126
sess.run(tf.global_variables_initializer())
125127
saver.restore(sess, FLAGS.input_model)
126128
graph_def = sess.graph.as_graph_def()
127-
_parse_ckpt_bn_input(graph_def)
128129

129130
replace_style = 'style_image_processing/ResizeBilinear_2'
130131
replace_content = 'batch_processing/batch'
@@ -136,6 +137,8 @@ def main(args=None):
136137
if replace_style == input_name:
137138
node.input[idx] = 'style_input'
138139

140+
if FLAGS.tune:
141+
_parse_ckpt_bn_input(graph_def)
139142
output_name = 'transformer/expand/conv3/conv/Sigmoid'
140143
frozen_graph = tf.graph_util.convert_variables_to_constants(sess, graph_def, [output_name])
141144
# use frozen pb instead
@@ -147,7 +150,7 @@ def main(args=None):
147150
print("not supported model format")
148151
exit(-1)
149152

150-
if FLAGS.precision == 'fp32':
153+
if FLAGS.tune:
151154
with tf.Graph().as_default() as graph:
152155
tf.import_graph_def(frozen_graph)
153156
tuner = ilit.Tuner(FLAGS.config)
@@ -157,6 +160,8 @@ def main(args=None):
157160
with tf.io.gfile.GFile(FLAGS.output_model, "wb") as f:
158161
f.write(quantized_model.as_graph_def().SerializeToString())
159162

163+
frozen_graph= quantized_model.as_graph_def()
164+
160165
# validate the quantized model here
161166
with tf.Graph().as_default(), tf.Session() as sess:
162167
# create dataloader using default style_transfer dataset and generate stylized images
@@ -166,7 +171,7 @@ def main(args=None):
166171
resize_shape=(256, 256))
167172
dataloader = ilit.data.DataLoader('tensorflow', dataset=dataset)
168173
tf.import_graph_def(frozen_graph)
169-
style_transfer(sess, dataloader, 'quantized')
174+
style_transfer(sess, dataloader, FLAGS.precision)
170175

171176
def add_import_to_name(sess, name, try_cnt=2):
172177
for i in range(0, try_cnt):

0 commit comments

Comments
 (0)