Skip to content

Commit da0c991

Browse files
committed
better keras impl
1 parent 5b9eecc commit da0c991

File tree

2 files changed

+27
-31
lines changed

2 files changed

+27
-31
lines changed

LearnDigitz/keras_train.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from tensorflow.keras import Sequential
66
from tensorflow.keras import backend as K
77
from tensorflow.keras.utils import to_categorical
8-
from misc.helpers import print_info, print_args, check_dir
8+
from misc.helpers import print_info, print_args, check_dir, info
99
from tensorflow.python.framework import graph_util, graph_io
1010
from tensorflow.keras.layers import Reshape, Flatten, Dense, Conv2D, MaxPooling2D
1111

@@ -14,11 +14,10 @@ def save(model, model_dir):
1414
print('\nSaving h5 model to {}'.format(m))
1515
model.save(m)
1616
print('Saving pb model to {}'.format(os.path.join(model_dir, 'digits.pb')))
17-
1817

1918
input_node = model.input.name.split(':')[0]
2019
output_node = model.output.name.split(':')[0]
21-
print("Input Tensor:", input_node)
20+
print("\nInput Tensor:", input_node)
2221
print("Output Tensor:", output_node)
2322

2423
K.set_learning_phase(0)
@@ -37,24 +36,18 @@ def load_digits(data_dir):
3736
return (x_train, y_train), (x_test, y_test)
3837

3938
###################################################################
40-
# Simple (W.T * X + b) #
39+
# shapes #
4140
###################################################################
4241
def linear():
4342
return Sequential([Dense(10)])
4443

45-
###################################################################
46-
# Neural Network #
47-
###################################################################
4844
def mlp():
4945
return Sequential([
5046
Dense(512, activation='relu'),
5147
Dense(512, activation='relu'),
5248
Dense(10, activation='softmax')
5349
])
5450

55-
###################################################################
56-
# Convolutional Neural Network #
57-
###################################################################
5851
def cnn():
5952
return Sequential([
6053
Reshape((28, 28, 1)),
@@ -67,28 +60,37 @@ def cnn():
6760
Dense(10, activation='softmax')
6861
])
6962

63+
64+
@print_info
7065
def run(data_dir, model_dir, epochs):
7166
# get data
7267
(x_train, y_train), (x_test, y_test) = load_digits(data_dir)
73-
68+
7469
# create model structure
7570
model = cnn()
76-
71+
7772
# compile model
7873
model.compile(loss='mse', optimizer='adam', metrics=['accuracy'])
79-
74+
8075
# run model
8176
model.fit(x_train, y_train, epochs=epochs)
8277
model.summary()
83-
model.evaluate(x_test, y_test)
84-
78+
evaluation = model.evaluate(x_test, y_test)
79+
80+
# save model
81+
info('Output...')
8582
save(model, model_dir)
83+
84+
# metrics
85+
info('Metrics...')
86+
print('Loss: {}'.format(evaluation[0]))
87+
print('Accuracy: {}'.format(evaluation[1]))
8688

8789

8890
if __name__ == "__main__":
8991
data_dir = check_dir(os.path.abspath('data'))
9092
output_dir = os.path.abspath('output')
9193
unique = datetime.now().strftime('%m.%d_%H.%M')
9294
model_dir = check_dir(os.path.join(output_dir, 'models', 'model_{}'.format(unique)))
93-
epochs = 5
95+
epochs = 10
9496
run(data_dir, model_dir, epochs)

LearnDigitz/train.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,14 @@ def linear_better(x, init=tf.zeros):
105105
return tf.identity(pred, name="prediction")
106106

107107
@print_info
108-
def mlp_better(x, hidden=[512, 512]):
109-
last_output = x
110-
for i in range(len(hidden)):
111-
# layer n
112-
last_output = tf.layers.dense(inputs=last_output, units=hidden[i], activation=tf.nn.relu)
108+
def mlp_better(x):
109+
# hidden layers
110+
h1 = tf.layers.dense(inputs=x, units=512, activation=tf.nn.relu)
111+
h2 = tf.layers.dense(inputs=h1, units=512, activation=tf.nn.relu)
113112

114113
# output layer
115114
with tf.name_scope("Model"):
116-
pred = tf.layers.dense(inputs=last_output, units=10, activation=tf.nn.softmax)
115+
pred = tf.layers.dense(inputs=h2, units=10, activation=tf.nn.softmax)
117116
return tf.identity(pred, name="prediction")
118117

119118
@print_info
@@ -136,7 +135,7 @@ def cnn_better(x):
136135
dense = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu)
137136

138137
with tf.name_scope('Model'):
139-
pred = tf.layers.dense(inputs=dense, units=10)
138+
pred = tf.layers.dense(inputs=dense, units=10, activation=tf.nn.softmax)
140139
return tf.identity(pred, name="prediction")
141140

142141
###################################################################
@@ -168,18 +167,13 @@ def train_model(x, y, cost, optimizer, accuracy, learning_rate, batch_size, epoc
168167
acc = 0.
169168
info('Training')
170169
# epochs to run
171-
with trange(epochs, desc="{:<10}".format("Training"),
172-
bar_format='{l_bar}{bar}|{postfix}',
173-
postfix=" acc: 0.0000") as t:
170+
with trange(epochs, desc="{:<10}".format("Training"), bar_format='{l_bar}{bar}|{postfix}', postfix=" acc: 0.0000") as t:
174171
for epoch in t:
175172
avg_cost = 0.
176173
t.postfix = ' acc: {:.4f}'.format(acc)
177174
t.update()
178175
# loop over all batches
179-
with tqdm(enumerate(digits),
180-
total=digits.total,
181-
desc="{:<10}".format("Epoch {}".format(epoch + 1)),
182-
bar_format='{l_bar}{bar}|{postfix}') as progress:
176+
with tqdm(enumerate(digits), total=digits.total, desc="{:<10}".format("Epoch {}".format(epoch + 1)), bar_format='{l_bar}{bar}|{postfix}') as progress:
183177
for i, (train_x, train_y) in progress:
184178
# Run optimization, cost, and summary
185179
_, c, summary = sess.run([optimizer, cost, merged_summary_op],
@@ -239,7 +233,7 @@ def main(settings):
239233

240234
args.data = check_dir(os.path.abspath(args.data))
241235
args.output = os.path.abspath(args.output)
242-
unique = datetime.now().strftime('%m.%d_%H.%M')
236+
unique = datetime.now().strftime('%m.%d_%H.%M')
243237
args.log = check_dir(os.path.join(args.output, 'logs', 'log_{}'.format(unique)))
244238
args.model = check_dir(os.path.join(args.output, 'models', 'model_{}'.format(unique)))
245239

0 commit comments

Comments
 (0)