图像翻译——pix2pix模型

1.介绍

图像处理、计算机图形学和计算机视觉中的许多问题均可以被视为将输入图像“翻译”成相应的输出图像。 “翻译”经常使用于语言之间的翻译,好比中文和英文的之间的翻译。但图像翻译的意思是图像与图像之间以不一样形式的转换。好比:一个图像场景能够以RGB图像、梯度场、边缘映射、语义标签映射等形式呈现,其效果以下图。html

image.png

传统图像转换过程当中都是针对具体问题采用特定算法去解决;而这些过程的本质都是根据像素点(输入信息)对像素点作出预测(predict from pixels to pixels),Pix2pix的目标就是创建一个通用的架构去解决以上全部的图像翻译问题,使得咱们没必要要为每一个功能都从新设计一个损失函数。前端

2. 核心思想

2.1 图像建模的结构化损失

图像到图像的翻译问题一般是根据像素分类或回归来解决的。这些公式将输出空间视为**“非结构化”**,即在给定输入图像的状况下,每一个输出像素被视为与全部其余像素有条件地独立。而cGANs( conditional-GAN)的不一样之处在于学习结构化损失,而且理论上能够惩罚输出和目标之间的任何可能结构。python

2.2 cGAN

在此以前,许多研究者使用 GAN 在修复、将来状态预测、用户约束引导的图像处理、风格迁移和超分辨率方面取得了使人瞩目的成果,但每种方法都是针对特定应用而定制的。Pix2pix框架不一样之处在于没有特定应用。它在生成器和判别器的几种架构选择中也与先前的工做不一样。对于生成器,咱们使用基于“U-Net”的架构;对于鉴别器,咱们使用卷积“PatchGAN”分类器,其仅在image patches(图片小块)的尺度上惩罚结构。git

Pix2pix 是借鉴了 cGAN 的思想。cGAN 在输入 G 网络的时候不光会输入噪音,还会输入一个条件(condition),G 网络生成的 fake images 会受到具体的 condition 的影响。那么若是把一副图像做为 condition,则生成的 fake images  就与这个 condition images 有对应关系,从而实现了一个 Image-to-Image Translation  的过程。Pixpix 原理图以下:github

image.png

Pix2pix 的网络结构如上图所示,生成器 G 用到的是 U-Net 结构,输入的轮廓图x编码再解码成真实图片,判别器 D 用到的是做者本身提出来的条件判别器 PatchGAN ,判别器 D 的做用是在轮廓图 x的条件下,对于生成的图片G(x)判断为假,对于真实图片判断为真。算法

2.3 cGAN 与 Pix2pix 对比

image.png

2.4 损失函数

通常的 cGANs 的目标函数以下:网络

L_{cGAN}(G, D) =E_{x,y}[log D(x, y)]+E_{x,z}[log(1 − D(x, G(x, z))]

其中 G 试图最小化目标而 D 则试图最大化目标,即:\rm G^∗ =arg; min_G; max_D ;L_{cGAN}(G, D)架构

为了作对比,同时再去训练一个普通的 GAN ,即只让 D 判断是否为真实图像。app

\rm L_{cGAN}(G, D) = E_y[log D(y)]+ E_{x,z}[log(1 − D(G(x, z))]

对于图像翻译任务而言,G 的输入和输出之间其实共享了不少信息,好比图像上色任务、输入和输出之间就共享了边信息。于是为了保证输入图像和输出图像之间的类似度、还加入了 L1 Loss:框架

\rm L_{L1}(G) = E_{x,y,z}[||y − G(x, z)||_1]

生成的 fake images 与 真实的 real images 之间的 L1 距离,(imgB**'** 和imgB)保证了输入和输出图像的类似度。

最终的损失函数:

\rm G^∗ = arg;\underset{G}{min};\underset{D}{max}; L_{cGAN}(G, D) + λL_{L1}(G)

3.网络架构(网络体系结构)

生成器和判别器都使用模块 convolution-BatchNorm-ReLu

3.1 生成网络G

图像到图像翻译问题的一个定义特征是它们将高分辨率输入网格映射到高分辨率输出网格。 另外,对于咱们考虑的问题,输入和输出的表面外观不一样,但二者应该共享一些信息。 所以,输入中的结构与输出中的结构大体对齐。 咱们围绕这些考虑设计了生成器架构。

image.png

U-Net 结构基于 Encoder-Decoder 模型,而 Encoder 和 Decoder 是对称结构。 U-Net 的不一样之处是将第 i 层和第 n-i 层链接起来,其中 n 是层的总数,这种链接方式称为跳过链接(skip connections)。第 i 层和第 n-i 层的图像大小是一致的,能够认为他们承载着相似的信息 。

3.2 判别网络 D

用损失函数 L1 和 L2 重建的图像很模糊,也就是说L1和L2并不能很好的恢复图像的高频部分(图像中的边缘等),但能较好地恢复图像的低频部分(图像中的色块)。

图像的高低频是对图像各个位置之间强度变化的一种度量方法,低频份量:主要对整副图像的强度的综合度量。高频份量:主要是对图像边缘和轮廓的度量。若是一副图像的各个位置的强度大小相等,则图像只存在低频份量,从图像的频谱图上看,只有一个主峰,且位于频率为零的位置。若是一副图像的各个位置的强度变化剧烈,则图像不只存在低频份量,同时也存在多种高频份量,从图像的频谱上看,不只有一个主峰,同时也存在多个旁峰。

为了能更好得对图像的局部作判断,Pix2pix 判别网络采用 patchGAN 结构,也就是说把图像等分红多个固定大小的 Patch,分别判断每一个Patch的真假,最后再取平均值做为 D 最后的输出。这样作的好处:

  • D  的输入变小,计算量小,训练速度快。
  • 由于 G 自己是全卷积的,对图像尺度没有限制。而D若是是按照Patch去处理图像,也对图像大小没有限制。就会让整个 Pix2pix 框架对图像大小没有限制,增大了框架的扩展性。

论文中将 PatchGAN 当作另外一种形式的纹理损失或样式损失。在具体实验时,采用不一样尺寸的 patch,发现 70x70 的尺寸比较合适。

3.3 优化和推理

训练使用的是标准的方法:交替训练 D 和 G;并使用了 minibatch SGD 和 Adam 优化器。

在推理的时候,咱们用训练阶段相同的方式来运行生成器。在测试阶段使用 dropout 和 batch normalization,这里咱们使用 test batch 的统计值而不是 train batch 的。

4.源码解读

该部分主要是解读论文源码:github.com/junyanz/pyt… 。

image.png

  • 文件 train:

通用的训练脚本,能够经过传参指定训练不一样的模型和不一样的数据集。

--model: e.g.,pix2pix,cyclegan,colorization

--dataset_mode: e.g.,aligned,unaligned,single,colorization)

  • 文件test:

通用的测试脚本,经过传参来加载模型 -- checkpoints_dir,保存输出的结果 --results_dir

4.1 文件夹data:

该目录中的文件包含数据的加载和处理以及用户可制做本身的数据集。下面详细说明data下的文件:

  • __init__.py: 实现包和train、test脚本之间的接口。train.py 和 test.py 根据给定的 opt 选项调包来建立数据集 from data import create_datasetdataset = create_dataset(opt)
  • **base_dataset.py:**继承了 torch 的 dataset 类和抽象基类,该文件还包括了一些经常使用的图片转换方法,方便后续子类使用。
  • **image_folder.py:**更改了官方pytorch的image folder的代码,使得从当前目录和子目录都能加载图片。
  • **template_dataset.py:**为制做本身数据集提供了模板和参考,里面注释一些细节信息。
  • aligned_dataset.py 和 **unaligned_dataset.py:**区别在于前者从同一个文件夹中加载的是一对图片 {A,B} ,后者是从两个不一样的文件夹下分别加载 {A},{B} 。
  • **single_dataset.py:**只加载指定路径下的一张图片。
  • **colorization_dataset.py:**加载一张 RGB 图片并转化成(L,ab)对在 Lab 彩色空间,pix2pix用来绘制彩色模型。

4.2 文件夹models:

models 包含的模块有:目标函数,优化器,网络架构。下面详细说明models下的文件:

  • __init__.py: 为了实现包和train、test脚本之间的接口。train.py 和 test.py 根据给定的 opt 选项调包来建立模型 from models import create_model 和 model = create_model(opt)
  • base_model.py: 继承了抽象类,也包括一些其余经常使用的函数:setuptestupdate_learning_ratesave_networksload_networks,在子类中会被使用。
  • template_model.py: 实现本身模型的一个模板,里面注释了一些细节。
  • pix2pix_model.py: 实现了pix2pix 模型,模型训练数据集--dataset_mode aligned,默认状况下--netG unet256 --netD basicdiscriminator (PatchGAN)。 --gan_mode vanillaGAN loss (标准交叉熵)。
  • **colorization_model.py:**继承了pix2pix_model,模型所作的是:将黑白图片映射为彩色图片。-dataset_model colorization dataset。默认状况下,colorization dataset会自动设置--input_nc 1and--output_nc 2
  • **cycle_gan_model.py:**来实现cyclegan模型。--dataset_mode unaligneddataset,--netG resnet_9blocksResNet generator,--netD basicdiscriminator (PatchGAN introduced by pix2pix),a least-square GANsobjective(--gan_mode lsgan )
  • **networks.py:**包含生成器和判别器的网络架构,normalization layers,初始化方法,优化器结构(learning rate policy)GAN的目标函数(vanilla,lsgan,wgangp)。
  • **test_model.py:**用来生成cyclegan的结果,该模型自动设置--dataset_mode single

4.3 文件夹options:

包含训练模块,测试模块的设置TrainOptions和TestOptions都是 BaseOptions的子类。详细说明options下的文件。

  • **init.py:**该文件起到让python解释器将options文件夹当作包来处理。
  • **base_options.py:**除了training,test都用到的option,还有一些helper 方法:parsing,printing,saving options。
  • **train_options.py:**训练须要的options。
  • test_options.py:测试须要的options。

4.4 文件夹utils:

主要包含一些有用的工具,如数据的可视化。详细说明utils下的文件:

  • **init.py:**该文件起到让python解释器将utils文件夹当作包来处理。
  • **get_data.py:**用来下载数据集的脚本。
  • **html.py:**保存图片写成html。基于diminate中的DOM API。
  • **image_pool.py:**实现一个缓冲来存放以前生成的图片。
  • **visualizer.py:**保存图片,展现图片。
  • **utils.py:**包含一些辅助函数:tensor2numpy转换,mkdir诊断网络梯度等。

5. 总结与展望

5.1 pix2pix的优缺点

Pix2pix模型是 x到y之间的一对一映射**。也就说,pix2pix就是对ground truth的重建:输入轮廓图→通过Unet编码解码成对应的向量→解码成真实图。这种一对一映射的应用范围十分有限,当咱们输入的数据与训练集中的数据差距较大时,生成的结果极可能就没有意义,这就要求咱们的数据集中要尽可能涵盖各类类型。

本文将Pix2Pix论文中的全部要点都表述了出来,主要包括:

  • cGAN,输入为图像而不是随机向量
  • U-Net,使用skip-connection来共享更多的信息
  • Pair输入到D来保证映射
  • Patch-D来下降计算量提高效果
  • L1损失函数的加入来保证输入和输出之间的一致性

5.2 总结

目前,您能够在  Mo 平台的应用中心中找到 pix2pixGAN,能够体验论文实验部分图像建筑标签→照片( Architectural labels→photo),即将您绘制的建筑图片草图生成为你心目中的小屋 。您在学习的过程当中,遇到困难或者发现咱们的错误,能够随时联系咱们。

经过本文,您应该初步了解Pix2pix模型的网络结构和实现原理,以及关键部分代码的初步实现。若是您对深度学习tensorflow比较了解,能够参考tensorflow版实现Pix2pix;若是您对pytorch框架比较熟悉,能够参考pytorch实现Pix2pix;若是您想更深刻的学习了解starGAN原理,能够参考论文

6.参考:

1.论文:arxiv.org/pdf/1611.07…

2.Pix2pix官网:phillipi.github.io/pix2pix/

3.代码PyTorch版本:github.com/phillipi/pi…

4.代码tensorflow版本:github.com/yenchenlin/…

5.代码tensorflow版本:github.com/affinelayer…

6.知乎:zhuanlan.zhihu.com/p/38411618

7.知乎:zhuanlan.zhihu.com/p/55059359

8.博客:blog.csdn.net/qq_16137569…

9.博客:blog.csdn.net/infinita_LV…

10.博客:blog.csdn.net/weixin_3647…

关于咱们

Mo(网址:momodel.cn)是一个支持 Python 的人工智能在线建模平台,能帮助你快速开发、训练并部署模型。


Mo 人工智能俱乐部 是由网站的研发与产品设计团队发起、致力于下降人工智能开发与使用门槛的俱乐部。团队具有大数据处理分析、可视化与数据建模经验,已承担多领域智能项目,具有从底层到前端的全线设计开发能力。主要研究方向为大数据管理分析与人工智能技术,并以此来促进数据驱动的科学研究。

目前俱乐部每周六在杭州举办以机器学习为主题的线下技术沙龙活动,不按期进行论文分享与学术交流。但愿能汇聚来自各行各业对人工智能感兴趣的朋友,不断交流共同成长,推进人工智能民主化、应用普及化。

image.png
相关文章
相关标签/搜索