4040
4141flags .DEFINE_string ('precision' , 'fp32' , 'precision' )
4242
43+ flags .DEFINE_bool ('tune' , False , 'if use tune' )
44+
4345flags .DEFINE_string ('config' , None , 'yaml configuration for tuning' )
4446
4547FLAGS = flags .FLAGS
@@ -124,7 +126,6 @@ def main(args=None):
124126 sess .run (tf .global_variables_initializer ())
125127 saver .restore (sess , FLAGS .input_model )
126128 graph_def = sess .graph .as_graph_def ()
127- _parse_ckpt_bn_input (graph_def )
128129
129130 replace_style = 'style_image_processing/ResizeBilinear_2'
130131 replace_content = 'batch_processing/batch'
@@ -136,6 +137,8 @@ def main(args=None):
136137 if replace_style == input_name :
137138 node .input [idx ] = 'style_input'
138139
140+ if FLAGS .tune :
141+ _parse_ckpt_bn_input (graph_def )
139142 output_name = 'transformer/expand/conv3/conv/Sigmoid'
140143 frozen_graph = tf .graph_util .convert_variables_to_constants (sess , graph_def , [output_name ])
141144 # use frozen pb instead
@@ -147,7 +150,7 @@ def main(args=None):
147150 print ("not supported model format" )
148151 exit (- 1 )
149152
150- if FLAGS .precision == 'fp32' :
153+ if FLAGS .tune :
151154 with tf .Graph ().as_default () as graph :
152155 tf .import_graph_def (frozen_graph )
153156 tuner = ilit .Tuner (FLAGS .config )
@@ -157,6 +160,8 @@ def main(args=None):
157160 with tf .io .gfile .GFile (FLAGS .output_model , "wb" ) as f :
158161 f .write (quantized_model .as_graph_def ().SerializeToString ())
159162
163+ frozen_graph = quantized_model .as_graph_def ()
164+
160165 # validate the quantized model here
161166 with tf .Graph ().as_default (), tf .Session () as sess :
162167 # create dataloader using default style_transfer dataset and generate stylized images
@@ -166,7 +171,7 @@ def main(args=None):
166171 resize_shape = (256 , 256 ))
167172 dataloader = ilit .data .DataLoader ('tensorflow' , dataset = dataset )
168173 tf .import_graph_def (frozen_graph )
169- style_transfer (sess , dataloader , 'quantized' )
174+ style_transfer (sess , dataloader , FLAGS . precision )
170175
171176def add_import_to_name (sess , name , try_cnt = 2 ):
172177 for i in range (0 , try_cnt ):
0 commit comments