-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
48 lines (39 loc) · 1.42 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torchvision
from torch.utils.data import Dataset
import pandas as pd
import os
from data_augment import *
class LeafDataset(Dataset):
def __init__(self, csv_file, imgs_path, transform=None, augment=True):
self.df = pd.read_csv(csv_file)
self.imgs_path = imgs_path
self.transform = transform
self.len = self.df.shape[0]
self.augmentflag = augment
def __len__(self):
return self.len
def __getitem__(self, index):
row = self.df.iloc[index]
image_path = self.imgs_path + row[0]
image = torchvision.io.read_image(image_path).float()
target = torch.tensor(row[-6:], dtype=torch.float)
if self.transform:
return self.transform(image), target
if(self.augmentflag == True):
image = augmentfunc(image) # 数据增强
return image, target
class TestDataSet(Dataset):
def __init__(self, main_dir, transform=None):
self.main_dir = main_dir
self.transform = transform
self.total_imgs = os.listdir("test_images/")
def __len__(self):
return len(self.total_imgs)
def __getitem__(self, idx):
name = self.total_imgs[idx]
img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
image = torchvision.io.read_image(img_loc).float()
if self.transform:
return self.transform(image), name
return image, name