Skip to content

Commit 0bb19df

Browse files
committed
fix cwgangp and add train_all.py
1 parent 8cef592 commit 0bb19df

File tree

3 files changed

+38
-3
lines changed

3 files changed

+38
-3
lines changed

recitation-10/mnist_cwgangp.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,31 @@ def generate(self, latent):
174174
return generated
175175

176176

177+
class CGeneratorTrainingCallback(GeneratorTrainingCallback):
178+
# Callback periodically trains the generator
179+
def __init__(self, args, parameters, criterion):
180+
super(CGeneratorTrainingCallback, self).__init__(args, parameters, criterion)
181+
182+
def train_generator(self):
183+
# Train the generator
184+
# Generate latent samples
185+
if self.trainer.is_cuda():
186+
latent = torch.cuda.FloatTensor(self.batch_size, self.latent_dim)
187+
else:
188+
latent = torch.FloatTensor(self.batch_size, self.latent_dim)
189+
torch.randn(*latent.size(), out=latent)
190+
latent = Variable(latent)
191+
# Calculate yfake
192+
y = Variable(torch.rand(latent.size(0), out=latent.data.new()) * 10).long()
193+
yfake = self.trainer.model.y_fake(latent, y)
194+
# Calculate loss
195+
loss = self.criterion(yfake)
196+
# Perform update
197+
self.opt.zero_grad()
198+
loss.backward()
199+
self.opt.step()
200+
201+
177202
def run(args):
178203
save_args(args) # save command line to a file for reference
179204
train_loader = mnist_cgan_data_loader(args) # get the data
@@ -190,7 +215,7 @@ def run(args):
190215
trainer.save_to_directory(args.save_directory)
191216
trainer.set_max_num_epochs(args.epochs)
192217
trainer.register_callback(CGenerateDataCallback(args))
193-
trainer.register_callback(GeneratorTrainingCallback(
218+
trainer.register_callback(CGeneratorTrainingCallback(
194219
args,
195220
parameters=model.generator.parameters(),
196221
criterion=WGANGeneratorLoss()))

recitation-10/mnist_gan.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,7 @@ def train_generator(self):
221221
torch.randn(*latent.size(), out=latent)
222222
latent = Variable(latent)
223223
# Calculate yfake
224-
y = Variable(torch.rand(latent.size(0), out=latent.data.new()) * 10).long()
225-
yfake = self.trainer.model.y_fake(latent, y)
224+
yfake = self.trainer.model.y_fake(latent)
226225
# Calculate loss
227226
loss = self.criterion(yfake)
228227
# Perform update

recitation-10/train_all.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import cifar10_wgangp
2+
import mnist_cwgangp
3+
import mnist_gan
4+
import mnist_wgangp
5+
6+
if __name__ == '__main__':
7+
mnist_gan.main(['--save-directory=output/mnist_gan/freq5'])
8+
mnist_gan.main(['--save-directory=output/mnist_gan/freq1', '--generator-frequency=1'])
9+
mnist_wgangp.main(['--save-directory=output/mnist_wgangp'])
10+
mnist_cwgangp.main(['--save-directory=output/mnist_cwgangp'])
11+
cifar10_wgangp.main(['--save-directory=output/cifar10_wgangp'])

0 commit comments

Comments
 (0)