如下内容都是针对Pytorch 1.0-1.1介绍。
不少文章都是从Dataset等对象自下往上进行介绍,可是对于初学者而言,其实这并很差理解,由于有的时候会不自觉地陷入到一些细枝末节中去,而不能把握重点,因此本文将会自上而下地对Pytorch数据读取方法进行介绍。python
首先咱们看一下DataLoader.__next__的源代码长什么样,为方便理解我只选取了num_works为0的状况(num_works简单理解就是可以并行化地读取数据)。git
class DataLoader(object): ... def __next__(self): if self.num_workers == 0: indices = next(self.sample_iter) # Sampler batch = self.collate_fn([self.dataset[i] for i in indices]) # Dataset if self.pin_memory: batch = _utils.pin_memory.pin_memory_batch(batch) return batch
在阅读上面代码前,咱们能够假设咱们的数据是一组图像,每一张图像对应一个index,那么若是咱们要读取数据就只须要对应的index便可,即上面代码中的indices
,而选取index的方式有多种,有按顺序的,也有乱序的,因此这个工做须要Sampler
完成,如今你不须要具体的细节,后面会介绍,你只须要知道DataLoader和Sampler在这里产生关系。github
那么Dataset和DataLoader在何时产生关系呢?没错就是下面一行。咱们已经拿到了indices,那么下一步咱们只须要根据index对数据进行读取便可了。dom
再下面的if
语句的做用简单理解就是,若是pin_memory=True
,那么Pytorch会采起一系列操做把数据拷贝到GPU,总之就是为了加速。ide
综上能够知道DataLoader,Sampler和Dataset三者关系以下:
函数
在阅读后文的过程当中,你始终须要将上面的关系记在内心,这样能帮助你更好地理解。ui
要更加细致地理解Sampler原理,咱们须要先阅读一下DataLoader 的源代码,以下:this
class DataLoader(object): def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
能够看到初始化参数里有两种sampler:sampler
和batch_sampler
,都默认为None
。前者的做用是生成一系列的index,而batch_sampler则是将sampler生成的indices打包分组,获得一个又一个batch的index。例以下面示例中,BatchSampler
将SequentialSampler
生成的index按照指定的batch size分组。spa
>>>in : list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) >>>out: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
Pytorch中已经实现的Sampler
有以下几种:3d
SequentialSampler
RandomSampler
WeightedSampler
SubsetRandomSampler
须要注意的是DataLoader的部分初始化参数之间存在互斥关系,这个你能够经过阅读源码更深地理解,这里只作总结:
batch_sampler
,那么这些参数都必须使用默认值:batch_size
, shuffle
,sampler
,drop_last
.sampler
,那么shuffle
须要设置为False
sampler
和batch_sampler
都为None
,那么batch_sampler
使用Pytorch已经实现好的BatchSampler
,而sampler
分两种状况:
shuffle=True
,则sampler=RandomSampler(dataset)
shuffle=False
,则sampler=SequentialSampler(dataset)
仔细查看源代码其实能够发现,全部采样器其实都继承自同一个父类,即Sampler
,其代码定义以下:
class Sampler(object): r"""Base class for all Samplers. Every Sampler subclass has to provide an :meth:`__iter__` method, providing a way to iterate over indices of dataset elements, and a :meth:`__len__` method that returns the length of the returned iterators. .. note:: The :meth:`__len__` method isn't strictly required by :class:`~torch.utils.data.DataLoader`, but is expected in any calculation involving the length of a :class:`~torch.utils.data.DataLoader`. """ def __init__(self, data_source): pass def __iter__(self): raise NotImplementedError def __len__(self): return len(self.data_source)
因此你要作的就是定义好__iter__(self)
函数,不过要注意的是该函数的返回值须要是可迭代的。例如SequentialSampler
返回的是iter(range(len(self.data_source)))
。
另外BatchSampler
与其余Sampler的主要区别是它须要将Sampler做为参数进行打包,进而每次迭代返回以batch size为大小的index列表。也就是说在后面的读取数据过程当中使用的都是batch sampler。
Dataset定义方式以下:
class Dataset(object): def __init__(self): ... def __getitem__(self, index): return ... def __len__(self): return ...
上面三个方法是最基本的,其中__getitem__
是最主要的方法,它规定了如何读取数据。可是它又不一样于通常的方法,由于它是python built-in方法,其主要做用是能让该类能够像list同样经过索引值对数据进行访问。假如你定义好了一个dataset,那么你能够直接经过dataset[0]
来访问第一个数据。在此以前我一直没弄清楚__getitem__
是什么做用,因此一直不知道该怎么进入到这个函数进行调试。如今若是你想对__getitem__
方法进行调试,你能够写一个for循环遍历dataset来进行调试了,而不用构建dataloader等一大堆东西了,建议学会使用ipdb
这个库,很是实用!!!之后有时间再写一篇ipdb的使用教程。另外,其实咱们经过最前面的Dataloader的__next__
函数能够看到DataLoader对数据的读取其实就是用了for循环来遍历数据,不用往上翻了,我直接复制了一遍,以下:
class DataLoader(object): ... def __next__(self): if self.num_workers == 0: indices = next(self.sample_iter) batch = self.collate_fn([self.dataset[i] for i in indices]) # this line if self.pin_memory: batch = _utils.pin_memory.pin_memory_batch(batch) return batch
咱们仔细看能够发现,前面还有一个self.collate_fn
方法,这个是干吗用的呢?在介绍前咱们须要知道每一个参数的意义:
indices
: 表示每个iteration,sampler返回的indices,即一个batch size大小的索引列表self.dataset[i]
: 前面已经介绍了,这里就是对第i个数据进行读取操做,通常来讲self.dataset[i]=(img, label)
看到这不难猜出collate_fn
的做用就是将一个batch的数据进行合并操做。默认的collate_fn
是将img和label分别合并成imgs和labels,因此若是你的__getitem__
方法只是返回 img, label
,那么你可使用默认的collate_fn
方法,可是若是你每次读取的数据有img, box, label
等等,那么你就须要自定义collate_fn
来将对应的数据合并成一个batch数据,这样方便后续的训练步骤。