diff --git a/Experiments/CV/LeafClassification/.gitignore b/Experiments/CV/LeafClassification/.gitignore new file mode 100644 index 0000000..676284b --- /dev/null +++ b/Experiments/CV/LeafClassification/.gitignore @@ -0,0 +1,6 @@ +dataset/ +splitted +venv/ +.DS_Store + +*.pt diff --git a/Experiments/CV/LeafClassification/Plant_Leaf_Classification.ipynb b/Experiments/CV/LeafClassification/Plant_Leaf_Classification.ipynb new file mode 100644 index 0000000..bd73e97 --- /dev/null +++ b/Experiments/CV/LeafClassification/Plant_Leaf_Classification.ipynb @@ -0,0 +1,1097 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "jOQ3XcF0wq-Y" + }, + "source": [ + "# Plant Leaf Classification" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LASRoTG5wwTO" + }, + "source": [ + "data:\n", + "\n", + "\n", + "\n", + "" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "UrVLKEHnxiC9", + "outputId": "2511501e-e181-4cd6-c46c-87623b8cd030" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mounted at /content/drive/\n" + ] + } + ], + "source": [ + "from google.colab import drive\n", + "\n", + "drive.mount('/content/drive/')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "c89W9d5cxibz" + }, + "outputs": [], + "source": [ + "import os\n", + "os.chdir(\"drive/\")\n", + "os.chdir('My Drive')\n", + "os.chdir('Tutorials')\n", + "os.chdir('DeepLearningProject')\n", + "os.chdir('LeafDiseaseDetection')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-8rwOLBVwq-e" + }, + "source": [ + "## Data Preparation" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "N3lNKbzSwq-e" + }, + "outputs": [], + "source": [ + "import os\n", + "import shutil\n", + "\n", + "original_dataset_dir = './dataset'\n", + "classes_list = os.listdir(original_dataset_dir)\n", + "\n", + "base_dir = './splitted'\n", + "ALREADY_SPLITTED = True\n", + "if not os.path.exists(base_dir):\n", + " ALREADY_SPLITTED = False\n", + " os.mkdir(base_dir)\n", + "\n", + "train_dir = os.path.join(base_dir, 'train')\n", + "\n", + "if not os.path.exists(train_dir):\n", + " os.mkdir(train_dir)\n", + "validation_dir = os.path.join(base_dir, 'val')\n", + "if not os.path.exists(validation_dir):\n", + " os.mkdir(validation_dir)\n", + "test_dir = os.path.join(base_dir, 'test')\n", + "if not os.path.exists(test_dir):\n", + " os.mkdir(test_dir)\n", + "\n", + "\n", + "if not ALREADY_SPLITTED:\n", + " for cls in classes_list:\n", + " os.mkdir(os.path.join(train_dir, cls))\n", + " os.mkdir(os.path.join(validation_dir, cls))\n", + " os.mkdir(os.path.join(test_dir, cls))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "3YZViAoKwq-f", + "outputId": "58639bad-ae1b-404f-ada3-fccbf89f00f8" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train size( Strawberry___healthy ): 600\n", + "Validation size( Strawberry___healthy ): 200\n", + "Test size( Strawberry___healthy ): 200\n", + "Train size( Grape___Black_rot ): 708\n", + "Validation size( Grape___Black_rot ): 236\n", + "Test size( Grape___Black_rot ): 236\n", + "Train size( Potato___Early_blight ): 600\n", + "Validation size( Potato___Early_blight ): 200\n", + "Test size( Potato___Early_blight ): 200\n", + "Train size( Blueberry___healthy ): 901\n", + "Validation size( Blueberry___healthy ): 300\n", + "Test size( Blueberry___healthy ): 300\n", + "Train size( Cherry___Powdery_mildew ): 631\n", + "Validation size( Cherry___Powdery_mildew ): 210\n", + "Test size( Cherry___Powdery_mildew ): 210\n", + "Train size( Tomato___Target_Spot ): 842\n", + "Validation size( Tomato___Target_Spot ): 280\n", + "Test size( Tomato___Target_Spot ): 280\n", + "Train size( Peach___healthy ): 600\n", + "Validation size( Peach___healthy ): 200\n", + "Test size( Peach___healthy ): 200\n", + "Train size( Potato___Late_blight ): 600\n", + "Validation size( Potato___Late_blight ): 200\n", + "Test size( Potato___Late_blight ): 200\n", + "Train size( Tomato___Late_blight ): 1145\n", + "Validation size( Tomato___Late_blight ): 381\n", + "Test size( Tomato___Late_blight ): 381\n", + "Train size( Tomato___Tomato_mosaic_virus ): 600\n", + "Validation size( Tomato___Tomato_mosaic_virus ): 200\n", + "Test size( Tomato___Tomato_mosaic_virus ): 200\n", + "Train size( Pepper,_bell___healthy ): 886\n", + "Validation size( Pepper,_bell___healthy ): 295\n", + "Test size( Pepper,_bell___healthy ): 295\n", + "Train size( Orange___Haunglongbing_(Citrus_greening) ): 3304\n", + "Validation size( Orange___Haunglongbing_(Citrus_greening) ): 1101\n", + "Test size( Orange___Haunglongbing_(Citrus_greening) ): 1101\n", + "Train size( Tomato___Leaf_Mold ): 600\n", + "Validation size( Tomato___Leaf_Mold ): 200\n", + "Test size( Tomato___Leaf_Mold ): 200\n", + "Train size( Grape___Leaf_blight_(Isariopsis_Leaf_Spot) ): 645\n", + "Validation size( Grape___Leaf_blight_(Isariopsis_Leaf_Spot) ): 215\n", + "Test size( Grape___Leaf_blight_(Isariopsis_Leaf_Spot) ): 215\n", + "Train size( Apple___Cedar_apple_rust ): 600\n", + "Validation size( Apple___Cedar_apple_rust ): 200\n", + "Test size( Apple___Cedar_apple_rust ): 200\n", + "Train size( Tomato___Bacterial_spot ): 1276\n", + "Validation size( Tomato___Bacterial_spot ): 425\n", + "Test size( Tomato___Bacterial_spot ): 425\n", + "Train size( Grape___healthy ): 600\n", + "Validation size( Grape___healthy ): 200\n", + "Test size( Grape___healthy ): 200\n", + "Train size( Corn___Cercospora_leaf_spot Gray_leaf_spot ): 600\n", + "Validation size( Corn___Cercospora_leaf_spot Gray_leaf_spot ): 200\n", + "Test size( Corn___Cercospora_leaf_spot Gray_leaf_spot ): 200\n", + "Train size( Tomato___Early_blight ): 600\n", + "Validation size( Tomato___Early_blight ): 200\n", + "Test size( Tomato___Early_blight ): 200\n", + "Train size( Grape___Esca_(Black_Measles) ): 829\n", + "Validation size( Grape___Esca_(Black_Measles) ): 276\n", + "Test size( Grape___Esca_(Black_Measles) ): 276\n", + "Train size( Raspberry___healthy ): 600\n", + "Validation size( Raspberry___healthy ): 200\n", + "Test size( Raspberry___healthy ): 200\n", + "Train size( Tomato___healthy ): 954\n", + "Validation size( Tomato___healthy ): 318\n", + "Test size( Tomato___healthy ): 318\n", + "Train size( Corn___Northern_Leaf_Blight ): 600\n", + "Validation size( Corn___Northern_Leaf_Blight ): 200\n", + "Test size( Corn___Northern_Leaf_Blight ): 200\n", + "Train size( Tomato___Tomato_Yellow_Leaf_Curl_Virus ): 3214\n", + "Validation size( Tomato___Tomato_Yellow_Leaf_Curl_Virus ): 1071\n", + "Test size( Tomato___Tomato_Yellow_Leaf_Curl_Virus ): 1071\n", + "Train size( Cherry___healthy ): 600\n", + "Validation size( Cherry___healthy ): 200\n", + "Test size( Cherry___healthy ): 200\n", + "Train size( Apple___Apple_scab ): 600\n", + "Validation size( Apple___Apple_scab ): 200\n", + "Test size( Apple___Apple_scab ): 200\n", + "Train size( Tomato___Spider_mites Two-spotted_spider_mite ): 1005\n", + "Validation size( Tomato___Spider_mites Two-spotted_spider_mite ): 335\n", + "Test size( Tomato___Spider_mites Two-spotted_spider_mite ): 335\n", + "Train size( Corn___Common_rust ): 715\n", + "Validation size( Corn___Common_rust ): 238\n", + "Test size( Corn___Common_rust ): 238\n", + "Train size( Background_without_leaves ): 685\n", + "Validation size( Background_without_leaves ): 228\n", + "Test size( Background_without_leaves ): 228\n", + "Train size( Peach___Bacterial_spot ): 1378\n", + "Validation size( Peach___Bacterial_spot ): 459\n", + "Test size( Peach___Bacterial_spot ): 459\n", + "Train size( Pepper,_bell___Bacterial_spot ): 600\n", + "Validation size( Pepper,_bell___Bacterial_spot ): 200\n", + "Test size( Pepper,_bell___Bacterial_spot ): 200\n", + "Train size( Tomato___Septoria_leaf_spot ): 1062\n", + "Validation size( Tomato___Septoria_leaf_spot ): 354\n", + "Test size( Tomato___Septoria_leaf_spot ): 354\n", + "Train size( Corn___healthy ): 697\n", + "Validation size( Corn___healthy ): 232\n", + "Test size( Corn___healthy ): 232\n", + "Train size( Squash___Powdery_mildew ): 1101\n", + "Validation size( Squash___Powdery_mildew ): 367\n", + "Test size( Squash___Powdery_mildew ): 367\n", + "Train size( Apple___Black_rot ): 600\n", + "Validation size( Apple___Black_rot ): 200\n", + "Test size( Apple___Black_rot ): 200\n", + "Train size( Apple___healthy ): 987\n", + "Validation size( Apple___healthy ): 329\n", + "Test size( Apple___healthy ): 329\n", + "Train size( Strawberry___Leaf_scorch ): 665\n", + "Validation size( Strawberry___Leaf_scorch ): 221\n", + "Test size( Strawberry___Leaf_scorch ): 221\n", + "Train size( Potato___healthy ): 600\n", + "Validation size( Potato___healthy ): 200\n", + "Test size( Potato___healthy ): 200\n", + "Train size( Soybean___healthy ): 3054\n", + "Validation size( Soybean___healthy ): 1018\n", + "Test size( Soybean___healthy ): 1018\n" + ] + } + ], + "source": [ + "import math\n", + "\n", + "\n", + "for cls in classes_list:\n", + " path = os.path.join(original_dataset_dir, cls)\n", + " fnames = os.listdir(path)\n", + "\n", + " train_size = math.floor(len(fnames) * 0.6)\n", + " validation_size = math.floor(len(fnames) * 0.2)\n", + " test_size = math.floor(len(fnames) * 0.2)\n", + "\n", + " train_fnames = fnames[:train_size]\n", + " print(\"Train size(\",cls,\"): \", len(train_fnames))\n", + " for fname in train_fnames:\n", + " src = os.path.join(path, fname)\n", + " dst = os.path.join(os.path.join(train_dir, cls), fname)\n", + " shutil.copyfile(src, dst)\n", + "\n", + " validation_fnames = fnames[train_size:(validation_size + train_size)]\n", + " print(\"Validation size(\",cls,\"): \", len(validation_fnames))\n", + " for fname in validation_fnames:\n", + " src = os.path.join(path, fname)\n", + " dst = os.path.join(os.path.join(validation_dir, cls), fname)\n", + " shutil.copyfile(src, dst)\n", + "\n", + " test_fnames = fnames[(train_size+validation_size):(validation_size + train_size +test_size)]\n", + "\n", + " print(\"Test size(\",cls,\"): \", len(test_fnames))\n", + " for fname in test_fnames:\n", + " src = os.path.join(path, fname)\n", + " dst = os.path.join(os.path.join(test_dir, cls), fname)\n", + " shutil.copyfile(src, dst)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bswekMi5wq-f" + }, + "source": [ + "## Training" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "6WJaekTUwq-f" + }, + "outputs": [], + "source": [ + "import torch\n", + "import os\n", + "\n", + "USE_CUDA = torch.cuda.is_available()\n", + "DEVICE = torch.device(\"cuda\" if USE_CUDA else \"cpu\")\n", + "if torch.backends.mps.is_available():\n", + " # torch.mps.set_start_method('forkserver')\n", + " DEVICE = torch.device(\"mps\")\n", + "\n", + "BATCH_SIZE = 256\n", + "EPOCH = 30" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CUDA Available: False\n", + "Device: mps\n" + ] + } + ], + "source": [ + "print(\"CUDA Available: \",USE_CUDA)\n", + "print(\"Device: \",DEVICE)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 408 + }, + "id": "aBFvg9I_wq-f", + "outputId": "fe41e222-edce-4a57-de30-35aea1d48cc0" + }, + "outputs": [], + "source": [ + "import torchvision.transforms as transforms\n", + "from torchvision.datasets import ImageFolder\n", + "\n", + "\n", + "transform_base = transforms.Compose([transforms.Resize((64,64)), transforms.ToTensor()])\n", + "train_dataset = ImageFolder(root='./splitted/train', transform=transform_base)\n", + "val_dataset = ImageFolder(root='./splitted/val', transform=transform_base)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "id": "JRbRMo0wwq-f" + }, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader\n", + "\n", + "train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)\n", + "val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "id": "-2CD4tNVwq-g" + }, + "outputs": [], + "source": [ + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "\n", + "\n", + "class Net(nn.Module):\n", + "\n", + " def __init__(self):\n", + "\n", + " super(Net, self).__init__()\n", + "\n", + " self.conv1 = nn.Conv2d(3, 32, 3, padding=1)\n", + " self.pool = nn.MaxPool2d(2,2)\n", + " self.conv2 = nn.Conv2d(32, 64, 3, padding=1)\n", + " self.conv3 = nn.Conv2d(64, 64, 3, padding=1)\n", + "\n", + " self.fc1 = nn.Linear(4096, 512)\n", + " self.fc2 = nn.Linear(512, 39)\n", + "\n", + " def forward(self, x):\n", + "\n", + " x = self.conv1(x)\n", + " x = F.relu(x)\n", + " x = self.pool(x)\n", + " x = F.dropout(x, p=0.25, training=self.training)\n", + "\n", + " x = self.conv2(x)\n", + " x = F.relu(x)\n", + " x = self.pool(x)\n", + " x = F.dropout(x, p=0.25, training=self.training)\n", + "\n", + " x = self.conv3(x)\n", + " x = F.relu(x)\n", + " x = self.pool(x)\n", + " x = F.dropout(x, p=0.25, training=self.training)\n", + "\n", + " x = x.view(-1, 4096)\n", + " x = self.fc1(x)\n", + " x = F.relu(x)\n", + " x = F.dropout(x, p=0.5, training=self.training)\n", + " x = self.fc2(x)\n", + "\n", + " return F.log_softmax(x, dim=1)\n", + "\n", + "\n", + "model_base = Net().to(DEVICE)\n", + "optimizer = optim.Adam(model_base.parameters(), lr=0.001)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "id": "yHECMdAewq-g" + }, + "outputs": [], + "source": [ + "def train(model, train_loader, optimizer):\n", + " model.train()\n", + " for batch_idx, (data, target) in enumerate(train_loader):\n", + " data, target = data.to(DEVICE), target.to(DEVICE)\n", + " optimizer.zero_grad()\n", + " output = model(data)\n", + " loss = F.cross_entropy(output, target)\n", + " loss.backward()\n", + " optimizer.step()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "id": "qR0yh4VVwq-g" + }, + "outputs": [], + "source": [ + "def evaluate(model, test_loader):\n", + " model.eval()\n", + " test_loss = 0\n", + " correct = 0\n", + "\n", + " with torch.no_grad():\n", + " for data, target in test_loader:\n", + " data, target = data.to(DEVICE), target.to(DEVICE)\n", + " output = model(data)\n", + "\n", + " test_loss += F.cross_entropy(output,target, reduction='sum').item()\n", + "\n", + "\n", + " pred = output.max(1, keepdim=True)[1]\n", + " correct += pred.eq(target.view_as(pred)).sum().item()\n", + "\n", + " test_loss /= len(test_loader.dataset)\n", + " test_accuracy = 100. * correct / len(test_loader.dataset)\n", + " return test_loss, test_accuracy" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "id": "K4mpQgd3wq-g" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-------------- epoch 1 ----------------\n", + "train Loss: 1.3961, Accuracy: 58.68%\n", + "val Loss: 1.4116, Accuracy: 58.01%\n", + "Completed in 1m 30s\n", + "-------------- epoch 2 ----------------\n", + "train Loss: 0.8478, Accuracy: 74.44%\n", + "val Loss: 0.8872, Accuracy: 73.43%\n", + "Completed in 1m 24s\n", + "-------------- epoch 3 ----------------\n", + "train Loss: 0.6518, Accuracy: 80.85%\n", + "val Loss: 0.6932, Accuracy: 79.27%\n", + "Completed in 1m 24s\n", + "-------------- epoch 4 ----------------\n", + "train Loss: 0.5244, Accuracy: 83.21%\n", + "val Loss: 0.5876, Accuracy: 81.26%\n", + "Completed in 1m 24s\n", + "-------------- epoch 5 ----------------\n", + "train Loss: 0.4623, Accuracy: 85.44%\n", + "val Loss: 0.5306, Accuracy: 83.38%\n", + "Completed in 1m 26s\n", + "-------------- epoch 6 ----------------\n", + "train Loss: 1.1426, Accuracy: 67.00%\n", + "val Loss: 1.2404, Accuracy: 64.76%\n", + "Completed in 1m 28s\n", + "-------------- epoch 7 ----------------\n", + "train Loss: 0.3366, Accuracy: 89.78%\n", + "val Loss: 0.4235, Accuracy: 86.69%\n", + "Completed in 1m 28s\n", + "-------------- epoch 8 ----------------\n", + "train Loss: 0.3423, Accuracy: 89.11%\n", + "val Loss: 0.4372, Accuracy: 86.21%\n", + "Completed in 1m 29s\n", + "-------------- epoch 9 ----------------\n", + "train Loss: 0.3032, Accuracy: 90.30%\n", + "val Loss: 0.4043, Accuracy: 87.08%\n", + "Completed in 1m 29s\n", + "-------------- epoch 10 ----------------\n", + "train Loss: 0.2211, Accuracy: 93.20%\n", + "val Loss: 0.3237, Accuracy: 89.86%\n", + "Completed in 1m 28s\n", + "-------------- epoch 11 ----------------\n", + "train Loss: 0.2328, Accuracy: 92.85%\n", + "val Loss: 0.3422, Accuracy: 89.39%\n", + "Completed in 1m 27s\n", + "-------------- epoch 12 ----------------\n", + "train Loss: 0.1921, Accuracy: 94.15%\n", + "val Loss: 0.3036, Accuracy: 90.50%\n", + "Completed in 1m 27s\n", + "-------------- epoch 13 ----------------\n", + "train Loss: 0.1920, Accuracy: 93.99%\n", + "val Loss: 0.3074, Accuracy: 90.29%\n", + "Completed in 1m 27s\n", + "-------------- epoch 14 ----------------\n", + "train Loss: 0.1432, Accuracy: 95.77%\n", + "val Loss: 0.2695, Accuracy: 91.68%\n", + "Completed in 1m 27s\n", + "-------------- epoch 15 ----------------\n", + "train Loss: 0.2044, Accuracy: 93.54%\n", + "val Loss: 0.3375, Accuracy: 89.50%\n", + "Completed in 1m 28s\n", + "-------------- epoch 16 ----------------\n", + "train Loss: 0.1471, Accuracy: 95.82%\n", + "val Loss: 0.2732, Accuracy: 91.48%\n", + "Completed in 1m 28s\n", + "-------------- epoch 17 ----------------\n", + "train Loss: 0.1097, Accuracy: 96.97%\n", + "val Loss: 0.2399, Accuracy: 92.79%\n", + "Completed in 1m 28s\n", + "-------------- epoch 18 ----------------\n", + "train Loss: 0.1456, Accuracy: 95.47%\n", + "val Loss: 0.2865, Accuracy: 91.07%\n", + "Completed in 1m 28s\n", + "-------------- epoch 19 ----------------\n", + "train Loss: 0.1139, Accuracy: 96.58%\n", + "val Loss: 0.2565, Accuracy: 91.90%\n", + "Completed in 1m 27s\n", + "-------------- epoch 20 ----------------\n", + "train Loss: 0.0924, Accuracy: 97.35%\n", + "val Loss: 0.2334, Accuracy: 92.73%\n", + "Completed in 1m 28s\n", + "-------------- epoch 21 ----------------\n", + "train Loss: 0.0850, Accuracy: 97.73%\n", + "val Loss: 0.2214, Accuracy: 93.24%\n", + "Completed in 1m 28s\n", + "-------------- epoch 22 ----------------\n", + "train Loss: 0.0856, Accuracy: 97.58%\n", + "val Loss: 0.2254, Accuracy: 93.17%\n", + "Completed in 1m 28s\n", + "-------------- epoch 23 ----------------\n", + "train Loss: 0.0704, Accuracy: 98.23%\n", + "val Loss: 0.2170, Accuracy: 93.56%\n", + "Completed in 1m 29s\n", + "-------------- epoch 24 ----------------\n", + "train Loss: 0.0842, Accuracy: 97.46%\n", + "val Loss: 0.2390, Accuracy: 92.86%\n", + "Completed in 1m 27s\n", + "-------------- epoch 25 ----------------\n", + "train Loss: 0.0568, Accuracy: 98.63%\n", + "val Loss: 0.1993, Accuracy: 93.95%\n", + "Completed in 1m 26s\n", + "-------------- epoch 26 ----------------\n", + "train Loss: 0.0548, Accuracy: 98.58%\n", + "val Loss: 0.2070, Accuracy: 93.80%\n", + "Completed in 1m 28s\n", + "-------------- epoch 27 ----------------\n", + "train Loss: 0.0673, Accuracy: 98.07%\n", + "val Loss: 0.2215, Accuracy: 93.11%\n", + "Completed in 1m 28s\n", + "-------------- epoch 28 ----------------\n", + "train Loss: 0.0671, Accuracy: 98.23%\n", + "val Loss: 0.2231, Accuracy: 93.06%\n", + "Completed in 1m 28s\n", + "-------------- epoch 29 ----------------\n", + "train Loss: 0.0679, Accuracy: 98.12%\n", + "val Loss: 0.2357, Accuracy: 93.06%\n", + "Completed in 1m 27s\n", + "-------------- epoch 30 ----------------\n", + "train Loss: 0.0398, Accuracy: 99.01%\n", + "val Loss: 0.1871, Accuracy: 94.40%\n", + "Completed in 1m 27s\n" + ] + } + ], + "source": [ + "import time\n", + "import copy\n", + "\n", + "\n", + "def train_baseline(model ,train_loader, val_loader, optimizer, num_epochs = 30):\n", + " best_acc = 0.0\n", + " best_model_wts = copy.deepcopy(model.state_dict())\n", + "\n", + " for epoch in range(1, num_epochs + 1):\n", + " since = time.time()\n", + " train(model, train_loader, optimizer)\n", + " train_loss, train_acc = evaluate(model, train_loader)\n", + " val_loss, val_acc = evaluate(model, val_loader)\n", + "\n", + " if val_acc > best_acc:\n", + " best_acc = val_acc\n", + " best_model_wts = copy.deepcopy(model.state_dict())\n", + "\n", + " time_elapsed = time.time() - since\n", + " print('-------------- epoch {} ----------------'.format(epoch))\n", + " print('train Loss: {:.4f}, Accuracy: {:.2f}%'.format(train_loss, train_acc))\n", + " print('val Loss: {:.4f}, Accuracy: {:.2f}%'.format(val_loss, val_acc))\n", + " print('Completed in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))\n", + " model.load_state_dict(best_model_wts)\n", + " return model\n", + "\n", + "\n", + "\n", + "base = train_baseline(model_base, train_loader, val_loader, optimizer, EPOCH) \t #(16)\n", + "torch.save(base,'baseline.pt')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FDb3yCp8wq-g" + }, + "source": [ + "## Transfer Learning" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "id": "0O_hciCLwq-g" + }, + "outputs": [], + "source": [ + "data_transforms = {\n", + " 'train': transforms.Compose([transforms.Resize([64,64]),\n", + " transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(),\n", + " transforms.RandomCrop(52), transforms.ToTensor(),\n", + " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]),\n", + "\n", + " 'val': transforms.Compose([transforms.Resize([64,64]),\n", + " transforms.RandomCrop(52), transforms.ToTensor(),\n", + " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "id": "BaLBbx91wq-g" + }, + "outputs": [], + "source": [ + "data_dir = './splitted'\n", + "image_datasets = {x: ImageFolder(root=os.path.join(data_dir, x), transform=data_transforms[x]) for x in ['train', 'val']}\n", + "dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=BATCH_SIZE, shuffle=True, num_workers=4) for x in ['train', 'val']}\n", + "dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}\n", + "\n", + "class_names = image_datasets['train'].classes" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ElpYXb4zwq-g" + }, + "source": [ + "* Load Pre-Trained Model" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "import ssl\n", + "ssl._create_default_https_context = ssl._create_unverified_context" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "id": "Kf2mUleRwq-g" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Downloading: \"https://download.pytorch.org/models/resnet50-0676ba61.pth\" to /Users/ywsung/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth\n", + "100.0%\n" + ] + } + ], + "source": [ + "from torchvision import models\n", + "from torch.optim import lr_scheduler\n", + "\n", + "\n", + "resnet = models.resnet50(pretrained=True)\n", + "num_ftrs = resnet.fc.in_features\n", + "resnet.fc = nn.Linear(num_ftrs, 33)\n", + "resnet = resnet.to(DEVICE)\n", + "\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer_ft = optim.Adam(filter(lambda p: p.requires_grad, resnet.parameters()), lr=0.001)\n", + "\n", + "\n", + "exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lC9mool_wq-g" + }, + "source": [ + "* Freeze first few layers of the pre-trained model" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "id": "JB54whUpwq-g" + }, + "outputs": [], + "source": [ + "ct = 0\n", + "for child in resnet.children():\n", + " ct += 1\n", + " if ct < 6:\n", + " for param in child.parameters():\n", + " param.requires_grad = False" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "id": "c1RNE8X0wq-g" + }, + "outputs": [], + "source": [ + "def train_resnet(model, criterion, optimizer, scheduler, num_epochs=25):\n", + "\n", + " best_model_wts = copy.deepcopy(model.state_dict())\n", + " best_acc = 0.0\n", + "\n", + " for epoch in range(num_epochs):\n", + " print('-------------- epoch {} ----------------'.format(epoch+1))\n", + " since = time.time()\n", + " for phase in ['train', 'val']:\n", + " if phase == 'train':\n", + " model.train()\n", + " else:\n", + " model.eval()\n", + "\n", + " running_loss = 0.0\n", + " running_corrects = 0\n", + "\n", + "\n", + " for inputs, labels in dataloaders[phase]:\n", + " inputs = inputs.to(DEVICE)\n", + " labels = labels.to(DEVICE)\n", + "\n", + " optimizer.zero_grad()\n", + "\n", + " with torch.set_grad_enabled(phase == 'train'):\n", + " outputs = model(inputs)\n", + " _, preds = torch.max(outputs, 1)\n", + " loss = criterion(outputs, labels)\n", + "\n", + " if phase == 'train':\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " running_loss += loss.item() * inputs.size(0)\n", + " running_corrects += torch.sum(preds == labels.data)\n", + " if phase == 'train':\n", + " scheduler.step()\n", + "\n", + " epoch_loss = running_loss/dataset_sizes[phase]\n", + " epoch_acc = running_corrects/dataset_sizes[phase]\n", + "\n", + " print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))\n", + "\n", + "\n", + " if phase == 'val' and epoch_acc > best_acc:\n", + " best_acc = epoch_acc\n", + " best_model_wts = copy.deepcopy(model.state_dict())\n", + "\n", + " time_elapsed = time.time() - since\n", + " print('Completed in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))\n", + " print('Best val Acc: {:4f}'.format(best_acc))\n", + "\n", + " model.load_state_dict(best_model_wts)\n", + "\n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "id": "cBJyjlNjwq-h" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-------------- epoch 1 ----------------\n", + "train Loss: 0.1702 Acc: 0.7396\n", + "val Loss: 0.1334 Acc: 0.7484\n", + "Completed in 1m 45s\n", + "-------------- epoch 2 ----------------\n", + "train Loss: 0.1043 Acc: 0.7576\n", + "val Loss: 0.1154 Acc: 0.7554\n", + "Completed in 1m 44s\n", + "-------------- epoch 3 ----------------\n", + "train Loss: 0.1101 Acc: 0.7576\n", + "val Loss: 0.0975 Acc: 0.7617\n", + "Completed in 1m 44s\n", + "-------------- epoch 4 ----------------\n", + "train Loss: 0.0976 Acc: 0.7614\n", + "val Loss: 0.0918 Acc: 0.7637\n", + "Completed in 1m 44s\n", + "-------------- epoch 5 ----------------\n", + "train Loss: 0.0765 Acc: 0.7682\n", + "val Loss: 0.0751 Acc: 0.7687\n", + "Completed in 1m 44s\n", + "-------------- epoch 6 ----------------\n", + "train Loss: 0.0570 Acc: 0.7737\n", + "val Loss: 0.0987 Acc: 0.7629\n", + "Completed in 1m 44s\n", + "-------------- epoch 7 ----------------\n", + "train Loss: 0.0388 Acc: 0.7792\n", + "val Loss: 0.0292 Acc: 0.7825\n", + "Completed in 1m 43s\n", + "-------------- epoch 8 ----------------\n", + "train Loss: 0.0188 Acc: 0.7863\n", + "val Loss: 0.0249 Acc: 0.7835\n", + "Completed in 1m 44s\n", + "-------------- epoch 9 ----------------\n", + "train Loss: 0.0158 Acc: 0.7870\n", + "val Loss: 0.0240 Acc: 0.7843\n", + "Completed in 1m 43s\n", + "-------------- epoch 10 ----------------\n", + "train Loss: 0.0135 Acc: 0.7875\n", + "val Loss: 0.0253 Acc: 0.7842\n", + "Completed in 1m 46s\n", + "-------------- epoch 11 ----------------\n", + "train Loss: 0.0129 Acc: 0.7877\n", + "val Loss: 0.0216 Acc: 0.7848\n", + "Completed in 1m 44s\n", + "-------------- epoch 12 ----------------\n", + "train Loss: 0.0117 Acc: 0.7876\n", + "val Loss: 0.0224 Acc: 0.7843\n", + "Completed in 1m 43s\n", + "-------------- epoch 13 ----------------\n", + "train Loss: 0.0137 Acc: 0.7875\n", + "val Loss: 0.0219 Acc: 0.7855\n", + "Completed in 1m 44s\n", + "-------------- epoch 14 ----------------\n", + "train Loss: 0.0103 Acc: 0.7888\n", + "val Loss: 0.0201 Acc: 0.7850\n", + "Completed in 1m 43s\n", + "-------------- epoch 15 ----------------\n", + "train Loss: 0.0101 Acc: 0.7888\n", + "val Loss: 0.0207 Acc: 0.7856\n", + "Completed in 1m 44s\n", + "-------------- epoch 16 ----------------\n", + "train Loss: 0.0091 Acc: 0.7887\n", + "val Loss: 0.0202 Acc: 0.7861\n", + "Completed in 1m 43s\n", + "-------------- epoch 17 ----------------\n", + "train Loss: 0.0099 Acc: 0.7887\n", + "val Loss: 0.0210 Acc: 0.7852\n", + "Completed in 1m 43s\n", + "-------------- epoch 18 ----------------\n", + "train Loss: 0.0092 Acc: 0.7889\n", + "val Loss: 0.0200 Acc: 0.7857\n", + "Completed in 1m 44s\n", + "-------------- epoch 19 ----------------\n", + "train Loss: 0.0093 Acc: 0.7891\n", + "val Loss: 0.0203 Acc: 0.7856\n", + "Completed in 1m 44s\n", + "-------------- epoch 20 ----------------\n", + "train Loss: 0.0088 Acc: 0.7891\n", + "val Loss: 0.0190 Acc: 0.7860\n", + "Completed in 1m 44s\n", + "-------------- epoch 21 ----------------\n", + "train Loss: 0.0089 Acc: 0.7890\n", + "val Loss: 0.0201 Acc: 0.7855\n", + "Completed in 1m 44s\n", + "-------------- epoch 22 ----------------\n", + "train Loss: 0.0083 Acc: 0.7891\n", + "val Loss: 0.0210 Acc: 0.7857\n", + "Completed in 1m 44s\n", + "-------------- epoch 23 ----------------\n", + "train Loss: 0.0081 Acc: 0.7895\n", + "val Loss: 0.0172 Acc: 0.7857\n", + "Completed in 1m 44s\n", + "-------------- epoch 24 ----------------\n", + "train Loss: 0.0092 Acc: 0.7891\n", + "val Loss: 0.0197 Acc: 0.7854\n", + "Completed in 1m 43s\n", + "-------------- epoch 25 ----------------\n", + "train Loss: 0.0098 Acc: 0.7887\n", + "val Loss: 0.0196 Acc: 0.7858\n", + "Completed in 1m 44s\n", + "-------------- epoch 26 ----------------\n", + "train Loss: 0.0085 Acc: 0.7891\n", + "val Loss: 0.0190 Acc: 0.7861\n", + "Completed in 1m 44s\n", + "-------------- epoch 27 ----------------\n", + "train Loss: 0.0088 Acc: 0.7893\n", + "val Loss: 0.0192 Acc: 0.7860\n", + "Completed in 1m 44s\n", + "-------------- epoch 28 ----------------\n", + "train Loss: 0.0084 Acc: 0.7892\n", + "val Loss: 0.0197 Acc: 0.7853\n", + "Completed in 1m 44s\n", + "-------------- epoch 29 ----------------\n", + "train Loss: 0.0082 Acc: 0.7893\n", + "val Loss: 0.0183 Acc: 0.7858\n", + "Completed in 1m 44s\n", + "-------------- epoch 30 ----------------\n", + "train Loss: 0.0082 Acc: 0.7893\n", + "val Loss: 0.0189 Acc: 0.7860\n", + "Completed in 1m 44s\n", + "Best val Acc: 0.786069\n" + ] + } + ], + "source": [ + "model_resnet50 = train_resnet(resnet, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=EPOCH)\n", + "\n", + "torch.save(model_resnet50, 'resnet50.pt')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "x--G4TZ6wq-h" + }, + "source": [ + "## Evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "id": "FkxLY9K_wq-h" + }, + "outputs": [], + "source": [ + "transform_base = transforms.Compose([transforms.Resize([64,64]),transforms.ToTensor()])\n", + "test_base = ImageFolder(root='./splitted/test',transform=transform_base)\n", + "\n", + "test_loader_base = torch.utils.data.DataLoader(\n", + " test_base,\n", + " batch_size=BATCH_SIZE,\n", + " shuffle=True,\n", + " num_workers=4\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "id": "xQhewGCQwq-h" + }, + "outputs": [], + "source": [ + "transform_resNet = transforms.Compose([\n", + " transforms.Resize([64,64]),\n", + " transforms.RandomCrop(52),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", + "])\n", + "\n", + "test_resNet = ImageFolder(root='./splitted/test', transform=transform_resNet)\n", + "test_loader_resNet = torch.utils.data.DataLoader(test_resNet, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "id": "l29EZyTlwq-h" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "baseline test acc: 94.37708519814468\n" + ] + } + ], + "source": [ + "baseline=torch.load('baseline.pt')\n", + "baseline.eval()\n", + "test_loss, test_accuracy = evaluate(baseline, test_loader_base)\n", + "\n", + "print('baseline test acc: ', test_accuracy)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "id": "SO3KmL7Owq-h" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ResNet test acc: 78.56619741231997\n" + ] + } + ], + "source": [ + "resnet50=torch.load('resnet50.pt')\n", + "resnet50.eval()\n", + "test_loss, test_accuracy = evaluate(resnet50, test_loader_resNet)\n", + "\n", + "print('ResNet test acc: ', test_accuracy)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gEdGsZ8_4LRy" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}