Skip to content

Commit

Permalink
fix(loss type in tr):
Browse files Browse the repository at this point in the history
dataset y: float > longtensor
nn.CrossEntropyLoss gets longtensor label

related to #2
  • Loading branch information
epsilon-deltta committed Aug 23, 2022
1 parent 1b26cb5 commit 6949c94
Show file tree
Hide file tree
Showing 2 changed files with 363 additions and 3 deletions.
9 changes: 6 additions & 3 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def __getitem__(self,idx):
y = F.one_hot(y,num_classes=2)

vital = vital.float()
y = y.float()
y = y.type(torch.LongTensor)

return vital, y

def __len__(self):
Expand Down Expand Up @@ -141,7 +142,7 @@ def getitem(self,idx):


y = item['nrs_2']

# y = torch.tensor(y)
# y = F.one_hot(y,num_classes=2)
#
Expand Down Expand Up @@ -177,8 +178,10 @@ def __getitem__(self,idx):

vital = torch.tensor(vital)
vital = vital.float()

y = torch.tensor(y)
y = y.float()
y = y.type(torch.LongTensor)

return vital, y

def __len__(self):
Expand Down
357 changes: 357 additions & 0 deletions train.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,357 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"id": "56c278bd-9e67-48b7-b11e-2c495744b6c5",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.utils.data import DataLoader\n",
"from dataset import VitalDataset\n",
"idx = 3\n",
"root_dir = '../data/pd_gy'\n",
"trdt = VitalDataset(root_dir,f'../data/pd_gy/train_{idx}.json')\n",
"valdt = VitalDataset(root_dir,f'../data/pd_gy/val_{idx}.json') \n",
"tedt = VitalDataset(root_dir,f'../data/pd_gy/test_{idx}.json') \n",
"\n",
"trdl = torch.utils.data.DataLoader(trdt,4)\n",
"valdl = torch.utils.data.DataLoader(valdt,4)\n",
"tedl = torch.utils.data.DataLoader(valdt,4)\n",
"\n",
"from models import PrePostNet\n",
"\n",
"model = PrePostNet()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "46956a6f-556c-42a5-88cd-99676be9db54",
"metadata": {},
"outputs": [],
"source": [
"from torch import nn\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"\n",
"batch_size = 4\n",
"loss = nn.CrossEntropyLoss()\n",
"lr = 0.01 \n",
"\n",
"\n",
"params = [p for p in model.parameters() if p.requires_grad]\n",
"opt = torch.optim.Adam(params,lr=lr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "94ebd30f-7bf0-4366-983d-fe63ce0b7049",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"best_model,val_losses = run(trdl,valdl,model,loss,opt,device=device, exist_acc=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "7be05afe-bc50-4e10-8e36-ab356f1ecd14",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"def train(dl,model,lossf,opt,device='cuda'):\n",
" model.train()\n",
" for x,y in dl:\n",
" x,y = x.to(device),y.to(device)\n",
" pre = model(x)\n",
" loss = lossf(pre,y)\n",
"\n",
" opt.zero_grad()\n",
" loss.backward()\n",
" opt.step()\n",
"\n",
"def test(dl,model,lossf,epoch=None,exist_acc=True,device='cuda'):\n",
" model.eval()\n",
" size, acc , losses = len(dl.dataset) ,0,0\n",
" with torch.no_grad():\n",
" for x,y in dl:\n",
" x,y = x.to(device),y.to(device)\n",
" pre = model(x)\n",
" loss = lossf(pre,y)\n",
" \n",
" if exist_acc: \n",
" acc += (pre.argmax(1)==y).type(torch.float).sum().item()\n",
" losses += loss.item()\n",
" if exist_acc:\n",
" accuracy = round(acc/size,4)\n",
" else:\n",
" accuracy = None\n",
" val_loss = round(losses/size,6)\n",
" print(f'[{epoch}] acc/loss: {accuracy}/{val_loss}')\n",
" return accuracy,val_loss\n",
"\n",
"import copy\n",
"def run(trdl, valdl, model, loss, opt, epoch=100,patience = 5, exist_acc=True, device='cuda'):\n",
" val_losses = {0:1}\n",
" model = model.to(device)\n",
" for i in range(epoch):\n",
" train(trdl,model,loss,opt,device=device)\n",
" acc,val_loss = test(valdl,model,loss,epoch=i,exist_acc=exist_acc,device=device)\n",
"\n",
"\n",
" if min(val_losses.values() ) > val_loss:\n",
" val_losses[i] = val_loss\n",
" best_model = copy.deepcopy(model)\n",
" if i == min(val_losses,key=val_losses.get)+patience:\n",
" break\n",
" return best_model, val_losses"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c499503a-56f2-46d7-ade5-49d5a4805494",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "387ee984-315f-40d9-8821-ec9b5e497268",
"metadata": {},
"source": [
"## for fastset "
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "cd3c0723-98ac-4b1d-9d40-7e381f65eac0",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.utils.data import DataLoader\n",
"from dataset import VitalDataset_fs\n",
"idx = 3\n",
"root_dir = f'../data/all_{idx}'\n",
"trdt = VitalDataset_fs(root_dir,f'../data/pd_gy/train_{idx}.json')\n",
"valdt = VitalDataset_fs(root_dir,f'../data/pd_gy/val_{idx}.json') \n",
"tedt = VitalDataset_fs(root_dir,f'../data/pd_gy/test_{idx}.json') \n",
"\n",
"trdl = torch.utils.data.DataLoader(trdt,4)\n",
"valdl = torch.utils.data.DataLoader(valdt,4)\n",
"tedl = torch.utils.data.DataLoader(valdt,4)\n",
"\n",
"from models import PrePostNet\n",
"\n",
"model = PrePostNet()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "122351a3-bd12-4556-8cdd-32602cfd154e",
"metadata": {},
"outputs": [],
"source": [
"from torch import nn\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"\n",
"batch_size = 4\n",
"loss = nn.CrossEntropyLoss()\n",
"lr = 0.01 \n",
"\n",
"\n",
"params = [p for p in model.parameters() if p.requires_grad]\n",
"opt = torch.optim.Adam(params,lr=lr)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "2b7b7aab-bba0-45c8-aa74-c1846640a836",
"metadata": {},
"outputs": [],
"source": [
"for x,y in trdl:\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "86c75c24-179b-4965-99c8-ae2324fda79b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0] acc/loss: 0.83/0.114464\n",
"[1] acc/loss: 0.83/0.126263\n",
"[2] acc/loss: 0.83/0.147722\n",
"[3] acc/loss: 0.83/0.117496\n",
"[4] acc/loss: 0.83/0.189358\n",
"[5] acc/loss: 0.83/21.716948\n"
]
}
],
"source": [
"best_model,val_losses = run(trdl,valdl,model,loss,opt,device=device, exist_acc=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "d9df029f-0baa-4368-9f35-120eae3d63fc",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"def train(dl,model,lossf,opt,device='cuda'):\n",
" model.train()\n",
" for x,y in dl:\n",
" x,y = x.to(device),y.to(device)\n",
" pre = model(x)\n",
" loss = lossf(pre,y)\n",
"\n",
" opt.zero_grad()\n",
" loss.backward()\n",
" opt.step()\n",
"\n",
"def test(dl,model,lossf,epoch=None,exist_acc=True,device='cuda'):\n",
" model.eval()\n",
" size, acc , losses = len(dl.dataset) ,0,0\n",
" with torch.no_grad():\n",
" for x,y in dl:\n",
" x,y = x.to(device),y.to(device)\n",
" pre = model(x)\n",
" loss = lossf(pre,y)\n",
" \n",
" if exist_acc: \n",
" acc += (pre.argmax(1)==y).type(torch.float).sum().item()\n",
" losses += loss.item()\n",
" if exist_acc:\n",
" accuracy = round(acc/size,4)\n",
" else:\n",
" accuracy = None\n",
" val_loss = round(losses/size,6)\n",
" print(f'[{epoch}] acc/loss: {accuracy}/{val_loss}')\n",
" return accuracy,val_loss\n",
"\n",
"import copy\n",
"def run(trdl, valdl, model, loss, opt, epoch=100,patience = 5, exist_acc=True, device='cuda'):\n",
" val_losses = {0:1}\n",
" model = model.to(device)\n",
" for i in range(epoch):\n",
" train(trdl,model,loss,opt,device=device)\n",
" acc,val_loss = test(valdl,model,loss,epoch=i,exist_acc=exist_acc,device=device)\n",
"\n",
"\n",
" if min(val_losses.values() ) > val_loss:\n",
" val_losses[i] = val_loss\n",
" best_model = copy.deepcopy(model)\n",
" if i == min(val_losses,key=val_losses.get)+patience:\n",
" break\n",
" return best_model, val_losses"
]
},
{
"cell_type": "markdown",
"id": "d6f9972b-b4b4-430b-afda-a118c32bdc29",
"metadata": {},
"source": [
"### loss error "
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "50348362-7f3b-4340-820c-eea44a1b4dea",
"metadata": {},
"outputs": [],
"source": [
"pre = model(x.to(device))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "1051b321-89c4-4856-8287-aad230776f33",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([3, 2])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pre.shape"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "02ce8191-656c-41d4-ad0e-e2a943490a1d",
"metadata": {},
"outputs": [],
"source": [
"y = y.type(torch.LongTensor).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "b5832d6c-9c6b-4300-9481-3af12c6c7048",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.4678, device='cuda:0', grad_fn=<NllLossBackward0>)"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loss(pre,y)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "torch0",
"language": "python",
"name": "torch0"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 6949c94

Please sign in to comment.