不要怂,就是GAN (生成式对抗网络) (六):Wasserstein GAN(WGAN) TensorFlow 代码

先来梳理一下咱们以前所写的代码,原始的生成对抗网络,所要优化的目标函数为:git

 此目标函数能够分为两部分来看:网络

①固定生成器 G,优化判别器 D, 则上式能够写成以下形式: 函数

 

能够转化为最小化形式: 优化

咱们编写的代码中,d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = D_logits, labels = tf.ones_like(D))),因为咱们判别器最后一层是 sigmoid ,因此能够看出来 d_loss_real 是上式中的第一项(舍去常数几率 1/2),d_loss_fake 为上式中的第二项。spa

②固定判别器 D,优化生成器 G,舍去前面的常数,至关于最小化:code

也至关于最小化:blog

咱们的代码中,g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = D_logits_, labels = tf.ones_like(D))),完美对应上式。it

 

 

接下来开始咱们的 WGAN 之旅,正如 https://zhuanlan.zhihu.com/p/25071913 所介绍的,咱们要构建一个判别器 D,使得 D 的参数不超过某个固定的常数,最后一层是非线性层,而且使式子:class

达到最大,那么 L 就能够做为咱们的 Wasserstein 距离,生成器的目标是最小化这个距离,去掉第一项与生成器无关的项,获得咱们生成器的损失函数。咱们能够把上式加个负号,做为 D 的损失函数,其中加负号后的第一项,是 d_loss_real,加负号后的第二项,是 d_loss_fake。循环

 

 

下面开始码代码:

为了方便,咱们直接在上一节咱们的 none_cond_DCGAN.py 文件中修改相应的代码:

在开头的宏定义中加入:

CLIP = [-0.01, 0.01]
CRITIC_NUM = 5

 如图:

注释掉原来 discriminator 的 return,从新输入一个 return 以下:

在 train 函数里面,修改以下地方:

在循环里面,要改以下地方,这里稍微作一下说明,idx < 25 时 D 循环更新 25 次才会更新 G,用来保证 D 的网络大体知足 Wasserstein 距离,这是一个小小的 trick。

 

改完以后点击运行进行训练,WGAN 收敛速度很快,大约一千屡次迭代的时候,生成网络生成的图像已经很像了,最后生成的图像以下,能够看到,图像仍是有些噪点和坏点的。

 

最后的最后,贴一张网络的 Graph:

 

 

参考文献:

1. https://zhuanlan.zhihu.com/p/25071913

相关文章
相关标签/搜索