Skip to content

Commit 43d69df

Browse files
committed
finalized model import and run
1 parent 6a58911 commit 43d69df

File tree

3 files changed

+64
-38
lines changed

3 files changed

+64
-38
lines changed

Digitz/MainWindow.xaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
</Grid.RowDefinitions>
5151

5252
<TextBlock Name="numberLabel"
53-
FontSize="150"
53+
FontSize="100"
5454
Grid.Row="0" />
5555

5656
<Grid Grid.Row="1">

Digitz/MainWindow.xaml.cs

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,27 @@ public MainWindow()
2727
InitializeComponent();
2828
}
2929

30+
private void clearButton_Click(object sender, RoutedEventArgs e)
31+
{
32+
inkCanvas.Strokes.Clear();
33+
numberLabel.Text = "";
34+
}
35+
36+
private string Stringify(float[] data)
37+
{
38+
StringBuilder sb = new StringBuilder();
39+
for (int i = 0; i < data.Length; i++)
40+
{
41+
if (i == 0) sb.Append("{\r\n\t");
42+
else if (i % 28 == 0)
43+
sb.Append("\r\n\t");
44+
sb.Append($"{data[i],3:##0}, ");
45+
46+
}
47+
sb.Append("\r\n}\r\n");
48+
return sb.ToString();
49+
}
50+
3051
private TFTensor GetWrittenDigit(int size)
3152
{
3253
RenderTargetBitmap b = new RenderTargetBitmap(
@@ -51,36 +72,26 @@ private TFTensor GetWrittenDigit(int size)
5172
// sanity check
5273
Console.Write(Stringify(data));
5374

54-
return new TFTensor(data);
75+
return TFTensor.FromBuffer(new TFShape(1, data.Length), data, 0, data.Length);
5576
}
5677

5778
private void recognizeButton_Click(object sender, RoutedEventArgs e)
5879
{
5980
var tensor = GetWrittenDigit(28);
60-
string modelFile = "saved_model.pb";
61-
var model = File.ReadAllBytes(modelFile);
62-
var graph = new TFGraph();
63-
graph.Import(model, "");
64-
}
6581

66-
private string Stringify(float[] data)
67-
{
68-
StringBuilder sb = new StringBuilder();
69-
for(int i = 0; i < data.Length; i++)
82+
using (var graph = new TFGraph())
7083
{
71-
if (i == 0) sb.Append("{\r\n\t");
72-
else if (i % 28 == 0)
73-
sb.Append("\r\n\t");
74-
sb.Append($"{data[i],3:##0}, ");
75-
84+
graph.Import(File.ReadAllBytes("frozen.pb"));
85+
var session = new TFSession(graph);
86+
var runner = session.GetRunner();
87+
runner.AddInput(graph["InputData"][0], tensor);
88+
runner.Fetch(graph["Model/linear"][0]);
89+
var output = runner.Run();
90+
TFTensor result = output[0];
91+
float[] p = ((float[][])result.GetValue(true))[0];
92+
int guess = Array.IndexOf(p, p.Max());
93+
numberLabel.Text = guess.ToString();
7694
}
77-
sb.Append("\r\n}\r\n");
78-
return sb.ToString();
79-
}
80-
81-
private void clearButton_Click(object sender, RoutedEventArgs e)
82-
{
83-
inkCanvas.Strokes.Clear();
8495
}
8596
}
8697
}

LearnDigitz/LearnDigitz.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sys
33
import os
44
import tensorflow as tf
5+
from tensorflow.python.tools import freeze_graph as freeze
56
from tensorflow.examples.tutorials.mnist import input_data
67
from datetime import *
78

@@ -26,7 +27,7 @@ def main(_):
2627

2728
# Parameters
2829
learning_rate = 0.01
29-
training_epochs = 50
30+
training_epochs = 10
3031
batch_size = 100
3132
display_epoch = 1
3233
unique = datetime.now().strftime('%m-%d_%H_%M')
@@ -40,15 +41,16 @@ def main(_):
4041
y = tf.placeholder(tf.float32, [None, 10], name='LabelData')
4142

4243
# Set model weights
43-
W = tf.Variable(tf.zeros([784, 10]), name='Weights')
44-
b = tf.Variable(tf.zeros([10]), name='Bias')
44+
W = tf.Variable(tf.zeros([784, 10]), name='weights')
45+
b = tf.Variable(tf.zeros([10]), name='bias')
4546

4647
# Construct model and encapsulating all ops into scopes, making
4748
# Tensorboard's Graph visualization more convenient
4849
with tf.name_scope('Model'):
4950
# Model
50-
#pred = tf.nn.softmax(tf.matmul(x, W) + b) # Softmax
51-
pred = tf.matmul(x, W) + b # Softmax
51+
#pred = tf.nn.softmax(tf.matmul(x, W) + b, name="model") # Softmax
52+
pred = tf.add(tf.matmul(x, W), b, name="linear") # linear combination
53+
5254
with tf.name_scope('Loss'):
5355
# Minimize error using cross entropy
5456
# cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))
@@ -112,16 +114,29 @@ def main(_):
112114
print("Accuracy:", acc.eval({x: mnist.test.images, y: mnist.test.labels}))
113115

114116
# saving model
115-
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
116-
builder.add_meta_graph_and_variables(
117-
sess,
118-
[tf.saved_model.tag_constants.SERVING],
119-
signature_def_map = {
120-
"model": tf.saved_model.signature_def_utils.predict_signature_def(
121-
inputs = { "x": x },
122-
outputs = { "prediction": pred })
123-
})
124-
builder.save()
117+
checkpoint = os.path.join(export_path, "model.ckpt")
118+
saver = tf.train.Saver()
119+
# checkpoint - variables
120+
saver.save(sess, checkpoint)
121+
# graph
122+
tf.train.write_graph(sess.graph_def, export_path, "model.pb", as_text=False)
123+
# freeze
124+
# python "Python\Lib\site-packages\tensorflow\python\tools\freeze_graph.py" --input_graph=.\Profile.pb --input_checkpoint=.\Profile.ckpt --output_node_names=Output/Predictions,Output/Loss --output_graph=frozen.pb
125+
g = os.path.join(export_path, "model.pb")
126+
frozen = os.path.join(export_path, "frozen.pb")
127+
128+
freeze.freeze_graph(
129+
input_graph = g,
130+
input_saver = "",
131+
input_binary = True,
132+
input_checkpoint = checkpoint,
133+
output_node_names = "Model/linear",
134+
restore_op_name = "",
135+
filename_tensor_name = "",
136+
output_graph = frozen,
137+
clear_devices = True,
138+
initializer_nodes = ""
139+
)
125140
print("Model saved!")
126141
exit(0)
127142

0 commit comments

Comments
 (0)