本节将使用torchvision包,它是服务于pytorch深度学习框架的,主要用来构建计算机视觉模型。
torchvision主要由如下几个部分构成:python
导入本节须要的包或者模块算法
import torch import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import time import sys sys.path.append('..') # 为了导入上层目录的d2lzh_pytorch import d2lzh_pytorch as d2l
经过调用torchvision的torchvision.datasets来下载这个数据集
能够经过train参数获取指定的训练集或者测试集、
测试集只用了评估模型,并不用来训练模型数组
同时指定了参数transform = transform.ToTensor()使全部数据转化为Tensor,若是不进行转化,则返回的是PIL照片。
transform.ToTensor()将尺寸为(H,W,C)且数据位于[0,255]的PIL图片或者数据类型为np.unit8的Numpy数组转化为(CxHxW)且数据类型为torch.float32且位于[0.0,1.0]的Tensor。app
mnist_train= torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',download=True,train=True,transform=transforms.ToTensor()) mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',download=True,train=False,transform=transforms.ToTensor())
print(type(mnist_train)) print(len(mnist_train), len(mnist_test))
<class 'torchvision.datasets.mnist.FashionMNIST'> 60000 10000
feature,label = mnist_train[0] print(feature.shape,label) # channel * height* width
torch.Size([1, 28, 28]) tensor(9)
feature对应的高和宽均为28像素的图像,因为咱们使用了transforms.ToTensor(),因此每一个像素的数值为[0,1]的32位浮点数。须要注意的是,feature的尺寸是(CxHxW)的,而不是(HxWxC)。第一维是通道数,由于数据集中是灰度图像,因此通道数为1,后面两维分别是图像的高和宽。框架
Fashion_MNIST中一共包括了10个类别,分别是t-shirt(T恤),trouser(裤子),pullover(套衫),dress(连衣裙),coat(外套),sandal(凉鞋),shirt(衬衫),sneaker(运动鞋),bag(包)和ankle boot(短靴)svg
import d2lzh_pytorch as d2l
def get_fashion_mnist_labels(labels): text_labels = ['t-shirt','trouser','pullover','dress','coat','sandal', 'shirt','sneaker','bag','ankle boost' ] return [text_labels[int(i)] for i in labels] def show_fashion_mnist(images,labels): d2l.use_svg_display() _,figs = plt.subplots(1,len(images),figsize=(12,12)) # 1行10列 for f ,img,lbl in zip(figs,images,labels): f.imshow(img.view((28,28)).numpy()) f.set_title(lbl) f.axes.get_xaxis().set_visible(False) f.axes.get_yaxis().set_visible(False) plt.show()
X,y = [],[] for i in range(10): X.append(mnist_train[i][0]) y.append(mnist_test[i][1]) show_fashion_mnist(X,get_fashion_mnist_labels(y))
咱们将在训练集上训练模型,并将训练好的模型预测测试集上评估模型的表现。
能够用torch.utils.data.Dataloader来建立一个读取小批量样本的DataLoader实例。函数
在实际中,数据读取常常是训练的性能瓶颈,特别是当模型较为简单或者计算硬件性能较高时,pytorch的DataLoader中一个很方便的功能是容许使用多进程来加速数据读取。这里咱们经过参数num_workers来设置进程数来加速读取数据性能
batch_size= 256 if sys.platform.startswith('win'): num_worker=0 # 表示不用额外的进程来加速读取数据 else: num_worker=4 train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle=True,num_workers=num_worker) test_iter = torch.utils.data.DataLoader(mnist_test,batch_size=batch_size,shuffle=False,num_workers=num_worker)
start = time.time() for X,y in train_iter: continue print('%.2f sec' % (time.time()-start))
1.28 sec