Pytorch数据读取框架

训练一个模型须要有一个数据库,一个网络,一个优化函数。数据读取是训练的第一步,如下是pytorch数据输入框架。html

1)实例化一个数据库

假设咱们已经定义了一个FaceLandmarksDataset数据库,此数据库将在如下创建。数据库

import FaceLandmarksDataset
face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                    root_dir='data/faces/',
                                    transform=transforms.Compose([ Rescale(256), RandomCrop(224), ToTensor()]) )

 

或者使用torchvision.datasets里封装的数据集(MNIST、Fashion-MNIST、KMNIST、EMNIST、COCO、LSUN、ImageFolder、DatasetFolder、Imagenet-十二、CIFAR、STL十、SVHN、PhotoTour、SBU、Flickr、VOC、Cityscapes)网络

import torchvision.datasets
imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')

2)建立一个数据加载器

import torch.utils.data.DataLoader
imagenet_loader = torch.utils.data.DataLoader(imagenet_data,
                                          batch_size=4,  
                                          shuffle=True,
                                          num_workers=4)
#or

facelandmark_loader = torch.utils.data.DataLoader(face_dataset,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=4) 

可见,数据加载器是通用的,只有数据库实例不同,其它的都参数都同样,参数值能够根据任务须要本身调。框架

3)使用数据库

数据加载器可迭代的,咱们能够使用数据库:dom

for item in facelandmark_loader:
     images,labels = item
do_somethi

固然, 咱们也能够直接对数据库实例face_dataset进行下标操做,但这样只可以每次获取一条数据。函数

sample = face_dataset[index]
相关文章
相关标签/搜索