摘要: 最通俗的GAN网络介绍!后端
在本教程中,你将了解什么是生成敌对网络(GAN),而且在整个过程当中不涉及负责的数学细节。以后,你还将学习如何编写一个能够建立数字的简单GAN!网络
理解GAN的最简单方法是经过一个简单的比喻:dom
假设有一家商店它们从顾客那里购买某些种类的葡萄酒,用于之后再销售。ide
然而,有些恶意的顾客为了得到金钱而出售假酒。在这种状况下,店主必须可以区分假酒和正品葡萄酒。函数
你能够想象,最初,伪造者在尝试出售假酒时可能会犯不少错误,而且店主很容易认定该酒不是真的。因为这些失败,伪造者会继续尝试使用不一样的技术来模拟真正的葡萄酒,最终才有可能成功。如今,伪造者知道某些技术已经超过了店主的认识假酒的能力,他能够开始进一步生产基于这些技术的假酒。学习
同时,店主可能会从其余店主或葡萄酒专家那里获得一些反馈,说明他拥有的一些葡萄酒不是原装的。这意味着店主必须改善他是如何肯定葡萄酒是伪造的仍是真实的。伪造者的目标是制造与真实葡萄酒没法区分的葡萄酒,而店主的目标是准确地分辨葡萄酒是否真实。优化
这种来回的竞争博弈就是GAN网络背后的主要思想。ui
用上面的例子,咱们能够想出一个GAN的体系结构。阿里云
GAN网络中有两个主要组件:生成器和鉴别器。这个例子中的店主被称为鉴别器网络,而且一般是卷积神经网络(由于GAN主要用于图像任务),其主要功能是判断图像是真实的几率。编码
伪造者被称为生成网络,而且一般也是卷积神经网络(具备解卷积层)。该网络须要一些噪声矢量并输出图像。在训练生成网络时,它会学习图像的哪些区域进行改进/更改,以便鉴别器将难以将其生成的图像与真实图像区分开来。
生成网络不断生成更接近真实图像的图像,而辨别网络试图肯定真实图像和假图像之间的差别。最终的目标是创建一个可生成与真实图像没法区分的图像的生成网络。
如今你已经了解了GAN是什么以及它们的主要组成部分,如今咱们能够开始编写一个很是简单的代码。本教程将使用Keras,若是你不熟悉此Python库,则应在继续以前阅读翻译小组其余文章。本教程是基于这里开发的很是酷且易于理解的GAN。
你须要作的第一件事是经过如下方式安装如下软件包pip:
- keras - matplotlib - tensorflow - tqdm
你将matplotlib用于绘制tensorflow——Keras后端库,并用tqdm为每一个时期(迭代)显示一个奇特的进度条。
下一步是建立一个Python脚本。在这个脚本中,你首先须要导入你将要使用的全部模块和函数,在使用它们时将给出每一个解释。
import os import numpy as np import matplotlib.pyplot as plt from tqdm import tqdm from keras.layers import Input from keras.models import Model, Sequential from keras.layers.core import Dense, Dropout from keras.layers.advanced_activations import LeakyReLU from keras.datasets import mnist from keras.optimizers import Adam from keras import initializers
你如今想要设置一些变量值:
# Let Keras know that we are using tensorflow as our backend engine os.environ["KERAS_BACKEND"] = "tensorflow" # To make sure that we can reproduce the experiment and get the same results np.random.seed(10) # The dimension of our random noise vector. random_dim = 100
在开始构建鉴别器和生成器以前,你应该首先收集并预处理数据。你将使用如今最流行的MNIST数据集,该数据集具备一组从0到9范围内的单个数字的图像。
def load_minst_data(): # load the data (x_train, y_train), (x_test, y_test) = mnist.load_data() # normalize our inputs to be in the range[-1, 1] x_train = (x_train.astype(np.float32) - 127.5)/127.5 # convert x_train with a shape of (60000, 28, 28) to (60000, 784) so we have # 784 columns per row x_train = x_train.reshape(60000, 784) return (x_train, y_train, x_test, y_test)
请注意,mnist.load_data()这个函数是Keras的一部分,它容许你轻松将MNIST数据集导入你的工做区。
如今,你能够建立你的生成器和鉴别器网络。你能够为这两个网络使用Adam优化器。对于生成器和鉴别器,你将建立一个带有三个隐藏层的神经网络,激活函数为Leaky Relu。你还应该为鉴别器添加Drop-out图层,以提升其对未见图像的鲁棒性。
def get_optimizer(): return Adam(lr=0.0002, beta_1=0.5) def get_generator(optimizer): generator = Sequential() generator.add(Dense(256, input_dim=random_dim, kernel_initializer=initializers.RandomNormal(stddev=0.02))) generator.add(LeakyReLU(0.2)) generator.add(Dense(512)) generator.add(LeakyReLU(0.2)) generator.add(Dense(1024)) generator.add(LeakyReLU(0.2)) generator.add(Dense(784, activation='tanh')) generator.compile(loss='binary_crossentropy', optimizer=optimizer) return generator def get_discriminator(optimizer): discriminator = Sequential() discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02))) discriminator.add(LeakyReLU(0.2)) discriminator.add(Dropout(0.3)) discriminator.add(Dense(512)) discriminator.add(LeakyReLU(0.2)) discriminator.add(Dropout(0.3)) discriminator.add(Dense(256)) discriminator.add(LeakyReLU(0.2)) discriminator.add(Dropout(0.3)) discriminator.add(Dense(1, activation='sigmoid')) discriminator.compile(loss='binary_crossentropy', optimizer=optimizer) return discriminator
终于到了将生成器和鉴别器放在一块儿的时候了!
def get_gan_network(discriminator, random_dim, generator, optimizer): # We initially set trainable to False since we only want to train either the # generator or discriminator at a time discriminator.trainable = False # gan input (noise) will be 100-dimensional vectors gan_input = Input(shape=(random_dim,)) # the output of the generator (an image) x = generator(gan_input) # get the output of the discriminator (probability if the image is real or not) gan_output = discriminator(x) gan = Model(inputs=gan_input, outputs=gan_output) gan.compile(loss='binary_crossentropy', optimizer=optimizer) return gan
为了保持整个过程的完整性,你能够建立一个功能,每20个纪元保存你生成的图像。因为这不是本教程的核心,因此你不须要彻底理解该功能。
def plot_generated_images(epoch, generator, examples=100, dim=(10, 10), figsize=(10, 10)): noise = np.random.normal(0, 1, size=[examples, random_dim]) generated_images = generator.predict(noise) generated_images = generated_images.reshape(examples, 28, 28) plt.figure(figsize=figsize) for i in range(generated_images.shape[0]): plt.subplot(dim[0], dim[1], i+1) plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r') plt.axis('off') plt.tight_layout() plt.savefig('gan_generated_image_epoch_%d.png' % epoch)
你如今已经编码了大部分网络,剩下的就是训练这个网络,并看看你建立的图像。
def train(epochs=1, batch_size=128): # Get the training and testing data x_train, y_train, x_test, y_test = load_minst_data() # Split the training data into batches of size 128 batch_count = x_train.shape[0] / batch_size # Build our GAN netowrk adam = get_optimizer() generator = get_generator(adam) discriminator = get_discriminator(adam) gan = get_gan_network(discriminator, random_dim, generator, adam) for e in xrange(1, epochs+1): print '-'*15, 'Epoch %d' % e, '-'*15 for _ in tqdm(xrange(batch_count)): # Get a random set of input noise and images noise = np.random.normal(0, 1, size=[batch_size, random_dim]) image_batch = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)] # Generate fake MNIST images generated_images = generator.predict(noise) X = np.concatenate([image_batch, generated_images]) # Labels for generated and real data y_dis = np.zeros(2*batch_size) # One-sided label smoothing y_dis[:batch_size] = 0.9 # Train discriminator discriminator.trainable = True discriminator.train_on_batch(X, y_dis) # Train generator noise = np.random.normal(0, 1, size=[batch_size, random_dim]) y_gen = np.ones(batch_size) discriminator.trainable = False gan.train_on_batch(noise, y_gen) if e == 1 or e % 20 == 0: plot_generated_images(e, generator) if __name__ == '__main__': train(400, 128)
训练400个纪元后,你能够查看生成的图像。查看第一个纪元后产生的图像,能够看到它没有任何真实的结构,在40个纪元后查看图像,数字开始成形,最后,400个纪元后产生的图像显示出清晰的数字,尽管是一对夫妇仍然没法辨认。
1纪元(左)后的结果40个纪元后(中)的结果400个时代后的结果(右)
此代码在CPU上每一个纪元大约须要2分钟,这是选择此代码的主要缘由。你能够尝试使用更多的纪元,并经过向生成器和鉴别器添加更多(和不一样的)图层。可是,当使用更复杂和更深的体系结构时,若是仅使用CPU,则运行时也会增长。
恭喜,你已经完成了本教程的最后部分,你已经以直观的方式学习生成敌对网络(GAN)的基础知识!
本文由@阿里云云栖社区组织翻译。
文章原标题《demystifying-generative-adversarial-networks》,
译者:虎说八道,审校:袁虎。
本文为云栖社区原创内容,未经容许不得转载