【机器学习PAI实战】—— 玩转人工智能之利用GAN自动生成二次元头像

前言

深度学习做为人工智能的重要手段,迎来了爆发,在NLP、CV、物联网、无人机等多个领域都发挥了很是重要的做用。最近几年,各类深度学习算法层出不穷, Generative Adverarial Network(GAN)自2014年提出以来,引发普遍关注,身为深度学习三巨头之一的Yan Lecun对GAN的评价颇高,认为GAN是近年来在深度学习上最大的突破,是近十年来机器学习上最有意思的工做。围绕GAN的论文数量也迅速增多,各类版本的GAN出现,主要在CV领域带来了一些贡献,以下图所示。python

咱们能够利用GAN生成一些咱们须要的图像或者文本,好比二次元头像。git

GAN简介

GAN主要的应用是自动生成一些东西,包括图像和文本等,好比随机给一个向量做为输入,经过GAN的Generator生成一张图片,或者生成一串语句。Conditional GAN的应用更多一些,好比数据集是一段文字和图像的数据对,经过训练,GAN能够经过给定一段文字生成对应的图像。github

GAN主要能够分为Generator(生成器)和Discriminator(判别器)两个部分,其中Generator其实就是一个神经网络,输入一个向量,能够输出一张图像(即一个高维的向量表示),以下图示。算法

​Discriminator也是一个神经网络,输入为一张图像,输出为一个数值,输出的数值用于判断输入的图像是不是真的,数值越大,说明图像是真的,数值越小,说明图像为假的,以下图示。网络

​Generator负责生成图像,Discriminator负责对Generator生成的图像和真实图像去进行对比,区别出真假,Generator须要不断优化来欺骗Discriminator,以假乱真;而Discriminator也不断优化,来提升识别能力,可以识别出Generator的把戏。两者的这种关系能够形象地经过下图展现。框架

Generator和Discriminator链接起来,造成一个比较大的深层网络,即为GAN网络。机器学习

场景描述

深度学习的各类算法在PAI上能够经过PAI-DSW进行实现,在PAI-DSW上进行训练数据,利用GAN自动生成二次元头像。学习

数据准备

首先须要准备真实的二次元头像做为数据集,这里从网上找到一些共享的资源,存储在了钉钉钉盘中,钉盘地址 ,提取密码: c2pz,数据集以下图示,约5万多张:大数据

算法实践

利用PAI-DSW进行GAN算法实践,首先须要安装准备好环境。优化

首先进入到Notebook建模,建立新实例,以后打开实例,进入Terminal,在Terminal下用户能够像在本身本地同样安装相应的依赖包,进行操做。

准备好环境以后,咱们能够经过以下图示方法,将基于Tensorflow的DCGAN代码和数据集上传上去。 ​

用于训练的DCGAN代码地址:https://github.com/carpedm20/DCGAN-tensorflow,关于DCGAN的网络框架图以下,详细介绍能够参考论文:https://arxiv.org/abs/1511.06434,这里咱们不作详述。

数据集和代码上传成功,以下图示。

其中,data目录下的faces即为数据集,该文件夹下为对应的5万多张真实二次元头像。DCGAN-tensorflow为整个代码路径,其中最主要的两个代码文件是main.py和model.py,其中最主要的核心代码以下。

def main(_):
  pp.pprint(flags.FLAGS.__flags)

  if FLAGS.input_width is None:
    FLAGS.input_width = FLAGS.input_height
  if FLAGS.output_width is None:
    FLAGS.output_width = FLAGS.output_height

  if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)
  if not os.path.exists(FLAGS.sample_dir):
    os.makedirs(FLAGS.sample_dir)

  #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
  run_config = tf.ConfigProto()
  run_config.gpu_options.allow_growth=True

  with tf.Session(config=run_config) as sess:
    if FLAGS.dataset == 'mnist':
      dcgan = DCGAN(
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          y_dim=10,
          z_dim=FLAGS.generate_test_images,
          dataset_name=FLAGS.dataset,
          input_fname_pattern=FLAGS.input_fname_pattern,
          crop=FLAGS.crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir,
          data_dir=FLAGS.data_dir)
    else:
      dcgan = DCGAN(
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          z_dim=FLAGS.generate_test_images,
          dataset_name=FLAGS.dataset,
          input_fname_pattern=FLAGS.input_fname_pattern,
          crop=FLAGS.crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir,
          data_dir=FLAGS.data_dir)

    show_all_variables()

    if FLAGS.train:
      dcgan.train(FLAGS)
else:
          # Update D network
          _, summary_str = self.sess.run([d_optim, self.d_sum],
            feed_dict={ self.inputs: batch_images, self.z: batch_z })
          self.writer.add_summary(summary_str, counter)

          # Update G network
          _, summary_str = self.sess.run([g_optim, self.g_sum],
            feed_dict={ self.z: batch_z })
          self.writer.add_summary(summary_str, counter)

          # Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
          _, summary_str = self.sess.run([g_optim, self.g_sum],
            feed_dict={ self.z: batch_z })
          self.writer.add_summary(summary_str, counter)

          errD_fake = self.d_loss_fake.eval({ self.z: batch_z })
          errD_real = self.d_loss_real.eval({ self.inputs: batch_images })
          errG = self.g_loss.eval({self.z: batch_z})

一切就绪以后,咱们执行命令进行训练,调用命令以下:

​python main.py --input_height 96 --input_width 96 --output_height 48 --output_width 48 --dataset faces --crop --train --epoch 300 --input_fname_pattern "*.jpg"

其中,参数dateset指定数据集的目录,epoch指定循环迭代的次数,input_height、input_width用于指定输入文件的大小,输出文件的大小一样也须要参数设定,代码执行过程以下图示:​

咱们来看下执行结果,分别看一下epoch为1,30,100的时候生成的二次元头像效果图。

epoch=1

epoch=30

epoch=100​

咱们发现,随着不断迭代,生成的二次元头像也愈来愈逼真。

总结

经过上面的实践,咱们领略到了GAN的魅力,GAN的变种有不少,除此以外咱们还能够利用GAN作很是多的有意思的事情,好比经过文字生成图像,经过简单文字生成宣传海报等。PAI-DSW像是一个练武场,为咱们准备好了深度学习所须要的环境和条件,让咱们能够尽情享受大数据和深度学习的乐趣,除了GAN,像比较火热的Bert等模型,咱们也均可以试一试。

 



本文做者:不等_赵振才

原文连接

本文为云栖社区原创内容,未经容许不得转载。

相关文章
相关标签/搜索