2
2
import sys
3
3
import os
4
4
import tensorflow as tf
5
+ from tensorflow .python .tools import freeze_graph as freeze
5
6
from tensorflow .examples .tutorials .mnist import input_data
6
7
from datetime import *
7
8
@@ -26,7 +27,7 @@ def main(_):
26
27
27
28
# Parameters
28
29
learning_rate = 0.01
29
- training_epochs = 50
30
+ training_epochs = 10
30
31
batch_size = 100
31
32
display_epoch = 1
32
33
unique = datetime .now ().strftime ('%m-%d_%H_%M' )
@@ -40,15 +41,16 @@ def main(_):
40
41
y = tf .placeholder (tf .float32 , [None , 10 ], name = 'LabelData' )
41
42
42
43
# Set model weights
43
- W = tf .Variable (tf .zeros ([784 , 10 ]), name = 'Weights ' )
44
- b = tf .Variable (tf .zeros ([10 ]), name = 'Bias ' )
44
+ W = tf .Variable (tf .zeros ([784 , 10 ]), name = 'weights ' )
45
+ b = tf .Variable (tf .zeros ([10 ]), name = 'bias ' )
45
46
46
47
# Construct model and encapsulating all ops into scopes, making
47
48
# Tensorboard's Graph visualization more convenient
48
49
with tf .name_scope ('Model' ):
49
50
# Model
50
- #pred = tf.nn.softmax(tf.matmul(x, W) + b) # Softmax
51
- pred = tf .matmul (x , W ) + b # Softmax
51
+ #pred = tf.nn.softmax(tf.matmul(x, W) + b, name="model") # Softmax
52
+ pred = tf .add (tf .matmul (x , W ), b , name = "linear" ) # linear combination
53
+
52
54
with tf .name_scope ('Loss' ):
53
55
# Minimize error using cross entropy
54
56
# cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))
@@ -112,16 +114,29 @@ def main(_):
112
114
print ("Accuracy:" , acc .eval ({x : mnist .test .images , y : mnist .test .labels }))
113
115
114
116
# saving model
115
- builder = tf .saved_model .builder .SavedModelBuilder (export_path )
116
- builder .add_meta_graph_and_variables (
117
- sess ,
118
- [tf .saved_model .tag_constants .SERVING ],
119
- signature_def_map = {
120
- "model" : tf .saved_model .signature_def_utils .predict_signature_def (
121
- inputs = { "x" : x },
122
- outputs = { "prediction" : pred })
123
- })
124
- builder .save ()
117
+ checkpoint = os .path .join (export_path , "model.ckpt" )
118
+ saver = tf .train .Saver ()
119
+ # checkpoint - variables
120
+ saver .save (sess , checkpoint )
121
+ # graph
122
+ tf .train .write_graph (sess .graph_def , export_path , "model.pb" , as_text = False )
123
+ # freeze
124
+ # python "Python\Lib\site-packages\tensorflow\python\tools\freeze_graph.py" --input_graph=.\Profile.pb --input_checkpoint=.\Profile.ckpt --output_node_names=Output/Predictions,Output/Loss --output_graph=frozen.pb
125
+ g = os .path .join (export_path , "model.pb" )
126
+ frozen = os .path .join (export_path , "frozen.pb" )
127
+
128
+ freeze .freeze_graph (
129
+ input_graph = g ,
130
+ input_saver = "" ,
131
+ input_binary = True ,
132
+ input_checkpoint = checkpoint ,
133
+ output_node_names = "Model/linear" ,
134
+ restore_op_name = "" ,
135
+ filename_tensor_name = "" ,
136
+ output_graph = frozen ,
137
+ clear_devices = True ,
138
+ initializer_nodes = ""
139
+ )
125
140
print ("Model saved!" )
126
141
exit (0 )
127
142
0 commit comments