From adc1d2bfc236d721ea955adc734178ecc9c05f78 Mon Sep 17 00:00:00 2001 From: YeonwooSung Date: Sun, 21 Jul 2024 22:46:11 +0900 Subject: [PATCH] feat: Add notebook for anime translation --- .../AnimeTranslation/anime_translation.ipynb | 885 ++++++++++++++++++ 1 file changed, 885 insertions(+) create mode 100644 Experiments/CV/AnimeTranslation/anime_translation.ipynb diff --git a/Experiments/CV/AnimeTranslation/anime_translation.ipynb b/Experiments/CV/AnimeTranslation/anime_translation.ipynb new file mode 100644 index 0000000..aa44b60 --- /dev/null +++ b/Experiments/CV/AnimeTranslation/anime_translation.ipynb @@ -0,0 +1,885 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "nje822klcXS7" + }, + "source": [ + "# Selfie to Anime Project" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Dataset: " + ], + "metadata": { + "id": "5_g_ciDAdl7x" + } + }, + { + "cell_type": "code", + "source": [ + "from google.colab import drive\n", + "\n", + "drive.mount('/content/drive/')" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0hjA1yBIcmIs", + "outputId": "893a69ea-1e51-49ba-93d8-9204de4c19f9" + }, + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount(\"/content/drive/\", force_remount=True).\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "import os\n", + "\n", + "os.chdir(\"drive/\")\n", + "os.chdir('My Drive')\n", + "os.chdir('Tutorials')\n", + "os.chdir('DeepLearningProject')\n", + "os.chdir('AnimeTranslation')" + ], + "metadata": { + "id": "fFELKtj3camy" + }, + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2dJ53ubccXS9" + }, + "source": [ + " **Code 0. Import Library**" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "W6LUI6a1cXS9" + }, + "outputs": [], + "source": [ + "import glob\n", + "import random\n", + "import os\n", + "\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from PIL import Image\n", + "import torchvision.transforms as transforms\n", + "import sys\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch\n", + "\n", + "import os\n", + "import numpy as np\n", + "import math\n", + "import itertools\n", + "import datetime\n", + "import time\n", + "\n", + "from torchvision.utils import save_image, make_grid\n", + "from torchvision import datasets\n", + "from torch.autograd import Variable" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lMCT-ZIrcXS-" + }, + "source": [ + "**Code 1. Dataset**" + ] + }, + { + "cell_type": "code", + "source": [ + "dataset_name=\"selfie2anime\"" + ], + "metadata": { + "id": "Y4nmWqMMfBhw" + }, + "execution_count": 4, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!unzip -q \"archive.zip\" -d \"selfie2anime/\"" + ], + "metadata": { + "id": "oPhcrCfOfB4S" + }, + "execution_count": 6, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!ls selfie2anime" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "P5EJJFZxfljA", + "outputId": "2c73c5b0-d0b8-4b3e-be16-b6f18d2fc6e6" + }, + "execution_count": 9, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "testA testB trainA trainB\n" + ] + } + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "jKSXVIS1cXS-" + }, + "outputs": [], + "source": [ + "def to_rgb(image):\n", + " rgb_image = Image.new(\"RGB\", image.size)\n", + " rgb_image.paste(image)\n", + " return rgb_image" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "jNgOPXHbcXS-" + }, + "outputs": [], + "source": [ + "class ImageDataset(Dataset):\n", + " def __init__(self, root, transforms_=None, unaligned=False, mode=\"train\"):\n", + " self.transform = transforms.Compose(transforms_)\n", + " self.unaligned = unaligned\n", + " # train 모드일 때는 trainA, trainB에 있는 디렉토리에서 이미지를 불러옵니다.\n", + " if mode==\"train\":\n", + " # glob 함수로 trainA 디렉토리의 이미지의 목록을 불러옵니다.\n", + " self.files_A = sorted(glob.glob(os.path.join(root, \"trainA\") + \"/*.*\"))\n", + " self.files_B = sorted(glob.glob(os.path.join(root, \"trainB\") + \"/*.*\"))\n", + " else:\n", + " self.files_A = sorted(glob.glob(os.path.join(root, \"testA\") + \"/*.*\"))\n", + " self.files_B = sorted(glob.glob(os.path.join(root, \"testB\") + \"/*.*\"))\n", + "\n", + " def __getitem__(self, index):\n", + " # index값으로 이미지의목록 중 이미지 하나를 불러옵니다.\n", + " image_A = Image.open(self.files_A[index % len(self.files_A)])\n", + " # unaligned 변수로 학습할 Pair를 랜덤으로 고릅니다.\n", + " if self.unaligned:\n", + " image_B = Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)])\n", + " else:\n", + " image_B = Image.open(self.files_B[index % len(self.files_B)])\n", + "\n", + " # Convert grayscale images to rgb\n", + " if image_A.mode != \"RGB\":\n", + " image_A = to_rgb(image_A)\n", + " if image_B.mode != \"RGB\":\n", + " image_B = to_rgb(image_B)\n", + "\n", + " item_A = self.transform(image_A)\n", + " item_B = self.transform(image_B)\n", + " return {\"A\": item_A, \"B\": item_B}\n", + "\n", + " def __len__(self):\n", + " return max(len(self.files_A), len(self.files_B))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ywy2bdWFcXS_" + }, + "source": [ + "**Code 2. Generator & Discriminator**" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "0-wK6Bx5cXS_" + }, + "outputs": [], + "source": [ + "def weights_init_normal(m):\n", + " classname = m.__class__.__name__\n", + " if classname.find(\"Conv\") != -1:\n", + " torch.nn.init.normal_(m.weight.data, 0.0, 0.02)\n", + " if hasattr(m, \"bias\") and m.bias is not None:\n", + " torch.nn.init.constant_(m.bias.data, 0.0)\n", + " elif classname.find(\"BatchNorm2d\") != -1:\n", + " torch.nn.init.normal_(m.weight.data, 1.0, 0.02)\n", + " torch.nn.init.constant_(m.bias.data, 0.0)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "EYy_TRplcXS_" + }, + "outputs": [], + "source": [ + "class ResidualBlock(nn.Module):\n", + " def __init__(self, in_features):\n", + " super(ResidualBlock, self).__init__()\n", + "\n", + " self.block = nn.Sequential(\n", + " nn.ReflectionPad2d(1),\n", + " nn.Conv2d(in_features, in_features, 3),\n", + " nn.InstanceNorm2d(in_features),\n", + " nn.ReLU(inplace=True),\n", + " nn.ReflectionPad2d(1),\n", + " nn.Conv2d(in_features, in_features, 3),\n", + " nn.InstanceNorm2d(in_features),\n", + " )\n", + "\n", + " def forward(self, x):\n", + " return x + self.block(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "id": "-ZgCUpfBcXS_" + }, + "outputs": [], + "source": [ + "class GeneratorResNet(nn.Module):\n", + " def __init__(self, input_shape, num_residual_blocks):\n", + " super(GeneratorResNet, self).__init__()\n", + "\n", + " channels = input_shape[0]\n", + "\n", + " # Initial convolution block\n", + " out_features = 64\n", + " model = [\n", + " nn.ReflectionPad2d(channels),\n", + " nn.Conv2d(channels, out_features, 7),\n", + " nn.InstanceNorm2d(out_features),\n", + " nn.ReLU(inplace=True),\n", + " ]\n", + " in_features = out_features\n", + "\n", + " # Downsampling\n", + " for _ in range(2):\n", + " out_features *= 2\n", + " model += [\n", + " nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),\n", + " nn.InstanceNorm2d(out_features),\n", + " nn.ReLU(inplace=True),\n", + " ]\n", + " in_features = out_features\n", + "\n", + " # Residual blocks\n", + " for _ in range(num_residual_blocks):\n", + " model += [ResidualBlock(out_features)]\n", + "\n", + " # Upsampling\n", + " for _ in range(2):\n", + " out_features //= 2\n", + " model += [\n", + " nn.Upsample(scale_factor=2),\n", + " nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),\n", + " nn.InstanceNorm2d(out_features),\n", + " nn.ReLU(inplace=True),\n", + " ]\n", + " in_features = out_features\n", + "\n", + " # Output layer\n", + " model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]\n", + "\n", + " self.model = nn.Sequential(*model)\n", + "\n", + " def forward(self, x):\n", + " return self.model(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "id": "t6c7wJfxcXS_" + }, + "outputs": [], + "source": [ + "class Discriminator(nn.Module):\n", + " def __init__(self, input_shape):\n", + " super(Discriminator, self).__init__()\n", + "\n", + " channels, height, width = input_shape\n", + "\n", + " # Calculate output shape of image discriminator (PatchGAN)\n", + " self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)\n", + "\n", + " def discriminator_block(in_filters, out_filters, normalize=True):\n", + " \"\"\"Returns downsampling layers of each discriminator block\"\"\"\n", + " layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]\n", + " if normalize:\n", + " layers.append(nn.InstanceNorm2d(out_filters))\n", + " layers.append(nn.LeakyReLU(0.2, inplace=True))\n", + " return layers\n", + "\n", + " self.model = nn.Sequential(\n", + " *discriminator_block(channels, 64, normalize=False),\n", + " *discriminator_block(64, 128),\n", + " *discriminator_block(128, 256),\n", + " *discriminator_block(256, 512),\n", + " nn.ZeroPad2d((1, 0, 1, 0)),\n", + " nn.Conv2d(512, 1, 4, padding=1)\n", + " )\n", + "\n", + " def forward(self, img):\n", + " return self.model(img)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LcPNxgDDcXS_" + }, + "source": [ + "**Code 3. Training**" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "id": "NiNjiXNmcXS_" + }, + "outputs": [], + "source": [ + "channels = 3\n", + "img_height = 256\n", + "img_width = 256\n", + "n_residual_blocks=9\n", + "lr=0.0002\n", + "b1=0.5\n", + "b2=0.999\n", + "n_epochs=200\n", + "init_epoch=0\n", + "decay_epoch=100\n", + "lambda_cyc=10.0\n", + "lambda_id=5.0\n", + "n_cpu=4\n", + "batch_size=3\n", + "sample_interval=100\n", + "checkpoint_interval=5" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "id": "6iWrleNXcXTA" + }, + "outputs": [], + "source": [ + "# Create sample and checkpoint directories\n", + "os.makedirs(\"images/%s\" % dataset_name, exist_ok=True)\n", + "os.makedirs(\"saved_models/%s\" % dataset_name, exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "id": "TGdqB61TcXTA" + }, + "outputs": [], + "source": [ + "# Losses\n", + "criterion_GAN = torch.nn.MSELoss()\n", + "criterion_cycle = torch.nn.L1Loss()\n", + "criterion_identity = torch.nn.L1Loss()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "id": "uKCIRoracXTA" + }, + "outputs": [], + "source": [ + "input_shape = (channels, img_height, img_width)\n", + "\n", + "# Initialize generator and discriminator\n", + "G_AB = GeneratorResNet(input_shape, n_residual_blocks)\n", + "G_BA = GeneratorResNet(input_shape, n_residual_blocks)\n", + "D_A = Discriminator(input_shape)\n", + "D_B = Discriminator(input_shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "id": "GBd6OD2HcXTA" + }, + "outputs": [], + "source": [ + "cuda = torch.cuda.is_available()\n", + "\n", + "if cuda:\n", + " G_AB = G_AB.cuda()\n", + " G_BA = G_BA.cuda()\n", + " D_A = D_A.cuda()\n", + " D_B = D_B.cuda()\n", + " criterion_GAN.cuda()\n", + " criterion_cycle.cuda()\n", + " criterion_identity.cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "yXQPrFxzcXTA", + "outputId": "72aa21ea-0552-47f1-e5a2-72747b6c6492" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "Discriminator(\n", + " (model): Sequential(\n", + " (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " (1): LeakyReLU(negative_slope=0.2, inplace=True)\n", + " (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)\n", + " (4): LeakyReLU(negative_slope=0.2, inplace=True)\n", + " (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)\n", + " (7): LeakyReLU(negative_slope=0.2, inplace=True)\n", + " (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)\n", + " (10): LeakyReLU(negative_slope=0.2, inplace=True)\n", + " (11): ZeroPad2d((1, 0, 1, 0))\n", + " (12): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))\n", + " )\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 16 + } + ], + "source": [ + "# Initialize weights\n", + "G_AB.apply(weights_init_normal)\n", + "G_BA.apply(weights_init_normal)\n", + "D_A.apply(weights_init_normal)\n", + "D_B.apply(weights_init_normal)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "id": "ituqSiAacXTA" + }, + "outputs": [], + "source": [ + "# Optimizers\n", + "optimizer_G = torch.optim.Adam(\n", + " itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=lr, betas=(b1, b2)\n", + ")\n", + "optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=lr, betas=(b1, b2))\n", + "optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=lr, betas=(b1, b2))" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "id": "HJ02pSwdcXTA" + }, + "outputs": [], + "source": [ + "class LambdaLR:\n", + " def __init__(self, n_epochs, offset, decay_start_epoch):\n", + " assert (n_epochs - decay_start_epoch) > 0, \"Decay must start before the training session ends!\"\n", + " self.n_epochs = n_epochs\n", + " self.offset = offset\n", + " self.decay_start_epoch = decay_start_epoch\n", + "\n", + " def step(self, epoch):\n", + " return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "id": "XIjLxUcfcXTA" + }, + "outputs": [], + "source": [ + "# Learning rate update schedulers\n", + "lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(\n", + " optimizer_G, lr_lambda=LambdaLR(n_epochs, init_epoch, decay_epoch).step\n", + ")\n", + "lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(\n", + " optimizer_D_A, lr_lambda=LambdaLR(n_epochs, init_epoch, decay_epoch).step\n", + ")\n", + "lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(\n", + " optimizer_D_B, lr_lambda=LambdaLR(n_epochs, init_epoch, decay_epoch).step\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "id": "dLFzJFaZcXTA" + }, + "outputs": [], + "source": [ + "Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "id": "L5ABTHJjcXTA" + }, + "outputs": [], + "source": [ + "class ReplayBuffer:\n", + " def __init__(self, max_size=50):\n", + " assert max_size > 0, \"Empty buffer or trying to create a black hole. Be careful.\"\n", + " self.max_size = max_size\n", + " self.data = []\n", + "\n", + " def push_and_pop(self, data):\n", + " to_return = []\n", + " for element in data.data:\n", + " element = torch.unsqueeze(element, 0)\n", + " if len(self.data) < self.max_size:\n", + " self.data.append(element)\n", + " to_return.append(element)\n", + " else:\n", + " if random.uniform(0, 1) > 0.5:\n", + " i = random.randint(0, self.max_size - 1)\n", + " to_return.append(self.data[i].clone())\n", + " self.data[i] = element\n", + " else:\n", + " to_return.append(element)\n", + " return Variable(torch.cat(to_return))" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "id": "GtxgLUT_cXTA" + }, + "outputs": [], + "source": [ + "# Buffers of previously generated samples\n", + "fake_A_buffer = ReplayBuffer()\n", + "fake_B_buffer = ReplayBuffer()" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "id": "vprcgSTycXTA" + }, + "outputs": [], + "source": [ + "# Image transformations\n", + "transforms_ = [\n", + " transforms.Resize(int(img_height * 1.12), Image.BICUBIC),\n", + " transforms.RandomCrop((img_height, img_width)),\n", + " transforms.RandomHorizontalFlip(),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "rReV2yFlcXTA", + "outputId": "08c9010a-5229-4820-9748-53449a9b7f3c" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:558: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", + " warnings.warn(_create_warning_msg(\n" + ] + } + ], + "source": [ + "# Training data loader\n", + "dataloader = DataLoader(\n", + " ImageDataset(dataset_name, transforms_=transforms_, unaligned=True),\n", + " batch_size=batch_size,\n", + " shuffle=True,\n", + " num_workers=n_cpu,\n", + ")\n", + "# Test data loader\n", + "val_dataloader = DataLoader(\n", + " ImageDataset(dataset_name, transforms_=transforms_, unaligned=True, mode=\"test\"),\n", + " batch_size=6,\n", + " shuffle=True,\n", + " num_workers=1,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "id": "Zxbt2FCicXTA" + }, + "outputs": [], + "source": [ + "def sample_images(batches_done):\n", + " \"\"\"Saves a generated sample from the test set\"\"\"\n", + " imgs = next(iter(val_dataloader))\n", + " G_AB.eval()\n", + " G_BA.eval()\n", + " real_A = Variable(imgs[\"A\"].type(Tensor))\n", + " fake_B = G_AB(real_A)\n", + " real_B = Variable(imgs[\"B\"].type(Tensor))\n", + " fake_A = G_BA(real_B)\n", + " # Arange images along x-axis\n", + " real_A = make_grid(real_A, nrow=5, normalize=True)\n", + " real_B = make_grid(real_B, nrow=5, normalize=True)\n", + " fake_A = make_grid(fake_A, nrow=5, normalize=True)\n", + " fake_B = make_grid(fake_B, nrow=5, normalize=True)\n", + " # Arange images along y-axis\n", + " image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)\n", + " save_image(image_grid, \"images/%s/%s.png\" % (dataset_name, batches_done), normalize=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "j_kl7qfccXTB", + "outputId": "a3b06f46-c5be-4fe4-ef0f-585386cc49c9" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + ":10: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:78.)\n", + " valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[Epoch 0/200] [Batch 53/1134] [D loss: 0.301228] [G loss: 3.974601, adv: 0.370753, cycle: 0.249056, identity: 0.222657] ETA: 4 days, 11:05:51.596526" + ] + } + ], + "source": [ + "prev_time = time.time()\n", + "for epoch in range(init_epoch, n_epochs):\n", + " for i, batch in enumerate(dataloader):\n", + "\n", + " # (1) Set model input\n", + " real_A = Variable(batch[\"A\"].type(Tensor))\n", + " real_B = Variable(batch[\"B\"].type(Tensor))\n", + "\n", + " # (2) Adversarial ground truths\n", + " valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)\n", + " fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False)\n", + "\n", + " # (3) Train Generators\n", + "\n", + " G_AB.train()\n", + " G_BA.train()\n", + "\n", + " optimizer_G.zero_grad()\n", + "\n", + " # (4) Identity loss\n", + " loss_id_A = criterion_identity(G_BA(real_A), real_A)\n", + " loss_id_B = criterion_identity(G_AB(real_B), real_B)\n", + "\n", + " loss_identity = (loss_id_A + loss_id_B) / 2\n", + "\n", + " # (5) GAN loss\n", + " fake_B = G_AB(real_A)\n", + " loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)\n", + " fake_A = G_BA(real_B)\n", + " loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)\n", + "\n", + " loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2\n", + "\n", + " # (6) Cycle loss\n", + " recov_A = G_BA(fake_B)\n", + " loss_cycle_A = criterion_cycle(recov_A, real_A)\n", + " recov_B = G_AB(fake_A)\n", + " loss_cycle_B = criterion_cycle(recov_B, real_B)\n", + "\n", + " loss_cycle = (loss_cycle_A + loss_cycle_B) / 2\n", + "\n", + " # (7) Total loss\n", + " loss_G = loss_GAN + lambda_cyc * loss_cycle + lambda_id * loss_identity\n", + "\n", + " loss_G.backward()\n", + " optimizer_G.step()\n", + "\n", + " # (8) Train Discriminator A\n", + "\n", + " optimizer_D_A.zero_grad()\n", + "\n", + " # (9) Real loss\n", + " loss_real = criterion_GAN(D_A(real_A), valid)\n", + " # (10) Fake loss (on batch of previously generated samples)\n", + " fake_A_ = fake_A_buffer.push_and_pop(fake_A)\n", + " loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)\n", + " # (11) Total loss\n", + " loss_D_A = (loss_real + loss_fake) / 2\n", + "\n", + " loss_D_A.backward()\n", + " optimizer_D_A.step()\n", + "\n", + " # (12) Train Discriminator B\n", + "\n", + " optimizer_D_B.zero_grad()\n", + "\n", + " # (13) Real loss\n", + " loss_real = criterion_GAN(D_B(real_B), valid)\n", + " # (14) Fake loss (on batch of previously generated samples)\n", + " fake_B_ = fake_B_buffer.push_and_pop(fake_B)\n", + " loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)\n", + " # (15) Total loss\n", + " loss_D_B = (loss_real + loss_fake) / 2\n", + "\n", + " loss_D_B.backward()\n", + " optimizer_D_B.step()\n", + "\n", + " loss_D = (loss_D_A + loss_D_B) / 2\n", + "\n", + " # (16) Determine approximate time left\n", + " batches_done = epoch * len(dataloader) + i\n", + " batches_left = n_epochs * len(dataloader) - batches_done\n", + " time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))\n", + " prev_time = time.time()\n", + "\n", + " # (17) Print log\n", + " sys.stdout.write(\n", + " \"\\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s\"\n", + " % (\n", + " epoch,\n", + " n_epochs,\n", + " i,\n", + " len(dataloader),\n", + " loss_D.item(),\n", + " loss_G.item(),\n", + " loss_GAN.item(),\n", + " loss_cycle.item(),\n", + " loss_identity.item(),\n", + " time_left,\n", + " )\n", + " )\n", + "\n", + " # (18) If at sample interval save image\n", + " if batches_done % sample_interval == 0:\n", + " sample_images(batches_done)\n", + "\n", + " # (19) Update learning rates\n", + " lr_scheduler_G.step()\n", + " lr_scheduler_D_A.step()\n", + " lr_scheduler_D_B.step()\n", + " # (20) Save model checkpoints\n", + " if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:\n", + " torch.save(G_AB.state_dict(), \"saved_models/%s/G_AB_%d.pth\" % (dataset_name, epoch))\n", + " torch.save(G_BA.state_dict(), \"saved_models/%s/G_BA_%d.pth\" % (dataset_name, epoch))\n", + " torch.save(D_A.state_dict(), \"saved_models/%s/D_A_%d.pth\" % (dataset_name, epoch))\n", + " torch.save(D_B.state_dict(), \"saved_models/%s/D_B_%d.pth\" % (dataset_name, epoch))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "AfZNCuEBcXTB" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "env_lsh", + "language": "python", + "name": "env_lsh" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.7" + }, + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "accelerator": "GPU" + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file