做者:曹庆庆(Stony Brook University 在读 PhD,关注Efficient NLP,QA方向,详见awk.ai)
背景
BERT、XLNet、RoBERTa等基于Transformer[^transfomer]的预训练模型推出后,天然语言理解任务都得到了大幅提高。问答任务(Question Answering,QA)[^qa-note]也一样取得了很大的进步。git
用BERT类模型来作问答或阅读理解任务,一般须要将问题和问题相关文档拼接一块儿做为输入文本,而后用自注意力机制对输入文本进行多层交互编码,以后用线性分类器判别文档中可能的答案序列。以下图:github


虽然这种片断拼接的输入方式可让自注意力机制对所有的token进行交互,获得的文档表示是问题相关的(反之亦然),但相关文档每每很长,token数量通常可达问题文本的10~20倍[^length],这样就形成了大量的计算。架构
在实际场景下,考虑到设备的运算速度和内存大小,每每会对模型进行压缩,好比经过蒸馏(distillation)小模型、剪枝(pruning)、量化(quantization)和低轶近似/权重共享等方法。性能
但模型压缩仍是会带来必定的精度损失。所以咱们思考,是否是能够参考双塔模型的结构,提早进行一些计算,从而提高模型的推理速度?测试
若是这种思路可行,会有几个很大的优点:编码
- 它不须要大幅修改原来的模型架构
- 也不须要从新预训练,能够继续使用标准Transformer初始化+目标数据集fine-tune的精调方式
- 还能够叠加模型压缩技术
通过不断地尝试,咱们提出了《Deformer:Decomposing Pre-trained Transformers for Faster Question Answering》[1],在小幅修改模型架构且不更换预训练模型的状况下提高推理速度。下面将为你们介绍咱们的思考历程。url
论文连接:https://awk.ai/assets/deformer.pdf spa
代码连接:https://github.com/StonyBrookNLP/deformer.net
模型结构
在开篇的介绍中,咱们指出了QA任务的计算瓶颈主要在于自注意力机制须要交互编码的token太多了。所以咱们猜测,是否能让文档和问题在编码阶段尽量地独立?设计
这样的话,就能够提早将最难计算的文档编码算好,只须要实时编码较短的问题文本,从而加速整个QA过程。
部分研究代表,Transformer 的低层(lower layers)编码主要关注一些局部的语言表层特征(词形、语法等等),到高层(upper layers)才开始逐渐编码与下游任务相关的全局语义信息。所以咱们猜测,至少在模型的某些部分,“文档编码可以不依赖于问题”的假设是成立的。 具体来讲能够在 Transformer 开始的低层分别对问题和文档各自编码,而后再在高层部分拼接问题和文档的表征进行交互编码,如图所示:


为了验证上述猜测,咱们设计了一个实验,测量文档在和不一样问题交互时编码的变化程度。下图为各层输出的文档向量和它们中心点cosine距离的方差:


能够看到,对于BERT-Based的QA模型,若是编码的文档不变而问题变化,模型的低层表征每每变化不大。这意味着并不是全部Transformer编码层都须要对整个输入文本的所有token序列进行自注意力交互。
所以,咱们提出Transformer模型的一种变形计算方式(称做 DeFormer):在前层对文档编码离线计算获得第
层表征,问题的第
层表征经过实时计算,而后拼接问题和文档的表征输入到后面
到
层。下面这幅图示意了DeFormer的计算过程:


值得一提的是,这种方式在有些QA任务(好比SQuAD)上有较大的精度损失,因此咱们添加了两个蒸馏损失项,目的是最小化DeFormer的高层表征和分类层logits与原始BERT模型的差别,这样能控制精度损失在1个点左右。
实验
这里简要描述下四组关键的实验结果:
(1)在三个QA任务上,BERT和XLNet采用DeFormer分解后,取得了2.7-3.5倍的加速,节省内存65.8-72.0%,效果损失只有0.6-1.8%。BERT-base()在SQuAD上,设置
能加快推理3.2倍,节省内存70%。


(2)实测了原模型和DeFormer在三种不一样硬件上的推理延迟。DeFormer均达到3倍以上的加速。


(3)消融实验证实,添加的两个蒸馏损失项能起到弥补精度损失的效果。


(4)测试DeFormer分解的层数(对应折线图横轴)对推理加速比和性能损失的影响。这个实验在SQuAD上进行,且没有使用蒸馏trick。


总结
这篇文章提主要提出了一种变形的计算方式DeFormer,使问题和文档编码在低层独立编码再在高层交互,从而使得能够离线计算文档编码来加速QA推理和节省内存。
创新之处在于它对原始模型并无太大修改。部署简单,且效果显著。 实验结果代表基于BERT和XLNet的DeFormer均能取得很好的表现。笔者推测对其余的Transformer模型应该也一样有效,而且其余模型压缩方法和技术应该也能够叠加使用到DeFormer上来进一步加速模型推理。
[^qa-note]: 严格来讲是机器阅读理解,即给出问题从相关文章中提取答案,通常 QA 系统还包括检索阶段来找到问题相关的文档 [^transfomer]: 论文方面能够参考邱老师组的文献综述:Pre-trained Models for Natural Language Processing: A Survey,实例代码能够参见 huggingface 的 transformer 库 [^length]: 好比 SQuAD 问题平均 10 个 token,但文档平均有 116 个 token
参考资料
[1] Deformer:Decomposing Pre-trained Transformers for Faster Question Answering: https://awk.ai/assets/deformer.pdf