-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathdataloader.py
33 lines (27 loc) · 1.13 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from torch.utils import data
import os
from PIL import Image
class EvalDataset(data.Dataset):
def __init__(self, pred_root, label_root):
pred_dirs = os.listdir(pred_root)
label_dirs = os.listdir(label_root)
dir_name_list = []
for idir in pred_dirs:
if idir in label_dirs:
pred_names = os.listdir(os.path.join(pred_root, idir))
label_names = os.listdir(os.path.join(label_root, idir))
for iname in pred_names:
if iname in label_names:
dir_name_list.append(os.path.join(idir, iname))
self.image_path = list(
map(lambda x: os.path.join(pred_root, x), dir_name_list))
self.label_path = list(
map(lambda x: os.path.join(label_root, x), dir_name_list))
def __getitem__(self, item):
pred = Image.open(self.image_path[item]).convert('L')
gt = Image.open(self.label_path[item]).convert('L')
if pred.size != gt.size:
pred = pred.resize(gt.size, Image.BILINEAR)
return pred, gt
def __len__(self):
return len(self.image_path)