机器学习分享——手把手带你写一个GAN

GANgit

今天让咱们从这几方面来探索:
GAN能用来作什么
GAN的原理
GAN的代码实现网络

用途ide

GAN自2014年诞生以来, 就一直备受关注, 著名的应用也随即产出, 好比比较著名的GAN的应用有Pix2Pix,CycleGAN等, 你们也将它用于各个地方。函数

  1. 缺失/模糊像素的补充
  2. 图片修复
  3. ……

我以为还有一个比较重要的用途, 不少人都会缺乏数据集, 那么就经过GAN去生成数据集了, 经过调节部分参数来进行数据集的产生的类似度。学习

原理编码

GAN的基本原理其实很是简单,这里以生成图片为例进行说明。假设咱们有两个网络,G(Generator) 和 D(Discriminator)。正如它的名字所暗示的那样, 它们的功能分别是:人工智能

G是一个生成图片的网络, 它接收一个随机的噪声(随机生成的图片)z, 经过这个噪声生成图片,记作G(z). D是一个判别网络, 判别一张图片是否是“真实的”。它的输入参数是x, x表明一张图片,输出D(x)表明x为真实图片的几率,若是为>0.5,就表明是真实(类似)的图片,反之,就表明不是真实的图片。spa

图片描述
咱们经过一个假产品宣传的例子来理解:code

首先, 咱们来定义一下角色:orm

  1. 进行宣传的'专家'(生成网络)
  2. 正在听讲的'咱们'(判别网络)

'专家'的手里面拿着一堆高仿的产品, 正在进行宣讲, 咱们是熟知真品的相关信息的, 经过去对比两个产品之间的差距, 来判断是赝品的可能性.

这时, 咱们就能够引出来一个概念, 若是'专家'团队比较厉害, 完美的仿造了咱们的判断依据, 好比说产出方, 发明日期, 说明文等等, 那么咱们就会以为他是真的, 那么他就是一个好的生成网络, 反之, 咱们会判断他是赝品.

从咱们(判别网络)出发, 咱们的判断条件越苛刻, 赝品和真品之间的差距会愈来愈小, 这样的最后的产出就是真假难分, 彻底被模仿了.

相关资源

深层的原理推荐你们能够去阅读Generative Adversarial Networks这篇论文 损失函数等相关细节咱们在实现里介绍。

实现

接下来咱们就以 mnist来实现 GAN吧.

  1. 首先, 咱们先下载数据集.

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('./Dataset/datasets/MNIST_data', one_hot=False)

咱们经过tensorflow去下载mnist的数据集, 而后加载到内存, one-hot参数决定咱们的label是否要通过编码(mnist数据集是有10个类别), 可是咱们判别网络是对比真实的和生成的之间的区别以及类似的可能性, 因此不须要执行one-hot编码了.

这里读取出来的图片已经归一化到[0, 1]之间了.

  1. 俗话说, 知己知彼, 百战百胜, 那咱们拿到数据集, 就先来看看它长什么样.

def show_images(images):

images = np.reshape(images, [images.shape[0], -1])  # images reshape to (batch_size, D)
sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))

fig = plt.figure(figsize=(sqrtn, sqrtn))
gs = gridspec.GridSpec(sqrtn, sqrtn)
gs.update(wspace=0.05, hspace=0.05)

for i, img in enumerate(images):
    ax = plt.subplot(gs[i])
    plt.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_aspect('equal')
    plt.imshow(img.reshape([sqrtimg,sqrtimg]))

图片描述

这里有一个小问题, 若是是在Notebook中执行, 记得加上这句话, 不然须要执行两次才会绘制.

%matplotlib inline

  1. 数据看过了, 咱们该对它进行必定的处理了, 这里咱们只是将数据缩放到[-1, 1]之间.

def preprocess_img(x):

return 2 * x - 1.0

def deprocess_img(x):

return (x + 1.0) / 2.0
  1. 数据处理完了, 接下来咱们要开始搭建模型了, 这一部分咱们有两个模型, 一个生成网络, 一个判别网络.

生成网络

def generator(z):

with tf.variable_scope("generator"):

    fc1 = tf.layers.dense(inputs=z, units=1024, activation=tf.nn.relu)
    bn1 = tf.layers.batch_normalization(inputs=fc1, training=True)
    fc2 = tf.layers.dense(inputs=bn1, units=7*7*128, activation=tf.nn.relu)
    bn2 = tf.layers.batch_normalization(inputs=fc2, training=True)
    reshaped = tf.reshape(bn2, shape=[-1, 7, 7, 128])
    conv_transpose1 = tf.layers.conv2d_transpose(inputs=reshaped, filters=64, kernel_size=4, strides=2, activation=tf.nn.relu,
                                                padding='same')
    bn3 = tf.layers.batch_normalization(inputs=conv_transpose1, training=True)
    conv_transpose2 = tf.layers.conv2d_transpose(inputs=bn3, filters=1, kernel_size=4, strides=2, activation=tf.nn.tanh,
                                    padding='same')

    img = tf.reshape(conv_transpose2, shape=[-1, 784])
    return img

判别网络

def discriminator(x):

with tf.variable_scope("discriminator"):

    unflatten = tf.reshape(x, shape=[-1, 28, 28, 1])
    conv1 = tf.layers.conv2d(inputs=unflatten, kernel_size=5, strides=1, filters=32 ,activation=leaky_relu)
    maxpool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=2, strides=2)
    conv2 = tf.layers.conv2d(inputs=maxpool1, kernel_size=5, strides=1, filters=64,activation=leaky_relu)
    maxpool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=2, strides=2)
    flatten = tf.reshape(maxpool2, shape=[-1, 1024])
    fc1 = tf.layers.dense(inputs=flatten, units=1024, activation=leaky_relu)
    logits = tf.layers.dense(inputs=fc1, units=1)

    return logits

激活函数咱们使用了leaky_relu, 他的代码实现是

def leaky_relu(x, alpha=0.01):

activation = tf.maximum(x,alpha*x)
return activation

它和 relu的区别就是, 小于0的值也会给与一点小的权重进行保留.

  1. 创建损失函数

def gan_loss(logits_real, logits_fake):

# Target label vector for generator loss and used in discriminator loss.
true_labels = tf.ones_like(logits_fake)

# DISCRIMINATOR loss has 2 parts: how well it classifies real images and how well it
# classifies fake images.
real_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_real, labels=true_labels)
fake_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=1-true_labels)

# Combine and average losses over the batch
D_loss = real_image_loss + fake_image_loss
D_loss = tf.reduce_mean(D_loss)

# GENERATOR is trying to make the discriminator output 1 for all its images.
# So we use our target label vector of ones for computing generator loss.
G_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=true_labels)

# Average generator loss over the batch.
G_loss = tf.reduce_mean(G_loss)

return D_loss, G_loss

损失咱们分为两部分, 一部分是生成网络的, 一部分是判别网络的.

生成网络的损失定义为, 生成图像的类别与真实标签(全是1)的交叉熵损失。
图片描述

判别网络的损失定义为, 咱们将真实图片的标签设置为1, 生成图片的标签设置为0, 而后由真实图片的输出以及生成图片的输出的交叉熵损失和.

T: True, G: Generate 生成损失

生成损失

图片描述
真实图片损失

图片描述
总损失
图片描述

  1. 训练

def run_a_gan(sess, G_train_step, G_loss, D_train_step, D_loss, G_extra_step, D_extra_step,\

show_every=250, print_every=50, batch_size=128, num_epoch=10):
# compute the number of iterations we need
max_iter = int(mnist.train.num_examples*num_epoch/batch_size)
for it in range(max_iter):
    # every show often, show a sample result
    if it % show_every == 0:
        samples = sess.run(G_sample)
        fig = show_images(samples[:16])

plt.show()

print()
    # run a batch of data through the network
    minibatch,minbatch_y = mnist.train.next_batch(batch_size)
    _1, D_loss_curr = sess.run([D_train_step, D_loss], feed_dict={x: minibatch})
    _2, G_loss_curr = sess.run([G_train_step, G_loss])
    if it % show_every == 0:
        print(_1,_2)
    # print loss every so often.
    # We want to make sure D_loss doesn't go to 0
    if it % print_every == 0:
        print('Iter: {}, D: {:.4}, G:{:.4}'.format(it,D_loss_curr,G_loss_curr))
print('Final images')
samples = sess.run(G_sample)

fig = show_images(samples[:16])

plt.show()

这里就是开始训练了, 并展现训练的结果.

  1. 查看结果 刚开始的时候, 还没学会怎么模仿:

图片描述
通过学习改进:

图片描述

图片描述

项目地址
查看源码(请在PC端打开)

声明
该文章参考了天雨粟:生成对抗网络(GAN)之MNIST数据生成。

————————————————————————————————————————————
Mo (网址:http://momodel.cn)是一个支持 Python 的人工智能建模平台,能帮助你快速开发训练并部署 AI 应用。

相关文章
相关标签/搜索