引入
图像分类数据集最经常使用的是手写数字识别数据集MNIST (1),可是大部分模型在其上的分类精度都超过了95%。为了更直观地观察算法之间的差别,将使用一个图像内容更加复杂的数据集[Fashion-MNIST (2)]。
接下来的部分将使用torchvision包,主要用于构建计算机视觉模型,主要由如下4部分组成:html
组成 | 功能 |
---|---|
torchvision.datasets | 加载数据的函数及经常使用的数据集接口 |
torchvision.models | 包含经常使用的模型结构 (含预训练模型) |
torchvision.transforms | 经常使用的图片变化,例如裁剪、旋转 |
torchvision…utils | 其余方法 |
代码已上传至github:
https://github.com/InkiInki/Python/blob/master/Python1/deepLearning/ImageMnist.pypython
1 获取数据集
须要导入的包以下:git
import torch import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import time import sys from IPython import display
下面,将经过torchvision.datasets下载数据集,第一次调用时会自动从网上获取数据 (若出现速度较慢,请向后查看注意);经过参数train来指定获取训练集或者测试集;经过transform = transforms.Tensor()将数据转化为Tensor,若是不转换,则返回PIL图片。
transforms.Tensor()将尺寸为 ( H × W × C H×W×C H×W×C)且数据位于 (0, 255)的PIL图片或数据类型为np.uint8的Numpy转换为尺寸为 ( C × H × W C×H×W C×H×W)且数据类型为torch.float32且位于 (0.0, 1.0)的Tensor。github
使用代码以下:web
class ImageMnist(): def __init__(self): self.mnist_train = torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor()) self.mnist_test = torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor()) if __name__ == "__main__": test = ImageDataSet() test.__init__() print(test.mnist_train) print(len(test.mnist_train), len(test.mnist_test))
运行结果:算法
Dataset FashionMNIST Number of datapoints: 60000 Root location: C:\Users\Administrator/DataSets/FashionMNIST Split: Train StandardTransform Transform: ToTensor() 60000 10000
注意:
1)若是用像素值表示图片数据,那么一概将其类型设置成unit8,以免没必要要的bug;
2)第一次下载时速度也许很慢,推荐在cmd中输入如下代码,并复制出现的http连接下载:app
import torchvision import torchvision.transforms as transforms torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor()) torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())
2 简单操做
能够经过下标来访问任意一个样本:svg
if __name__ == "__main__": test = ImageMnist() test.__init__() data, label = test.mnist_train[0] print(data.shape) print(label)
运行结果:函数
torch.Size([1, 28, 28]) # 分别对应通道数、图像高、图像宽 9
Fashion-MNIST共10个类别,分别为t-shirt、trouser、pullover、dress、coat、sandal、shirt、sneaker、bag和ankle boot,如下函数能够将数值标签转换成相应的文本标签:学习
... def get_text_labels(self, labels): text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] return [text_labels[int(i)] for i in labels] if __name__ == "__main__": test = ImageMnist() test.__init__() data, label = test.mnist_train[0] print(test.get_text_labels([label]))
运行结果:
['ankle boot']
如今定义一个能够在一行里画出多张图像和对应标签的函数:
... def show_mnist(self, images, labels): display.set_matplotlib_formats('svg') _, figs = plt.subplots(1, len(images), figsize=(12, 12)) # zip()接受一系列可迭代对象做为参数,将对象中对应的元素打包成一个个元组,而后返回由这些元组组成的列表 for f, img, lbl in zip(figs, images, labels): f.imshow(img.view((28, 28)).numpy()) f.set_title(lbl) f.axis('off') plt.show() if __name__ == "__main__": test = ImageMnist() test.__init__() x, y = [], [] for i in range(10): x.append(test.mnist_train[i][0]) y.append(test.mnist_train[i][1]) test.show_mnist(x, test.get_text_labels(y))
运行结果:
3 读取小批量
torch的DataLoader中一个很方便的功能是运行使用多进程来加速读取数据,这里经过参数num_workers来设置4个进程读取数据。
... def data_iter(self, batch_size=256): if sys.platform.startswith('win'): num_workers = 0 # 0表示不须要额外的进程来加速读取数据 else: num_workers = 4 train_iter = torch.utils.data.DataLoader(self.mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers) test_iter = torch.utils.data.DataLoader(self.mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers) return train_iter, test_iter if __name__ == "__main__": start = time.time() test = ImageMnist() test.__init__() train_iter, test_iter = test.data_iter() for x, y in train_iter: continue print("%.2f sec" % (time.time() - start))
运行结果:
6.65 sec
4 完整代码
''' @(#)test.py The class of test. Author: Yu-Xuan Zhang Email: inki.yinji@qq.com Created on May 05, 2020 Last Modified on May 05, 2020 @author: inki ''' import torch import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import time import sys from IPython import display class ImageMnist(): def __init__(self): self.mnist_train = torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor()) self.mnist_test = torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor()) def get_text_labels(self, labels): text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] return [text_labels[int(i)] for i in labels] def show_mnist(self, images, labels): display.set_matplotlib_formats('svg') _, figs = plt.subplots(1, len(images), figsize=(12, 12)) for f, img, lbl in zip(figs, images, labels): f.imshow(img.view((28, 28)).numpy()) f.set_title(lbl) f.axis('off') plt.show() def data_iter(self, batch_size=256): if sys.platform.startswith('win'): num_workers = 0 else: num_workers = 4 train_iter = torch.utils.data.DataLoader(self.mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers) test_iter = torch.utils.data.DataLoader(self.mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers) return train_iter, test_iter if __name__ == "__main__": start = time.time() test = ImageMnist() test.__init__() train_iter, test_iter = test.data_iter() for x, y in train_iter: continue print("%.2f sec" % (time.time() - start))
致谢
特别感谢李沐、Aston Zhang等老师的这本《动手学深度学习》一书~
本文分享 CSDN - 因吉。
若有侵权,请联系 support@oschina.cn 删除。
本文参与“OSC源创计划”,欢迎正在阅读的你也加入,一块儿分享。