虽然torchvision.datasets中已经封装了好多通用的数据集,可是咱们在使用Pytorch作深度学习任务的时候,会面临着自定义数据库来知足本身的任务须要。如咱们要训练一我的脸关键点检测算法,提供的训练数据标注以下形式,存在CSV文件中:算法
image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y 0805personali01.jpg,27,83,27,98, ... 84,134 1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312
在本次教程中,咱们须要用到两个额外的包:数据库
首先学习如何使用pandas库解析csv文件学习
import pandas as pd
landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv') n = 65 img_name = landmarks_frame.iloc[n, 0] landmarks = landmarks_frame.iloc[n, 1:].as_matrix() landmarks = landmarks.astype('float').reshape(-1, 2) print('Image name: {}'.format(img_name)) print('Landmarks shape: {}'.format(landmarks.shape)) print('First 4 Landmarks: {}'.format(landmarks[:4]))
torch.utils.data.Dataset
是一个表示数据库的抽象类,自定义数据库须要继承这个类,而且重写其如下方法:spa
__len__ :返回数据库的大小. __getitem__ :支持使用下标的方式 如dataset[i] 来获取第i个样本
如下建立人脸特征点检测的数据库。咱们将在__init__中解析csv文件,而在__getitem__中读取图片。这样能够在须要图片是才加载,内存效率高。此外,咱们还能够先将数据集封装成lmdb数据库,读取速度更快。code
import torch.utils.data.Dataset as Dataset class FaceLandmarksDataset(Dataset): """Face Landmarks dataset.""" def __init__(self, csv_file, root_dir, transform=None): """ Args: csv_file (string): 到达标注文件cvs的路径. root_dir (string): 全部图片的根目录. transform (callable, optional): (可选参数)对每个样本进行转换. """ self.landmarks_frame = pd.read_csv(csv_file) self.root_dir = root_dir self.transform = transform def __len__(self): return len(self.landmarks_frame) def __getitem__(self, idx): img_name = os.path.join(self.root_dir,self.landmarks_frame.iloc[idx, 0]) #第idx条数据的第一个字段,即文件名称 image = io.imread(img_name) #读取图像数据 landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix() #读取第idx条数据的第二个字段及其以后的全部字段,即全部关键点的坐标。而后转成矩阵形式 landmarks = landmarks.astype('float').reshape(-1, 2) #将矩阵reshape成n行两列矩阵 sample = {'image': image, 'landmarks': landmarks} #封装数据 if self.transform: sample = self.transform(sample) #数据转换 return sample #返回数据
注:__getitem__每次只返回一个条数据,至于batch的封装能够在DataLoader中设置batchsize,至于读取速度能够设置num_worker。orm