送书 | AI插画师:如何用基于PyTorch的生成对抗网络生成动漫头像?

本文由 「AI前线」原创,原文连接: 送书 | AI插画师:如何用基于PyTorch的生成对抗网络生成动漫头像?
做者|陈云
编辑|Natalie

AI 前线导读:”2016 年是属于 TensorFlow 的一年,凭借谷歌的大力推广,TensorFlow 占据了各大媒体的头条。2017 年年初,PyTorch 的横空出世吸引了研究人员极大的关注,PyTorch 简洁优雅的设计、统一易用的接口、追风逐电的速度和变化无方的灵活性给人留下深入的印象。程序员

本文节选自《深度学习框架 PyTorch 入门与实践》第 7 章,为读者讲解当前最火爆的生成对抗网络(GAN),带领读者从零开始实现一个动漫头像生成器,可以利用 GAN 生成风格多变的动漫头像。注意啦,文末有送书福利!”小程序


生成对抗网络(Generative Adversarial Net,GAN)是近年来深度学习中一个十分热门的方向,卷积网络之父、深度学习元老级人物 LeCun Yan 就曾说过“GAN is the most interesting idea in the last 10 years in machine learning”。尤为是近两年,GAN 的论文呈现井喷的趋势,GitHub 上有人收集了各类各样的 GAN 变种、应用、研究论文等,其中有名称的多达数百篇。做者还统计了 GAN 论文发表数目随时间变化的趋势,如图 7-1 所示,足见 GAN 的火爆程度。网络

图 7-1 GAN 的论文数目逐月累加图框架


GAN 的原理简介dom

GAN 的开山之做是被称为“GAN 之父”的 Ian Goodfellow 发表于 2014 年的经典论文 Generative Adversarial Networks ,在这篇论文中他提出了生成对抗网络,并设计了第一个 GAN 实验——手写数字生成。机器学习

GAN 的产生来自于一个灵机一动的想法:ide

“What I cannot create, I do not understand.”(那些我所不能创造的,我也没有真正地理解它。)
—Richard Feynman

相似地,若是深度学习不能创造图片,那么它也没有真正地理解图片。当时深度学习已经开始在各种计算机视觉领域中攻城略地,在几乎全部任务中都取得了突破。可是人们一直对神经网络的黑盒模型表示质疑,因而愈来愈多的人从可视化的角度探索卷积网络所学习的特征和特征间的组合,而 GAN 则从生成学习角度展现了神经网络的强大能力。GAN 解决了非监督学习中的著名问题:给定一批样本,训练一个系统可以生成相似的新样本函数

生成对抗网络的网络结构如图 7-2 所示,主要包含如下两个子网络。工具

  • 生成器(generator):输入一个随机噪声,生成一张图片。
  • 判别器(discriminator):判断输入的图片是真图片仍是假图片。

图 7-2 生成对抗网络结构图学习

训练判别器时,须要利用生成器生成的假图片和来自真实世界的真图片;训练生成器时,只用噪声生成假图片。判别器用来评估生成的假图片的质量,促使生成器相应地调整参数。

生成器的目标是尽量地生成以假乱真的图片,让判别器觉得这是真的图片;判别器的目标是将生成器生成的图片和真实世界的图片区分开。能够看出这两者的目标相反,在训练过程当中互相对抗,这也是它被称为生成对抗网络的缘由。

上面的描述可能有点抽象,让咱们用收藏齐白石做品(齐白石做品如图 7-3 所示)的书画收藏家和假画贩子的例子来讲明。假画贩子至关因而生成器,他们但愿可以模仿大师真迹伪造出以假乱真的假画,骗过收藏家,从而卖出高价;书画收藏家则但愿将赝品和真迹区分开,让真迹流传于世,销毁赝品。这里假画贩子和收藏家所交易的画,主要是齐白石画的虾。齐白石画虾能够说是画坛一绝,从来为世人所追捧。

图 7-3 齐白石画虾图真迹

在这个例子中,一开始假画贩子和书画收藏家都是新手,他们对真迹和赝品的概念都很模糊。假画贩子仿造出来的假画几乎都是随机涂鸦,而书画收藏家的鉴定能力不好,有很多赝品被他当成真迹,也有许多真迹被当成赝品。

首先,书画收藏家收集了一大堆市面上的赝品和齐白石大师的真迹,仔细研究对比,初步学习了画中虾的结构,明白画中的生物形状弯曲,而且有一对相似钳子的“螯足”,对于不符合这个条件的假画所有过滤掉。当收藏家用这个标准到市场上进行鉴定时,假画基本没法骗过收藏家,假画贩子损失惨重。可是假画贩子本身仿造的赝品中,仍是有一些蒙骗过关,这些蒙骗过关的赝品中都有弯曲的形状,而且有一对相似钳子的“螯足”。因而假画贩子开始修改仿造的手法,在仿造的做品中加入弯曲的形状和一对相似钳子的“螯足”。除了这些特色,其余地方例如颜色、线条都是随机画的。假画贩子制造出的初版赝品如图 7-4 所示。

图 7-4 假画贩子制造的初版赝品

当假画贩子把这些画拿到市面上去卖时,很容易就骗过了收藏家,由于画中有一只弯曲的生物,生物前面有一对相似钳子的东西,符合收藏家认定的真迹的标准,因此收藏家就把它当成真迹买回来。随着时间的推移,收藏家买回愈来愈多的假画,损失惨重,因而他又闭门研究赝品和真迹之间的区别,通过反复比较对比,他发现齐白石画虾的真迹中除了有弯曲的形状,虾的触须蔓长,通身做半透明状,而且画的虾的细节十分丰富,虾的每一节之间均呈白色状。

收藏家学成以后,从新出山,而假画贩子的仿造技法没有提高,所制造出来的赝品被收藏家轻松识破。因而假画贩子也开始尝试不一样的画虾手法,大多都是徒劳无功,不过在众多尝试之中,仍是有一些赝品骗过了收藏家的眼睛。假画贩子发现这些仿制的赝品触须蔓长,通身做半透明状,而且画的虾的细节十分丰富,如图 7-5 所示。因而假画贩子开始大量仿造这种画,并拿到市面上销售,许多都成功地骗过了收藏家。

图 7-5 假画贩子制造的第二版赝品

收藏家再度损失惨重,被迫关门研究齐白石的真迹和赝品之间的区别,学习齐白石真迹的特色,提高本身的鉴定能力。就这样,经过收藏家和假画贩子之间的博弈,收藏家从零开始慢慢提高了本身对真迹和赝品的鉴别能力,而假画贩子也不断地提升本身仿造齐白石真迹的水平。收藏家利用假画贩子提供的赝品,做为和真迹的对比,对齐白石画虾真迹有了更好的鉴赏能力;而假画贩子也不断尝试,提高仿造水平,提高仿造假画的质量,即便最后制造出来的仍属于赝品,可是和真迹相比也很接近了。收藏家和假画贩子两者之间互相博弈对抗,同时又不断促使着对方学习进步,达到共同提高的目的。

在这个例子中,假画贩子至关于一个生成器,收藏家至关于一个判别器。一开始生成器和判别器的水平都不好,由于两者都是随机初始化的。训练过程分为两步交替进行,第一步是训练判别器(只修改判别器的参数,固定生成器),目标是把真迹和赝品区分开;第二步是训练生成器(只修改生成器的参数,固定判别器),为的是生成的假画可以被判别器判别为真迹(被收藏家认为是真迹)。这两步交替进行,进而分类器和判别器都达到了一个很高的水平。训练到最后,生成器生成的虾的图片(如图 7-6 所示)和齐白石的真迹几乎没有差异。

图 7-6 生成器生成的虾

下面咱们来思考网络结构的设计。判别器的目标是判断输入的图片是真迹仍是赝品,因此能够当作是一个二分类网络,参考第 6 章中 Dog vs. Cat 的实验,咱们能够设计一个简单的卷积网络。生成器的目标是从噪声中生成一张彩色图片,这里咱们采用普遍使用的 DCGAN(Deep Convolutional Generative Adversarial Networks)结构,即采用全卷积网络,其结构如图 7-7 所示。网络的输入是一个 100 维的噪声,输出是一个 3×64×64 的图片。这里的输入能够当作是一个 100×1×1 的图片,经过上卷积慢慢增大为 4×四、8×八、16×1六、32×32 和 64×64。上卷积,或称转置卷积,是一种特殊的卷积操做,相似于卷积操做的逆运算。当卷积的 stride 为 2 时,输出相比输入会下采样到一半的尺寸;而当上卷积的 stride 为 2 时,输出会上采样到输入的两倍尺寸。这种上采样的作法能够理解为图片的信息保存于 100 个向量之中,神经网络根据这 100 个向量描述的信息,前几步的上采样先勾勒出轮廓、色调等基础信息,后几步上采样慢慢完善细节。网络越深,细节越详细。

图 7-7 DCGAN 中生成器网络结构图

在 DCGAN 中,判别器的结构和生成器对称:生成器中采用上采样的卷积,判别器中就采用下采样的卷积,生成器是根据噪声输出一张 64×64×3 的图片,而判别器则是根据输入的 64×64×3 的图片输出图片属于正负样本的分数(几率)。


用 GAN 生成动漫头像

本节将用 GAN 实现一个生成动漫人物头像的例子。在日本的技术博客网站上 有个博主(估计是一位二次元的爱好者),利用 DCGAN 从 20 万张动漫头像中学习,最终可以利用程序自动生成动漫头像,生成的图片效果如图 7-8 所示。源程序是利用 Chainer 框架实现的,本节咱们尝试利用 PyTorch 实现。

图 7-8 DCGAN 生成的动漫头像

原始的图片是从网站中爬取的,并利用 OpenCV 从中截取头像,处理起来比较麻烦。这里咱们使用知乎用户何之源爬取并通过处理的 5 万张图片。能够从本书配套程序的 README.MD 的百度网盘连接下载全部的图片压缩包,并解压缩到指定的文件夹中。须要注意的是,这里图片的分辨率是 3×96×96,而不是论文中的 3×64×64,所以须要相应地调整网络结构,使生成图像的尺寸为 96。

咱们首先来看本实验的代码结构。

接着来看 model.py 中是如何定义生成器的。

能够看出生成器的搭建相对比较简单,直接使用 nn.Sequential 将上卷积、激活、池化等操做拼接起来便可,这里须要注意上卷积 ConvTransposed2d 的使用。当 kernel size 为 四、stride 为 二、padding 为 1 时,根据公式 H_out=(H_in-1)*stride-2*padding+kernel_size,输出尺寸恰好变成输入的两倍。最后一层采用 kernel size 为 五、stride 为 三、padding 为 1,是为了将 32×32 上采样到 96×96,这是本例中图片的尺寸,与论文中 64×64 的尺寸不同。最后一层用 Tanh 将输出图片的像素归一化至 -1~1,若是但愿归一化至 0~1,则需使用 Sigmoid。接着咱们来看判别器的网络结构。

能够看出判别器和生成器的网络结构几乎是对称的,从卷积核大小到 padding、stride 等设置,几乎如出一辙。例如生成器的最后一个卷积层的尺度是(5,3,1),判别器的第一个卷积层的尺度也是(5,3,1)。另外,这里须要注意的是生成器的激活函数用的是 ReLU,而判别器使用的是 LeakyReLU,两者并没有本质区别,这里的选择更可能是经验总结。每个样本通过判别器后,输出一个 0~1 的数,表示这个样本是真图片的几率。在开始写训练函数前,先来看看模型的配置参数。

这些只是模型的默认参数,还能够利用 Fire 等工具经过命令行传入,覆盖默认值。另外,咱们也能够直接使用 opt.attr,还能够利用 IDE/IPython 提供的自动补全功能,十分方便。这里的超参数设置大可能是照搬 DCGAN 论文的默认值,做者通过大量实验,发现这些参数可以更快地训练出一个不错的模型。

当咱们下载完数据以后,须要将全部图片放在一个文件夹,而后将该文件夹移动至 data 目录下(请确保 data 下没有其余的文件夹)。这种处理方式是为了可以直接使用 torchvision 自带的 ImageFolder 读取图片,而没必要本身写 Dataset。数据读取与加载的代码以下:

可见,用 ImageFolder 配合 DataLoader 加载图片十分方便。

在进行训练以前,咱们还须要定义几个变量:模型、优化器、噪声等。


在加载预训练模型时,最好指定 map_location。由于若是程序以前在 GPU 上运行,那么模型就会被存成 torch.cuda.Tensor,这样加载时会默认将数据加载至显存。若是运行该程序的计算机中没有 GPU,加载就会报错,故经过指定 map_location 将 Tensor 默认加载入内存中,待有须要时再移至显存中。

下面开始训练网络,训练步骤以下。

(1)训练判别器。

  • 固定生成器
  • 对于真图片,判别器的输出几率值尽量接近 1
  • 对于生成器生成的假图片,判别器尽量输出 0

(2)训练生成器。

  • 固定判别器
  • 生成器生成图片,尽量让判别器输出 1

(3)返回第一步,循环交替训练。

这里须要注意如下几点。

  • 训练生成器时,无须调整判别器的参数;训练判别器时,无须调整生成器的参数。
  • 在训练判别器时,须要对生成器生成的图片用 detach 操做进行计算图截断,避免反向传播将梯度传到生成器中。由于在训练判别器时咱们不须要训练生成器,也就不须要生成器的梯度。
  • 在训练分类器时,须要反向传播两次,一次是但愿把真图片判为 1,一次是但愿把假图片判为 0。也能够将这二者的数据放到一个 batch 中,进行一次前向传播和一次反向传播便可。可是人们发现,在一个 batch 中只包含真图片或只包含假图片的作法最好。
  • 对于假图片,在训练判别器时,咱们但愿它输出为 0;而在训练生成器时,咱们但愿它输出为 1。所以能够看到一对看似矛盾的代码:error_d_fake = criterion(fake_output, fake_labels) 和 error_g = criterion(fake_output, true_labels)。其实这也很好理解,判别器但愿可以把假图片判别为 fake_label,而生成器则但愿能把它判别为 true_label,判别器和生成器互相对抗提高。

接下来就是一些可视化的代码。每次可视化使用的噪声都是固定的 fix_noises,由于这样便于咱们比较对于相同的输入,生成器生成的图片是如何一步步提高的。另外,因为咱们对输入的图片进行了归一化处理(-1~1),在可视化时则须要将它还原成原来的 scale(0~1) 。

除此以外,还提供了一个函数,能加载预训练好的模型,并利用噪声随机生成图片。

完整的代码请参考本书的附带样例代码 chapter7/AnimeGAN。参照 README.MD 中的指南配置环境,并准备好数据,然后用以下命令便可开始训练:

若是使用 visdom 的话,此时打开 http://[your ip]:8097 就能看到生成的图像。

训练完成后,咱们能够利用生成网络随机生成动漫头像,输入命令以下:


实验结果分析

实验结果如图 7-9 所示,分别是训练 1 个、10 个、20 个、30 个、40 个、200 个 epoch 以后神经网络生成的动漫头像。须要注意的是,每次生成器输入的噪声都是同样的,因此咱们能够对比在相同的输入下,生成图片的质量是如何慢慢改善的。

刚开始生成的图像比较模糊(1 个 epoch),可是能够看出图像已经有面部轮廓。

继续训练 10 个 epoch 以后,生成的图多了不少细节信息,包括头发、颜色等,可是整体仍是很模糊。

训练 20 个 epoch 以后,细节继续完善,包括头发的纹理、眼睛的细节等,但仍是有很多涂抹的痕迹。

训练到第 40 个 epoch 时,已经能看出明显的面部轮廓和细节,但仍是有涂抹现象,而且有些细节不够合理,例如眼睛一大一小,面部的轮廓扭曲严重。

当训练到 200 个 epoch 以后,图片的细节已经十分完善,线条更流畅,轮廓更清晰,虽然还有一些不合理之处,可是已经有很多图片可以以假乱真了。

图 7-9 GAN 生成的动漫头像

相似的生成动漫头像的项目还有“用 DRGAN 生成高清的动漫头像”,效果如图 7-10 所示。但遗憾的是,因为论文中使用的数据涉及版权问题,未能公开。这篇论文的主要改进包括使用了更高质量的图片数据和更深、更复杂的模型。

图 7-10 用 DRGAN 生成的动漫头像

本章讲解的样例程序还能够应用到不一样的生成图片场景中,只要将训练图片改为其余类型的图片便可,例如 LSUN 客房图片集、MNIST 手写数据集或 CIFAR10 数据集等。事实上,上述模型还有很大的改进空间。在这里,咱们使用的全卷积网络只有四层,模型比较浅,而在 ResNet 的论文发表以后,也有很多研究者尝试在 GAN 的网络结构中引入 Residual Block 结构,并取得了不错的视觉效果。感兴趣的读者能够尝试将示例代码中的单层卷积修改成 Residual Block,相信能够取得不错的效果。

近年来,GAN 的一个重大突破在于理论研究。论文 Towards Principled Methods for Training Generative Adversarial Networks 从理论的角度分析了 GAN 为什么难以训练,做者随后在另外一篇论文 Wasserstein GAN 中针对性地提出了一个更好的解决方案。可是 Wasserstein GAN 这篇论文在部分技术细节上的实现过于随意,因此随后又有人有针对性地提出 Improved Training of Wasserstein GANs,更好地训练 WGAN。后面两篇论文分别用 PyTorch 和 TensorFlow 实现,代码能够从 GitHub 上搜索到。笔者当初也尝试用 100 行左右的代码实现了 Wasserstein GAN,感兴趣的读者能够去了解 。

随着 GAN 研究的逐渐成熟,人们也尝试把 GAN 用于工业实际问题之中,而在众多相关论文中,最使人印象深入的就是 Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks ,论文中提出了一种新的 GAN 结构称为 CycleGAN。CycleGAN 利用 GAN 实现风格迁移、黑白图像彩色化,以及马和斑马相互转化等,效果十分出众。论文的做者用 PyTorch 实现了全部代码,并开源在 GitHub 上,感兴趣的读者能够自行查阅。

本章主要介绍 GAN 的基本原理,并带领读者利用 GAN 生成动漫头像。GAN 有许多变种,GitHub 上有许多利用 PyTorch 实现的各类 GAN,感兴趣的读者能够自行查阅。

做者介绍

陈云,Python 程序员、Linux 爱好者和 PyTorch 源码贡献者。主要研究方向包括计算机视觉和机器学习。“2017 知乎看山杯机器学习挑战赛”一等奖,“2017 天池医疗 AI 大赛”第八名。热衷于推广 PyTorch,并有丰富的使用经验,活跃于 PyTorch 论坛和知乎相关板块。

福利!福利!咱们将给 AI 前线的粉丝送出《深度学习框架 PyTorch 入门与实践》纸质书籍 10 本!在本文下方留言给出你想要这本书的理由,咱们会邀请你加入赠书群,本次获奖名单由抽奖小程序随机抽取,2 月 6 日(周二)上午 10 点开奖,获奖者每人得到一本。另附京东购买地址,戳「阅读原文」!

更多干货内容,可关注AI前线,ID:ai-front,后台回复「AI」、「TF」、「大数据」可得到《AI前线》系列PDF迷你书和技能图谱。

相关文章
相关标签/搜索