-
Notifications
You must be signed in to change notification settings - Fork 115
Dataset About PyTorch
Kelang edited this page Aug 2, 2020
·
5 revisions
图片从硬盘到模型流程详细描述(mnist为例子):
- 从
MyDataset
类中初始化txt
,txt
中有图片路径和标签 - 初始化
DataLoder
时,将train_data
传入,从而使DataLoder
拥有图片的路径 - 在一个
iteration
进行时,才读取一个batch
的图片数据enumerate()
函数会返回可迭代数据的一个“元素”在这里data
是一个batch
的图片数据和标签,data
是一个list
-
class DataLoader()
中再调用class _DataLoderIter()
- 在
_DataLoderiter()
类中会跳到__next__(self)
函数,在该函数中会通过indices = next(self.sample_iter)
获取一个batch
的indices
再通过batch = self.collate_fn([self.dataset[i] for i in indices])
获取一个batch
的数据.在batch = self.collate_fn([self.dataset[i] for i in indices])
中会调用self.collate_fn
函数 -
self.collate_fn
中会调用MyDataset
类中的__getitem__()
函数,在__getitem__()
中通过Image.open(fn).convert('RGB')
读取图片 - 通过
Image.open(fn).convert('RGB')
读取图片之后,会对图片进行预处理,例如减均值,除以标准差,随机裁剪等等一系列提前设置好的操作。 具体transform
的用法将用单独一小节介绍,最后返回img
,label
,再通过self.collate_fn
来拼接成一个batch
。一个batch
是一个list
,有两个元素,第一个元素是图片数据,是一个4D的Tensor
,shape
为(64,3,32,32)
,第二个元素是标签shape
为(64)。 - 将图片数据转换成
Variable
类型(老版本需要,现在不用了),然后称为模型真正的输入inputs, labels = Variable(inputs), Variable(labels) outputs = net(inputs)
Pseudocode:
1. main.py: train_data = MyDataset(txt_path=train_txt_path, ...) --->
2. main.py: train_loader = DataLoader(dataset=train_data, ...) --->
3. main.py: for i, data in enumerate(train_loader, 0) --->
4. dataloder.py: class DataLoader(): def __iter__(self): return _DataLoaderIter(self) --->
5. dataloder.py: class _DataLoderIter(): def __next__(self): batch = self.collate_fn([self.dataset[i] for i in indices]) --->
6. tool.py: class MyDataset(): def __getitem__(): img = Image.open(fn).convert('RGB') --->
7. tool.py: class MyDataset(): img = self.transform(img) --->
8. main.py: inputs, labels = inputs, labels
outputs = net(inputs)
直接通过keras.dataset
加载mnist
数据集,不能自动下载的话可以手动下载.npz
并保存至相应目录下。保存的时候一行为一个图像信息,便于后续读取。
由于
mnist
数据集其实是灰度图,这里用matplotlib
保存的图像是伪彩色图像。如果用scipy.misc.imsave
的话保存的则是灰度图像。
xxx_img.txt
文件中存放的是每张图像的名字。
xxx_label.txt
文件中存放的是类别标记。
def LoadData(root_path, base_path, training_path, test_path):
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_baseset = np.concatenate((x_train, x_test))
y_baseset = np.concatenate((y_train, y_test))
train_num = len(x_train)
test_num = len(x_test)
# baseset
file_img = open((os.path.join(root_path, base_path) + 'baseset_img.txt'), 'w')
file_label = open((os.path.join(root_path, base_path) + 'baseset_label.txt'), 'w')
for i in range(train_num + test_num):
file_img.write(root_path + base_path + 'img/' + str(i) + '.png\n') # name
file_label.write(str(y_baseset[i]) + '\n') # label
# scipy.misc.imsave(root_path + base_path + '/img/'+str(i) + '.png', x_baseset[i])
matplotlib.image.imsave(root_path + base_path + 'img/' + str(i) + '.png', x_baseset[i])
file_img.close()
file_label.close()
# trainingset
file_img = open((os.path.join(root_path, training_path) + 'trainingset_img.txt'), 'w')
file_label = open((os.path.join(root_path, training_path) + 'trainingset_label.txt'), 'w')
for i in range(train_num):
file_img.write(root_path + training_path + 'img/' + str(i) + '.png\n') # name
file_label.write(str(y_train[i]) + '\n') # label
# scipy.misc.imsave(root_path + training_path + '/img/'+str(i) + '.png', x_train[i])
matplotlib.image.imsave(root_path + training_path + 'img/' + str(i) + '.png', x_train[i])
file_img.close()
file_label.close()
# testset
file_img = open((os.path.join(root_path, test_path) + 'testset_img.txt'), 'w')
file_label = open((os.path.join(root_path, test_path) + 'testset_label.txt'), 'w')
for i in range(test_num):
file_img.write(root_path + test_path + 'img/' + str(i) + '.png\n') # name
file_label.write(str(y_test[i]) + '\n') # label
# scipy.misc.imsave(root_path + test_path + '/img/'+str(i) + '.png', x_test[i])
matplotlib.image.imsave(root_path + test_path + 'img/' + str(i) + '.png', x_test[i])
file_img.close()
file_label.close()
定义自己的Dataset
类,PyTorch
训练数据时需要数据集为Dataset
类,便于迭代等等,这里将加载保存之后的数
据封装成Dataset
类,继承该类需要写初始化方法__init__
,获取指定下标数据的方法__getitem__
,
获取数据个数的方法__len__
。这里尤其需要注意的是要把label
转为LongTensor
类型的。
class DataProcessingMnist(Dataset):
def __init__(self, root_path, imgfile_path, labelfile_path, imgdata_path, transform=None):
self.root_path = root_path
self.transform = transform
self.imagedata_path = imgdata_path
img_file = open((root_path + imgfile_path), 'r')
self.image_name = [x.strip() for x in img_file]
img_file.close()
label_file = open((root_path + labelfile_path), 'r')
label = [int(x.strip()) for x in label_file]
label_file.close()
self.label = torch.LongTensor(label) # 这句很重要,一定要把label转为LongTensor类型的
def __getitem__(self, idx):
image = Image.open(str(self.image_name[idx]))
image = image.convert('RGB')
if self.transform is not None:
image = self.transform(image)
label = self.label[idx]
return image, label
def __len__(self):
return len(self.image_name)
__getitem__
接收一个index
,然后返回图片数据和标签,这个index
通常指的是一个list
的index
,这个list
的每个元素就包含了图片数据的路径和标签信息。然而,如何制作这个list
呢,通常的方法是将图片的路径和标签信息存储在一个txt
中.
那么读取自己数据的基本流程就是:
- 制作存储了图片的路径和标签信息的
txt
- 将这些信息转化为
list
,该list
每一个元素对应一个样本 - 通过
__getitem__
函数,读取数据和标签,并返回数据和标签
import os
import matplotlib
import matplotlib.image as image
from keras.datasets import mnist
import numpy as np
from torch.utils.data.dataset import Dataset
from PIL import Image
import torch
import scipy.misc
import tensorflow as tf
root_path = './mnist_np2dataset/'
base_path = 'baseset/'
training_path = 'trainingset/'
test_path = 'testset/'
# LoadData(root_path, base_path, training_path, test_path)
training_imgfile = training_path + 'trainingset_img.txt'
training_labelfile = training_path + 'trainingset_label.txt'
training_imgdata = training_path + 'img/'
#实例化一个类
dataset = DataProcessingMnist(root_path, training_imgfile, training_labelfile, training_imgdata)
name = dataset.image_name
print(name[0])
# 获取固定下标的图像
im, label = dataset.__getitem__(0)
print("type im:",type(im))
print("type label:",type(label))