-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDataSet.py
More file actions
80 lines (77 loc) · 2.89 KB
/
DataSet.py
File metadata and controls
80 lines (77 loc) · 2.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import os
import nibabel as nib
import numpy as np
import random
import skimage.transform as skTrans
class CrossDataSet(Dataset):
def __init__(self, root, txt_name, train_flag=False, transformer=None, slice=16, downsample=4):
'''
:param root: data path
:param txt_name: data name
:param train_flag: whether to use on train laoder
:param transformer:
:param slice: 3D slicen num
:param downsample: use downsample on W * H
'''
#origin image :[512,512,120]
self.images_list = self.get_file(txt_name)
self.train_flag=train_flag
self.root = root
self.transform = transformer
self.slice = slice
self.downsample=downsample
def get_file(self,txt_name):
file=open(txt_name,'r')
list=file.readlines()
return list
def __getitem__(self, index):
'''
:param index:
:return: image:[1,512/ds,512/ds,slice]
label:[1,512/ds,512/ds,slice]
name of image
name of label
only return image if train_flag=False
'''
if (self.train_flag==True): index=index*2
file_name=self.root+'/'+self.images_list[index].strip()
img = nib.load(file_name).get_fdata()
spos=img.shape[-1]
r=random.randint(0,spos-self.slice-1)
img=img[:,:,r:r+self.slice]
#img = img.
if (self.train_flag):
label_name = self.root + '/' + self.images_list[index+1].strip()
label = nib.load(label_name).get_fdata()
label = label[:,:,r:r+self.slice]
if (self.transform!=None): img = self.transform(img)
img = skTrans.resize(img, (img.shape[0]//self.downsample, img.shape[1]//self.downsample, img.shape[2]), order=1, preserve_range=True).astype(np.float32)
label = skTrans.resize(label , (label.shape[0] // self.downsample, label.shape[1] // self.downsample, label.shape[2]), order=1, preserve_range=True)
label[label>0]=1
label=label.astype(np.int64)
label = np.expand_dims(label, axis=0)
img = np.expand_dims(img, axis=0)
if (self.train_flag):
return img,label,file_name,label_name
else:
return img,file_name
def __len__(self):
if (self.train_flag):
return len(self.images_list)//2
else:
return len(self.images_list)
if __name__ == '__main__':
dataSet=CrossDataSet(root='D:\\Download\\cross\\source_training',txt_name='./data/train_source.txt',train_flag=True)
img,label,file_name,label_name=dataSet.__getitem__(0)
print(file_name)
print(label_name)
print(img.shape)
print(type(img))
print(img.dtype)
print(label.shape)
print(type(label))
#print(label.dtype)
#label[label>0]=1.
print(np.sum(label==0))