pytorch :: Dataloader中的迭代器和生成器应用

在使用pytorch训练模型,常常须要加载大量图片数据,所以pytorch提供了好用的数据加载工具Dataloader。
为了实现小批量循环读取大型数据集,在Dataloader类具体实现中,使用了迭代器和生成器。
这一应用场景正是python中迭代器模式的意义所在,所以本文对Dataloader中代码进行解读,能够更好的理解python中迭代器和生成器的概念。html

本文的内容主要有:python

  1. 解释python中的迭代器和生成器概念
  2. 解读pytorch中Dataloader代码,如何使用迭代器和生成器实现数据加载

python迭代基础

python中围绕着迭代有如下概念:git

  1. 可迭代对象 iterables
  2. 迭代器 iterator
  3. 生成器 generator

这三个概念互相关联,并非孤立的。在可迭代对象的基础上发展了迭代器,在迭代器的基础上又发展了生成器。
学习这些概念的名词解释没有多大意义。编程中不少的抽象概念都是为了更好的实现某些功能,才去人为创造的协议和模式。
所以,要理解它们,须要探究概念背后的逻辑,为何这样设计?要解决的真正问题是什么?在哪些场景下应用是最好的?github

迭代模式首先要解决的基础问题是,须要按必定顺序获取集合内部数据,好比循环某个list。
当数据很小时,不会有问题。但当读取大量数据时,一次性读取会超出内存限制,所以想出如下方法:编程

  • 把大的数据分红几个小块,分批处理
  • 惰性的取值方式,按需取值

循环读数据可分为下面三种应用场景,对应着容器(可迭代对象),迭代器和生成器:app

  1. for x in container: 为了遍历python内部序列容器(如list), 这些类型内部实现了__getitem__() 方法,能够从0开始按顺序遍历序列容器中的元素。
  2. for x in iterator: 为了循环用户自定义的迭代器,须要实现__iter__和__next__方法,__iter__是迭代协议,具体每次迭代的执行逻辑在 __next__或next方法里
  3. for x in generator: 为了节省循环的内存和加速,使用生成器来实现惰性加载,在迭代器的基础上加入了yield语句,最简单的例子是 range(5)

代码示例:dom

# 普通循环 for x in list
numbers = [1, 2, 3,]
for n in numbers:
    print(n) # 1,2,3

# for循环实际干的事情
# iter输入一个可迭代对象list,返回迭代器
# next方法取数据
my_iterator = iter(numbers)
next(my_iterator) # 1
next(my_iterator) # 2
next(my_iterator) # 3
next(my_iterator) # StopIteration exception

# 迭代器循环 for x in iterator
for i,n in enumerate(numbers):
    print(i,n) # 0,1 / 1,3 / 2,3

# 生成器循环 for x in generator
for i in range(3):
    print(i) # 0,1,2

上面示例代码中python内置函数iter和next的用法:ide

  • iter函数,调用__iter__,返回一个迭代器
  • next函数,输入迭代器,调用__next__,取出数据

比较容易混淆的是__iter__和__next__两个方法。它们的区别是:函数

  1. __iter__是为了能够迭代,真正执行取数据的逻辑是__next__方法实现的,实际调用是经过next(iterator)完成
  2. __iter__能够返回自身(return self),实际读取数据的实现放在__next__方法
  3. __iter__能够和yield搭配,返回生成器对象

__iter__返回自身的作法有点相似 python中的类型系统。为了保持一致性,python中一切皆对象。
每一个对象建立后,都有类型指针,而类型对象的指针指向元对象,元对象的指针指向自身。工具

生成器,是在__iter__方法中加入yield语句,好处有:

  1. 减小循环判断逻辑的复杂度
  2. 惰性取值,节省内存和时间

yield做用:

  1. 代替函数中的return语句
  2. 记住上一次循环迭代器内部元素的位置

三种循环模式经常使用函数

for x in container方法:

  • list, deque, …
  • set, frozensets, …
  • dict, defaultdict, OrderedDict, Counter, …
  • tuple, namedtuple, …
  • str

for x in iterator方法:

  • enumerate() # 加上list的index
  • sorted() # 排序list
  • reversed() # 倒序list
  • zip() # 合并list

for x in generator方法:

  • range()
  • map()
  • filter()
  • reduce()
  • [x for x in list(...)]

Dataloder源码分析

pytorch采用for x in iterator模式,从Dataloader类中读取数据。

  1. 为了实现该迭代模式,在Dataloader内部实现__iter__方法,实际返回的是_DataLoaderIter类。
  2. _DataLoaderIter类里面,实现了 __iter__方法,返回自身,具体执行读数据的逻辑,在__next__方法中。

如下代码只截取了单线程下的数据读取。

class DataLoader(object):
    r"""
    Data loader. Combines a dataset and a sampler, and provides
    single- or multi-process iterators over the dataset.
    """
    def __init__(self, dataset, batch_size=1, shuffle=False, ...):
        self.dataset = dataset
        self.batch_sampler = batch_sampler
        ...
    
    def __iter__(self):
        return _DataLoaderIter(self)

    def __len__(self):
        return len(self.batch_sampler)

class _DataLoaderIter(object):
    r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
    def __init__(self, loader):
        self.sample_iter = iter(self.batch_sampler)
        ...

    def __next__(self):
        if self.num_workers == 0:  # same-process loading
            indices = next(self.sample_iter)  # may raise StopIteration
            batch = self.collate_fn([self.dataset[i] for i in indices])
            if self.pin_memory:
                batch = pin_memory_batch(batch)
            return batch
        ...

    def __iter__(self):
        return self

Dataloader类中读取数据Index的方法,采用了 for x in generator方式,可是调用采用iter和next函数

  1. 构建随机采样类RandomSampler,内部实现了 __iter__方法
  2. __iter__方法内部使用了 yield,循环遍历数据集,当数量达到batch_size大小时,就返回
  3. 实例化随机采样类,传入iter函数,返回一个迭代器
  4. next会调用随机采样类中生成器,返回相应的index数据
class RandomSampler(object):
    """random sampler to yield a mini-batch of indices."""
    def __init__(self, batch_size, dataset, drop_last=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_imgs = len(dataset)
        self.drop_last = drop_last

    def __iter__(self):
        indices = np.random.permutation(self.num_imgs)
        batch = []
        for i in indices:
            batch.append(i)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        ## if images not to yield a batch
        if len(batch)>0 and not self.drop_last:
            yield batch


    def __len__(self):
        if self.drop_last:
            return self.num_imgs // self.batch_size
        else:
            return (self.num_imgs + self.batch_size - 1) // self.batch_size

batch_sampler = RandomSampler(batch_size. dataset)
sample_iter = iter(batch_sampler)
indices = next(sample_iter)

总结

本文总结了python中循环的三种模式:

  1. for x in container 可迭代对象
  2. for x in iterator 迭代器
  3. for x in generator 生成器

pytorch中的数据加载模块 Dataloader,使用生成器来返回数据的索引,使用迭代器来返回须要的张量数据,能够在大量数据状况下,实现小批量循环迭代式的读取,避免了内存不足问题。

参考文章

相关文章
相关标签/搜索