在使用pytorch训练模型,常常须要加载大量图片数据,所以pytorch提供了好用的数据加载工具Dataloader。
为了实现小批量循环读取大型数据集,在Dataloader类具体实现中,使用了迭代器和生成器。
这一应用场景正是python中迭代器模式的意义所在,所以本文对Dataloader中代码进行解读,能够更好的理解python中迭代器和生成器的概念。html
本文的内容主要有:python
python中围绕着迭代有如下概念:git
这三个概念互相关联,并非孤立的。在可迭代对象的基础上发展了迭代器,在迭代器的基础上又发展了生成器。
学习这些概念的名词解释没有多大意义。编程中不少的抽象概念都是为了更好的实现某些功能,才去人为创造的协议和模式。
所以,要理解它们,须要探究概念背后的逻辑,为何这样设计?要解决的真正问题是什么?在哪些场景下应用是最好的?github
迭代模式首先要解决的基础问题是,须要按必定顺序获取集合内部数据,好比循环某个list。
当数据很小时,不会有问题。但当读取大量数据时,一次性读取会超出内存限制,所以想出如下方法:编程
循环读数据可分为下面三种应用场景,对应着容器(可迭代对象),迭代器和生成器:app
for x in container
: 为了遍历python内部序列容器(如list), 这些类型内部实现了__getitem__() 方法,能够从0开始按顺序遍历序列容器中的元素。for x in iterator
: 为了循环用户自定义的迭代器,须要实现__iter__和__next__方法,__iter__是迭代协议,具体每次迭代的执行逻辑在 __next__或next方法里for x in generator
: 为了节省循环的内存和加速,使用生成器来实现惰性加载,在迭代器的基础上加入了yield语句,最简单的例子是 range(5)代码示例:dom
# 普通循环 for x in list numbers = [1, 2, 3,] for n in numbers: print(n) # 1,2,3 # for循环实际干的事情 # iter输入一个可迭代对象list,返回迭代器 # next方法取数据 my_iterator = iter(numbers) next(my_iterator) # 1 next(my_iterator) # 2 next(my_iterator) # 3 next(my_iterator) # StopIteration exception # 迭代器循环 for x in iterator for i,n in enumerate(numbers): print(i,n) # 0,1 / 1,3 / 2,3 # 生成器循环 for x in generator for i in range(3): print(i) # 0,1,2
上面示例代码中python内置函数iter和next的用法:ide
比较容易混淆的是__iter__和__next__两个方法。它们的区别是:函数
__iter__返回自身的作法有点相似 python中的类型系统。为了保持一致性,python中一切皆对象。
每一个对象建立后,都有类型指针,而类型对象的指针指向元对象,元对象的指针指向自身。工具
生成器,是在__iter__方法中加入yield语句,好处有:
yield做用:
for x in container
方法:
list, deque, …
set, frozensets, …
dict, defaultdict, OrderedDict, Counter, …
tuple, namedtuple, …
str
for x in iterator
方法:
enumerate()
# 加上list的indexsorted()
# 排序listreversed()
# 倒序listzip()
# 合并listfor x in generator
方法:
range()
map()
filter()
reduce()
[x for x in list(...)]
pytorch采用for x in iterator
模式,从Dataloader类中读取数据。
如下代码只截取了单线程下的数据读取。
class DataLoader(object): r""" Data loader. Combines a dataset and a sampler, and provides single- or multi-process iterators over the dataset. """ def __init__(self, dataset, batch_size=1, shuffle=False, ...): self.dataset = dataset self.batch_sampler = batch_sampler ... def __iter__(self): return _DataLoaderIter(self) def __len__(self): return len(self.batch_sampler) class _DataLoaderIter(object): r"""Iterates once over the DataLoader's dataset, as specified by the sampler""" def __init__(self, loader): self.sample_iter = iter(self.batch_sampler) ... def __next__(self): if self.num_workers == 0: # same-process loading indices = next(self.sample_iter) # may raise StopIteration batch = self.collate_fn([self.dataset[i] for i in indices]) if self.pin_memory: batch = pin_memory_batch(batch) return batch ... def __iter__(self): return self
Dataloader类中读取数据Index的方法,采用了 for x in generator
方式,可是调用采用iter和next函数
class RandomSampler(object): """random sampler to yield a mini-batch of indices.""" def __init__(self, batch_size, dataset, drop_last=False): self.dataset = dataset self.batch_size = batch_size self.num_imgs = len(dataset) self.drop_last = drop_last def __iter__(self): indices = np.random.permutation(self.num_imgs) batch = [] for i in indices: batch.append(i) if len(batch) == self.batch_size: yield batch batch = [] ## if images not to yield a batch if len(batch)>0 and not self.drop_last: yield batch def __len__(self): if self.drop_last: return self.num_imgs // self.batch_size else: return (self.num_imgs + self.batch_size - 1) // self.batch_size batch_sampler = RandomSampler(batch_size. dataset) sample_iter = iter(batch_sampler) indices = next(sample_iter)
本文总结了python中循环的三种模式:
for x in container
可迭代对象for x in iterator
迭代器for x in generator
生成器pytorch中的数据加载模块 Dataloader,使用生成器来返回数据的索引,使用迭代器来返回须要的张量数据,能够在大量数据状况下,实现小批量循环迭代式的读取,避免了内存不足问题。