From 6949c94f07e83bb6d675103797c395a23b0f73c1 Mon Sep 17 00:00:00 2001 From: epsilon-deltta Date: Tue, 23 Aug 2022 16:15:10 +0900 Subject: [PATCH] fix(loss type in tr): dataset y: float > longtensor nn.CrossEntropyLoss gets longtensor label related to #2 --- dataset.py | 9 +- train.ipynb | 357 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 363 insertions(+), 3 deletions(-) create mode 100644 train.ipynb diff --git a/dataset.py b/dataset.py index df22bf9..bc31963 100644 --- a/dataset.py +++ b/dataset.py @@ -108,7 +108,8 @@ def __getitem__(self,idx): y = F.one_hot(y,num_classes=2) vital = vital.float() - y = y.float() + y = y.type(torch.LongTensor) + return vital, y def __len__(self): @@ -141,7 +142,7 @@ def getitem(self,idx): y = item['nrs_2'] - + # y = torch.tensor(y) # y = F.one_hot(y,num_classes=2) # @@ -177,8 +178,10 @@ def __getitem__(self,idx): vital = torch.tensor(vital) vital = vital.float() + y = torch.tensor(y) - y = y.float() + y = y.type(torch.LongTensor) + return vital, y def __len__(self): diff --git a/train.ipynb b/train.ipynb new file mode 100644 index 0000000..9090498 --- /dev/null +++ b/train.ipynb @@ -0,0 +1,357 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 5, + "id": "56c278bd-9e67-48b7-b11e-2c495744b6c5", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch.utils.data import DataLoader\n", + "from dataset import VitalDataset\n", + "idx = 3\n", + "root_dir = '../data/pd_gy'\n", + "trdt = VitalDataset(root_dir,f'../data/pd_gy/train_{idx}.json')\n", + "valdt = VitalDataset(root_dir,f'../data/pd_gy/val_{idx}.json') \n", + "tedt = VitalDataset(root_dir,f'../data/pd_gy/test_{idx}.json') \n", + "\n", + "trdl = torch.utils.data.DataLoader(trdt,4)\n", + "valdl = torch.utils.data.DataLoader(valdt,4)\n", + "tedl = torch.utils.data.DataLoader(valdt,4)\n", + "\n", + "from models import PrePostNet\n", + "\n", + "model = PrePostNet()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "46956a6f-556c-42a5-88cd-99676be9db54", + "metadata": {}, + "outputs": [], + "source": [ + "from torch import nn\n", + "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "\n", + "batch_size = 4\n", + "loss = nn.CrossEntropyLoss()\n", + "lr = 0.01 \n", + "\n", + "\n", + "params = [p for p in model.parameters() if p.requires_grad]\n", + "opt = torch.optim.Adam(params,lr=lr)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "94ebd30f-7bf0-4366-983d-fe63ce0b7049", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "best_model,val_losses = run(trdl,valdl,model,loss,opt,device=device, exist_acc=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "7be05afe-bc50-4e10-8e36-ab356f1ecd14", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "def train(dl,model,lossf,opt,device='cuda'):\n", + " model.train()\n", + " for x,y in dl:\n", + " x,y = x.to(device),y.to(device)\n", + " pre = model(x)\n", + " loss = lossf(pre,y)\n", + "\n", + " opt.zero_grad()\n", + " loss.backward()\n", + " opt.step()\n", + "\n", + "def test(dl,model,lossf,epoch=None,exist_acc=True,device='cuda'):\n", + " model.eval()\n", + " size, acc , losses = len(dl.dataset) ,0,0\n", + " with torch.no_grad():\n", + " for x,y in dl:\n", + " x,y = x.to(device),y.to(device)\n", + " pre = model(x)\n", + " loss = lossf(pre,y)\n", + " \n", + " if exist_acc: \n", + " acc += (pre.argmax(1)==y).type(torch.float).sum().item()\n", + " losses += loss.item()\n", + " if exist_acc:\n", + " accuracy = round(acc/size,4)\n", + " else:\n", + " accuracy = None\n", + " val_loss = round(losses/size,6)\n", + " print(f'[{epoch}] acc/loss: {accuracy}/{val_loss}')\n", + " return accuracy,val_loss\n", + "\n", + "import copy\n", + "def run(trdl, valdl, model, loss, opt, epoch=100,patience = 5, exist_acc=True, device='cuda'):\n", + " val_losses = {0:1}\n", + " model = model.to(device)\n", + " for i in range(epoch):\n", + " train(trdl,model,loss,opt,device=device)\n", + " acc,val_loss = test(valdl,model,loss,epoch=i,exist_acc=exist_acc,device=device)\n", + "\n", + "\n", + " if min(val_losses.values() ) > val_loss:\n", + " val_losses[i] = val_loss\n", + " best_model = copy.deepcopy(model)\n", + " if i == min(val_losses,key=val_losses.get)+patience:\n", + " break\n", + " return best_model, val_losses" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c499503a-56f2-46d7-ade5-49d5a4805494", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "387ee984-315f-40d9-8821-ec9b5e497268", + "metadata": {}, + "source": [ + "## for fastset " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "cd3c0723-98ac-4b1d-9d40-7e381f65eac0", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch.utils.data import DataLoader\n", + "from dataset import VitalDataset_fs\n", + "idx = 3\n", + "root_dir = f'../data/all_{idx}'\n", + "trdt = VitalDataset_fs(root_dir,f'../data/pd_gy/train_{idx}.json')\n", + "valdt = VitalDataset_fs(root_dir,f'../data/pd_gy/val_{idx}.json') \n", + "tedt = VitalDataset_fs(root_dir,f'../data/pd_gy/test_{idx}.json') \n", + "\n", + "trdl = torch.utils.data.DataLoader(trdt,4)\n", + "valdl = torch.utils.data.DataLoader(valdt,4)\n", + "tedl = torch.utils.data.DataLoader(valdt,4)\n", + "\n", + "from models import PrePostNet\n", + "\n", + "model = PrePostNet()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "122351a3-bd12-4556-8cdd-32602cfd154e", + "metadata": {}, + "outputs": [], + "source": [ + "from torch import nn\n", + "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "\n", + "batch_size = 4\n", + "loss = nn.CrossEntropyLoss()\n", + "lr = 0.01 \n", + "\n", + "\n", + "params = [p for p in model.parameters() if p.requires_grad]\n", + "opt = torch.optim.Adam(params,lr=lr)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "2b7b7aab-bba0-45c8-aa74-c1846640a836", + "metadata": {}, + "outputs": [], + "source": [ + "for x,y in trdl:\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "86c75c24-179b-4965-99c8-ae2324fda79b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0] acc/loss: 0.83/0.114464\n", + "[1] acc/loss: 0.83/0.126263\n", + "[2] acc/loss: 0.83/0.147722\n", + "[3] acc/loss: 0.83/0.117496\n", + "[4] acc/loss: 0.83/0.189358\n", + "[5] acc/loss: 0.83/21.716948\n" + ] + } + ], + "source": [ + "best_model,val_losses = run(trdl,valdl,model,loss,opt,device=device, exist_acc=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "d9df029f-0baa-4368-9f35-120eae3d63fc", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "def train(dl,model,lossf,opt,device='cuda'):\n", + " model.train()\n", + " for x,y in dl:\n", + " x,y = x.to(device),y.to(device)\n", + " pre = model(x)\n", + " loss = lossf(pre,y)\n", + "\n", + " opt.zero_grad()\n", + " loss.backward()\n", + " opt.step()\n", + "\n", + "def test(dl,model,lossf,epoch=None,exist_acc=True,device='cuda'):\n", + " model.eval()\n", + " size, acc , losses = len(dl.dataset) ,0,0\n", + " with torch.no_grad():\n", + " for x,y in dl:\n", + " x,y = x.to(device),y.to(device)\n", + " pre = model(x)\n", + " loss = lossf(pre,y)\n", + " \n", + " if exist_acc: \n", + " acc += (pre.argmax(1)==y).type(torch.float).sum().item()\n", + " losses += loss.item()\n", + " if exist_acc:\n", + " accuracy = round(acc/size,4)\n", + " else:\n", + " accuracy = None\n", + " val_loss = round(losses/size,6)\n", + " print(f'[{epoch}] acc/loss: {accuracy}/{val_loss}')\n", + " return accuracy,val_loss\n", + "\n", + "import copy\n", + "def run(trdl, valdl, model, loss, opt, epoch=100,patience = 5, exist_acc=True, device='cuda'):\n", + " val_losses = {0:1}\n", + " model = model.to(device)\n", + " for i in range(epoch):\n", + " train(trdl,model,loss,opt,device=device)\n", + " acc,val_loss = test(valdl,model,loss,epoch=i,exist_acc=exist_acc,device=device)\n", + "\n", + "\n", + " if min(val_losses.values() ) > val_loss:\n", + " val_losses[i] = val_loss\n", + " best_model = copy.deepcopy(model)\n", + " if i == min(val_losses,key=val_losses.get)+patience:\n", + " break\n", + " return best_model, val_losses" + ] + }, + { + "cell_type": "markdown", + "id": "d6f9972b-b4b4-430b-afda-a118c32bdc29", + "metadata": {}, + "source": [ + "### loss error " + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "50348362-7f3b-4340-820c-eea44a1b4dea", + "metadata": {}, + "outputs": [], + "source": [ + "pre = model(x.to(device))" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "1051b321-89c4-4856-8287-aad230776f33", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([3, 2])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pre.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "02ce8191-656c-41d4-ad0e-e2a943490a1d", + "metadata": {}, + "outputs": [], + "source": [ + "y = y.type(torch.LongTensor).to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "b5832d6c-9c6b-4300-9481-3af12c6c7048", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(0.4678, device='cuda:0', grad_fn=)" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loss(pre,y)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "torch0", + "language": "python", + "name": "torch0" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}