-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
32 lines (26 loc) · 1.17 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from config import *
from model import *
import tensorflow as tf
from trainfeed import *
from util import *
ix2char = loadindex2char()
saver = tf.train.Saver()
tf.train.export_meta_graph(filename=model_dir+".meta")
ops = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.restore(sess,model_dir)
for i in range(epochs):
print('--------epoch:'+str(i))
for j,(imgs,spar_lb) in enumerate(datagen()):
ctc_err,_log,_logits = sess.run([loss,logging,logits], feed_dict={images: imgs, sparse_label: spar_lb})
writer.add_summary(_log)
print(ctc_err)
if(j%1)==0:
_preds = sess.run(dense_decodes,feed_dict={logits:_logits,seq_len:np.tile([35],[batch_size])})#top_n*[[batch_size, max_decoded_length]]
#densematrix = tf.sparse_to_dense(_preds[0].indices,_preds[0].dense_shape,_preds[0].values,default_value=-1)
strarr = []
for line in _preds:
strarr.append(''.join([ix2char[char_index] if char_index!=-1 else '' for char_index in line]))
print(strarr)
#saver.save(sess,model_dir,write_meta_graph=False)