from:https://www.leiphone.com/news/201707/1JEkcUZI1leAFq5L.html |
Generative Adversarial Network,就是你们耳熟能详的 GAN,由 Ian Goodfellow 首先提出,在这两年更是深度学习中最热门的东西,仿佛什么东西都能由 GAN 作出来。我最近刚入门 GAN,看了些资料,作一些笔记。html
什么是生成(generation)?就是模型经过学习一些数据,而后生成相似的数据。让机器看一些动物图片,而后本身来产生动物的图片,这就是生成。网络
之前就有不少能够用来生成的技术了,好比 auto-encoder(自编码器),结构以下图:iphone
你训练一个 encoder,把 input 转换成 code,而后训练一个 decoder,把 code 转换成一个 image,而后计算获得的 image 和 input 之间的 MSE(mean square error),训练完这个 model 以后,取出后半部分 NN Decoder,输入一个随机的 code,就能 generate 一个 image。机器学习
可是 auto-encoder 生成 image 的效果,固然看着很别扭啦,一眼就能看出真假。因此后来还提出了好比VAE这样的生成模型,我对此也不是很了解,在这就不细说。函数
上述的这些生成模型,其实有一个很是严重的弊端。好比 VAE,它生成的 image 是但愿和 input 越类似越好,可是 model 是如何来衡量这个类似呢?model 会计算一个 loss,采用的大可能是 MSE,即每个像素上的均方差。loss 小真的表示类似嘛?学习
好比这两张图,第一张,咱们认为是好的生成图片,第二张是差的生成图片,可是对于上述的 model 来讲,这两张图片计算出来的 loss 是同样大的,因此会认为是同样好的图片。编码
这就是上述生成模型的弊端,用来衡量生成图片好坏的标准并不能很好的完成想要实现的目的。因而就有了下面要讲的 GAN。3d
大名鼎鼎的 GAN 是如何生成图片的呢?首先你们都知道 GAN 有两个网络,一个是 generator,一个是 discriminator,从二人零和博弈中受启发,经过两个网络互相对抗来达到最好的生成效果。流程以下:code
主要流程相似上面这个图。首先,有一个一代的 generator,它能生成一些不好的图片,而后有一个一代的 discriminator,它能准确的把生成的图片,和真实的图片分类,简而言之,这个 discriminator 就是一个二分类器,对生成的图片输出 0,对真实的图片输出 1。视频
接着,开始训练出二代的 generator,它能生成稍好一点的图片,可以让一代的 discriminator 认为这些生成的图片是真实的图片。而后会训练出一个二代的 discriminator,它能准确的识别出真实的图片,和二代 generator 生成的图片。以此类推,会有三代,四代。。。n 代的 generator 和 discriminator,最后 discriminator 没法分辨生成的图片和真实图片,这个网络就拟合了。
这就是 GAN,运行过程就是这么的简单。这就结束了嘛?显然没有,下面还要介绍一下 GAN 的原理。
首先咱们知道真实图片集的分布 Pdata(x),x 是一个真实图片,能够想象成一个向量,这个向量集合的分布就是 Pdata。咱们须要生成一些也在这个分布内的图片,若是直接就是这个分布的话,怕是作不到的。
咱们如今有的 generator 生成的分布能够假设为 PG(x;θ),这是一个由 θ 控制的分布,θ 是这个分布的参数(若是是高斯混合模型,那么 θ 就是每一个高斯分布的平均值和方差)
假设咱们在真实分布中取出一些数据,{x1, x2, ... , xm},咱们想要计算一个似然 PG(xi; θ)。
对于这些数据,在生成模型中的似然就是
咱们想要最大化这个似然,等价于让 generator 生成那些真实图片的几率最大。这就变成了一个最大似然估计的问题了,咱们须要找到一个 θ* 来最大化这个似然。
寻找一个 θ* 来最大化这个似然,等价于最大化 log 似然。由于此时这 m 个数据,是从真实分布中取的,因此也就约等于,真实分布中的全部 x 在 PG 分布中的 log 似然的指望。
真实分布中的全部 x 的指望,等价于求几率积分,因此能够转化成积分运算,由于减号后面的项和 θ 无关,因此添上以后仍是等价的。而后提出共有的项,括号内的反转,max 变 min,就能够转化为 KL divergence 的形式了,KL divergence 描述的是两个几率分布之间的差别。
因此最大化似然,让 generator 最大几率的生成真实图片,也就是要找一个 θ 让 PG 更接近于 Pdata。
那如何来找这个最合理的 θ 呢?咱们能够假设 PG(x; θ) 是一个神经网络。
首先随机一个向量 z,经过 G(z)=x 这个网络,生成图片 x,那么咱们如何比较两个分布是否类似呢?只要咱们取一组 sample z,这组 z 符合一个分布,那么经过网络就能够生成另外一个分布 PG,而后来比较与真实分布 Pdata。
你们都知道,神经网络只要有非线性激活函数,就能够去拟合任意的函数,那么分布也是同样,因此能够用一直正态分布,或者高斯分布,取样去训练一个神经网络,学习到一个很复杂的分布。
如何来找到更接近的分布,这就是 GAN 的贡献了。先给出 GAN 的公式:
这个式子的好处在于,固定 G,max V(G,D) 就表示 PG 和 Pdata 之间的差别,而后要找一个最好的 G,让这个最大值最小,也就是两个分布之间的差别最小。
表面上看这个的意思是,D 要让这个式子尽量的大,也就是对于 x 是真实分布中,D(x) 要接近与 1,对于 x 来自于生成的分布,D(x) 要接近于 0,而后 G 要让式子尽量的小,让来自于生成分布中的 x,D(x) 尽量的接近 1。
如今咱们先固定 G,来求解最优的 D:
对于一个给定的 x,获得最优的 D 如上图,范围在 (0,1) 内,把最优的 D 带入
能够获得:
JS divergence 是 KL divergence 的对称平滑版本,表示了两个分布之间的差别,这个推导就代表了上面所说的,固定 G。
表示两个分布之间的差别,最小值是 -2log2,最大值为 0。
如今咱们须要找个 G,来最小化
观察上式,当 PG(x)=Pdata(x) 时,G 是最优的。
有了上面推导的基础以后,咱们就能够开始训练 GAN 了。结合咱们开头说的,两个网络交替训练,咱们能够在起初有一个 G0 和 D0,先训练 D0 找到 :
而后固定 D0 开始训练 G0, 训练的过程均可以使用 gradient descent,以此类推,训练 D1,G1,D2,G2,...
可是这里有个问题就是,你可能在 D0* 的位置取到了:
而后更新 G0 为 G1,可能
了,可是并不保证会出现一个新的点 D1* 使得
这样更新 G 就没达到它原来应该要的效果,以下图所示:
避免上述状况的方法就是更新 G 的时候,不要更新 G 太多。
知道了网络的训练顺序,咱们还须要设定两个 loss function,一个是 D 的 loss,一个是 G 的 loss。下面是整个 GAN 的训练具体步骤:
上述步骤在机器学习和深度学习中也是很是常见,易于理解。
可是上面 G 的 loss function 仍是有一点小问题,下图是两个函数的图像:
log(1-D(x)) 是咱们计算时 G 的 loss function,可是咱们发现,在 D(x) 接近于 0 的时候,这个函数十分平滑,梯度很是的小。这就会致使,在训练的初期,G 想要骗过 D,变化十分的缓慢,而上面的函数,趋势和下面的是同样的,都是递减的。可是它的优点是在 D(x) 接近 0 的时候,梯度很大,有利于训练,在 D(x) 愈来愈大以后,梯度减少,这也很符合实际,在初期应该训练速度更快,到后期速度减慢。
因此咱们把 G 的 loss function 修改成
这样能够提升训练的速度。
还有一个问题,在其余 paper 中提出,就是通过实验发现,通过许屡次训练,loss 一直都是平的,也就是
JS divergence 一直都是 log2,PG 和 Pdata 彻底没有交集,可是实际上两个分布是有交集的,形成这个的缘由是由于,咱们没法真正计算指望和积分,只能使用 sample 的方法,若是训练的过拟合了,D 仍是可以彻底把两部分的点分开,以下图:
对于这个问题,咱们是否应该让 D 变得弱一点,减弱它的分类能力,可是从理论上讲,为了让它可以有效的区分真假图片,咱们又但愿它可以 powerful,因此这里就产生了矛盾。
还有可能的缘由是,虽然两个分布都是高维的,可是两个分布都十分的窄,可能交集至关小,这样也会致使 JS divergence 算出来 =log2,约等于没有交集。
解决的一些方法,有添加噪声,让两个分布变得更宽,可能能够增大它们的交集,这样 JS divergence 就能够计算,可是随着时间变化,噪声须要逐渐变小。
还有一个问题叫 Mode Collapse,以下图:
这个图的意思是,data 的分布是一个双峰的,可是学习到的生成分布却只有单峰,咱们能够看到模型学到的数据,可是殊不知道它没有学到的分布。
形成这个状况的缘由是,KL divergence 里的两个分布写反了
这个图很清楚的显示了,若是是第一个 KL divergence 的写法,为了防止出现无穷大,因此有 Pdata 出现的地方都必需要有 PG 覆盖,就不会出现 Mode Collapse。
这是对 GAN 入门学习作的一些笔记和理解,后来太懒了,不想打公式了,主要是参考了李宏毅老师的视频: