-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
102 lines (89 loc) · 3.73 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import transformnet
import extractnet
import numpy as np
import tensorflow as tf
import PIL.Image as Image
import os
from config import *
import uuid
import shutil
import config
def traingen(dirpath):
flist = os.listdir(dirpath)
np.random.shuffle(flist)
cursor = 0
while cursor<len(flist):
batchfn = flist[cursor:cursor+batchsize]
cursor+=batchsize
ims = []
for fn in batchfn:
im = Image.open(os.path.join(dirpath, fn))
im = im.resize((224, 224), Image.ANTIALIAS)
im = np.array(im) / 255.
ims.append(im)
yield np.array(ims)
sess = tf.Session()
# model1
input = tf.placeholder(dtype='float32', shape=(None, None, None, 3),name="input")
synthesizetensor = transformnet.transform(input)
predcontenttensor, predstyletensors = extractnet.extract(synthesizetensor,prefix="A")
# model2
contenttensor, styletensors = extractnet.extract(input,prefix="B")
styleim = Image.open(styleimg)
#styleim = styleim.resize((224, 224), Image.ANTIALIAS)
styleim = np.expand_dims(np.array(styleim) / 255., axis=0)
styles = sess.run(styletensors,feed_dict={input:styleim})
grams = []
for style in styles:
style = np.squeeze(style)
style = np.transpose(style,[2,0,1])
channelnum = style.shape[0]
style = np.resize(style,(channelnum,-1))
styletranspose = np.transpose(style,[1,0])
gram = np.matmul(style,styletranspose)/style.shape[1]
grams.append(gram)
loss = tf.reduce_sum((predcontenttensor-contenttensor)**2)/tf.cast(tf.size(contenttensor[0]),dtype=tf.float32)
for predstyletensor,truthgram in zip(predstyletensors,grams):
predstyletensor = tf.transpose(predstyletensor,[0,3,1,2])
shape = tf.shape(predstyletensor)
predstyletensor = tf.reshape(predstyletensor,shape=[shape[0],shape[1],-1])
grampred = tf.matmul(predstyletensor,predstyletensor,transpose_b=True)/tf.cast(tf.shape(predstyletensor)[2],dtype=tf.float32)
loss+=30*tf.reduce_sum((grampred-truthgram)**2)/truthgram.size/len(config.styletensors)
loss/=batchsize
trainop = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss)
saver = tf.train.Saver()
#sess.run(tf.global_variables_initializer())
saver.restore(sess,modelpath)
def train():
for ep in range(epoch):
testims = next(traingen(testpath))
minerr = 999999999
argminerr = -1
errorrec = []
for i, batchims in enumerate(traingen(trainpath)):
_, losstrain = sess.run([trainop, loss], feed_dict={input: batchims})
print(losstrain)
if i % 5 == 0:
synthesizeimgs, lossval = sess.run([synthesizetensor, loss], feed_dict={input: testims})
errorrec.append(lossval)
print("validate loss:"+str(lossval))
if len(errorrec) - argminerr > 100:
return
list(map(os.unlink, (os.path.join(savepath, f) for f in os.listdir(savepath))))
for im in synthesizeimgs:
im = np.uint8(im*255)
Image.fromarray(im).save(os.path.join(savepath, str(uuid.uuid4()) + ".jpg"))
if lossval < minerr:
minerr = lossval
argminerr = len(errorrec) - 1
saver.save(sess,modelpath)
train()
# tf.summary.scalar('cc',loss)
# merged = tf.summary.merge_all()
# train_writer = tf.summary.FileWriter(r"D:\Users\yl_gong\Desktop\log",
# sess.graph)
# sess.run(tf.global_variables_initializer())
# im = Image.open(testimg)
# im = np.expand_dims(np.array(im)/255.,axis=0)
# summary = sess.run(merged,feed_dict={input:im})
# train_writer.add_summary(summary, 0)