@@ -174,6 +174,31 @@ def generate(self, latent):
174
174
return generated
175
175
176
176
177
+ class CGeneratorTrainingCallback (GeneratorTrainingCallback ):
178
+ # Callback periodically trains the generator
179
+ def __init__ (self , args , parameters , criterion ):
180
+ super (CGeneratorTrainingCallback , self ).__init__ (args , parameters , criterion )
181
+
182
+ def train_generator (self ):
183
+ # Train the generator
184
+ # Generate latent samples
185
+ if self .trainer .is_cuda ():
186
+ latent = torch .cuda .FloatTensor (self .batch_size , self .latent_dim )
187
+ else :
188
+ latent = torch .FloatTensor (self .batch_size , self .latent_dim )
189
+ torch .randn (* latent .size (), out = latent )
190
+ latent = Variable (latent )
191
+ # Calculate yfake
192
+ y = Variable (torch .rand (latent .size (0 ), out = latent .data .new ()) * 10 ).long ()
193
+ yfake = self .trainer .model .y_fake (latent , y )
194
+ # Calculate loss
195
+ loss = self .criterion (yfake )
196
+ # Perform update
197
+ self .opt .zero_grad ()
198
+ loss .backward ()
199
+ self .opt .step ()
200
+
201
+
177
202
def run (args ):
178
203
save_args (args ) # save command line to a file for reference
179
204
train_loader = mnist_cgan_data_loader (args ) # get the data
@@ -190,7 +215,7 @@ def run(args):
190
215
trainer .save_to_directory (args .save_directory )
191
216
trainer .set_max_num_epochs (args .epochs )
192
217
trainer .register_callback (CGenerateDataCallback (args ))
193
- trainer .register_callback (GeneratorTrainingCallback (
218
+ trainer .register_callback (CGeneratorTrainingCallback (
194
219
args ,
195
220
parameters = model .generator .parameters (),
196
221
criterion = WGANGeneratorLoss ()))
0 commit comments