在上一篇文章《实战生成对抗网络[2]:生成手写数字》中,咱们使用了简单的神经网络来生成手写数字,能够看出手写数字字形,但不够完美,生成的手写数字有些毛糙,边缘不够平滑。git
生成对抗网络中,生成器和判别器是一对冤家。要提升生成器的水平,就要提升判别器的识别能力。在《一步步提升手写数字的识别率(3)》系列文章中,咱们探讨了如何提升手写数字的识别率,发现卷积神经网络在图像处理方面优点巨大,最后采用卷积神经网络模型,达到一个不错的识别率。天然的,为了提升生成对抗网络的手写数字生成质量,咱们是否也能够采用卷积神经网络呢?github
答案是确定的,不过和《一步步提升手写数字的识别率(3)》中随便采用一个卷积神经网络结构是不够的,由于生成对抗网络中,有两个神经网络模型互相对抗,随便选择网络结构,容易在迭代过程当中引发振荡,难以收敛。web
好在有专家学者进行了这方面的研究,下面就介绍一篇由Alec Radford、Luke Metz和Soumith Chintala合做完成的论文 arXiv: 1511.06434, 《利用深度卷积生成对抗网络进行无监督表征学习(Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks)》。bash
论文给出了生成器的模型结构,以下图所示:网络
从图中能够看,该网络采用100x1噪声向量(随机输入),表示为z,并将其映射到G(Z)输出,即64x64x3,其变换过程为:架构
100x1 → 1024x4x4 → 512x8x8 → 256x16x16 → 128x32x32 → 64x64x3post
若是采用keras实现上述模型,很是简单。不过须要注意的是,在本文中探讨的手写数字生成,其最终输出是28 x 28 x 1的灰度图片,因此咱们沿袭上面的模型架构,但在具体实现上作一些调整:学习
100x1 → 1024x1 → 128x7x7 → 128x14x14 → 14x14x64 → 28x28x64 → 8x28x1ui
代码以下:spa
def generator_model():
model = Sequential()
model.add(Dense(input_dim=100, output_dim=1024))
model.add(Activation('tanh'))
model.add(Dense(128 * 7 * 7))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(Reshape((7, 7, 128), input_shape=(128 * 7 * 7,)))
model.add(UpSampling2D(size=(2, 2)))
model.add(Conv2D(64, (5, 5), padding='same'))
model.add(Activation('tanh'))
model.add(UpSampling2D(size=(2, 2)))
model.add(Conv2D(1, (5, 5), padding='same'))
model.add(Activation('tanh'))
return model
复制代码
代码中引入了批量规则化(BatchNormalization),在实践中被证明能够在许多场合提高训练速度,减小初始化不佳带来的问题而且一般能产生准确的结果。上采样则是用来扩大维度。
判别器的实现差很少是将上述生成器模型倒过来实现,但使用最大池化代替了上采样,代码以下:
def discriminator_model():
model = Sequential()
model.add(
Conv2D(64, (5, 5),
padding='same',
input_shape=(28, 28, 1))
)
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(128, (5, 5)))
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(1024))
model.add(Activation('tanh'))
model.add(Dense(1))
model.add(Activation('sigmoid'))
return model
复制代码
在论文中,做者建议经过下面一些架构性的约束来固化网络:
上述代码并无彻底遵照做者的建议,可见在面对不一样的场景,开发者能够有本身的发挥。事实上,在GANs in Action这本书中,做者也给出了手写数字生成的另一种DCGAN模型,代码可参考:github.com/GANs-in-Act…
通过100个epoch的迭代,咱们的代码生成的手写数字以下图所示,虽然有些数字生成得不太准确,不过相对于上一篇文章的输出,边缘仍是要平滑一些,效果也有所改进:
本文所演示内容的完整代码,请参考:github.com/mogoweb/aie…