@@ -105,15 +105,14 @@ def linear_better(x, init=tf.zeros):
105
105
return tf .identity (pred , name = "prediction" )
106
106
107
107
@print_info
108
- def mlp_better (x , hidden = [512 , 512 ]):
109
- last_output = x
110
- for i in range (len (hidden )):
111
- # layer n
112
- last_output = tf .layers .dense (inputs = last_output , units = hidden [i ], activation = tf .nn .relu )
108
+ def mlp_better (x ):
109
+ # hidden layers
110
+ h1 = tf .layers .dense (inputs = x , units = 512 , activation = tf .nn .relu )
111
+ h2 = tf .layers .dense (inputs = h1 , units = 512 , activation = tf .nn .relu )
113
112
114
113
# output layer
115
114
with tf .name_scope ("Model" ):
116
- pred = tf .layers .dense (inputs = last_output , units = 10 , activation = tf .nn .softmax )
115
+ pred = tf .layers .dense (inputs = h2 , units = 10 , activation = tf .nn .softmax )
117
116
return tf .identity (pred , name = "prediction" )
118
117
119
118
@print_info
@@ -136,7 +135,7 @@ def cnn_better(x):
136
135
dense = tf .layers .dense (inputs = pool2_flat , units = 1024 , activation = tf .nn .relu )
137
136
138
137
with tf .name_scope ('Model' ):
139
- pred = tf .layers .dense (inputs = dense , units = 10 )
138
+ pred = tf .layers .dense (inputs = dense , units = 10 , activation = tf . nn . softmax )
140
139
return tf .identity (pred , name = "prediction" )
141
140
142
141
###################################################################
@@ -168,18 +167,13 @@ def train_model(x, y, cost, optimizer, accuracy, learning_rate, batch_size, epoc
168
167
acc = 0.
169
168
info ('Training' )
170
169
# epochs to run
171
- with trange (epochs , desc = "{:<10}" .format ("Training" ),
172
- bar_format = '{l_bar}{bar}|{postfix}' ,
173
- postfix = " acc: 0.0000" ) as t :
170
+ with trange (epochs , desc = "{:<10}" .format ("Training" ), bar_format = '{l_bar}{bar}|{postfix}' , postfix = " acc: 0.0000" ) as t :
174
171
for epoch in t :
175
172
avg_cost = 0.
176
173
t .postfix = ' acc: {:.4f}' .format (acc )
177
174
t .update ()
178
175
# loop over all batches
179
- with tqdm (enumerate (digits ),
180
- total = digits .total ,
181
- desc = "{:<10}" .format ("Epoch {}" .format (epoch + 1 )),
182
- bar_format = '{l_bar}{bar}|{postfix}' ) as progress :
176
+ with tqdm (enumerate (digits ), total = digits .total , desc = "{:<10}" .format ("Epoch {}" .format (epoch + 1 )), bar_format = '{l_bar}{bar}|{postfix}' ) as progress :
183
177
for i , (train_x , train_y ) in progress :
184
178
# Run optimization, cost, and summary
185
179
_ , c , summary = sess .run ([optimizer , cost , merged_summary_op ],
@@ -239,7 +233,7 @@ def main(settings):
239
233
240
234
args .data = check_dir (os .path .abspath (args .data ))
241
235
args .output = os .path .abspath (args .output )
242
- unique = datetime .now ().strftime ('%m.%d_%H.%M' )
236
+ unique = datetime .now ().strftime ('%m.%d_%H.%M' )
243
237
args .log = check_dir (os .path .join (args .output , 'logs' , 'log_{}' .format (unique )))
244
238
args .model = check_dir (os .path .join (args .output , 'models' , 'model_{}' .format (unique )))
245
239
0 commit comments