众所周知,尽管基于Attention机制的Transformer类模型有着良好的并行性能,但它的空间和时间复杂度都是
O(n2)级别的,
n是序列长度,因此当
n比较大时Transformer模型的计算量难以承受。近来,也有很多工做致力于下降Transformer模型的计算量,好比模型剪枝、量化、蒸馏等精简技术,又或者修改Attention结构,使得其复杂度能下降到
O(nlogn)甚至
O(n)php
论文《Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention》当中提到一种线性化Attention(Linear Attention)的方法,由此引起了个人兴趣,继而阅读了一些相关博客,有一些不错的收获,最后将本身对线性化Attention的理解汇总在此文中html
Attention
当前最流行的Attention机制当属Scaled-Dot Attention,即python
Attention(Q,K,V)=softmax(QK⊤)V(1)
这里的
Q∈Rn×dk,K∈Rm×dk,V∈Rm×dv,简单起见我就没显示的写出Attention的缩放因子
d
1了。本文咱们主要关心Self Attention的场景,因此为了介绍上的方便,统一设
Q,K,V∈Rn×dmarkdown
摘掉Softmax
读者也许想不到,制约Attention性能的关键因素,实际上是定义里边的Softmax!事实上,简单地推导一下就能够获得这个结论。
QKT这一步咱们获得一个
n×n的矩阵,以后还要作一个Softmax网络
对一个
1×n的行向量进行Softmax,时间复杂度是
O(n),可是对一个
n×n矩阵的每一行作一个Softmax,时间复杂度就是
O(n2)app
若是没有Softmax,那么Attention的公式就变为三个矩阵连乘
QK⊤V,而矩阵乘法是知足结合率的,因此咱们能够先算
K⊤V,获得一个
d×d的矩阵(这一步的时间复杂度是
O(d2n)),而后再用
Q左乘它(这一步的时间复杂度是
O(d2n)),因为
d≪n,因此这样算大体的时间复杂度只是
O(n)ide
对于BERT base来讲,
d=64而不是768,why?由于768其实是经过Multi-Head拼接获得的,而每一个head的
d=64svg
也就是说,去掉Softmax的Attention复杂度能够降到最理想的线性级别
O(n)!这显然就是咱们的终极追求:Linear Attention函数
通常的定义
问题是,直接去掉Softmax还能算是Attention吗?他还能有标准的Attention的效果吗?为了回答这个问题,咱们先将Scaled-Dot Attention的定义等价的改写为(本文的向量都是列向量)oop
Attention(Q,K,V)i=j=1∑neqi⊤kjj=1∑neqi⊤kjvj(2)
这里稍微解释下,首先咱们知道
Q,K∈Rn×d,令
M=Q×K⊤,由矩阵乘法法则可知,
M的第一行是由
Q的第一行乘以
K⊤的全部列获得的
Attention(Q,K,V)i表示最终输出结果矩阵的第
i行
qi⊤表示
Q∈Rn×d矩阵的第
i行(行向量)
kj表示
K⊤∈Rd×n矩阵的第
j列(列向量)
vj表示
V⊤∈Rd×n矩阵的的第
j列(列向量)
因此,Scaled-Dot Attention其实就是以
eqi⊤kj为权重对
vj作加权平均。因此咱们能够提出一个Attention的通常化定义
Attention(Q,K,V)i=j=1∑nsim(qi,kj)j=1∑nsim(qi,kj)vj(3)
也就是把
eqi⊤kj换成
qi,ki的通常函数
sim(qi,kj),为了保留Attention类似的分布特性,咱们要求
sim(qi,kj)≥0恒成立。也就是说,咱们若是要定义新的Attention,必需要保留式(3)的形式,而且知足
sim(qi,kj)≥0
这种通常形式的Attention在CV中也被称为Non-Local网络,出自论文《Non-local Neural Networks》
几个例子
若是直接去掉Softmax,那么就是
sim(qi,kj)=qi⊤kj,问题是内积没法保证非负性,因此这还不是一个合理的选择。下面咱们介绍几种可取的方案
值得一提的是,下面介绍的这几种Linear Attention,前两种来自CV领域,第三种是苏剑林大佬构思的(除了下面的介绍外,还有EMANet等CV领域对Attention的改进工做)
核函数形式
一个天然的想法是:若是
qi,kj的每一个元素都是非负的,那么内积天然也是非负的。为了完成这点,咱们能够给
qi,kj各自加个激活函数
ϕ,φ,即
sim(qi,kj)=ϕ(qi)⊤φ(kj)(4)
其中
ϕ(⋅),φ(⋅)是值域非负的激活函数。本文开头提到的论文《Transformers are RNNs》选择的是
ϕ(x)=φ(x)=elu(x)+1,其中
elu(x)={xα(ex−1)if x>0if x<0
常见的
α取值为
[0.1,0.3]
非要讲故事的话,式(4)能够联想到"核方法",尤为是
ϕ=φ时,
ϕ就至关于一个核函数,而
⟨ϕ(qi),ϕ(kj)⟩就是经过核函数所定义的内积。这方面的思考能够参考论文《Transformer dissection: An unified understanding for transformer’s attention via the lens of kernel》,此处不作过多延伸
妙用Softmax
另外一篇更早的文章《Efficient Attention: Attention with Linear Complexities》则给出了一个更有意思的选择。它留意到在
QK⊤中,
Q,K∈Rn×d,若是“
Q在
d那一维是归一化的,而且
K在
n那一维是归一化的”,那么
QK⊤就是自动知足归一化了,因此它给出的选择是
Attention(Q,K,V)=softmax2(Q)softmax1(K)⊤V(5)
其中
softmax1、
softmax2分别表示在第一个
(n)、第二个维度
(d)进行Softmax运算。也就是说,这时候咱们是各自给
Q,K加Softmax,而不是算完
QK⊤以后再加Softmax
其实能够证实这个形式也是式(4)的一个特例,此时对应于
ϕ(qi)=softmax(qi),φ(kj)=ekj,读者能够自行推导一下
苏神的构思
在这里,苏神给出了一种构思。这个构思的出发点再也不是式(4),而是源于咱们对原始定义(2)的泰勒展开。由泰勒展开咱们有
eqi⊤kj≈1+qi⊤kj(6)
若是
qi⊤kj≥−1,那么就能够保证右端的非负性,从而可让
sim(qi,kj)=1+qi⊤kj。到这里读者可能已经想到了,想要保证
qi⊤kj≥−1,只须要分别对
qi,kj作
l2归一化。因此,苏神最终提出的方案就是:
sim(qi,kj)=1+(∥qi∥qi)⊤(∥kj∥kj)(7)
若
x=[x1,x2,...,xn],则
∥x∥=x12+x22+⋅⋅⋅+xn2
这不一样于式(4),但理论上它更加接近原始的Scaled-Dot Attention
实现
这里主要是针对苏神所提出的方法进行实现,可是因为笔者本人水平有限,所以最终实现的代码当中其实存在一些问题,主要是:
- 从测试结果来看,改进后的计算速度并无提高
- 没法作到求和为1
代码实现主要是针对BERT的PyTorch实现这篇文章的代码,更具体的说,其实仅修改了ScaledDotProductAttention
这个函数,所以下面只放出这部分代码
class ScaledDotProductAttention(nn.Module):
def __init__(self):
super(ScaledDotProductAttention, self).__init__()
def forward(self, Q, K, V, attn_mask):
Q = F.normalize(Q, dim=3)
K = F.normalize(K, dim=3)
M = (torch.ones(Q.shape[0], Q.shape[1], Q.shape[2], K.shape[2]) + torch.matmul(Q, K.transpose(-1, -2)))
M_sum = torch.sum(M, dim=3)
M = M / M_sum.unsqueeze(3).repeat(1, 1, 1, M.shape[3])
attn = M.masked_fill(attn_mask, 0)
context = torch.matmul(attn, V)
return context
复制代码
若是您有更好的实现方法,还望不吝赐教
Reference