diff --git a/train_sagan.py b/train_sagan.py index e0ca583..59a54c5 100644 --- a/train_sagan.py +++ b/train_sagan.py @@ -74,7 +74,7 @@ def validate(netG, netD, GANLoss, ReconLoss, DLoss, optG, optD, dataloader, epoc imgs = (imgs / 127.5 - 1) # mask is 1 on masked region # forward - coarse_imgs, recon_imgs, attention = netG(imgs, masks) + coarse_imgs, recon_imgs = netG(imgs, masks) complete_imgs = recon_imgs * masks + imgs * (1 - masks)