变分自编码器(Variational Auto-Encoder,VAE)

最近看论文看到变分自编码器,发现它也可以用于数据增强,就仔细了解了一下,把比较好的讲解资料和自己的想法整理一下,以备用。

  • 经典论文
    Auto-Encoding Variational Bayes(还没看,据说很经典)
  • 详细介绍
    Tutorial - What is a variational autoencoder?(从神经网络和图模型两个方面来讲解)
    变分自编码器(一):原来是这么一回事(写的特别好,看完这篇基本可以了解)
  • 变分编码器在NLU数据增强中的应用
    AAAI2019:Data Augmentation for Spoken Language Understanding via
    Joint Variational Generation
  • 自己的理解
    变分编码器类似自编码器,只不过是在中间添加了噪音。从数据增强的角度来说,增加噪音可以提高生成的数据的多样性,所以变分编码器自身的特点非常适用于数据增强。下面从一张图来解读变分编码器:
    在这里插入图片描述我们想利用现有的样本去生成一些新的样本,从而扩大数据量。如果我们能知道原始数据集x的分布p(x),那么我们直接从这个分布中采样就可以得到新的样本。而在现实中,我们无法得到数据集的真实分布,于是变分编码器利用了一个隐分布z,通过z来生成新的样本,从概率角度来说就是p(x) = p(z)×p(x|z)。通过构造隐变量,让我们可以近似等于从原始分布中采样来组成新样本,而效果的好坏就靠模型来实现了。
    首先想到的是用神经网络来训练参数,那么当训练好模型后,我们输入x,可以生成新的样本x’,且他们近似都是从原始数据集分布中采样出来的,故可以参与训练其他模型,也就完成了数据增强。那么第一个损失函数就应运而生,即比较生成的x’和原始样本x的差距,差距越小效果越好。
    但是这样自然是不可行的,因为既然训练的目的是使输出样本和原样本相似度最高,那么我们直接拷贝一份原数据就行了,就不需要训练了。也就是说,数据增强的目的不仅要保证增强的数据都服从和源数据一样的分布,而且要有多样性,即有一些源数据没有的特征,这样才能达到数据增强的目的。
    于是我们考虑采样,因为采样的数据服从同一分布,但又各不相同,就可以满足多样性这一要求。
    考虑从隐分布z来进行数据采样,一个很自然的想法就是采用标准正态分布,但是如果对于所有的样本,训练的时候我们都从z中进行采样,然后用于训练,那么也就和原始的输入数据没什么关系了,即采样出的z不知道是由哪一个原始数据x生成的,这样也就没法进行损失计算,也就没法训练神经网络了。于是考虑如何把生成的数据和原始数据能一一对应起来,变分编码器让每个样本都服从一个正态分布,但这个正态分布又各不相同,即正态分布的两个参数均值和方差各不相同,那么这样输入输出就可以一一对应了。
    那么如何学习输入的每个样本服从的正态分布的均值和方差呢,这个时候就要用到神经网络了,一切不好计算的都用神经网络去拟合,于是对于每个输入样本经过一层或多层神经网络,分别输出均值和方差,然后再从服从此均值和方差的正态分布中去随机采样生成新的样本。
    看似已经可以完成数据增强的目的了,但是仔细想想问题还有很多。首先既然损失函数是输出样本和输入样本的相似度,那么网络肯定往减少输出样本和输入样本的差异上去拟合,我们记得输出样本是通过在正态分布上采样得到的,我们知道正态分布的方差越大,则数据越分散,采样得到的数据噪声越大,导致输出样本和输入样本的相似度很小,于是网络训练会尽量减小正态分布的方差,直到接近0,那么还是那个问题,方差接近0,则生成的数据的多样性又减少了。所以这个问题还是没有解决。
    变分编码器使神经网络学到的正态分布尽量趋近于标准正态分布,即学到的均值和方差分别趋近于0和1。如何做到呢,那就是再增加一个损失函数,这里也可以看出,损失函数对于神经网络训练的重要性。增加的损失函数是衡量网路学习到的正态分布和标准正态分布的差异性,即使用了KL散度。这样,神经网路一方面要学习输入输出的相似性,还要考虑增加一定的噪音,即不使正态分布的方差趋紧于0,也就增加了输出的多样性。
    这就是变分编码器的大体思想,还有一些细节可以参考上述资料。最后再说一下变分编码器的一个实现技巧–重参数技巧(reparameterization trick)。
    重参数技巧
    上面说到,新的样本是从一个正态分布中采样出来的,但是神经网络训练的时候,采样这个操作是没法反向传播的,即无法对参数求导。举个例子,z = wx+b,y对x求导是可以的,但如果 μ \mu = wx+b是正态分布的一个参数,例如均值,然后我从服从这些参数的正态分布中采样出一个值z,这时,y怎么对w求导呢?关键就在于中间缺少了等号。于是就用到了重参数技巧。
    我们都知道,z如果服从正态分布 N ( μ , σ 2 ) N(\mu,{\sigma}^2) ,那么 ε = z μ σ \varepsilon =\frac{z-\mu}{\sigma} 就服从标准正态分布。于是, z = ε σ + μ z=\varepsilon *\sigma+\mu ,这样,我们就有一个等式了, μ \mu σ \sigma 都是神经网络训练出来的,即包含了训练的参数w,我们只需每次从标准正态分布中采样 ε \varepsilon 即可,这样也就可以进行求导和反向传播了。

其实变分编码器里面也是有对抗的思想,这一点和对抗神经网络特别像。
变分编码器是生成样本和输入样本尽可能相似,和隐变量分布尽可能靠近标准正态分布两者的对抗。对于隐变量z的构造,其实就是给源数据添加了噪音,正态分布方差越大,采样的噪音越大,生成的数据可能越不像原始数据,而生成的数据和原始数据越相似,则正态分布方差又会减小,从而导致和标准正态分布差异变大。通过两种损失函数相加做为最终损失函数,变分编码器能生成和输入样本尽可能同分布又尽可能多样化的数据。具体看参考资料,讲解得已经很详细了。