text to image(三):《Learning What and Where to Draw》

       继续介绍文本生成图像的工作,本文给出的是发表于NIPS 2016的文章《Learning What and Where to Draw》。这篇文章的源码是用torch写的,不是很熟悉,所以就不配合源码解析了.这篇博客主要是参考https://zhuanlan.zhihu.com/p/34379810给出的部分文章翻译,加一些自己的理解.

        论文地址:https://arxiv.org/abs/1610.02454

        源码地址:https://github.com/reedscot/nips2016

 

一、相关工作

        本文是《Generative Adversarial Text to Image Synthesis》和《Learning Deep Representations of Fine-Grained Visual Descriptions》的续作。

        对GAN的相关理解:https://blog.csdn.net/zlrai5895/article/details/80648898

 

二、基本思想及成果

        文章提出了一个新模型,即Generative Adversarial What-Where Network(GAWWN),该网络通过给出的在哪个位置绘制什么内容的说明来生成图像。

        以文本描述和对象位置为条件在Caltech-UCSD Birds数据集上展示生成高质量的128×128图像。 系统还能够以部分为条件(例如,只有喙和尾部)。

        需要注意的是,实验中在训练文本编码器的时候使用到了Caltech-UCSD Birds的标签(鸟的类别)。但是现实生活中大多图片(COCO数据集)不会给每张图片中的场景配一个标签。所以一方面该模型有一定的局限性,另一方面,生成的图像分辨率也比较低(128*128)

 

三、 数据集

本次实验使用的数据集是加利福尼亚理工学院鸟类数据库-2011(CUB_200_2011)。

 

四、模型结构:

    1、文本和图像编码器

    我们通过优化下面这个损失来得到最终的sentence embedding。

 

                                             

是训练集。是0-1损失。是图像集,是图像集对应的文本描述,是类别标签。被如下定义:

可以看到它们最终的输出是预测的标签。

是图像编码器(比如一个神经网络)是文本编码器。是y类别文本描述的集合,是y类别图像的集合。为了得到一个更好的text_encoder,对Reed et al.之前的工作(《Learning Deep Representations of Fine-Grained Visual Descriptions》)做了一些改变。

(1)不再使用char-CNN-RNN,而是使用char-CNN-GRU

(2)对于每幅图像提取出来的4个caption向量,求它们的平均值。

 

图像编码器和文本编码器可以和后面的GAN一起训练,但是单独训练可以使用从224*224的图像所提取的特征(分辨率比较高),并且可以加速GAN的训练。

2、Bounding-box-conditional text-to-image model

(1)生成器

由文本编码器器,可以得到对应的text_embedding

首先,将文本嵌入(以绿色显示)在空间上复制以形成M×M×T特征映射,然后在空间上进行扭曲以适应归一化的边界框坐标。 框外的功能图条目全部为零。 该图显示了单个对象,但在多个caption的情况下,这些特征地图取平均。 然后,应用卷积和合并操作将空间维度降低到1×1。直观地,该特征向量对图像中的粗糙空间结构进行编码,并且将其与噪声向量z连接起来。

下面分为两分支:

       1)Global处理

全局路径只是一系列stride-2反卷积,将空间维数从1×1增加到M×M。

        2)local处理

在局部路径中,当达到空间维度M×M时,应用掩蔽操作,使得对象边界框外的区域被设置为0.

最后,局部路径和全局路径通过深度级联合并。最后一系列的去卷积层用于达到最终的空间维度。在最后一层,我们应用Tanh非线性将输出限制为[-1,1]。

(2)鉴别器

首先是文本嵌入向量(t)在空间上被类似地复制以形成M×M×T张量。

图像的处理仍然包括了gobal和local两部分。

        1)local

       在局部路径中,图像通过stride-2卷积被馈送到M×M空间维度,在该点处它与刚才得到的张量级联。得到的张量被空间裁剪到边界框坐标内,并进一步进行卷积处理,直到空间维度为1×1。

         2)global

       全局路径简单地包含向下卷积,附加文本嵌入t。

最后,局部和全局路径输出向量相加合并并馈送到最终层,产生标量判别器分数。

3、Keypoint-conditional text-to-image model

 

(1)生成器

如图左半部分所示:

位置关键点被编码成M×M×K的空间特征图(关键点张量),其中通道对应于该部分; 即通道1中的头部,通道2中的左脚,等等。

关键点张量被输入到网络的几个阶段。如图所示:

首先,它通过stride-2卷积馈送以产生与噪声z和文本嵌入t串联的向量。 结果矢量提供关于内容和part位置的粗略信息。

其次,关键点张量被平化为二进制矩阵,其中1表示在特定空间位置存在的任何部分,然后在深度方向上复制为尺寸为M×M×H的张量。留着在后续使用。

之前的得到的串联向量仍然被分为两个阶段:

        1)local

噪声文本特征点矢量通过反卷积来产生另一个M×M×H张量。通过与相同尺寸的关键点张量逐点相乘来对局部通路激活进行门控。得到局部张量。

        2)global

噪声文本特征点矢量通过反卷积来产生另一个M×M×H张量。得到全局张量。

最后,关键点张量、局部张量、全局张量级联并进一步进行反卷积处理以产生最终图像。再次应用Tanh非线性。

 

(2)鉴别器

首先,文本嵌入向量(t)被空间复制到M×M备用。

生成器生成的图像仍然分为local 路径和global路径

        1)local

图像stride2卷积到M*M,然后与刚才得到的文本嵌入t复制得到的M*M块级联,得到一个局部张量。这个局部张量然后与发生器中的二元关键点掩模(之前M*M*H的那个)进行乘法门控,并且得到的张量与M×M×T关键点深度级联。做进一步的stride2卷积以产生矢量。

          2)global

图像stride2卷积到M*M,做卷积到矢量。

最后,local和global的输出矢量加法组合,然后进入最终层,产生标量鉴别器得分。

 

4、Conditional keypoint generation model

对每只鸟,给出所有的关键点是比较困难的。由于我们通常要指定1或2个关键点,因此在实验中,我们将“开”概率设置为0.1。 也就是说,15个鸟类部分中的每一个只有10%的几率作为给定训练图像的条件变量。 也就是  每个关键点只有0.1的概率作为训练图像的条件变量。

 

五、其他相关

当关键点为0时,无法生成出任何合理的图像。最好的使用合成或者手工标记的关键点。