机器学习分享——手把手带你写一个GAN

GANgit

今天让咱们从这几方面来探索:网络

GAN能用来作什么 GAN的原理 GAN的代码实现ide

用途函数

4年诞生以来, 就一直备受关注, 著名的应用也随即产出, 好比比较著名的GAN的应用有Pix2Pix,CycleGAN等, 你们也将它用于各个地方。学习

  1. 缺失/模糊像素的补充编码

  2. 图片修复人工智能

  3. ……spa

我以为还有一个比较重要的用途, 不少人都会缺乏数据集, 那么就经过GAN去生成数据集了, 经过调节部分参数来进行数据集的产生的类似度。code

原理orm

GAN的基本原理其实很是简单,这里以生成图片为例进行说明。假设咱们有两个网络,G(Generator) 和 D(Discriminator)。正如它的名字所暗示的那样, 它们的功能分别是:

G是一个生成图片的网络, 它接收一个随机的噪声(随机生成的图片)z, 经过这个噪声生成图片,记作G(z). D是一个判别网络, 判别一张图片是否是“真实的”。它的输入参数是x, x表明一张图片,输出D(x)表明x为真实图片的几率,若是为>0.5,就表明是真实(类似)的图片,反之,就表明不是真实的图片。

咱们经过一个假产品宣传的例子来理解:

首先, 咱们来定义一下角色:

  1. 进行宣传的'专家'(生成网络)

  2. 正在听讲的'咱们'(判别网络)

'专家'的手里面拿着一堆高仿的产品, 正在进行宣讲, 咱们是熟知真品的相关信息的, 经过去对比两个产品之间的差距, 来判断是赝品的可能性.

这时, 咱们就能够引出来一个概念, 若是'专家'团队比较厉害, 完美的仿造了咱们的判断依据, 好比说产出方, 发明日期, 说明文等等, 那么咱们就会以为他是真的, 那么他就是一个好的生成网络, 反之, 咱们会判断他是赝品.

从咱们(判别网络)出发, 咱们的判断条件越苛刻, 赝品和真品之间的差距会愈来愈小, 这样的最后的产出就是真假难分, 彻底被模仿了.

相关资源

深层的原理推荐你们能够去阅读Generative Adversarial Networks这篇论文 损失函数等相关细节咱们在实现里介绍。

实现

接下来咱们就以 mnist来实现 GAN吧.

  1. 首先, 咱们先下载数据集.

from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('./Dataset/datasets/MNIST_data', one_hot=False)

咱们经过tensorflow去下载mnist的数据集, 而后加载到内存, one-hot参数决定咱们的label是否要通过编码(mnist数据集是有10个类别), 可是咱们判别网络是对比真实的和生成的之间的区别以及类似的可能性, 因此不须要执行one-hot编码了.

这里读取出来的图片已经归一化到[0, 1]之间了.

  1. 俗话说, 知己知彼, 百战百胜, 那咱们拿到数据集, 就先来看看它长什么样.

def show_images(images): images = np.reshape(images, [images.shape[0], -1]) # images reshape to (batch_size, D) sqrtn = int(np.ceil(np.sqrt(images.shape[0]))) sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))

fig = plt.figure(figsize=(sqrtn, sqrtn))
gs = gridspec.GridSpec(sqrtn, sqrtn)
gs.update(wspace=0.05, hspace=0.05)

for i, img in enumerate(images):
    ax = plt.subplot(gs[i])
    plt.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_aspect('equal')
    plt.imshow(img.reshape([sqrtimg,sqrtimg]))
复制代码

这里有一个小问题, 若是是在Notebook中执行, 记得加上这句话, 不然须要执行两次才会绘制.

%matplotlib inline

  1. 数据看过了, 咱们该对它进行必定的处理了, 这里咱们只是将数据缩放到[-1, 1]之间.

def preprocess_img(x): return 2 * x - 1.0

def deprocess_img(x): return (x + 1.0) / 2.0

  1. 数据处理完了, 接下来咱们要开始搭建模型了, 这一部分咱们有两个模型, 一个生成网络, 一个判别网络.

生成网络

def generator(z):

with tf.variable_scope("generator"):

    fc1 = tf.layers.dense(inputs=z, units=1024, activation=tf.nn.relu)
    bn1 = tf.layers.batch_normalization(inputs=fc1, training=True)
    fc2 = tf.layers.dense(inputs=bn1, units=7*7*128, activation=tf.nn.relu)
    bn2 = tf.layers.batch_normalization(inputs=fc2, training=True)
    reshaped = tf.reshape(bn2, shape=[-1, 7, 7, 128])
    conv_transpose1 = tf.layers.conv2d_transpose(inputs=reshaped, filters=64, kernel_size=4, strides=2, activation=tf.nn.relu,
                                                padding='same')
    bn3 = tf.layers.batch_normalization(inputs=conv_transpose1, training=True)
    conv_transpose2 = tf.layers.conv2d_transpose(inputs=bn3, filters=1, kernel_size=4, strides=2, activation=tf.nn.tanh,
                                    padding='same')

    img = tf.reshape(conv_transpose2, shape=[-1, 784])
    return img
复制代码

判别网络

def discriminator(x):

with tf.variable_scope("discriminator"):

    unflatten = tf.reshape(x, shape=[-1, 28, 28, 1])
    conv1 = tf.layers.conv2d(inputs=unflatten, kernel_size=5, strides=1, filters=32 ,activation=leaky_relu)
    maxpool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=2, strides=2)
    conv2 = tf.layers.conv2d(inputs=maxpool1, kernel_size=5, strides=1, filters=64,activation=leaky_relu)
    maxpool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=2, strides=2)
    flatten = tf.reshape(maxpool2, shape=[-1, 1024])
    fc1 = tf.layers.dense(inputs=flatten, units=1024, activation=leaky_relu)
    logits = tf.layers.dense(inputs=fc1, units=1)

    return logits
复制代码

激活函数咱们使用了leaky_relu, 他的代码实现是

def leaky_relu(x, alpha=0.01): activation = tf.maximum(x,alpha*x) return activation 它和 relu的区别就是, 小于0的值也会给与一点小的权重进行保留.

  1. 创建损失函数

def gan_loss(logits_real, logits_fake):

# Target label vector for generator loss and used in discriminator loss.
true_labels = tf.ones_like(logits_fake)

# DISCRIMINATOR loss has 2 parts: how well it classifies real images and how well it
# classifies fake images.
real_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_real, labels=true_labels)
fake_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=1-true_labels)

# Combine and average losses over the batch
D_loss = real_image_loss + fake_image_loss
D_loss = tf.reduce_mean(D_loss)

# GENERATOR is trying to make the discriminator output 1 for all its images.
# So we use our target label vector of ones for computing generator loss.
G_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=true_labels)

# Average generator loss over the batch.
G_loss = tf.reduce_mean(G_loss)

return D_loss, G_loss
复制代码

损失咱们分为两部分, 一部分是生成网络的, 一部分是判别网络的.

生成网络的损失定义为, 生成图像的类别与真实标签(全是1)的交叉熵损失。

判别网络的损失定义为, 咱们将真实图片的标签设置为1, 生成图片的标签设置为0, 而后由真实图片的输出以及生成图片的输出的交叉熵损失和.

T: True, G: Generate 生成损失

生成损失

真实图片损失

总损失

6. 训练

def run_a_gan(sess, G_train_step, G_loss, D_train_step, D_loss, G_extra_step, D_extra_step,
show_every=250, print_every=50, batch_size=128, num_epoch=10): # compute the number of iterations we need max_iter = int(mnist.train.num_examples*num_epoch/batch_size) for it in range(max_iter): # every show often, show a sample result if it % show_every == 0: samples = sess.run(G_sample) fig = show_images(samples[:16]) plt.show() print() # run a batch of data through the network minibatch,minbatch_y = mnist.train.next_batch(batch_size) _1, D_loss_curr = sess.run([D_train_step, D_loss], feed_dict={x: minibatch}) _2, G_loss_curr = sess.run([G_train_step, G_loss]) if it % show_every == 0: print(_1,_2) # print loss every so often. # We want to make sure D_loss doesn't go to 0 if it % print_every == 0: print('Iter: {}, D: {:.4}, G:{:.4}'.format(it,D_loss_curr,G_loss_curr)) print('Final images') samples = sess.run(G_sample)

fig = show_images(samples[:16])
复制代码

plt.show()

这里就是开始训练了, 并展现训练的结果.

  1. 查看结果 刚开始的时候, 还没学会怎么模仿:

通过学习改进:

项目地址 查看源码(请在PC端打开)

声明 该文章参考了天雨粟:生成对抗网络(GAN)之MNIST数据生成。

————————————————————————————————————————

Mo (网址:momodel.cn)是一个支持 Python 的人工智能建模平台,能帮助你快速开发训练并部署 AI 应用。

相关文章
相关标签/搜索